/*
 * Decompiled with CFR 0.152.
 */
package org.ddogleg.nn.alg;

import java.util.List;
import java.util.Objects;
import org.ddogleg.nn.alg.AxisSplitRule;
import org.ddogleg.nn.alg.AxisSplitRuleMax;
import org.ddogleg.nn.alg.AxisSplitter;
import org.ddogleg.nn.alg.KdTreeDistance;
import org.ddogleg.sorting.QuickSelect;
import org.ddogleg.struct.DogArray_I32;
import org.jetbrains.annotations.Nullable;

public class AxisSplitterMedian<P>
implements AxisSplitter<P> {
    private final int N;
    private final double[] mean;
    private final double[] var;
    private double[] tmp = new double[1];
    private int[] indexes = new int[1];
    AxisSplitRule splitRule;
    KdTreeDistance<P> distance;
    int splitAxis;
    P splitPoint;
    int splitIndex;

    public AxisSplitterMedian(KdTreeDistance<P> distance, AxisSplitRule splitRule) {
        this.distance = distance;
        this.splitRule = splitRule;
        this.N = distance.length();
        this.mean = new double[this.N];
        this.var = new double[this.N];
        splitRule.setDimension(this.N);
    }

    public AxisSplitterMedian(KdTreeDistance<P> distance) {
        this(distance, new AxisSplitRuleMax());
    }

    @Override
    public void splitData(List<P> points, @Nullable DogArray_I32 indexes, List<P> left, @Nullable DogArray_I32 leftIndexes, List<P> right, @Nullable DogArray_I32 rightIndexes) {
        this.computeAxisVariance(points);
        for (int i = 0; i < this.N; ++i) {
            if (!Double.isNaN(this.var[i])) continue;
            throw new RuntimeException("Variance is NaN.  Bad input is the cause. mean[i]=" + this.mean[i] + " i=" + i + " points.size=" + points.size());
        }
        this.splitAxis = this.splitRule.select(this.var);
        int medianNum = points.size() / 2;
        this.quickSelect(points, this.splitAxis, medianNum);
        this.splitPoint = points.get(this.indexes[medianNum]);
        if (indexes == null) {
            int i;
            for (i = 0; i < medianNum; ++i) {
                left.add(points.get(this.indexes[i]));
            }
            for (i = medianNum + 1; i < points.size(); ++i) {
                right.add(points.get(this.indexes[i]));
            }
        } else {
            int index;
            int i;
            Objects.requireNonNull(leftIndexes);
            Objects.requireNonNull(rightIndexes);
            leftIndexes.reset();
            rightIndexes.reset();
            for (i = 0; i < medianNum; ++i) {
                index = this.indexes[i];
                left.add(points.get(index));
                leftIndexes.add(indexes.get(index));
            }
            for (i = medianNum + 1; i < points.size(); ++i) {
                index = this.indexes[i];
                right.add(points.get(index));
                rightIndexes.add(indexes.get(index));
            }
            this.splitIndex = indexes.get(this.indexes[medianNum]);
        }
    }

    @Override
    public P getSplitPoint() {
        return this.splitPoint;
    }

    @Override
    public int getSplitIndex() {
        return this.splitIndex;
    }

    @Override
    public int getSplitAxis() {
        return this.splitAxis;
    }

    @Override
    public int getPointLength() {
        return this.N;
    }

    private void computeAxisVariance(List<P> points) {
        int j;
        P p;
        int i;
        int numPoints = points.size();
        for (i = 0; i < this.N; ++i) {
            this.mean[i] = 0.0;
            this.var[i] = 0.0;
        }
        for (i = 0; i < numPoints; ++i) {
            p = points.get(i);
            for (j = 0; j < this.N; ++j) {
                int n = j;
                this.mean[n] = this.mean[n] + this.distance.valueAt(p, j);
            }
        }
        i = 0;
        while (i < this.N) {
            int n = i++;
            this.mean[n] = this.mean[n] / (double)numPoints;
        }
        for (i = 0; i < numPoints; ++i) {
            p = points.get(i);
            j = 0;
            while (j < this.N) {
                double d = this.mean[j] - this.distance.valueAt(p, j);
                int n = j++;
                this.var[n] = this.var[n] + d * d;
            }
        }
    }

    private void quickSelect(List<P> points, int splitAxis, int medianNum) {
        int numPoints = points.size();
        if (this.tmp.length < numPoints) {
            this.tmp = new double[numPoints];
            this.indexes = new int[numPoints];
        }
        for (int i = 0; i < numPoints; ++i) {
            this.tmp[i] = this.distance.valueAt(points.get(i), splitAxis);
        }
        QuickSelect.selectIndex(this.tmp, medianNum, numPoints, this.indexes);
    }
}

