/*
 * Decompiled with CFR 0.152.
 */
package deepboof.misc;

import deepboof.misc.TensorOps;
import deepboof.tensors.Tensor_F32;
import java.util.Arrays;

public class TensorOps_F32 {
    public static void elementMult(Tensor_F32 tensor, float value) {
        int index = tensor.startIndex;
        int end = index + tensor.length();
        while (index < end) {
            int n = index++;
            tensor.d[n] = tensor.d[n] * value;
        }
    }

    public static void elementMult(Tensor_F32 input, float value, Tensor_F32 output) {
        TensorOps.checkShape(input, output);
        int indexIn = input.startIndex;
        int indexOut = output.startIndex;
        int end = indexIn + input.length();
        while (indexIn < end) {
            output.d[indexOut] = input.d[indexIn] * value;
            ++indexIn;
            ++indexOut;
        }
    }

    public static void elementMult(Tensor_F32 A2, Tensor_F32 B, Tensor_F32 output) {
        block4: {
            int indexB;
            int endA;
            int indexA;
            block5: {
                block3: {
                    indexA = A2.startIndex;
                    endA = indexA + A2.length();
                    indexB = B.startIndex;
                    if (A2 == output || B == output) break block3;
                    int indexOut = output.startIndex;
                    while (indexA < endA) {
                        output.d[indexOut++] = A2.d[indexA++] * B.d[indexB++];
                    }
                    break block4;
                }
                if (B != output) break block5;
                while (indexA < endA) {
                    int n = indexB++;
                    B.d[n] = B.d[n] * A2.d[indexA++];
                }
                break block4;
            }
            if (A2 != output) break block4;
            while (indexA < endA) {
                int n = indexA++;
                A2.d[n] = A2.d[n] * B.d[indexB++];
            }
        }
    }

    public static void elementAdd(Tensor_F32 A2, Tensor_F32 B, Tensor_F32 output) {
        int indexA = A2.startIndex;
        int endA = indexA + A2.length();
        int indexB = B.startIndex;
        if (A2 != output && B != output) {
            int indexOut = output.startIndex;
            while (indexA < endA) {
                output.d[indexOut++] = A2.d[indexA++] + B.d[indexB++];
            }
        } else if (B == output) {
            while (indexA < endA) {
                int n = indexB++;
                B.d[n] = B.d[n] + A2.d[indexA++];
            }
        } else {
            while (indexA < endA) {
                int n = indexA++;
                A2.d[n] = A2.d[n] + B.d[indexB++];
            }
        }
    }

    public static float elementSum(Tensor_F32 tensor) {
        int index = tensor.startIndex;
        int end = index + tensor.length();
        float sum = 0.0f;
        while (index < end) {
            sum += tensor.d[index++];
        }
        return sum;
    }

    public static void insertSubChannel(Tensor_F32 src, int srcStartIndex, int srcStride, Tensor_F32 dst, int dstStartIndex, int dstStride, int rows, int columns) {
        int indexSrc = srcStartIndex;
        int indexDst = dstStartIndex;
        for (int i = 0; i < rows; ++i) {
            System.arraycopy(src.d, indexSrc, dst.d, indexDst, columns);
            indexSrc += srcStride;
            indexDst += dstStride;
        }
    }

    public static void insertSpatial(Tensor_F32 src, int[] srcCoor, Tensor_F32 dst, int[] dstCoor) {
        if (srcCoor.length < 3) {
            throw new IllegalArgumentException("dimensions must be >= 3 for src");
        }
        if (dstCoor.length < 3) {
            throw new IllegalArgumentException("dimensions must be >= 3 for dst");
        }
        if (srcCoor.length != src.getDimension()) {
            throw new IllegalArgumentException("Coordinate length doesn't match tensor dimension for src");
        }
        if (dstCoor.length != dst.getDimension()) {
            throw new IllegalArgumentException("Coordinate length doesn't match tensor dimension for dst");
        }
        int srcAxis = srcCoor.length - 3;
        int dstAxis = dstCoor.length - 3;
        for (int i = 0; i < 3; ++i) {
            srcCoor[srcAxis + i] = 0;
        }
        dstCoor[dstAxis] = 0;
        int numChannels = src.length(-3);
        int height = src.length(-2);
        int width = src.length(-1);
        int heightDst = dst.length(-2);
        int widthDst = dst.length(-1);
        if (numChannels != dst.length(dstAxis)) {
            throw new IllegalArgumentException("Number of channels do not match in src and dst");
        }
        if (height > heightDst) {
            throw new IllegalArgumentException("src height is larger than dst");
        }
        if (width > widthDst) {
            throw new IllegalArgumentException("src width is larger than dst");
        }
        int pixelSrc = src.idx(srcCoor);
        int pixelDst = dst.idx(dstCoor);
        if (width == widthDst && height == heightDst) {
            System.arraycopy(src.d, pixelSrc, dst.d, pixelDst, numChannels * width * height);
        } else {
            int channelSrc = pixelSrc;
            int channelDst = pixelDst;
            for (int channel = 0; channel < numChannels; ++channel) {
                pixelSrc = channelSrc;
                pixelDst = channelDst;
                for (int row = 0; row < height; ++row) {
                    System.arraycopy(src.d, pixelSrc, dst.d, pixelDst, width);
                    pixelSrc += width;
                    pixelDst += widthDst;
                }
                channelSrc += width * height;
                channelDst += widthDst * heightDst;
            }
        }
    }

    public static void fillSpatialBorder(Tensor_F32 tensor, int[] coor, int borderY0, int borderX0, int borderY1, int borderX1, float value) {
        int channelAxis;
        for (int i = channelAxis = coor.length - 3; i < coor.length; ++i) {
            coor[i] = 0;
        }
        int numChannels = tensor.length(channelAxis);
        int height = tensor.length(channelAxis + 1);
        int width = tensor.length(channelAxis + 2);
        if (borderY0 + borderY1 > height) {
            throw new IllegalArgumentException("Y border is larger than image height");
        }
        if (borderX0 + borderX1 > width) {
            throw new IllegalArgumentException("X border is larger than image width");
        }
        for (int channel = 0; channel < numChannels; ++channel) {
            coor[channelAxis] = channel;
            coor[channelAxis + 1] = 0;
            coor[channelAxis + 2] = 0;
            int indexTop = tensor.idx(coor);
            Arrays.fill(tensor.d, indexTop, indexTop + borderY0 * width, value);
            coor[channelAxis + 1] = height - borderY1;
            int indexBottom = tensor.idx(coor);
            Arrays.fill(tensor.d, indexBottom, indexBottom + borderY1 * width, value);
            for (int y = borderY0; y < height - borderY1; ++y) {
                int i;
                coor[channelAxis + 1] = y;
                int left = tensor.idx(coor);
                int right = left + width - borderX1;
                for (i = 0; i < borderX0; ++i) {
                    tensor.d[left + i] = value;
                }
                for (i = 0; i < borderX1; ++i) {
                    tensor.d[right + i] = value;
                }
            }
        }
    }

    public static void printSpatial(Tensor_F32 tensor, int batch, int channel) {
        int rows = tensor.length(2);
        int cols = tensor.length(3);
        System.out.println(tensor.getClass().getSimpleName() + " batch " + batch + "  channel " + channel);
        System.out.println("     rows " + rows + " columns " + cols);
        for (int row = 0; row < rows; ++row) {
            int col = 0;
            while (col < cols) {
                System.out.printf("%10.3fe ", Float.valueOf(tensor.get(batch, channel, row, col++)));
            }
            System.out.println();
        }
    }

    public static void fill(Tensor_F32 tensor, float value) {
        Arrays.fill(tensor.d, tensor.startIndex, tensor.startIndex + tensor.length(), value);
    }
}

