/*
 * Decompiled with CFR 0.152.
 */
package org.jpmml.rexp;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.stream.Collectors;
import org.dmg.pmml.CompoundPredicate;
import org.dmg.pmml.DataType;
import org.dmg.pmml.MiningFunction;
import org.dmg.pmml.Predicate;
import org.dmg.pmml.ScoreFrequency;
import org.dmg.pmml.SimplePredicate;
import org.dmg.pmml.SimpleScoreDistribution;
import org.dmg.pmml.True;
import org.dmg.pmml.tree.ClassifierNode;
import org.dmg.pmml.tree.CountingBranchNode;
import org.dmg.pmml.tree.CountingLeafNode;
import org.dmg.pmml.tree.Node;
import org.dmg.pmml.tree.SimpleNode;
import org.dmg.pmml.tree.TreeModel;
import org.jpmml.converter.CategoricalFeature;
import org.jpmml.converter.CategoricalLabel;
import org.jpmml.converter.DiscreteLabel;
import org.jpmml.converter.Feature;
import org.jpmml.converter.FeatureImportanceMap;
import org.jpmml.converter.FortranMatrixUtil;
import org.jpmml.converter.Label;
import org.jpmml.converter.ModelUtil;
import org.jpmml.converter.Schema;
import org.jpmml.converter.ValueUtil;
import org.jpmml.rexp.Formula;
import org.jpmml.rexp.FormulaUtil;
import org.jpmml.rexp.HasFeatureImportances;
import org.jpmml.rexp.RDoubleVector;
import org.jpmml.rexp.RExp;
import org.jpmml.rexp.RExpEncoder;
import org.jpmml.rexp.RFactorVector;
import org.jpmml.rexp.RGenericVector;
import org.jpmml.rexp.RIntegerVector;
import org.jpmml.rexp.RNumberVector;
import org.jpmml.rexp.RStringVector;
import org.jpmml.rexp.RVector;
import org.jpmml.rexp.TreeModelConverter;
import org.jpmml.rexp.XLevelsFormulaContext;

public class RPartConverter
extends TreeModelConverter<RGenericVector>
implements HasFeatureImportances {
    private int useSurrogate = 0;
    private Formula formula = null;
    private static final int INDEX_LEAF = 0;

    public RPartConverter(RGenericVector rpart) {
        super(rpart);
        RGenericVector control = rpart.getGenericElement("control");
        RNumberVector<?> useSurrogate = control.getNumericElement("usesurrogate");
        this.useSurrogate = ValueUtil.asInt((Number)((Number)useSurrogate.asScalar()));
        switch (this.useSurrogate) {
            case 0: 
            case 1: 
            case 2: {
                break;
            }
            default: {
                throw new IllegalArgumentException();
            }
        }
    }

    public boolean hasScoreDistribution() {
        return true;
    }

    @Override
    public void encodeSchema(RExpEncoder encoder) {
        List<String> names;
        RGenericVector rpart = (RGenericVector)this.getObject();
        RGenericVector frame = rpart.getGenericElement("frame");
        RExp terms = (RExp)rpart.getElement("terms");
        RGenericVector xlevels = rpart.getGenericAttribute("xlevels", false);
        RStringVector ylevels = rpart.getStringAttribute("ylevels", false);
        RVector<?> var = frame.getVectorElement("var");
        XLevelsFormulaContext context = new XLevelsFormulaContext(xlevels);
        Formula formula = FormulaUtil.createFormula(terms, context, encoder);
        FormulaUtil.setLabel(formula, terms, ylevels, encoder);
        if (var instanceof RStringVector) {
            RStringVector stringVar = (RStringVector)var;
            names = RPartConverter.getFeatureNames(stringVar.getValues());
        } else if (var instanceof RFactorVector) {
            RFactorVector factorVar = (RFactorVector)var;
            names = RPartConverter.getFeatureNames(factorVar.getFactorValues());
        } else {
            throw new IllegalArgumentException();
        }
        FormulaUtil.addFeatures(formula, names, false, encoder);
        this.formula = formula;
    }

    public TreeModel encodeModel(Schema schema) {
        RGenericVector rpart = (RGenericVector)this.getObject();
        RGenericVector frame = rpart.getGenericElement("frame");
        RStringVector method = rpart.getStringElement("method");
        RNumberVector<?> splits = rpart.getNumericElement("splits");
        RIntegerVector csplit = rpart.getIntegerElement("csplit", false);
        RVector<?> var = frame.getVectorElement("var");
        RIntegerVector n = frame.getIntegerElement("n");
        RIntegerVector ncompete = frame.getIntegerElement("ncompete");
        RIntegerVector nsurrogate = frame.getIntegerElement("nsurrogate");
        RIntegerVector rowNames = frame.getIntegerAttribute("row.names");
        if (rowNames.getValues().indexOf(Integer.MIN_VALUE) > -1) {
            throw new IllegalArgumentException();
        }
        List features = schema.getFeatures();
        int[][] splitInfo = new int[1 + rowNames.size()][3];
        for (int offset = 0; offset < rowNames.size(); ++offset) {
            int splitVar = RPartConverter.getFeatureIndex(var, offset, features);
            splitInfo[offset][1] = ncompete.getValue(offset);
            splitInfo[offset][2] = nsurrogate.getValue(offset);
            splitInfo[offset + 1][0] = splitInfo[offset][0] + splitInfo[offset][1] + splitInfo[offset][2] + (splitVar != 0 ? 1 : 0);
        }
        switch ((String)method.asScalar()) {
            case "anova": {
                return this.encodeRegression(frame, rowNames, var, n, splitInfo, splits, csplit, schema);
            }
            case "class": {
                return this.encodeClassification(frame, rowNames, var, n, splitInfo, splits, csplit, schema);
            }
        }
        throw new IllegalArgumentException();
    }

    @Override
    public FeatureImportanceMap getFeatureImportances(Schema schema) {
        RGenericVector rpart = (RGenericVector)this.getObject();
        RDoubleVector variableImportance = rpart.getDoubleElement("variable.importance", false);
        if (variableImportance == null) {
            return null;
        }
        List features = schema.getFeatures();
        FeatureImportanceMap result = new FeatureImportanceMap(null);
        for (int i = 0; i < features.size(); ++i) {
            Feature feature = (Feature)features.get(i);
            Double importance = (Double)variableImportance.getElement(feature.getName());
            result.put((Object)feature, (Object)importance);
        }
        return result;
    }

    private TreeModel encodeRegression(RGenericVector frame, RIntegerVector rowNames, RVector<?> var, final RIntegerVector n, int[][] splitInfo, RNumberVector<?> splits, RIntegerVector csplit, Schema schema) {
        final RNumberVector<?> yval = frame.getNumericElement("yval");
        ScoreEncoder scoreEncoder = new ScoreEncoder(){

            @Override
            public Node encode(Node node, int offset) {
                Number score = (Number)yval.getValue(offset);
                Integer recordCount = n.getValue(offset);
                node.setScore((Object)score).setRecordCount((Number)recordCount);
                return node;
            }
        };
        Node root = this.encodeNode((Predicate)True.INSTANCE, 1, rowNames, var, n, splitInfo, splits, csplit, scoreEncoder, schema);
        TreeModel treeModel = new TreeModel(MiningFunction.REGRESSION, ModelUtil.createMiningSchema((Label)schema.getLabel()), root);
        return this.configureTreeModel(treeModel);
    }

    private TreeModel encodeClassification(RGenericVector frame, final RIntegerVector rowNames, RVector<?> var, final RIntegerVector n, int[][] splitInfo, RNumberVector<?> splits, RIntegerVector csplit, Schema schema) {
        final RDoubleVector yval2 = frame.getDoubleElement("yval2");
        CategoricalLabel categoricalLabel = (CategoricalLabel)schema.getLabel();
        final List categories = categoricalLabel.getValues();
        final boolean hasScoreDistribution = this.hasScoreDistribution();
        ScoreEncoder scoreEncoder = new ScoreEncoder(){
            private List<Integer> classes = null;
            private List<List<? extends Number>> recordCounts = null;
            {
                int rows = rowNames.size();
                int columns = 1 + 2 * categories.size() + 1;
                List classes = ValueUtil.asIntegers((List)FortranMatrixUtil.getColumn(yval2.getValues(), (int)rows, (int)columns, (int)0));
                this.classes = new ArrayList<Integer>(classes);
                if (hasScoreDistribution) {
                    this.recordCounts = new ArrayList<List<? extends Number>>();
                    for (int i = 0; i < categories.size(); ++i) {
                        List recordCounts = FortranMatrixUtil.getColumn(yval2.getValues(), (int)rows, (int)columns, (int)(1 + i));
                        this.recordCounts.add(new ArrayList(recordCounts));
                    }
                }
            }

            @Override
            public Node encode(Node node, int offset) {
                Object score = categories.get(this.classes.get(offset) - 1);
                Integer recordCount = n.getValue(offset);
                node.setScore(score).setRecordCount((Number)recordCount);
                if (hasScoreDistribution) {
                    node = new ClassifierNode(node);
                    List scoreDistributions = node.getScoreDistributions();
                    for (int i = 0; i < categories.size(); ++i) {
                        List<? extends Number> recordCounts = this.recordCounts.get(i);
                        SimpleScoreDistribution scoreDistribution = new ScoreFrequency().setValue(categories.get(i)).setRecordCount(recordCounts.get(offset));
                        scoreDistributions.add(scoreDistribution);
                    }
                }
                return node;
            }
        };
        Node root = this.encodeNode((Predicate)True.INSTANCE, 1, rowNames, var, n, splitInfo, splits, csplit, scoreEncoder, schema);
        TreeModel treeModel = new TreeModel(MiningFunction.CLASSIFICATION, ModelUtil.createMiningSchema((Label)schema.getLabel()), root);
        if (hasScoreDistribution) {
            treeModel.setOutput(ModelUtil.createProbabilityOutput((DataType)DataType.DOUBLE, (DiscreteLabel)categoricalLabel));
        }
        return this.configureTreeModel(treeModel);
    }

    private TreeModel configureTreeModel(TreeModel treeModel) {
        TreeModel.MissingValueStrategy missingValueStrategy;
        TreeModel.NoTrueChildStrategy noTrueChildStrategy = TreeModel.NoTrueChildStrategy.RETURN_LAST_PREDICTION;
        switch (this.useSurrogate) {
            case 0: {
                missingValueStrategy = TreeModel.MissingValueStrategy.NULL_PREDICTION;
                break;
            }
            case 1: {
                missingValueStrategy = TreeModel.MissingValueStrategy.LAST_PREDICTION;
                break;
            }
            case 2: {
                missingValueStrategy = null;
                break;
            }
            default: {
                throw new IllegalArgumentException();
            }
        }
        treeModel.setNoTrueChildStrategy(noTrueChildStrategy).setMissingValueStrategy(missingValueStrategy);
        return treeModel;
    }

    private Node encodeNode(Predicate predicate, int rowName, RIntegerVector rowNames, RVector<?> var, RIntegerVector n, int[][] splitInfo, RNumberVector<?> splits, RIntegerVector csplit, ScoreEncoder scoreEncoder, Schema schema) {
        int offset = RPartConverter.getIndex(rowNames, rowName);
        Integer id = rowName;
        List features = schema.getFeatures();
        int splitVar = RPartConverter.getFeatureIndex(var, offset, features);
        if (splitVar == 0) {
            SimpleNode result = new CountingLeafNode(null, predicate).setId((Object)id);
            return scoreEncoder.encode((Node)result, offset);
        }
        int leftRowName = rowName * 2;
        int rightRowName = rowName * 2 + 1;
        Integer majorityDir = null;
        if (this.useSurrogate == 2) {
            int leftOffset = RPartConverter.getIndex(rowNames, leftRowName);
            int rightOffset = RPartConverter.getIndex(rowNames, rightRowName);
            majorityDir = Double.compare(n.getValue(leftOffset).intValue(), n.getValue(rightOffset).intValue());
        }
        Feature feature = (Feature)features.get(splitVar - 1);
        int splitOffset = splitInfo[offset][0];
        int splitNumCompete = splitInfo[offset][1];
        int splitNumSurrogate = splitInfo[offset][2];
        List<Predicate> predicates = this.encodePredicates(feature, splitOffset, splits, csplit);
        Predicate leftPredicate = predicates.get(0);
        Predicate rightPredicate = predicates.get(1);
        if (this.useSurrogate > 0 && splitNumSurrogate > 0) {
            CompoundPredicate leftCompoundPredicate = new CompoundPredicate(CompoundPredicate.BooleanOperator.SURROGATE, null).addPredicates(new Predicate[]{leftPredicate});
            CompoundPredicate rightCompoundPredicate = new CompoundPredicate(CompoundPredicate.BooleanOperator.SURROGATE, null).addPredicates(new Predicate[]{rightPredicate});
            RStringVector splitRowNames = splits.dimnames(0);
            for (int i = 0; i < splitNumSurrogate; ++i) {
                int surrogateSplitOffset = splitOffset + 1 + splitNumCompete + i;
                feature = this.getFeature(splitRowNames.getValue(surrogateSplitOffset));
                predicates = this.encodePredicates(feature, surrogateSplitOffset, splits, csplit);
                leftCompoundPredicate.addPredicates(new Predicate[]{predicates.get(0)});
                rightCompoundPredicate.addPredicates(new Predicate[]{predicates.get(1)});
            }
            leftPredicate = leftCompoundPredicate;
            rightPredicate = rightCompoundPredicate;
        }
        Node leftChild = this.encodeNode(leftPredicate, leftRowName, rowNames, var, n, splitInfo, splits, csplit, scoreEncoder, schema);
        Node rightChild = this.encodeNode(rightPredicate, rightRowName, rowNames, var, n, splitInfo, splits, csplit, scoreEncoder, schema);
        if (this.useSurrogate == 2) {
            if (majorityDir < 0) {
                this.makeDefault(rightChild);
            } else if (majorityDir > 0) {
                Node tmp = leftChild;
                this.makeDefault(leftChild);
                leftChild = rightChild;
                rightChild = tmp;
            }
        }
        Node result = new CountingBranchNode(null, predicate).setId((Object)id).addNodes(leftChild, rightChild);
        return scoreEncoder.encode(result, offset);
    }

    private List<Predicate> encodePredicates(Feature feature, int splitOffset, RNumberVector<?> splits, RIntegerVector csplit) {
        Predicate rightPredicate;
        Predicate leftPredicate;
        RIntegerVector splitsDim = splits.dim();
        int splitRows = splitsDim.getValue(0);
        int splitColumns = splitsDim.getValue(1);
        List ncat = FortranMatrixUtil.getColumn(splits.getValues(), (int)splitRows, (int)splitColumns, (int)1);
        List index = FortranMatrixUtil.getColumn(splits.getValues(), (int)splitRows, (int)splitColumns, (int)3);
        int splitType = ValueUtil.asInt((Number)((Number)ncat.get(splitOffset)));
        Number splitValue = (Number)index.get(splitOffset);
        if (Math.abs(splitType) == 1) {
            SimplePredicate.Operator rightOperator;
            SimplePredicate.Operator leftOperator;
            if (splitType == -1) {
                leftOperator = SimplePredicate.Operator.LESS_THAN;
                rightOperator = SimplePredicate.Operator.GREATER_OR_EQUAL;
            } else {
                leftOperator = SimplePredicate.Operator.GREATER_OR_EQUAL;
                rightOperator = SimplePredicate.Operator.LESS_THAN;
            }
            leftPredicate = this.createSimplePredicate(feature, leftOperator, splitValue);
            rightPredicate = this.createSimplePredicate(feature, rightOperator, splitValue);
        } else {
            CategoricalFeature categoricalFeature = (CategoricalFeature)feature;
            RIntegerVector csplitDim = csplit.dim();
            int csplitRows = csplitDim.getValue(0);
            int csplitColumns = csplitDim.getValue(1);
            List csplitRow = FortranMatrixUtil.getRow(csplit.getValues(), (int)csplitRows, (int)csplitColumns, (int)(ValueUtil.asInt((Number)splitValue) - 1));
            List values = categoricalFeature.getValues();
            leftPredicate = this.createPredicate((Feature)categoricalFeature, RPartConverter.selectValues(values, csplitRow, 1));
            rightPredicate = this.createPredicate((Feature)categoricalFeature, RPartConverter.selectValues(values, csplitRow, 3));
        }
        return Arrays.asList(leftPredicate, rightPredicate);
    }

    private void makeDefault(Node node) {
        CompoundPredicate compoundPredicate;
        Predicate predicate = node.requirePredicate();
        if (predicate instanceof CompoundPredicate) {
            compoundPredicate = (CompoundPredicate)predicate;
        } else {
            compoundPredicate = new CompoundPredicate(CompoundPredicate.BooleanOperator.SURROGATE, null).addPredicates(new Predicate[]{predicate});
            node.setPredicate((Predicate)compoundPredicate);
        }
        compoundPredicate.addPredicates(new Predicate[]{True.INSTANCE});
    }

    private Feature getFeature(String name) {
        return this.formula.resolveComplexFeature(name);
    }

    private static List<String> getFeatureNames(List<String> names) {
        return names.stream().filter(name -> !"<leaf>".equals(name)).distinct().collect(Collectors.toList());
    }

    private static int getFeatureIndex(RVector<?> var, int offset, List<? extends Feature> features) {
        if (var instanceof RStringVector) {
            RStringVector stringVar = (RStringVector)var;
            String stringName = stringVar.getValue(offset);
            if ("<leaf>".equals(stringName)) {
                return 0;
            }
            for (int i = 0; i < features.size(); ++i) {
                Feature feature = features.get(i);
                String name = feature.getName();
                if (!name.equals(stringName)) continue;
                return i + 1;
            }
            throw new IllegalArgumentException();
        }
        if (var instanceof RFactorVector) {
            RFactorVector factorVar = (RFactorVector)var;
            return factorVar.getValue(offset) - 1;
        }
        throw new IllegalArgumentException();
    }

    private static int getIndex(RIntegerVector rowNames, int rowName) {
        int index = rowNames.indexOf(rowName);
        if (index < 0) {
            throw new IllegalArgumentException();
        }
        return index;
    }

    private static <E> List<E> selectValues(List<E> values, List<Integer> valueFlags, int flag) {
        ArrayList<E> result = new ArrayList<E>(values.size());
        for (int i = 0; i < values.size(); ++i) {
            E value = values.get(i);
            Integer valueFlag = valueFlags.get(i);
            if (valueFlag != flag) continue;
            result.add(value);
        }
        return result;
    }

    private static interface ScoreEncoder {
        public Node encode(Node var1, int var2);
    }
}

