/*
 * Decompiled with CFR 0.152.
 */
package org.tribuo.classification.evaluation;

import java.util.List;
import java.util.function.ToDoubleBiFunction;
import org.tribuo.Prediction;
import org.tribuo.classification.Label;
import org.tribuo.classification.evaluation.ConfusionMetrics;
import org.tribuo.classification.evaluation.LabelEvaluationUtil;
import org.tribuo.classification.evaluation.LabelMetric;
import org.tribuo.evaluation.metrics.MetricTarget;

public enum LabelMetrics {
    TP((tgt, ctx) -> ConfusionMetrics.tp(tgt, ctx.getCM())),
    FP((tgt, ctx) -> ConfusionMetrics.fp(tgt, ctx.getCM())),
    TN((tgt, ctx) -> ConfusionMetrics.tn(tgt, ctx.getCM())),
    FN((tgt, ctx) -> ConfusionMetrics.fn(tgt, ctx.getCM())),
    PRECISION((tgt, ctx) -> ConfusionMetrics.precision(tgt, ctx.getCM())),
    RECALL((tgt, ctx) -> ConfusionMetrics.recall(tgt, ctx.getCM())),
    F1((tgt, ctx) -> ConfusionMetrics.f1(tgt, ctx.getCM())),
    ACCURACY((tgt, ctx) -> ConfusionMetrics.accuracy(tgt, ctx.getCM())),
    BALANCED_ERROR_RATE((tgt, ctx) -> ConfusionMetrics.balancedErrorRate(ctx.getCM())),
    AUCROC((tgt, ctx) -> LabelMetrics.AUCROC((MetricTarget<Label>)tgt, (List<Prediction<Label>>)ctx.getPredictions())),
    AVERAGED_PRECISION((tgt, ctx) -> LabelMetrics.averagedPrecision((MetricTarget<Label>)tgt, (List<Prediction<Label>>)ctx.getPredictions()));

    private final ToDoubleBiFunction<MetricTarget<Label>, LabelMetric.Context> impl;

    private LabelMetrics(ToDoubleBiFunction<MetricTarget<Label>, LabelMetric.Context> impl) {
        this.impl = impl;
    }

    public ToDoubleBiFunction<MetricTarget<Label>, LabelMetric.Context> getImpl() {
        return this.impl;
    }

    public LabelMetric forTarget(MetricTarget<Label> tgt) {
        return new LabelMetric(tgt, this.name(), this.getImpl());
    }

    public static double averagedPrecision(MetricTarget<Label> tgt, List<Prediction<Label>> predictions) {
        if (tgt.getOutputTarget().isPresent()) {
            return LabelMetrics.averagedPrecision((Label)tgt.getOutputTarget().get(), predictions);
        }
        throw new IllegalStateException("Unsupported MetricTarget for averagedPrecision");
    }

    public static double averagedPrecision(Label label, List<Prediction<Label>> predictions) {
        PredictionProbabilities record = new PredictionProbabilities(label, predictions);
        return LabelEvaluationUtil.averagedPrecision(record.ypos, record.yscore);
    }

    public static LabelEvaluationUtil.PRCurve precisionRecallCurve(Label label, List<Prediction<Label>> predictions) {
        PredictionProbabilities record = new PredictionProbabilities(label, predictions);
        return LabelEvaluationUtil.generatePRCurve(record.ypos, record.yscore);
    }

    public static double AUCROC(Label label, List<Prediction<Label>> predictions) {
        PredictionProbabilities record = new PredictionProbabilities(label, predictions);
        return LabelEvaluationUtil.binaryAUCROC(record.ypos, record.yscore);
    }

    public static double AUCROC(MetricTarget<Label> tgt, List<Prediction<Label>> predictions) {
        if (tgt.getOutputTarget().isPresent()) {
            return LabelMetrics.AUCROC((Label)tgt.getOutputTarget().get(), predictions);
        }
        throw new IllegalStateException("Unsupported MetricTarget for AUCROC");
    }

    private static final class PredictionProbabilities {
        final boolean[] ypos;
        final double[] yscore;

        PredictionProbabilities(Label label, List<Prediction<Label>> predictions) {
            int n = predictions.size();
            this.ypos = new boolean[n];
            this.yscore = new double[n];
            for (int i = 0; i < n; ++i) {
                Prediction<Label> prediction = predictions.get(i);
                if (!prediction.hasProbabilities()) {
                    throw new UnsupportedOperationException(String.format("Invalid prediction at index %d: has no probability score.", i));
                }
                if (((Label)prediction.getExample().getOutput()).equals(label)) {
                    this.ypos[i] = true;
                }
                this.yscore[i] = ((Label)prediction.getOutputScores().get(label.getLabel())).getScore();
            }
        }
    }
}

