/*
 * Decompiled with CFR 0.152.
 */
package org.apache.calcite.rel.rules;

import java.math.BigDecimal;
import java.util.ArrayList;
import java.util.List;
import org.apache.calcite.plan.RelOptCluster;
import org.apache.calcite.plan.RelOptRuleCall;
import org.apache.calcite.plan.RelRule;
import org.apache.calcite.rel.RelCollations;
import org.apache.calcite.rel.core.Aggregate;
import org.apache.calcite.rel.core.AggregateCall;
import org.apache.calcite.rel.core.Project;
import org.apache.calcite.rel.rules.ImmutableAggregateCaseToFilterRule;
import org.apache.calcite.rel.rules.TransformationRule;
import org.apache.calcite.rel.type.RelDataType;
import org.apache.calcite.rel.type.RelDataTypeFactory;
import org.apache.calcite.rex.RexBuilder;
import org.apache.calcite.rex.RexCall;
import org.apache.calcite.rex.RexLiteral;
import org.apache.calcite.rex.RexNode;
import org.apache.calcite.sql.SqlKind;
import org.apache.calcite.sql.SqlOperator;
import org.apache.calcite.sql.SqlPostfixOperator;
import org.apache.calcite.sql.fun.SqlStdOperatorTable;
import org.apache.calcite.sql.parser.SqlParserPos;
import org.apache.calcite.sql.type.SqlTypeName;
import org.apache.calcite.tools.RelBuilder;
import org.apache.calcite.tools.RelBuilderFactory;
import org.apache.calcite.util.ImmutableBitSet;
import org.checkerframework.checker.nullness.qual.Nullable;
import org.immutables.value.Value;
import shaded.com.google.common.collect.ImmutableList;

@Value.Enclosing
public class AggregateCaseToFilterRule
extends RelRule<Config>
implements TransformationRule {
    protected AggregateCaseToFilterRule(Config config) {
        super(config);
    }

    @Deprecated
    protected AggregateCaseToFilterRule(RelBuilderFactory relBuilderFactory, String description2) {
        this(Config.DEFAULT.withRelBuilderFactory(relBuilderFactory).withDescription(description2).as(Config.class));
    }

    @Override
    public boolean matches(RelOptRuleCall call) {
        Aggregate aggregate = (Aggregate)call.rel(0);
        Project project = (Project)call.rel(1);
        for (AggregateCall aggregateCall : aggregate.getAggCallList()) {
            int singleArg = AggregateCaseToFilterRule.soleArgument(aggregateCall);
            if (singleArg < 0 || !AggregateCaseToFilterRule.isThreeArgCase(project.getProjects().get(singleArg))) continue;
            return true;
        }
        return false;
    }

    @Override
    public void onMatch(RelOptRuleCall call) {
        Aggregate aggregate = (Aggregate)call.rel(0);
        Project project = (Project)call.rel(1);
        ArrayList<AggregateCall> newCalls = new ArrayList<AggregateCall>(aggregate.getAggCallList().size());
        ArrayList<RexNode> newProjects = new ArrayList<RexNode>(project.getProjects());
        for (AggregateCall aggregateCall : aggregate.getAggCallList()) {
            AggregateCall newCall = AggregateCaseToFilterRule.transform(aggregateCall, project, newProjects);
            if (newCall == null) {
                newCalls.add(aggregateCall);
                continue;
            }
            newCalls.add(newCall);
        }
        if (newCalls.equals(aggregate.getAggCallList())) {
            return;
        }
        RelBuilder relBuilder = call.builder().push(project.getInput()).project(newProjects);
        RelBuilder.GroupKey groupKey = relBuilder.groupKey(aggregate.getGroupSet(), (Iterable<? extends ImmutableBitSet>)aggregate.getGroupSets());
        relBuilder.aggregate(groupKey, (List<AggregateCall>)newCalls).convert(aggregate.getRowType(), false);
        call.transformTo(relBuilder.build());
        call.getPlanner().prune(aggregate);
    }

    private static @Nullable AggregateCall transform(AggregateCall call, Project project, List<RexNode> newProjects) {
        int singleArg = AggregateCaseToFilterRule.soleArgument(call);
        if (singleArg < 0) {
            return null;
        }
        RexNode rexNode = project.getProjects().get(singleArg);
        if (!AggregateCaseToFilterRule.isThreeArgCase(rexNode)) {
            return null;
        }
        RelOptCluster cluster = project.getCluster();
        RexBuilder rexBuilder = cluster.getRexBuilder();
        RexCall caseCall = (RexCall)rexNode;
        boolean flip = RexLiteral.isNullLiteral((RexNode)caseCall.operands.get(1)) && !RexLiteral.isNullLiteral((RexNode)caseCall.operands.get(2));
        RexNode arg1 = (RexNode)caseCall.operands.get(flip ? 2 : 1);
        RexNode arg2 = (RexNode)caseCall.operands.get(flip ? 1 : 2);
        SqlPostfixOperator op = flip ? SqlStdOperatorTable.IS_NOT_TRUE : SqlStdOperatorTable.IS_TRUE;
        RexNode filterFromCase = rexBuilder.makeCall((SqlOperator)op, (RexNode)caseCall.operands.get(0));
        RexNode filter = call.filterArg >= 0 ? rexBuilder.makeCall((SqlOperator)SqlStdOperatorTable.AND, project.getProjects().get(call.filterArg), filterFromCase) : filterFromCase;
        SqlKind kind = call.getAggregation().getKind();
        if (call.isDistinct()) {
            if (kind == SqlKind.COUNT && RexLiteral.isNullLiteral(arg2)) {
                newProjects.add(arg1);
                newProjects.add(filter);
                return AggregateCall.create(call.getParserPosition(), SqlStdOperatorTable.COUNT, true, false, false, call.rexList, ImmutableList.of(Integer.valueOf(newProjects.size() - 2)), newProjects.size() - 1, null, RelCollations.EMPTY, call.getType(), call.getName());
            }
            return null;
        }
        SqlParserPos pos = call.getParserPosition();
        if (kind == SqlKind.COUNT && arg1.isA(SqlKind.LITERAL) && !RexLiteral.isNullLiteral(arg1) && RexLiteral.isNullLiteral(arg2)) {
            newProjects.add(filter);
            return AggregateCall.create(pos, SqlStdOperatorTable.COUNT, false, false, false, call.rexList, ImmutableList.of(), newProjects.size() - 1, null, RelCollations.EMPTY, call.getType(), call.getName());
        }
        if (kind == SqlKind.SUM0 && AggregateCaseToFilterRule.isIntLiteral(arg1, BigDecimal.ONE) && AggregateCaseToFilterRule.isIntLiteral(arg2, BigDecimal.ZERO)) {
            newProjects.add(filter);
            RelDataTypeFactory typeFactory = cluster.getTypeFactory();
            RelDataType dataType = typeFactory.createTypeWithNullability(typeFactory.createSqlType(SqlTypeName.BIGINT), false);
            return AggregateCall.create(pos, SqlStdOperatorTable.COUNT, false, false, false, call.rexList, ImmutableList.of(), newProjects.size() - 1, null, RelCollations.EMPTY, dataType, call.getName());
        }
        if (RexLiteral.isNullLiteral(arg2) && call.getAggregation().allowsFilter() || kind == SqlKind.SUM0 && AggregateCaseToFilterRule.isIntLiteral(arg2, BigDecimal.ZERO)) {
            newProjects.add(arg1);
            newProjects.add(filter);
            return AggregateCall.create(pos, call.getAggregation(), false, false, false, call.rexList, ImmutableList.of(Integer.valueOf(newProjects.size() - 2)), newProjects.size() - 1, null, RelCollations.EMPTY, call.getType(), call.getName());
        }
        return null;
    }

    private static int soleArgument(AggregateCall aggregateCall) {
        return aggregateCall.getArgList().size() == 1 ? aggregateCall.getArgList().get(0) : -1;
    }

    private static boolean isThreeArgCase(RexNode rexNode) {
        return rexNode.getKind() == SqlKind.CASE && ((RexCall)rexNode).operands.size() == 3;
    }

    private static boolean isIntLiteral(RexNode rexNode, BigDecimal value) {
        return rexNode instanceof RexLiteral && SqlTypeName.INT_TYPES.contains((Object)rexNode.getType().getSqlTypeName()) && value.equals(((RexLiteral)rexNode).getValueAs(BigDecimal.class));
    }

    @Value.Immutable
    public static interface Config
    extends RelRule.Config {
        public static final Config DEFAULT = ImmutableAggregateCaseToFilterRule.Config.of().withOperandSupplier(b0 -> b0.operand(Aggregate.class).oneInput(b1 -> b1.operand(Project.class).anyInputs()));

        @Override
        default public AggregateCaseToFilterRule toRule() {
            return new AggregateCaseToFilterRule(this);
        }
    }
}

