/*
 * Decompiled with CFR 0.152.
 */
package org.opensearch.sql.calcite.plan;

import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Set;
import java.util.function.Function;
import java.util.stream.IntStream;
import org.apache.calcite.plan.RelOptRuleCall;
import org.apache.calcite.plan.RelOptUtil;
import org.apache.calcite.plan.RelRule;
import org.apache.calcite.rel.RelNode;
import org.apache.calcite.rel.core.AggregateCall;
import org.apache.calcite.rel.core.Project;
import org.apache.calcite.rel.logical.LogicalAggregate;
import org.apache.calcite.rel.logical.LogicalProject;
import org.apache.calcite.rex.RexBuilder;
import org.apache.calcite.rex.RexCall;
import org.apache.calcite.rex.RexInputRef;
import org.apache.calcite.rex.RexLiteral;
import org.apache.calcite.rex.RexNode;
import org.apache.calcite.runtime.PairList;
import org.apache.calcite.sql.SqlKind;
import org.apache.calcite.sql.SqlOperator;
import org.apache.calcite.sql.fun.SqlStdOperatorTable;
import org.apache.calcite.tools.RelBuilder;
import org.apache.calcite.util.ImmutableBitSet;
import org.apache.calcite.util.mapping.Mappings;
import org.apache.commons.lang3.tuple.Pair;
import org.immutables.value.Value;
import org.opensearch.sql.calcite.plan.ImmutablePPLAggregateConvertRule;
import shaded.com.google.common.collect.ImmutableList;

@Value.Enclosing
public class PPLAggregateConvertRule
extends RelRule<Config> {
    protected PPLAggregateConvertRule(Config config) {
        super(config);
    }

    @Override
    public void onMatch(RelOptRuleCall call) {
        if (call.rels.length != 2) {
            throw new AssertionError((Object)String.format("The length of rels should be %s but got %s", this.operands.size(), call.rels.length));
        }
        LogicalAggregate aggregate = (LogicalAggregate)call.rel(0);
        LogicalProject project = (LogicalProject)call.rel(1);
        this.apply(call, aggregate, project);
    }

    public void apply(RelOptRuleCall call, LogicalAggregate aggregate, LogicalProject project) {
        RelBuilder relBuilder = call.builder();
        RexBuilder rexBuilder = aggregate.getCluster().getRexBuilder();
        relBuilder.push(project.getInput());
        List<AggregateCall> aggCalls = aggregate.getAggCallList();
        ArrayList<RexNode> newChildProjects = new ArrayList<RexNode>(project.getProjects());
        List<Integer> convertedAggCallArgs = aggCalls.stream().filter(aggCall -> this.isConvertableAggCall((AggregateCall)aggCall, project)).map(aggCall -> {
            RexInputRef rexRef = PPLAggregateConvertRule.getFieldAndLiteral(project.getProjects().get(aggCall.getArgList().getFirst())).getLeft();
            int ref = newChildProjects.indexOf(rexRef);
            if (ref == -1) {
                ref = newChildProjects.size();
                newChildProjects.add(rexRef);
            }
            return ref;
        }).toList();
        relBuilder.project(newChildProjects);
        RelNode newInput = relBuilder.peek();
        int convertedAggCallCnt = 0;
        int groupSetOffset = aggregate.getGroupSet().cardinality();
        ArrayList<AggregateCall> distinctAggregateCalls = new ArrayList<AggregateCall>();
        PairList<OperatorConstructor, String> newExprOnAggCall = PairList.of();
        for (int i = 0; i < aggregate.getAggCallList().size(); ++i) {
            AggregateCall aggCall2 = aggregate.getAggCallList().get(i);
            if (this.isConvertableAggCall(aggCall2, project)) {
                Function<RelNode, Function> literalConverterProvider;
                int argRef = convertedAggCallArgs.get(convertedAggCallCnt++);
                AggregateCall sumCall = AggregateCall.create(aggCall2.getParserPosition(), aggCall2.getAggregation(), aggCall2.isDistinct(), aggCall2.isApproximate(), aggCall2.ignoreNulls(), aggCall2.rexList, ImmutableList.of(Integer.valueOf(argRef)), aggCall2.filterArg, aggCall2.distinctKeys, aggCall2.collation, aggregate.getGroupCount(), newInput, null, aggCall2.getName() + "_SUM");
                int sumCallRef = this.putToDistinctAggregateCalls(distinctAggregateCalls, sumCall);
                RexCall rexCall = (RexCall)project.getProjects().get(aggCall2.getArgList().getFirst());
                if (rexCall.getOperator().kind == SqlKind.PLUS || rexCall.getOperator().kind == SqlKind.MINUS) {
                    AggregateCall countCall = AggregateCall.create(aggCall2.getParserPosition(), SqlStdOperatorTable.COUNT, aggCall2.isDistinct(), aggCall2.isApproximate(), aggCall2.ignoreNulls(), aggCall2.rexList, ImmutableList.of(Integer.valueOf(argRef)), aggCall2.filterArg, aggCall2.distinctKeys, aggCall2.collation, aggregate.getGroupCount(), newInput, null, aggCall2.getName() + "_COUNT");
                    int countCallRef = this.putToDistinctAggregateCalls(distinctAggregateCalls, countCall);
                    literalConverterProvider = input -> literal -> rexBuilder.makeCall(aggCall2.getType(), (SqlOperator)SqlStdOperatorTable.MULTIPLY, List.of(rexBuilder.makeInputRef((RelNode)input, groupSetOffset + countCallRef), literal));
                } else {
                    literalConverterProvider = input -> literal -> literal;
                }
                newExprOnAggCall.add(input -> {
                    Function<RexNode, RexNode> fieldConverter = field -> rexBuilder.makeInputRef(input, groupSetOffset + sumCallRef);
                    Function literalConverter = (Function)literalConverterProvider.apply(input);
                    List<RexNode> operands = List.of(PPLAggregateConvertRule.convertToNewOperand(rexCall.getOperands().getFirst(), fieldConverter, literalConverter), PPLAggregateConvertRule.convertToNewOperand(rexCall.getOperands().getLast(), fieldConverter, literalConverter));
                    return rexBuilder.makeCall(aggCall2.getType(), rexCall.getOperator(), operands);
                }, aggCall2.getName());
                continue;
            }
            int callRef = this.putToDistinctAggregateCalls(distinctAggregateCalls, aggCall2);
            newExprOnAggCall.add(input -> rexBuilder.makeInputRef(input, groupSetOffset + callRef), aggCall2.getName());
        }
        ImmutableBitSet newGroupSet = aggregate.getGroupSet();
        ImmutableList<ImmutableBitSet> newGroupSets = aggregate.getGroupSets();
        Set<Integer> fieldsUsed = RelOptUtil.getAllFields2(aggregate.getGroupSet(), distinctAggregateCalls);
        if (fieldsUsed.size() < newChildProjects.size()) {
            HashMap<Integer, Integer> sourceFieldToTargetFieldMap = new HashMap<Integer, Integer>();
            for (int source2 : fieldsUsed) {
                sourceFieldToTargetFieldMap.put(source2, sourceFieldToTargetFieldMap.size());
            }
            newGroupSet = aggregate.getGroupSet().permute(sourceFieldToTargetFieldMap);
            newGroupSets = ImmutableBitSet.ORDERING.immutableSortedCopy(ImmutableBitSet.permute(aggregate.getGroupSets(), sourceFieldToTargetFieldMap));
            Mappings.TargetMapping targetMapping = Mappings.target(sourceFieldToTargetFieldMap, newChildProjects.size(), fieldsUsed.size());
            ArrayList<AggregateCall> oldAggregateCalls = new ArrayList<AggregateCall>(distinctAggregateCalls);
            distinctAggregateCalls.clear();
            for (AggregateCall aggregateCall : oldAggregateCalls) {
                distinctAggregateCalls.add(aggregateCall.transform(targetMapping));
            }
            relBuilder.project(relBuilder.fields(fieldsUsed.stream().toList()));
        }
        relBuilder.aggregate(relBuilder.groupKey(newGroupSet, (Iterable<? extends ImmutableBitSet>)newGroupSets), (List<AggregateCall>)distinctAggregateCalls);
        ArrayList<RexNode> parentProjects = new ArrayList<RexNode>(relBuilder.fields(IntStream.range(0, groupSetOffset).boxed().toList()));
        parentProjects.addAll(newExprOnAggCall.transform((constructor, name) -> this.aliasMaybe(relBuilder, constructor.apply(relBuilder.peek()), (String)name)));
        relBuilder.project(parentProjects);
        call.transformTo(relBuilder.build());
    }

    private int putToDistinctAggregateCalls(List<AggregateCall> distinctAggregateCalls, AggregateCall aggCall) {
        int i = distinctAggregateCalls.indexOf(aggCall);
        if (i < 0) {
            i = distinctAggregateCalls.size();
            distinctAggregateCalls.add(aggCall);
        }
        return i;
    }

    private boolean isConvertableAggCall(AggregateCall aggCall, Project project) {
        return aggCall.getAggregation().getKind() == SqlKind.SUM && Config.isCallWithLiteral(project.getProjects().get(aggCall.getArgList().getFirst()));
    }

    private static Pair<RexInputRef, RexLiteral> getFieldAndLiteral(RexNode node) {
        RexCall call = (RexCall)node;
        RexNode arg1 = call.getOperands().getFirst();
        RexNode arg2 = call.getOperands().getLast();
        return arg1.getKind() == SqlKind.INPUT_REF ? Pair.of((RexInputRef)arg1, (RexLiteral)arg2) : Pair.of((RexInputRef)arg2, (RexLiteral)arg1);
    }

    private static RexNode convertToNewOperand(RexNode operand, Function<RexNode, RexNode> fieldConverter, Function<RexNode, RexNode> literalConverter) {
        if (operand.getKind() == SqlKind.INPUT_REF) {
            return fieldConverter.apply(operand);
        }
        return literalConverter.apply(operand);
    }

    private RexNode aliasMaybe(RelBuilder builder, RexNode node, String alias) {
        return alias == null ? node : builder.alias(node, alias);
    }

    static interface OperatorConstructor {
        public RexNode apply(RelNode var1);
    }

    @Value.Immutable
    public static interface Config
    extends RelRule.Config {
        public static final Config SUM_CONVERTER = ImmutablePPLAggregateConvertRule.Config.builder().build().withOperandSupplier(b0 -> b0.operand(LogicalAggregate.class).predicate(Config::containsSumAggCall).oneInput(b1 -> b1.operand(LogicalProject.class).predicate(Config::containsCallWithNumber).anyInputs()));
        public static final List<SqlKind> CONVERTABLE_FUNCTIONS = List.of(SqlKind.PLUS, SqlKind.MINUS, SqlKind.TIMES);

        public static boolean containsSumAggCall(LogicalAggregate aggregate) {
            return aggregate.getAggCallList().stream().anyMatch(aggCall -> aggCall.getAggregation().getKind() == SqlKind.SUM);
        }

        public static boolean containsCallWithNumber(LogicalProject project) {
            return project.getProjects().stream().anyMatch(Config::isCallWithLiteral);
        }

        private static boolean isCallWithLiteral(RexNode node) {
            if (CONVERTABLE_FUNCTIONS.contains((Object)node.getKind()) && node instanceof RexCall) {
                RexCall call = (RexCall)node;
                RexNode arg1 = call.getOperands().getFirst();
                RexNode arg2 = call.getOperands().getLast();
                return arg1.getKind() == SqlKind.INPUT_REF && arg2.getKind() == SqlKind.LITERAL || arg1.getKind() == SqlKind.LITERAL && arg2.getKind() == SqlKind.INPUT_REF;
            }
            return false;
        }

        @Override
        default public PPLAggregateConvertRule toRule() {
            return new PPLAggregateConvertRule(this);
        }
    }
}

