/*
 * Decompiled with CFR 0.152.
 */
package deepboof.impl.forward.standard;

import deepboof.DeepBoofConstants;
import deepboof.forward.FunctionBatchNorm;
import deepboof.impl.forward.standard.BaseFunction;
import deepboof.misc.TensorOps;
import deepboof.tensors.Tensor_F64;
import java.util.List;

public class FunctionBatchNorm_F64
extends BaseFunction<Tensor_F64>
implements FunctionBatchNorm<Tensor_F64> {
    protected boolean requiresGammaBeta;
    protected Tensor_F64 params = new Tensor_F64(0);
    protected double EPS = DeepBoofConstants.TEST_TOL_F64 * 0.1;

    public FunctionBatchNorm_F64(boolean requiresGammaBeta) {
        this.requiresGammaBeta = requiresGammaBeta;
    }

    @Override
    public void _initialize() {
        this.shapeOutput = (int[])this.shapeInput.clone();
        int[] shapeParam = TensorOps.WI(this.shapeInput, this.requiresGammaBeta ? 4 : 2);
        this.shapeParameters.add(shapeParam);
        this.params.reshape(shapeParam);
    }

    @Override
    public void _setParameters(List<Tensor_F64> parameters) {
        this.params.setTo(parameters.get(0));
        int N = this.params.length();
        int stride = this.requiresGammaBeta ? 4 : 2;
        for (int i = 1; i < N; i += stride) {
            this.params.d[i] = 1.0 / Math.sqrt(this.params.d[i] + this.EPS);
        }
    }

    @Override
    public void _forward(Tensor_F64 input, Tensor_F64 output) {
        if (input.getDimension() <= 1) {
            throw new IllegalArgumentException("Input tensor must be at least 2D.  First dimension of batch.");
        }
        int D2 = TensorOps.outerLength(input.shape, 1);
        int indexIn = input.startIndex;
        int indexOut = output.startIndex;
        if (this.requiresGammaBeta) {
            for (int batch = 0; batch < this.miniBatchSize; ++batch) {
                int indexP = this.params.startIndex;
                int end = indexIn + D2;
                while (indexIn < end) {
                    double mean = this.params.d[indexP++];
                    double inv_stdev_eps = this.params.d[indexP++];
                    double gamma = this.params.d[indexP++];
                    double beta = this.params.d[indexP++];
                    output.d[indexOut++] = (input.d[indexIn++] - mean) * (gamma * inv_stdev_eps) + beta;
                }
            }
        } else {
            for (int stack = 0; stack < this.miniBatchSize; ++stack) {
                int indexP = this.params.startIndex;
                int end = indexIn + D2;
                while (indexIn < end) {
                    double mean = this.params.d[indexP++];
                    double inv_stdev_eps = this.params.d[indexP++];
                    output.d[indexOut++] = (input.d[indexIn++] - mean) * inv_stdev_eps;
                }
            }
        }
    }

    @Override
    public double getEPS() {
        return this.EPS;
    }

    @Override
    public void setEPS(double EPS) {
        this.EPS = EPS;
    }

    @Override
    public boolean hasGammaBeta() {
        return this.requiresGammaBeta;
    }

    @Override
    public Class<Tensor_F64> getTensorType() {
        return Tensor_F64.class;
    }
}

