/*
 * Decompiled with CFR 0.152.
 */
package org.apache.sysds.runtime.instructions.cp;

import org.apache.commons.lang3.tuple.Pair;
import org.apache.sysds.common.Opcodes;
import org.apache.sysds.common.Types;
import org.apache.sysds.lops.Ctable;
import org.apache.sysds.runtime.DMLRuntimeException;
import org.apache.sysds.runtime.controlprogram.context.ExecutionContext;
import org.apache.sysds.runtime.instructions.InstructionUtils;
import org.apache.sysds.runtime.instructions.cp.CPInstruction;
import org.apache.sysds.runtime.instructions.cp.CPOperand;
import org.apache.sysds.runtime.instructions.cp.ComputationCPInstruction;
import org.apache.sysds.runtime.lineage.LineageItem;
import org.apache.sysds.runtime.lineage.LineageItemUtils;
import org.apache.sysds.runtime.matrix.data.CTableMap;
import org.apache.sysds.runtime.matrix.data.LibMatrixReorg;
import org.apache.sysds.runtime.matrix.data.MatrixBlock;
import org.apache.sysds.runtime.matrix.data.MatrixValue;
import org.apache.sysds.runtime.util.DataConverter;
import org.apache.sysds.runtime.util.LongLongDoubleHashMap;

public class CtableCPInstruction
extends ComputationCPInstruction {
    private final CPOperand _outDim1;
    private final CPOperand _outDim2;
    private final boolean _isExpand;
    private final boolean _ignoreZeros;
    private final int _k;

    private CtableCPInstruction(CPOperand in1, CPOperand in2, CPOperand in3, CPOperand out, String outputDim1, boolean dim1Literal, String outputDim2, boolean dim2Literal, boolean isExpand, boolean ignoreZeros, String opcode, String istr, int k) {
        super(CPInstruction.CPType.Ctable, null, in1, in2, in3, out, opcode, istr);
        this._outDim1 = new CPOperand(outputDim1, Types.ValueType.FP64, Types.DataType.SCALAR, dim1Literal);
        this._outDim2 = new CPOperand(outputDim2, Types.ValueType.FP64, Types.DataType.SCALAR, dim2Literal);
        this._isExpand = isExpand;
        this._ignoreZeros = ignoreZeros;
        this._k = k;
    }

    public static CtableCPInstruction parseInstruction(String inst) {
        String[] parts = InstructionUtils.getInstructionPartsWithValueType(inst);
        InstructionUtils.checkNumFields(parts, 8);
        String opcode = parts[0];
        if (!opcode.equalsIgnoreCase(Opcodes.CTABLE.toString()) && !opcode.equalsIgnoreCase(Opcodes.CTABLEEXPAND.toString())) {
            throw new DMLRuntimeException("Unexpected opcode in TertiaryCPInstruction: " + inst);
        }
        boolean isExpand = opcode.equalsIgnoreCase(Opcodes.CTABLEEXPAND.toString());
        CPOperand in1 = new CPOperand(parts[1]);
        CPOperand in2 = new CPOperand(parts[2]);
        CPOperand in3 = new CPOperand(parts[3]);
        String[] dim1Fields = parts[4].split("\u00b7");
        String[] dim2Fields = parts[5].split("\u00b7");
        CPOperand out = new CPOperand(parts[6]);
        boolean ignoreZeros = Boolean.parseBoolean(parts[7]);
        int k = Integer.parseInt(parts[8]);
        return new CtableCPInstruction(in1, in2, in3, out, dim1Fields[0], Boolean.parseBoolean(dim1Fields[1]), dim2Fields[0], Boolean.parseBoolean(dim2Fields[1]), isExpand, ignoreZeros, opcode, inst, k);
    }

    private Ctable.OperationTypes findCtableOperation() {
        Types.DataType dt1 = this.input1.getDataType();
        Types.DataType dt2 = this.input2.getDataType();
        Types.DataType dt3 = this.input3.getDataType();
        return Ctable.findCtableOperationByInputDataTypes(dt1, dt2, dt3);
    }

    @Override
    public void processInstruction(ExecutionContext ec) {
        boolean outputDimsKnown;
        MatrixBlock matBlock1 = !this._isExpand ? ec.getMatrixInput(this.input1) : null;
        MatrixBlock matBlock2 = null;
        MatrixBlock wtBlock = null;
        CTableMap resultMap = new CTableMap(LongLongDoubleHashMap.EntryType.INT);
        MatrixBlock resultBlock = null;
        Ctable.OperationTypes ctableOp = this.findCtableOperation();
        ctableOp = this._isExpand ? Ctable.OperationTypes.CTABLE_EXPAND_SCALAR_WEIGHT : ctableOp;
        long outputDim1 = ec.getScalarInput(this._outDim1).getLongValue();
        long outputDim2 = ec.getScalarInput(this._outDim2).getLongValue();
        boolean bl = outputDimsKnown = outputDim1 != -1L && outputDim2 != -1L;
        if (outputDimsKnown) {
            if (this._isExpand) {
                resultBlock = new MatrixBlock((int)outputDim1, (int)outputDim2, true);
            } else {
                int inputCols;
                int inputRows = matBlock1.getNumRows();
                boolean sparse = MatrixBlock.evalSparseFormatInMemory(outputDim1, outputDim2, inputRows * (inputCols = matBlock1.getNumColumns()));
                if (!sparse) {
                    resultBlock = new MatrixBlock((int)outputDim1, (int)outputDim2, false);
                }
            }
        }
        switch (ctableOp) {
            case CTABLE_TRANSFORM: {
                matBlock2 = ec.getMatrixInput(this.input2.getName());
                wtBlock = ec.getMatrixInput(this.input3.getName());
                matBlock1.ctableOperations(this._optr, matBlock2, (MatrixValue)wtBlock, resultMap, resultBlock);
                break;
            }
            case CTABLE_TRANSFORM_SCALAR_WEIGHT: {
                matBlock2 = ec.getMatrixInput(this.input2.getName());
                double cst1 = ec.getScalarInput(this.input3).getDoubleValue();
                matBlock1.ctableOperations(this._optr, matBlock2, cst1, this._ignoreZeros, resultMap, resultBlock);
                break;
            }
            case CTABLE_EXPAND_SCALAR_WEIGHT: {
                if (this.input1.getDataType() == Types.DataType.MATRIX) {
                    LOG.warn((Object)"rewrite for table expand not activated please fix");
                }
                matBlock2 = ec.getMatrixInput(this.input2.getName());
                double cst1 = ec.getScalarInput(this.input3).getDoubleValue();
                resultBlock = LibMatrixReorg.fusedSeqRexpand(matBlock2.getNumRows(), matBlock2, cst1, resultBlock, !outputDimsKnown, this._k);
                break;
            }
            case CTABLE_TRANSFORM_HISTOGRAM: {
                double cst1 = ec.getScalarInput(this.input2).getDoubleValue();
                double cst2 = ec.getScalarInput(this.input3).getDoubleValue();
                matBlock1.ctableOperations(this._optr, cst1, cst2, resultMap, resultBlock);
                break;
            }
            case CTABLE_TRANSFORM_WEIGHTED_HISTOGRAM: {
                wtBlock = ec.getMatrixInput(this.input3.getName());
                double cst1 = ec.getScalarInput(this.input2.getName(), this.input2.getValueType(), this.input2.isLiteral()).getDoubleValue();
                matBlock1.ctableOperations(this._optr, cst1, (MatrixValue)wtBlock, resultMap, resultBlock);
                break;
            }
            default: {
                throw new DMLRuntimeException("Encountered an invalid ctable operation (" + ctableOp + ") while executing instruction: " + this.toString());
            }
        }
        if (this.input1.getDataType() == Types.DataType.MATRIX && ctableOp != Ctable.OperationTypes.CTABLE_EXPAND_SCALAR_WEIGHT) {
            ec.releaseMatrixInput(this.input1.getName());
        }
        if (this.input2.getDataType() == Types.DataType.MATRIX) {
            ec.releaseMatrixInput(this.input2.getName());
        }
        if (this.input3.getDataType() == Types.DataType.MATRIX) {
            ec.releaseMatrixInput(this.input3.getName());
        }
        if (resultBlock == null) {
            resultBlock = outputDimsKnown ? DataConverter.convertToMatrixBlock(resultMap, (int)outputDim1, (int)outputDim2) : DataConverter.convertToMatrixBlock(resultMap);
        } else {
            resultBlock.examSparsity();
        }
        if (this.checkGuardedRepresentationChange(matBlock1, matBlock2, resultBlock)) {
            resultBlock.examSparsity();
        }
        ec.setMatrixOutput(this.output.getName(), resultBlock);
    }

    @Override
    public Pair<String, LineageItem> getLineageItem(ExecutionContext ec) {
        LineageItem[] linputs = !this._outDim1.getName().equals("-1") || !this._outDim2.getName().equals("-1") ? LineageItemUtils.getLineage(ec, this.input1, this.input2, this.input3, this._outDim1, this._outDim2) : LineageItemUtils.getLineage(ec, this.input1, this.input2, this.input3);
        return Pair.of((Object)this.output.getName(), (Object)new LineageItem(this.getOpcode(), linputs));
    }

    public CPOperand getOutDim1() {
        return this._outDim1;
    }

    public CPOperand getOutDim2() {
        return this._outDim2;
    }

    public boolean getIsExpand() {
        return this._isExpand;
    }

    public boolean getIgnoreZeros() {
        return this._ignoreZeros;
    }
}

