/*
 * Decompiled with CFR 0.152.
 */
package org.apache.sysds.runtime.codegen;

import java.io.Serializable;
import java.util.ArrayList;
import java.util.List;
import java.util.concurrent.Callable;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Future;
import org.apache.sysds.runtime.DMLRuntimeException;
import org.apache.sysds.runtime.codegen.SpoofCUDAOperator;
import org.apache.sysds.runtime.codegen.SpoofCellwise;
import org.apache.sysds.runtime.codegen.SpoofOperator;
import org.apache.sysds.runtime.data.DenseBlock;
import org.apache.sysds.runtime.data.SparseBlock;
import org.apache.sysds.runtime.functionobjects.Builtin;
import org.apache.sysds.runtime.functionobjects.KahanFunction;
import org.apache.sysds.runtime.functionobjects.KahanPlus;
import org.apache.sysds.runtime.functionobjects.KahanPlusSq;
import org.apache.sysds.runtime.functionobjects.ValueFunction;
import org.apache.sysds.runtime.instructions.cp.KahanObject;
import org.apache.sysds.runtime.instructions.cp.ScalarObject;
import org.apache.sysds.runtime.matrix.data.MatrixBlock;
import org.apache.sysds.runtime.util.CommonThreadPool;
import org.apache.sysds.runtime.util.UtilFunctions;

public abstract class SpoofMultiAggregate
extends SpoofOperator
implements Serializable {
    private static final long serialVersionUID = -6164871955591089349L;
    private final SpoofCellwise.AggOp[] _aggOps;
    private final boolean _sparseSafe;

    public SpoofMultiAggregate(boolean sparseSafe, SpoofCellwise.AggOp ... aggOps) {
        this._sparseSafe = sparseSafe;
        this._aggOps = aggOps;
    }

    public SpoofCellwise.AggOp[] getAggOps() {
        return this._aggOps;
    }

    public boolean isSparseSafe() {
        return this._sparseSafe;
    }

    @Override
    public String getSpoofType() {
        return "MA" + this.getClass().getName().split("\\.")[1];
    }

    @Override
    public SpoofCUDAOperator createCUDAInstrcution(Integer opID, SpoofCUDAOperator.PrecisionProxy ep) {
        return null;
    }

    @Override
    public MatrixBlock execute(ArrayList<MatrixBlock> inputs, ArrayList<ScalarObject> scalarObjects, MatrixBlock out) {
        return this.execute(inputs, scalarObjects, out, 1, 0L);
    }

    @Override
    public MatrixBlock execute(ArrayList<MatrixBlock> inputs, ArrayList<ScalarObject> scalarObjects, MatrixBlock out, int k) {
        return this.execute(inputs, scalarObjects, out, k, 0L);
    }

    public MatrixBlock execute(ArrayList<MatrixBlock> inputs, ArrayList<ScalarObject> scalarObjects, MatrixBlock out, int k, long rix) {
        long inputSize;
        if (inputs == null || inputs.size() < 1) {
            throw new RuntimeException("Invalid input arguments.");
        }
        long l = inputSize = this.isSparseSafe() ? SpoofMultiAggregate.getTotalInputNnz(inputs) : SpoofMultiAggregate.getTotalInputSize(inputs);
        if (inputSize < 0x100000L) {
            k = 1;
        }
        out.reset(1, this._aggOps.length, false);
        out.allocateDenseBlock();
        double[] c = out.getDenseBlockValues();
        this.setInitialOutputValues(c);
        SpoofOperator.SideInput[] b = this.prepInputMatrices(inputs);
        double[] scalars = SpoofMultiAggregate.prepInputScalars(scalarObjects);
        int m = inputs.get(0).getNumRows();
        int n = inputs.get(0).getNumColumns();
        boolean sparseSafe = this.isSparseSafe();
        if (k <= 1) {
            if (!inputs.get(0).isInSparseFormat()) {
                this.executeDense(inputs.get(0).getDenseBlock(), b, scalars, c, m, n, sparseSafe, 0, m, rix);
            } else {
                this.executeSparse(inputs.get(0).getSparseBlock(), b, scalars, c, m, n, sparseSafe, 0, m, rix);
            }
        } else {
            try {
                ExecutorService pool = CommonThreadPool.get(k);
                ArrayList<ParAggTask> tasks = new ArrayList<ParAggTask>();
                int nk = UtilFunctions.roundToNext(Math.min(8 * k, m / 32), k);
                int blklen = (int)Math.ceil((double)m / (double)nk);
                int i = 0;
                while (i < nk & i * blklen < m) {
                    tasks.add(new ParAggTask(inputs.get(0), b, scalars, m, n, sparseSafe, i * blklen, Math.min((i + 1) * blklen, m)));
                    ++i;
                }
                List taskret = pool.invokeAll(tasks);
                pool.shutdown();
                ArrayList<double[]> pret = new ArrayList<double[]>();
                for (Future task : taskret) {
                    pret.add((double[])task.get());
                }
                this.aggregatePartialResults(c, pret);
            }
            catch (Exception ex) {
                throw new DMLRuntimeException(ex);
            }
        }
        out.recomputeNonZeros();
        out.examSparsity();
        return out;
    }

    private void executeDense(DenseBlock a, SpoofOperator.SideInput[] b, double[] scalars, double[] c, int m, int n, boolean sparseSafe, int rl, int ru, long rix) {
        block5: {
            SpoofOperator.SideInput[] lb;
            block4: {
                lb = SpoofMultiAggregate.createSparseSideInputs(b);
                if (a != null || sparseSafe) break block4;
                for (int i = rl; i < ru; ++i) {
                    for (int j = 0; j < n; ++j) {
                        this.genexec(0.0, lb, scalars, c, m, n, rix + (long)i, i, j);
                    }
                }
                break block5;
            }
            if (a == null) break block5;
            for (int i = rl; i < ru; ++i) {
                double[] avals = a.values(i);
                int aix = a.pos(i);
                for (int j = 0; j < n; ++j) {
                    this.genexec(avals[aix + j], lb, scalars, c, m, n, rix + (long)i, i, j);
                }
            }
        }
    }

    private void executeSparse(SparseBlock sblock, SpoofOperator.SideInput[] b, double[] scalars, double[] c, int m, int n, boolean sparseSafe, int rl, int ru, long rix) {
        if (sblock == null && sparseSafe) {
            return;
        }
        SpoofOperator.SideInput[] lb = SpoofMultiAggregate.createSparseSideInputs(b);
        for (int i = rl; i < ru; ++i) {
            int lastj = -1;
            if (sblock != null && !sblock.isEmpty(i)) {
                int apos = sblock.pos(i);
                int alen = sblock.size(i);
                int[] aix = sblock.indexes(i);
                double[] avals = sblock.values(i);
                for (int k = apos; k < apos + alen; ++k) {
                    if (!sparseSafe) {
                        for (int j = lastj + 1; j < aix[k]; ++j) {
                            this.genexec(0.0, lb, scalars, c, m, n, rix + (long)i, i, j);
                        }
                    }
                    lastj = aix[k];
                    this.genexec(avals[k], lb, scalars, c, m, n, rix + (long)i, i, lastj);
                }
            }
            if (sparseSafe) continue;
            for (int j = lastj + 1; j < n; ++j) {
                this.genexec(0.0, lb, scalars, c, m, n, rix + (long)i, i, j);
            }
        }
    }

    protected final void genexec(double a, SpoofOperator.SideInput[] b, double[] scalars, double[] c, int m, int n, int rix, int cix) {
        this.genexec(a, b, scalars, c, m, n, rix, rix, cix);
    }

    protected abstract void genexec(double var1, SpoofOperator.SideInput[] var3, double[] var4, double[] var5, int var6, int var7, long var8, int var10, int var11);

    private void setInitialOutputValues(double[] c) {
        for (int k = 0; k < this._aggOps.length; ++k) {
            c[k] = SpoofMultiAggregate.getInitialValue(this._aggOps[k]);
        }
    }

    public static double getInitialValue(SpoofCellwise.AggOp aggop) {
        switch (aggop) {
            case SUM: 
            case SUM_SQ: {
                return 0.0;
            }
            case MIN: {
                return Double.POSITIVE_INFINITY;
            }
            case MAX: {
                return Double.NEGATIVE_INFINITY;
            }
        }
        return 0.0;
    }

    private void aggregatePartialResults(double[] c, ArrayList<double[]> pret) {
        ValueFunction[] vfun = SpoofMultiAggregate.getAggFunctions(this._aggOps);
        for (int k = 0; k < this._aggOps.length; ++k) {
            if (vfun[k] instanceof KahanFunction) {
                KahanObject kbuff = new KahanObject(0.0, 0.0);
                KahanPlus kplus = KahanPlus.getKahanPlusFnObject();
                for (double[] tmp : pret) {
                    kplus.execute2(kbuff, tmp[k]);
                }
                c[k] = kbuff._sum;
                continue;
            }
            for (double[] tmp : pret) {
                c[k] = vfun[k].execute(c[k], tmp[k]);
            }
        }
    }

    public static void aggregatePartialResults(SpoofCellwise.AggOp[] aggOps, MatrixBlock c, MatrixBlock b) {
        ValueFunction[] vfun = SpoofMultiAggregate.getAggFunctions(aggOps);
        for (int k = 0; k < aggOps.length; ++k) {
            if (vfun[k] instanceof KahanFunction) {
                KahanObject kbuff = new KahanObject(c.quickGetValue(0, k), 0.0);
                KahanPlus kplus = KahanPlus.getKahanPlusFnObject();
                kplus.execute2(kbuff, b.quickGetValue(0, k));
                c.quickSetValue(0, k, kbuff._sum);
                continue;
            }
            double cval = c.quickGetValue(0, k);
            double bval = b.quickGetValue(0, k);
            c.quickSetValue(0, k, vfun[k].execute(cval, bval));
        }
    }

    public static ValueFunction[] getAggFunctions(SpoofCellwise.AggOp[] aggOps) {
        ValueFunction[] fun = new ValueFunction[aggOps.length];
        block6: for (int i = 0; i < aggOps.length; ++i) {
            switch (aggOps[i]) {
                case SUM: {
                    fun[i] = KahanPlus.getKahanPlusFnObject();
                    continue block6;
                }
                case SUM_SQ: {
                    fun[i] = KahanPlusSq.getKahanPlusSqFnObject();
                    continue block6;
                }
                case MIN: {
                    fun[i] = Builtin.getBuiltinFnObject(Builtin.BuiltinCode.MIN);
                    continue block6;
                }
                case MAX: {
                    fun[i] = Builtin.getBuiltinFnObject(Builtin.BuiltinCode.MAX);
                    continue block6;
                }
                default: {
                    throw new RuntimeException("Unsupported aggregation type: " + aggOps[i].name());
                }
            }
        }
        return fun;
    }

    private class ParAggTask
    implements Callable<double[]> {
        private final MatrixBlock _a;
        private final SpoofOperator.SideInput[] _b;
        private final double[] _scalars;
        private final int _rlen;
        private final int _clen;
        private final boolean _safe;
        private final int _rl;
        private final int _ru;

        protected ParAggTask(MatrixBlock a, SpoofOperator.SideInput[] b, double[] scalars, int rlen, int clen, boolean safe, int rl, int ru) {
            this._a = a;
            this._b = b;
            this._scalars = scalars;
            this._rlen = rlen;
            this._clen = clen;
            this._safe = safe;
            this._rl = rl;
            this._ru = ru;
        }

        @Override
        public double[] call() {
            double[] c = new double[SpoofMultiAggregate.this._aggOps.length];
            SpoofMultiAggregate.this.setInitialOutputValues(c);
            if (!this._a.isInSparseFormat()) {
                SpoofMultiAggregate.this.executeDense(this._a.getDenseBlock(), this._b, this._scalars, c, this._rlen, this._clen, this._safe, this._rl, this._ru, 0L);
            } else {
                SpoofMultiAggregate.this.executeSparse(this._a.getSparseBlock(), this._b, this._scalars, c, this._rlen, this._clen, this._safe, this._rl, this._ru, 0L);
            }
            return c;
        }
    }
}

