/*
 * Decompiled with CFR 0.152.
 */
package org.apache.sysds.runtime.compress.lib;

import java.util.ArrayList;
import java.util.Collection;
import java.util.List;
import java.util.concurrent.Callable;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Future;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.apache.sysds.hops.OptimizerUtils;
import org.apache.sysds.runtime.DMLRuntimeException;
import org.apache.sysds.runtime.compress.CompressedMatrixBlock;
import org.apache.sysds.runtime.compress.DMLCompressionException;
import org.apache.sysds.runtime.compress.colgroup.AColGroup;
import org.apache.sysds.runtime.compress.colgroup.ColGroupConst;
import org.apache.sysds.runtime.compress.colgroup.ColGroupEmpty;
import org.apache.sysds.runtime.compress.colgroup.ColGroupOLE;
import org.apache.sysds.runtime.compress.colgroup.ColGroupUncompressed;
import org.apache.sysds.runtime.compress.colgroup.indexes.IColIndex;
import org.apache.sysds.runtime.data.DenseBlock;
import org.apache.sysds.runtime.functionobjects.Divide;
import org.apache.sysds.runtime.functionobjects.Minus;
import org.apache.sysds.runtime.functionobjects.Multiply;
import org.apache.sysds.runtime.functionobjects.Plus;
import org.apache.sysds.runtime.matrix.data.MatrixBlock;
import org.apache.sysds.runtime.matrix.data.MatrixValue;
import org.apache.sysds.runtime.matrix.operators.LeftScalarOperator;
import org.apache.sysds.runtime.matrix.operators.RightScalarOperator;
import org.apache.sysds.runtime.matrix.operators.ScalarOperator;
import org.apache.sysds.runtime.util.CommonThreadPool;

public final class CLALibScalar {
    private static final Log LOG = LogFactory.getLog((String)CLALibScalar.class.getName());
    private static final int MINIMUM_PARALLEL_SIZE = 8096;

    private CLALibScalar() {
    }

    public static MatrixBlock scalarOperations(ScalarOperator sop, CompressedMatrixBlock m1, MatrixValue result) {
        if (CLALibScalar.isInvalidForCompressedOutput(m1, sop)) {
            LOG.warn((Object)("scalar overlapping not supported for op: " + sop.fn.getClass().getSimpleName()));
            return CLALibScalar.fusedScalarAndDecompress(m1, sop);
        }
        CompressedMatrixBlock ret = CLALibScalar.setupRet(m1, result);
        List<AColGroup> colGroups = m1.getColGroups();
        if (m1.isOverlapping() && !(sop.fn instanceof Multiply) && !(sop.fn instanceof Divide)) {
            double v0 = sop.executeScalar(0.0);
            ColGroupConst c = v0 != 0.0 ? CLALibScalar.constOverlap(m1, v0) : null;
            boolean isMinus = sop instanceof LeftScalarOperator && sop.fn instanceof Minus;
            List<AColGroup> newColGroups = isMinus ? CLALibScalar.copyGroupsAndMultMinus(m1, sop, c, ret) : CLALibScalar.copyGroups(m1, sop, c, ret);
            ret.allocateColGroupList(newColGroups);
            ret.setOverlapping(true);
        } else {
            int threadsAvailable;
            int n = threadsAvailable = sop.getNumThreads() > 1 ? sop.getNumThreads() : OptimizerUtils.getConstrainedNumThreads(-1);
            if (threadsAvailable > 1) {
                CLALibScalar.parallelScalarOperations(sop, colGroups, ret, threadsAvailable);
            } else {
                ArrayList<AColGroup> newColGroups = new ArrayList<AColGroup>();
                for (AColGroup grp : colGroups) {
                    newColGroups.add(grp.scalarOperation(sop));
                }
                ret.allocateColGroupList(newColGroups);
            }
            ret.setOverlapping(m1.isOverlapping());
        }
        if (sop.fn instanceof Divide) {
            ret.setNonZeros(m1.getNonZeros());
        } else {
            ret.recomputeNonZeros();
        }
        return ret;
    }

    private static MatrixBlock fusedScalarAndDecompress(CompressedMatrixBlock in, ScalarOperator sop) {
        if (sop.getNumThreads() <= 1) {
            return CLALibScalar.singleThreadFusedScalarAndDecompress(in, sop);
        }
        return CLALibScalar.parallelFusedScalarAndDecompress(in, sop);
    }

    private static MatrixBlock singleThreadFusedScalarAndDecompress(CompressedMatrixBlock in, ScalarOperator sop) {
        int nRow = in.getNumRows();
        int nCol = in.getNumColumns();
        MatrixBlock out = new MatrixBlock(nRow, nCol, false);
        out.allocateDenseBlock();
        DenseBlock db = out.getDenseBlock();
        List<AColGroup> groups = in.getColGroups();
        long nnz = CLALibScalar.fusedDecompressAndScalar(groups, nCol, 0, nRow, db, sop);
        out.setNonZeros(nnz);
        out.examSparsity(true);
        return out;
    }

    private static MatrixBlock parallelFusedScalarAndDecompress(CompressedMatrixBlock in, ScalarOperator sop) {
        int k = sop.getNumThreads();
        ExecutorService pool = CommonThreadPool.get(k);
        try {
            int nRow = in.getNumRows();
            int nCol = in.getNumColumns();
            MatrixBlock out = new MatrixBlock(nRow, nCol, false);
            List<AColGroup> groups = in.getColGroups();
            out.allocateDenseBlock();
            DenseBlock db = out.getDenseBlock();
            int blkz = Math.max((int)Math.ceil((double)nRow / (double)k), 256);
            ArrayList<Future<Long>> tasks = new ArrayList<Future<Long>>();
            for (int i = 0; i < nRow; i += blkz) {
                int start = i;
                int end = Math.min(i + blkz, nRow);
                tasks.add(pool.submit(() -> CLALibScalar.fusedDecompressAndScalar(groups, nCol, start, end, db, sop)));
            }
            long nnz = 0L;
            for (Future future : tasks) {
                nnz += ((Long)future.get()).longValue();
            }
            out.setNonZeros(nnz);
            out.examSparsity(true, k);
            MatrixBlock matrixBlock = out;
            return matrixBlock;
        }
        catch (Exception e) {
            throw new DMLCompressionException("failed fused scalar operation", e);
        }
        finally {
            pool.shutdown();
        }
    }

    private static long fusedDecompressAndScalar(List<AColGroup> groups, int nCol, int start, int end, DenseBlock db, ScalarOperator sop) {
        long nnz = 0L;
        for (int b = start; b < end; b += 32) {
            int bs = b;
            int be = Math.min(b + 32, end);
            nnz += CLALibScalar.fusedDecompressAndScalarBlock(groups, nCol, bs, be, db, sop);
        }
        return nnz;
    }

    private static long fusedDecompressAndScalarBlock(List<AColGroup> groups, int nCol, int bs, int be, DenseBlock db, ScalarOperator sop) {
        long nnz = 0L;
        for (AColGroup g : groups) {
            g.decompressToDenseBlock(db, bs, be);
        }
        for (int r = bs; r < be; ++r) {
            int off;
            double[] vals = db.values(r);
            for (int c = off = db.pos(r); c < nCol + off; ++c) {
                vals[c] = sop.executeScalar(vals[c]);
                nnz += vals[c] == 0.0 ? 0L : 1L;
            }
        }
        return nnz;
    }

    private static CompressedMatrixBlock setupRet(CompressedMatrixBlock m1, MatrixValue result) {
        CompressedMatrixBlock ret;
        if (result == null || !(result instanceof CompressedMatrixBlock)) {
            ret = new CompressedMatrixBlock(m1.getNumRows(), m1.getNumColumns());
        } else {
            ret = (CompressedMatrixBlock)result;
            ret.setNumColumns(m1.getNumColumns());
            ret.setNumRows(m1.getNumRows());
        }
        return ret;
    }

    private static ColGroupConst constOverlap(CompressedMatrixBlock m1, double v) {
        return (ColGroupConst)ColGroupConst.create(m1.getNumColumns(), v);
    }

    private static List<AColGroup> copyGroups(CompressedMatrixBlock m1, ScalarOperator sop, ColGroupConst c, CompressedMatrixBlock ret) {
        double[] constV = c != null ? c.getValues() : null;
        List<AColGroup> old = m1.getColGroups();
        ArrayList<AColGroup> newColGroups = new ArrayList<AColGroup>(old.size() + 1);
        for (AColGroup grp : old) {
            if (grp instanceof ColGroupEmpty) continue;
            if (grp instanceof ColGroupConst) {
                ColGroupConst g = (ColGroupConst)grp;
                double[] gv = g.getValues();
                IColIndex colIdx = grp.getColIndices();
                if (constV != null) {
                    for (int i = 0; i < colIdx.size(); ++i) {
                        int n = colIdx.get(i);
                        constV[n] = constV[n] + gv[i];
                    }
                    continue;
                }
                newColGroups.add(grp);
                continue;
            }
            newColGroups.add(grp);
        }
        if (c != null) {
            newColGroups.add(c);
        }
        return newColGroups;
    }

    private static List<AColGroup> copyGroupsAndMultMinus(CompressedMatrixBlock m1, ScalarOperator sop, ColGroupConst c, CompressedMatrixBlock ret) {
        double[] constV = c != null ? c.getValues() : null;
        ArrayList<AColGroup> newColGroups = new ArrayList<AColGroup>();
        for (AColGroup grp : m1.getColGroups()) {
            if (grp instanceof ColGroupEmpty) continue;
            if (grp instanceof ColGroupConst) {
                ColGroupConst g = (ColGroupConst)grp;
                double[] gv = g.getValues();
                IColIndex colIdx = grp.getColIndices();
                if (constV != null) {
                    for (int i = 0; i < colIdx.size(); ++i) {
                        int n = colIdx.get(i);
                        constV[n] = constV[n] - gv[i];
                    }
                    continue;
                }
                newColGroups.add(grp);
                continue;
            }
            newColGroups.add(grp.scalarOperation(new RightScalarOperator(Multiply.getMultiplyFnObject(), -1.0)));
        }
        if (c != null) {
            newColGroups.add(c);
        }
        newColGroups.add(c);
        return newColGroups;
    }

    private static boolean isInvalidForCompressedOutput(CompressedMatrixBlock m1, ScalarOperator sop) {
        return m1.isOverlapping() && !(sop.fn instanceof Multiply) && (!(sop.fn instanceof Divide) || !(sop instanceof RightScalarOperator)) && !(sop.fn instanceof Plus) && !(sop.fn instanceof Minus);
    }

    private static void parallelScalarOperations(ScalarOperator sop, List<AColGroup> colGroups, CompressedMatrixBlock ret, int k) {
        if (colGroups == null) {
            return;
        }
        ExecutorService pool = CommonThreadPool.get(k);
        try {
            List<ScalarTask> tasks = CLALibScalar.partition(sop, colGroups);
            List rtasks = pool.invokeAll(tasks);
            ArrayList<AColGroup> newColGroups = new ArrayList<AColGroup>();
            for (Future f : rtasks) {
                newColGroups.addAll((Collection)f.get());
            }
            ret.allocateColGroupList(newColGroups);
        }
        catch (InterruptedException | ExecutionException e) {
            throw new DMLRuntimeException(e);
        }
        finally {
            pool.shutdown();
        }
    }

    private static List<ScalarTask> partition(ScalarOperator sop, List<AColGroup> colGroups) {
        ArrayList<ScalarTask> tasks = new ArrayList<ScalarTask>();
        ArrayList<AColGroup> small = new ArrayList<AColGroup>();
        for (AColGroup grp : colGroups) {
            if (grp instanceof ColGroupUncompressed) {
                ArrayList<AColGroup> uc = new ArrayList<AColGroup>();
                uc.add(grp);
                tasks.add(new ScalarTask(uc, sop));
            } else {
                int nv = grp.getNumValues() * grp.getNumCols();
                if (nv < 8096 && !(grp instanceof ColGroupOLE)) {
                    small.add(grp);
                } else {
                    ArrayList<AColGroup> large = new ArrayList<AColGroup>();
                    large.add(grp);
                    tasks.add(new ScalarTask(large, sop));
                }
            }
            if (small.size() <= 10) continue;
            tasks.add(new ScalarTask(small, sop));
            small = new ArrayList();
        }
        if (small.size() > 0) {
            tasks.add(new ScalarTask(small, sop));
        }
        return tasks;
    }

    private static class ScalarTask
    implements Callable<List<AColGroup>> {
        private final List<AColGroup> _colGroups;
        private final ScalarOperator _sop;

        protected ScalarTask(List<AColGroup> colGroups, ScalarOperator sop) {
            this._colGroups = colGroups;
            this._sop = sop;
        }

        @Override
        public List<AColGroup> call() {
            ArrayList<AColGroup> res = new ArrayList<AColGroup>();
            for (AColGroup x : this._colGroups) {
                res.add(x.scalarOperation(this._sop));
            }
            return res;
        }
    }
}

