package weka.classifiers.trees;

import org.xmlpull.v1.XmlPullParser;
import weka.classifiers.Classifier;
import weka.classifiers.Sourcable;
import weka.classifiers.lazy.kstar.KStarConstants;
import weka.classifiers.rules.ZeroR;
import weka.core.Attribute;
import weka.core.Capabilities;
import weka.core.ContingencyTables;
import weka.core.Instance;
import weka.core.Instances;
import weka.core.RevisionUtils;
import weka.core.Utils;
import weka.core.WeightedInstancesHandler;

/* loaded from: input_file:pmmlDevelopment/lib/weka.jar:weka/classifiers/trees/DecisionStump.class */
public class DecisionStump extends Classifier implements WeightedInstancesHandler, Sourcable {
    static final long serialVersionUID = 1618384535950391L;
    private int m_AttIndex;
    private double m_SplitPoint;
    private double[][] m_Distribution;
    private Instances m_Instances;
    private Classifier m_ZeroR;

    public String globalInfo() {
        return "Class for building and using a decision stump. Usually used in conjunction with a boosting algorithm. Does regression (based on mean-squared error) or classification (based on entropy). Missing is treated as a separate value.";
    }

    @Override // weka.classifiers.Classifier, weka.core.CapabilitiesHandler
    public Capabilities getCapabilities() {
        Capabilities capabilities = super.getCapabilities();
        capabilities.disableAll();
        capabilities.enable(Capabilities.Capability.NOMINAL_ATTRIBUTES);
        capabilities.enable(Capabilities.Capability.NUMERIC_ATTRIBUTES);
        capabilities.enable(Capabilities.Capability.DATE_ATTRIBUTES);
        capabilities.enable(Capabilities.Capability.MISSING_VALUES);
        capabilities.enable(Capabilities.Capability.NOMINAL_CLASS);
        capabilities.enable(Capabilities.Capability.NUMERIC_CLASS);
        capabilities.enable(Capabilities.Capability.DATE_CLASS);
        capabilities.enable(Capabilities.Capability.MISSING_CLASS_VALUES);
        return capabilities;
    }

    @Override // weka.classifiers.Classifier
    public void buildClassifier(Instances instances) throws Exception {
        double d = Double.MAX_VALUE;
        double d2 = -1.7976931348623157E308d;
        int i = -1;
        getCapabilities().testWithFail(instances);
        Instances instances2 = new Instances(instances);
        instances2.deleteWithMissingClass();
        if (instances2.numAttributes() == 1) {
            System.err.println("Cannot build model (only class attribute present in data!), using ZeroR model instead!");
            this.m_ZeroR = new ZeroR();
            this.m_ZeroR.buildClassifier(instances2);
            return;
        }
        this.m_ZeroR = null;
        double[][] dArr = new double[3][instances2.numClasses()];
        this.m_Instances = new Instances(instances2);
        int numClasses = this.m_Instances.classAttribute().isNominal() ? this.m_Instances.numClasses() : 1;
        boolean z = true;
        for (int i2 = 0; i2 < this.m_Instances.numAttributes(); i2++) {
            if (i2 != this.m_Instances.classIndex()) {
                this.m_Distribution = new double[3][numClasses];
                double findSplitNominal = this.m_Instances.attribute(i2).isNominal() ? findSplitNominal(i2) : findSplitNumeric(i2);
                if (z || findSplitNominal < d) {
                    d = findSplitNominal;
                    i = i2;
                    d2 = this.m_SplitPoint;
                    for (int i3 = 0; i3 < 3; i3++) {
                        System.arraycopy(this.m_Distribution[i3], 0, dArr[i3], 0, numClasses);
                    }
                }
                z = false;
            }
        }
        this.m_AttIndex = i;
        this.m_SplitPoint = d2;
        this.m_Distribution = dArr;
        if (this.m_Instances.classAttribute().isNominal()) {
            for (int i4 = 0; i4 < this.m_Distribution.length; i4++) {
                double sum = Utils.sum(this.m_Distribution[i4]);
                if (sum == KStarConstants.FLOOR) {
                    System.arraycopy(this.m_Distribution[2], 0, this.m_Distribution[i4], 0, this.m_Distribution[2].length);
                    Utils.normalize(this.m_Distribution[i4]);
                } else {
                    Utils.normalize(this.m_Distribution[i4], sum);
                }
            }
        }
        this.m_Instances = new Instances(this.m_Instances, 0);
    }

    @Override // weka.classifiers.Classifier
    public double[] distributionForInstance(Instance instance) throws Exception {
        return this.m_ZeroR != null ? this.m_ZeroR.distributionForInstance(instance) : this.m_Distribution[whichSubset(instance)];
    }

    @Override // weka.classifiers.Sourcable
    public String toSource(String str) throws Exception {
        StringBuffer stringBuffer = new StringBuffer("class ");
        Attribute classAttribute = this.m_Instances.classAttribute();
        stringBuffer.append(str).append(" {\n  public static double classify(Object[] i) {\n");
        stringBuffer.append("    /* " + this.m_Instances.attribute(this.m_AttIndex).name() + " */\n");
        stringBuffer.append("    if (i[").append(this.m_AttIndex);
        stringBuffer.append("] == null) { return ");
        stringBuffer.append(sourceClass(classAttribute, this.m_Distribution[2])).append(";");
        if (this.m_Instances.attribute(this.m_AttIndex).isNominal()) {
            stringBuffer.append(" } else if (((String)i[").append(this.m_AttIndex);
            stringBuffer.append("]).equals(\"");
            stringBuffer.append(this.m_Instances.attribute(this.m_AttIndex).value((int) this.m_SplitPoint));
            stringBuffer.append("\")");
        } else {
            stringBuffer.append(" } else if (((Double)i[").append(this.m_AttIndex);
            stringBuffer.append("]).doubleValue() <= ").append(this.m_SplitPoint);
        }
        stringBuffer.append(") { return ");
        stringBuffer.append(sourceClass(classAttribute, this.m_Distribution[0])).append(";");
        stringBuffer.append(" } else { return ");
        stringBuffer.append(sourceClass(classAttribute, this.m_Distribution[1])).append(";");
        stringBuffer.append(" }\n  }\n}\n");
        return stringBuffer.toString();
    }

    private String sourceClass(Attribute attribute, double[] dArr) {
        return attribute.isNominal() ? Integer.toString(Utils.maxIndex(dArr)) : Double.toString(dArr[0]);
    }

    public String toString() {
        if (this.m_ZeroR != null) {
            StringBuffer stringBuffer = new StringBuffer();
            stringBuffer.append(getClass().getName().replaceAll(".*\\.", XmlPullParser.NO_NAMESPACE) + "\n");
            stringBuffer.append(getClass().getName().replaceAll(".*\\.", XmlPullParser.NO_NAMESPACE).replaceAll(".", "=") + "\n\n");
            stringBuffer.append("Warning: No model could be built, hence ZeroR model is used:\n\n");
            stringBuffer.append(this.m_ZeroR.toString());
            return stringBuffer.toString();
        }
        if (this.m_Instances == null) {
            return "Decision Stump: No model built yet.";
        }
        try {
            StringBuffer stringBuffer2 = new StringBuffer();
            stringBuffer2.append("Decision Stump\n\n");
            stringBuffer2.append("Classifications\n\n");
            Attribute attribute = this.m_Instances.attribute(this.m_AttIndex);
            if (attribute.isNominal()) {
                stringBuffer2.append(attribute.name() + " = " + attribute.value((int) this.m_SplitPoint) + " : ");
                stringBuffer2.append(printClass(this.m_Distribution[0]));
                stringBuffer2.append(attribute.name() + " != " + attribute.value((int) this.m_SplitPoint) + " : ");
                stringBuffer2.append(printClass(this.m_Distribution[1]));
            } else {
                stringBuffer2.append(attribute.name() + " <= " + this.m_SplitPoint + " : ");
                stringBuffer2.append(printClass(this.m_Distribution[0]));
                stringBuffer2.append(attribute.name() + " > " + this.m_SplitPoint + " : ");
                stringBuffer2.append(printClass(this.m_Distribution[1]));
            }
            stringBuffer2.append(attribute.name() + " is missing : ");
            stringBuffer2.append(printClass(this.m_Distribution[2]));
            if (this.m_Instances.classAttribute().isNominal()) {
                stringBuffer2.append("\nClass distributions\n\n");
                if (attribute.isNominal()) {
                    stringBuffer2.append(attribute.name() + " = " + attribute.value((int) this.m_SplitPoint) + "\n");
                    stringBuffer2.append(printDist(this.m_Distribution[0]));
                    stringBuffer2.append(attribute.name() + " != " + attribute.value((int) this.m_SplitPoint) + "\n");
                    stringBuffer2.append(printDist(this.m_Distribution[1]));
                } else {
                    stringBuffer2.append(attribute.name() + " <= " + this.m_SplitPoint + "\n");
                    stringBuffer2.append(printDist(this.m_Distribution[0]));
                    stringBuffer2.append(attribute.name() + " > " + this.m_SplitPoint + "\n");
                    stringBuffer2.append(printDist(this.m_Distribution[1]));
                }
                stringBuffer2.append(attribute.name() + " is missing\n");
                stringBuffer2.append(printDist(this.m_Distribution[2]));
            }
            return stringBuffer2.toString();
        } catch (Exception e) {
            return "Can't print decision stump classifier!";
        }
    }

    private String printDist(double[] dArr) throws Exception {
        StringBuffer stringBuffer = new StringBuffer();
        if (this.m_Instances.classAttribute().isNominal()) {
            for (int i = 0; i < this.m_Instances.numClasses(); i++) {
                stringBuffer.append(this.m_Instances.classAttribute().value(i) + "\t");
            }
            stringBuffer.append("\n");
            for (int i2 = 0; i2 < this.m_Instances.numClasses(); i2++) {
                stringBuffer.append(dArr[i2] + "\t");
            }
            stringBuffer.append("\n");
        }
        return stringBuffer.toString();
    }

    private String printClass(double[] dArr) throws Exception {
        StringBuffer stringBuffer = new StringBuffer();
        if (this.m_Instances.classAttribute().isNominal()) {
            stringBuffer.append(this.m_Instances.classAttribute().value(Utils.maxIndex(dArr)));
        } else {
            stringBuffer.append(dArr[0]);
        }
        return stringBuffer.toString() + "\n";
    }

    private double findSplitNominal(int i) throws Exception {
        return this.m_Instances.classAttribute().isNominal() ? findSplitNominalNominal(i) : findSplitNominalNumeric(i);
    }

    private double findSplitNominalNominal(int i) throws Exception {
        double d = Double.MAX_VALUE;
        double[][] dArr = new double[this.m_Instances.attribute(i).numValues() + 1][this.m_Instances.numClasses()];
        double[] dArr2 = new double[this.m_Instances.numClasses()];
        double[][] dArr3 = new double[3][this.m_Instances.numClasses()];
        int i2 = 0;
        for (int i3 = 0; i3 < this.m_Instances.numInstances(); i3++) {
            Instance instance = this.m_Instances.instance(i3);
            if (instance.isMissing(i)) {
                i2++;
                double[] dArr4 = dArr[this.m_Instances.attribute(i).numValues()];
                int classValue = (int) instance.classValue();
                dArr4[classValue] = dArr4[classValue] + instance.weight();
            } else {
                double[] dArr5 = dArr[(int) instance.value(i)];
                int classValue2 = (int) instance.classValue();
                dArr5[classValue2] = dArr5[classValue2] + instance.weight();
            }
        }
        for (int i4 = 0; i4 < this.m_Instances.attribute(i).numValues(); i4++) {
            for (int i5 = 0; i5 < this.m_Instances.numClasses(); i5++) {
                int i6 = i5;
                dArr2[i6] = dArr2[i6] + dArr[i4][i5];
            }
        }
        System.arraycopy(dArr[this.m_Instances.attribute(i).numValues()], 0, this.m_Distribution[2], 0, this.m_Instances.numClasses());
        for (int i7 = 0; i7 < this.m_Instances.attribute(i).numValues(); i7++) {
            for (int i8 = 0; i8 < this.m_Instances.numClasses(); i8++) {
                this.m_Distribution[0][i8] = dArr[i7][i8];
                this.m_Distribution[1][i8] = dArr2[i8] - dArr[i7][i8];
            }
            double entropyConditionedOnRows = ContingencyTables.entropyConditionedOnRows(this.m_Distribution);
            if (entropyConditionedOnRows < d) {
                d = entropyConditionedOnRows;
                this.m_SplitPoint = i7;
                for (int i9 = 0; i9 < 3; i9++) {
                    System.arraycopy(this.m_Distribution[i9], 0, dArr3[i9], 0, this.m_Instances.numClasses());
                }
            }
        }
        if (i2 == 0) {
            System.arraycopy(dArr2, 0, dArr3[2], 0, this.m_Instances.numClasses());
        }
        this.m_Distribution = dArr3;
        return d;
    }

    private double findSplitNominalNumeric(int i) throws Exception {
        double d = Double.MAX_VALUE;
        double[] dArr = new double[this.m_Instances.attribute(i).numValues()];
        double[] dArr2 = new double[this.m_Instances.attribute(i).numValues()];
        double[] dArr3 = new double[this.m_Instances.attribute(i).numValues()];
        double d2 = 0.0d;
        double d3 = 0.0d;
        double d4 = 0.0d;
        double d5 = 0.0d;
        double d6 = 0.0d;
        double[] dArr4 = new double[3];
        double[] dArr5 = new double[3];
        double[][] dArr6 = new double[3][1];
        for (int i2 = 0; i2 < this.m_Instances.numInstances(); i2++) {
            Instance instance = this.m_Instances.instance(i2);
            if (instance.isMissing(i)) {
                double[] dArr7 = this.m_Distribution[2];
                dArr7[0] = dArr7[0] + (instance.classValue() * instance.weight());
                dArr4[2] = dArr4[2] + (instance.classValue() * instance.classValue() * instance.weight());
                dArr5[2] = dArr5[2] + instance.weight();
            } else {
                int value = (int) instance.value(i);
                dArr3[value] = dArr3[value] + instance.weight();
                int value2 = (int) instance.value(i);
                dArr2[value2] = dArr2[value2] + (instance.classValue() * instance.weight());
                int value3 = (int) instance.value(i);
                dArr[value3] = dArr[value3] + (instance.classValue() * instance.classValue() * instance.weight());
            }
            d5 += instance.weight();
            d6 += instance.classValue() * instance.weight();
        }
        if (d5 <= KStarConstants.FLOOR) {
            return Double.MAX_VALUE;
        }
        for (int i3 = 0; i3 < this.m_Instances.attribute(i).numValues(); i3++) {
            d4 += dArr3[i3];
            d2 += dArr[i3];
            d3 += dArr2[i3];
        }
        for (int i4 = 0; i4 < this.m_Instances.attribute(i).numValues(); i4++) {
            this.m_Distribution[0][0] = dArr2[i4];
            dArr4[0] = dArr[i4];
            dArr5[0] = dArr3[i4];
            this.m_Distribution[1][0] = d3 - dArr2[i4];
            dArr4[1] = d2 - dArr[i4];
            dArr5[1] = d4 - dArr3[i4];
            double variance = variance(this.m_Distribution, dArr4, dArr5);
            if (variance < d) {
                d = variance;
                this.m_SplitPoint = i4;
                for (int i5 = 0; i5 < 3; i5++) {
                    if (dArr5[i5] > KStarConstants.FLOOR) {
                        dArr6[i5][0] = this.m_Distribution[i5][0] / dArr5[i5];
                    } else {
                        dArr6[i5][0] = d6 / d5;
                    }
                }
            }
        }
        this.m_Distribution = dArr6;
        return d;
    }

    private double findSplitNumeric(int i) throws Exception {
        return this.m_Instances.classAttribute().isNominal() ? findSplitNumericNominal(i) : findSplitNumericNumeric(i);
    }

    private double findSplitNumericNominal(int i) throws Exception {
        double d = Double.MAX_VALUE;
        int i2 = 0;
        double[] dArr = new double[this.m_Instances.numClasses()];
        double[][] dArr2 = new double[3][this.m_Instances.numClasses()];
        for (int i3 = 0; i3 < this.m_Instances.numInstances(); i3++) {
            Instance instance = this.m_Instances.instance(i3);
            if (instance.isMissing(i)) {
                double[] dArr3 = this.m_Distribution[2];
                int classValue = (int) instance.classValue();
                dArr3[classValue] = dArr3[classValue] + instance.weight();
                i2++;
            } else {
                double[] dArr4 = this.m_Distribution[1];
                int classValue2 = (int) instance.classValue();
                dArr4[classValue2] = dArr4[classValue2] + instance.weight();
            }
        }
        System.arraycopy(this.m_Distribution[1], 0, dArr, 0, this.m_Instances.numClasses());
        for (int i4 = 0; i4 < 3; i4++) {
            System.arraycopy(this.m_Distribution[i4], 0, dArr2[i4], 0, this.m_Instances.numClasses());
        }
        this.m_Instances.sort(i);
        for (int i5 = 0; i5 < this.m_Instances.numInstances() - (i2 + 1); i5++) {
            Instance instance2 = this.m_Instances.instance(i5);
            Instance instance3 = this.m_Instances.instance(i5 + 1);
            double[] dArr5 = this.m_Distribution[0];
            int classValue3 = (int) instance2.classValue();
            dArr5[classValue3] = dArr5[classValue3] + instance2.weight();
            double[] dArr6 = this.m_Distribution[1];
            int classValue4 = (int) instance2.classValue();
            dArr6[classValue4] = dArr6[classValue4] - instance2.weight();
            if (instance2.value(i) < instance3.value(i)) {
                double value = (instance2.value(i) + instance3.value(i)) / 2.0d;
                double entropyConditionedOnRows = ContingencyTables.entropyConditionedOnRows(this.m_Distribution);
                if (entropyConditionedOnRows < d) {
                    this.m_SplitPoint = value;
                    d = entropyConditionedOnRows;
                    for (int i6 = 0; i6 < 3; i6++) {
                        System.arraycopy(this.m_Distribution[i6], 0, dArr2[i6], 0, this.m_Instances.numClasses());
                    }
                }
            }
        }
        if (i2 == 0) {
            System.arraycopy(dArr, 0, dArr2[2], 0, this.m_Instances.numClasses());
        }
        this.m_Distribution = dArr2;
        return d;
    }

    private double findSplitNumericNumeric(int i) throws Exception {
        double d = Double.MAX_VALUE;
        int i2 = 0;
        double[] dArr = new double[3];
        double[] dArr2 = new double[3];
        double[][] dArr3 = new double[3][1];
        double d2 = 0.0d;
        double d3 = 0.0d;
        for (int i3 = 0; i3 < this.m_Instances.numInstances(); i3++) {
            Instance instance = this.m_Instances.instance(i3);
            if (instance.isMissing(i)) {
                double[] dArr4 = this.m_Distribution[2];
                dArr4[0] = dArr4[0] + (instance.classValue() * instance.weight());
                dArr[2] = dArr[2] + (instance.classValue() * instance.classValue() * instance.weight());
                dArr2[2] = dArr2[2] + instance.weight();
                i2++;
            } else {
                double[] dArr5 = this.m_Distribution[1];
                dArr5[0] = dArr5[0] + (instance.classValue() * instance.weight());
                dArr[1] = dArr[1] + (instance.classValue() * instance.classValue() * instance.weight());
                dArr2[1] = dArr2[1] + instance.weight();
            }
            d3 += instance.weight();
            d2 += instance.classValue() * instance.weight();
        }
        if (d3 <= KStarConstants.FLOOR) {
            return Double.MAX_VALUE;
        }
        this.m_Instances.sort(i);
        for (int i4 = 0; i4 < this.m_Instances.numInstances() - (i2 + 1); i4++) {
            Instance instance2 = this.m_Instances.instance(i4);
            Instance instance3 = this.m_Instances.instance(i4 + 1);
            double[] dArr6 = this.m_Distribution[0];
            dArr6[0] = dArr6[0] + (instance2.classValue() * instance2.weight());
            dArr[0] = dArr[0] + (instance2.classValue() * instance2.classValue() * instance2.weight());
            dArr2[0] = dArr2[0] + instance2.weight();
            double[] dArr7 = this.m_Distribution[1];
            dArr7[0] = dArr7[0] - (instance2.classValue() * instance2.weight());
            dArr[1] = dArr[1] - ((instance2.classValue() * instance2.classValue()) * instance2.weight());
            dArr2[1] = dArr2[1] - instance2.weight();
            if (instance2.value(i) < instance3.value(i)) {
                double value = (instance2.value(i) + instance3.value(i)) / 2.0d;
                double variance = variance(this.m_Distribution, dArr, dArr2);
                if (variance < d) {
                    this.m_SplitPoint = value;
                    d = variance;
                    for (int i5 = 0; i5 < 3; i5++) {
                        if (dArr2[i5] > KStarConstants.FLOOR) {
                            dArr3[i5][0] = this.m_Distribution[i5][0] / dArr2[i5];
                        } else {
                            dArr3[i5][0] = d2 / d3;
                        }
                    }
                }
            }
        }
        this.m_Distribution = dArr3;
        return d;
    }

    private double variance(double[][] dArr, double[] dArr2, double[] dArr3) {
        double d = 0.0d;
        for (int i = 0; i < dArr.length; i++) {
            if (dArr3[i] > KStarConstants.FLOOR) {
                d += dArr2[i] - ((dArr[i][0] * dArr[i][0]) / dArr3[i]);
            }
        }
        return d;
    }

    private int whichSubset(Instance instance) throws Exception {
        if (instance.isMissing(this.m_AttIndex)) {
            return 2;
        }
        return instance.attribute(this.m_AttIndex).isNominal() ? ((double) ((int) instance.value(this.m_AttIndex))) == this.m_SplitPoint ? 0 : 1 : instance.value(this.m_AttIndex) <= this.m_SplitPoint ? 0 : 1;
    }

    @Override // weka.classifiers.Classifier, weka.core.RevisionHandler
    public String getRevision() {
        return RevisionUtils.extract("$Revision: 5485 $");
    }

    public static void main(String[] strArr) {
        runClassifier(new DecisionStump(), strArr);
    }
}
