/*
 * Decompiled with CFR 0.152.
 */
package ace;

import ace.InstanceClassifier;
import ace.Trainer;
import ace.datatypes.CrossValidationResults;
import ace.datatypes.TrainedModel;
import java.io.DataOutputStream;
import java.io.File;
import java.io.FileOutputStream;
import java.io.OutputStream;
import java.text.DecimalFormat;
import java.util.LinkedList;
import mckay.utilities.staticlibraries.MathAndStatsMethods;
import weka.core.Instance;
import weka.core.Instances;

public class CrossValidator {
    Instance[][][] folds;
    String[][][] names;

    public CrossValidator(Instances instances, int num_folds, String[] identifiers, int num_overall, String[] hierarchy) {
        int[] partition = CrossValidator.generatePartitionArray(num_folds, num_overall);
        this.folds = this.partitionInstances(instances, partition, num_folds, identifiers, hierarchy);
    }

    public CrossValidator(Instances instances, int[] partition, int num_folds, String[] hierarchy, String[] identifiers) {
        this.folds = this.partitionInstances(instances, partition, num_folds, identifiers, hierarchy);
    }

    public String crossValidate(TrainedModel trained, CrossValidationResults[] cvres, Instances instances, OutputStream out, StringBuffer cv_results, String file_name, String feature_selector, boolean save_intermediate_arffs, boolean verbose, int i) throws Exception {
        out.write("\tCross Validating....\n".getBytes());
        DecimalFormat df = new DecimalFormat("####0.0#");
        long cross_val_start_time = System.currentTimeMillis();
        int num_folds = this.folds.length;
        double[] error_rates = new double[num_folds];
        double[][][] confusion_matrices = new double[num_folds][][];
        String[] classes = CrossValidator.getClassNames(instances);
        StringBuffer to_return = new StringBuffer();
        StringBuffer per_fold = new StringBuffer();
        per_fold.append("\n==================================================\n");
        per_fold.append("\n----------RESULTS PER FOLD----------\n");
        for (int fold = 0; fold < num_folds; ++fold) {
            Instances training = new Instances(instances, 100);
            Instances testing = new Instances(instances, 100);
            String[][] identifiers = this.getPartitionedInstances(training, testing, fold);
            Trainer.train(training, trained);
            Instances classified = InstanceClassifier.classifyInstances(trained, testing, save_intermediate_arffs);
            double correct_count = InstanceClassifier.getCorrectCount(testing, classified);
            double total_count = testing.numInstances();
            double success_rate = 100.0 * correct_count / total_count;
            error_rates[fold] = 100.0 - success_rate;
            confusion_matrices[fold] = InstanceClassifier.getConfusionMatrix(testing, classified, classes);
            if (cvres.length != 1) continue;
            per_fold.append("\n*************** FOLD: " + fold + " ***************");
            per_fold.append("\nSuccess rate: " + df.format(success_rate));
            per_fold.append("\nConfusion matrix: \n" + InstanceClassifier.formatConfusionMatrix(confusion_matrices[fold], classes));
            if (!verbose) continue;
            per_fold.append(CrossValidator.getClassifications(testing, classified, training, identifiers) + "\n");
        }
        long cross_val_end_time = System.currentTimeMillis();
        cvres[i].error_rates = MathAndStatsMethods.getAverage((double[])error_rates);
        cvres[i].cross_val_times = (double)(cross_val_end_time - cross_val_start_time) / 60000.0;
        cvres[i].standard_deviation = MathAndStatsMethods.getStandardDeviation((double[])error_rates);
        double[][] average_confusion_matrix = this.getOverallConfusionMatrix(confusion_matrices);
        cvres[i].cross_validation_confusion_matrices = InstanceClassifier.formatConfusionMatrix(average_confusion_matrix, classes);
        cvres[i].trained = trained;
        StringBuffer both = new StringBuffer();
        to_return.append("\n==================================================\n");
        cv_results.append("\n==================================================\n");
        to_return.append("\n----------AVERAGE RESULTS----------\n");
        both.append("\nAVERAGE SUCCESS RATE: " + df.format(100.0 - cvres[i].error_rates) + "%");
        both.append("\nAVERAGE ERROR RATE: " + df.format(cvres[i].error_rates) + "%\n");
        both.append("STANDARD DEVIATION: " + df.format(cvres[i].standard_deviation) + "\n");
        both.append("TOTAL CROSS-VALIDATION TIME: " + cvres[i].cross_val_times + " minutes\n");
        cv_results.append(both);
        to_return.append(both);
        if (cvres.length == 1) {
            String dr = trained.attribute_selector != null ? feature_selector : "None";
            to_return.append("\nNUMBER OF FOLDS: " + num_folds + "\n");
            to_return.append("CONFUSION MATRIX:\n" + cvres[i].cross_validation_confusion_matrices + "\n");
            to_return.append("DIMENSIONALITY REDUCTION: " + dr + "\n");
        }
        cv_results.append("CLASSIFIER TYPE: " + cvres[i].classifier_descriptions + "\n");
        to_return.append("CLASSIFIER TYPE: " + cvres[i].classifier_descriptions + "\n");
        StringBuffer intermediate = new StringBuffer();
        intermediate.append(per_fold);
        intermediate.append(to_return);
        if (file_name != null) {
            File save_file = new File(file_name);
            FileOutputStream to = new FileOutputStream(save_file);
            DataOutputStream writer = new DataOutputStream(to);
            writer.writeBytes(to_return.append(per_fold).toString());
        }
        return intermediate.toString();
    }

    public static int[] generatePartitionArray(int num_folds, int num_instances) {
        int[] partition_array = new int[num_instances];
        int[] count = new int[num_folds];
        double max = Math.ceil((double)num_instances / (double)num_folds);
        boolean over_max = false;
        for (int i = 0; i < num_instances; ++i) {
            int selected;
            do {
                if ((double)count[selected = MathAndStatsMethods.generateRandomNumber((int)num_folds)] >= max) {
                    over_max = true;
                    if (i >= num_instances - 1) {
                        over_max = false;
                    }
                } else {
                    over_max = false;
                }
                int n = selected;
                count[n] = count[n] + 1;
            } while (over_max);
            partition_array[i] = selected;
        }
        return partition_array;
    }

    public double[][] getOverallConfusionMatrix(double[][][] confusion_matrices) {
        int num_classes = confusion_matrices[1][1].length;
        int num_folds = confusion_matrices.length;
        double[][] average_matrix = new double[num_classes][num_classes];
        for (int a = 0; a < num_classes; ++a) {
            for (int p = 0; p < num_classes; ++p) {
                int sum = 0;
                for (int f = 0; f < num_folds; ++f) {
                    sum = (int)((double)sum + confusion_matrices[f][a][p]);
                }
                average_matrix[a][p] = sum;
            }
        }
        return average_matrix;
    }

    public static String[] getClassNames(Instances instances) {
        String[] classes = new String[instances.numClasses()];
        for (int c = 0; c < instances.numClasses(); ++c) {
            classes[c] = instances.classAttribute().value(c);
        }
        return classes;
    }

    public static StringBuffer getClassifications(Instances actual, Instances predicted, Instances training, String[][] identifiers) {
        StringBuffer classes = new StringBuffer();
        for (int train = 0; train < identifiers[0].length; ++train) {
            if (identifiers[0][train] == null) continue;
            if (train == 0) {
                classes.append("\n----------Training Instances----------\n");
            }
            classes.append("\n\nTraining Instance " + train + ": " + identifiers[0][train]);
            classes.append("\n\tModel Class: " + training.instance(train).stringValue(training.classIndex()));
        }
        for (int inst = 0; inst < actual.numInstances(); ++inst) {
            if (identifiers[1][inst] == null) continue;
            if (inst == 0) {
                classes.append("\n\n----------Testing Instances----------\n");
            }
            String actual_class = actual.instance(inst).stringValue(actual.classIndex());
            String predicted_class = predicted.instance(inst).stringValue(predicted.classIndex());
            classes.append("\n");
            if (!actual_class.equals(predicted_class)) {
                classes.append("*");
            }
            classes.append("Testing Instance " + inst + ": " + identifiers[1][inst]);
            classes.append("\n\tActual Class: " + actual_class);
            classes.append("\tPredicted Class: " + predicted_class + "\n");
        }
        return classes;
    }

    public static int[] getIndecesOfOverallInstances(String[] hierarchy, int num_overall) {
        int[] indeces = new int[num_overall];
        int j = 0;
        for (int i = 0; i < hierarchy.length; ++i) {
            if (hierarchy[i].contains("_")) continue;
            indeces[j] = i;
            ++j;
        }
        return indeces;
    }

    public static Integer[] getIndecesOfSubsections(String[] hierarchy, String overall) {
        LinkedList<Integer> index_list = new LinkedList<Integer>();
        for (int i = 0; i < hierarchy.length; ++i) {
            if (!hierarchy[i].startsWith(overall + "_")) continue;
            index_list.add(i);
        }
        return index_list.toArray(new Integer[0]);
    }

    private Instance[][][] partitionInstances(Instances instances2, int[] partition, int num_folds, String[] identifiers, String[] hierarchy) {
        Instances instances = new Instances(instances2);
        Instance[][][] partitioned = new Instance[num_folds][2][instances.numInstances()];
        this.names = new String[num_folds][2][instances.numInstances()];
        int[] overall = CrossValidator.getIndecesOfOverallInstances(hierarchy, partition.length);
        for (int fold = 0; fold < num_folds; ++fold) {
            for (int inst = 0; inst < partition.length; ++inst) {
                if (partition[inst] == fold) {
                    partitioned[fold][1][inst] = instances.instance(overall[inst]);
                    this.names[fold][1][inst] = identifiers[overall[inst]];
                    partitioned[fold][0][inst] = null;
                    Integer[] subs = CrossValidator.getIndecesOfSubsections(hierarchy, hierarchy[overall[inst]]);
                    if (subs == null || subs.length <= 0) continue;
                    for (int i = 0; i < subs.length; ++i) {
                        partitioned[fold][1][inst + i + 1] = instances.instance(subs[i].intValue());
                        this.names[fold][1][inst + i + 1] = identifiers[subs[i]];
                        partitioned[fold][0][inst + i + 1] = null;
                    }
                    continue;
                }
                partitioned[fold][0][inst] = instances.instance(inst);
                this.names[fold][0][inst] = identifiers[inst];
                partitioned[fold][1][inst] = null;
            }
        }
        return partitioned;
    }

    private String[][] getPartitionedInstances(Instances training, Instances testing, int fold) {
        String[][] name_array = new String[2][this.folds[fold][0].length];
        int i = 0;
        int k = 0;
        for (int j = 0; j < this.folds[fold][0].length; ++j) {
            if (this.folds[fold][0][j] != null) {
                training.add(this.folds[fold][0][j]);
                name_array[0][i] = this.names[fold][0][j];
                ++i;
                continue;
            }
            if (this.folds[fold][1][j] == null) continue;
            testing.add(this.folds[fold][1][j]);
            name_array[1][k] = this.names[fold][1][j];
            ++k;
        }
        return name_array;
    }

    private static String testInstanceIdentification(String[] identifiers, Instances instances) {
        String compare = "";
        for (int i = 0; i < instances.numInstances(); ++i) {
            System.out.println("\nidentifier: " + identifiers[i] + "\tfeatures: " + instances.instance(i).toString());
        }
        return compare;
    }
}

