/*
 * Decompiled with CFR 0.152.
 */
package org.opensearch.sql.opensearch.planner.rules;

import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.function.Predicate;
import org.apache.calcite.plan.RelOptRuleCall;
import org.apache.calcite.rel.RelFieldCollation;
import org.apache.calcite.rel.RelNode;
import org.apache.calcite.rel.core.Project;
import org.apache.calcite.rel.core.Sort;
import org.apache.calcite.rel.logical.LogicalSort;
import org.apache.calcite.rex.RexInputRef;
import org.apache.calcite.rex.RexLiteral;
import org.apache.calcite.rex.RexNode;
import org.apache.calcite.rex.RexUtil;
import org.apache.calcite.sql.type.SqlTypeName;
import org.apache.commons.lang3.StringUtils;
import org.apache.commons.lang3.tuple.Pair;
import org.immutables.value.Value;
import org.opensearch.sql.calcite.plan.rule.OpenSearchRuleConfig;
import org.opensearch.sql.calcite.utils.PlanUtils;
import org.opensearch.sql.opensearch.planner.rules.ImmutableSortExprIndexScanRule;
import org.opensearch.sql.opensearch.planner.rules.InterruptibleRelRule;
import org.opensearch.sql.opensearch.planner.rules.LimitIndexScanRule;
import org.opensearch.sql.opensearch.storage.scan.AbstractCalciteIndexScan;
import org.opensearch.sql.opensearch.storage.scan.CalciteLogicalIndexScan;
import org.opensearch.sql.opensearch.storage.scan.context.SortExprDigest;
import org.opensearch.sql.opensearch.util.OpenSearchRelOptUtil;

@Value.Enclosing
public class SortExprIndexScanRule
extends InterruptibleRelRule<Config> {
    protected SortExprIndexScanRule(Config config) {
        super(config);
    }

    @Override
    protected void onMatchImpl(RelOptRuleCall call) {
        Sort sort = (Sort)call.rel(0);
        Project project = (Project)call.rel(1);
        AbstractCalciteIndexScan scan = (AbstractCalciteIndexScan)call.rel(2);
        if (sort.getConvention() != project.getConvention() || project.getConvention() != scan.getConvention()) {
            return;
        }
        if (!PlanUtils.sortReferencesExpr((Sort)sort, (Project)project)) {
            return;
        }
        boolean allSimpleExprs = true;
        HashMap<Integer, Optional<Pair<Integer, Boolean>>> orderEquivInfoMap = new HashMap<Integer, Optional<Pair<Integer, Boolean>>>();
        for (RelFieldCollation relFieldCollation : sort.getCollation().getFieldCollations()) {
            Optional<Pair<Integer, Boolean>> orderEquivInfo = OpenSearchRelOptUtil.getOrderEquivalentInputInfo((RexNode)project.getProjects().get(relFieldCollation.getFieldIndex()));
            orderEquivInfoMap.put(relFieldCollation.getFieldIndex(), orderEquivInfo);
            if (!allSimpleExprs || !orderEquivInfo.isEmpty()) continue;
            allSimpleExprs = false;
        }
        if (allSimpleExprs) {
            return;
        }
        boolean scanProvidesRequiredCollation = OpenSearchRelOptUtil.canScanProvideSortCollation(scan, project, sort.collation, orderEquivInfoMap);
        if (scan.isTopKPushed() && !scanProvidesRequiredCollation) {
            return;
        }
        List<SortExprDigest> sortExprDigests = this.extractSortExpressionInfos(sort, project, scan, orderEquivInfoMap);
        if (sortExprDigests.isEmpty() || !this.canPushDownSortExpressionInfos(sortExprDigests)) {
            return;
        }
        AbstractCalciteIndexScan newScan = scan.isTopKPushed() && scanProvidesRequiredCollation ? scan.copy() : scan.pushdownSortExpr(sortExprDigests);
        Integer limitValue = LimitIndexScanRule.extractLimitValue(sort.fetch);
        Integer offsetValue = LimitIndexScanRule.extractOffsetValue(sort.offset);
        if (newScan instanceof CalciteLogicalIndexScan && limitValue != null && offsetValue != null) {
            newScan = (CalciteLogicalIndexScan)((CalciteLogicalIndexScan)newScan).pushDownLimit((LogicalSort)sort, limitValue, offsetValue);
        }
        if (newScan != null) {
            Project newProject = project.copy(sort.getTraitSet(), (RelNode)newScan, project.getProjects(), project.getRowType());
            call.transformTo((RelNode)newProject);
            PlanUtils.tryPruneRelNodes((RelOptRuleCall)call);
        }
    }

    private List<SortExprDigest> extractSortExpressionInfos(Sort sort, Project project, AbstractCalciteIndexScan scan, Map<Integer, Optional<Pair<Integer, Boolean>>> orderEquivInfoMap) {
        ArrayList<SortExprDigest> sortExprDigests = new ArrayList<SortExprDigest>();
        List sortKeys = sort.getSortExps();
        List collations = sort.getCollation().getFieldCollations();
        for (int i = 0; i < sortKeys.size(); ++i) {
            RelFieldCollation collation;
            RexNode sortKey = (RexNode)sortKeys.get(i);
            SortExprDigest info = this.mapThroughProject(sortKey, project, scan, collation = (RelFieldCollation)collations.get(i), orderEquivInfoMap);
            if (info == null) continue;
            sortExprDigests.add(info);
        }
        return sortExprDigests;
    }

    private SortExprDigest mapThroughProject(RexNode sortKey, Project project, AbstractCalciteIndexScan scan, RelFieldCollation collation, Map<Integer, Optional<Pair<Integer, Boolean>>> orderEquivInfoMap) {
        assert (sortKey instanceof RexInputRef) : "sort key should be always RexInputRef";
        RexInputRef inputRef = (RexInputRef)sortKey;
        RexNode projectExpression = (RexNode)project.getProjects().get(inputRef.getIndex());
        List scanFieldNames = scan.getRowType().getFieldNames();
        Optional<Pair<Integer, Boolean>> orderEquivalentInfo = orderEquivInfoMap.get(collation.getFieldIndex());
        if (orderEquivalentInfo.isPresent()) {
            RelFieldCollation.Direction equivalentDirection = (Boolean)orderEquivalentInfo.get().getRight() != false ? collation.getDirection().reverse() : collation.getDirection();
            return new SortExprDigest((String)scanFieldNames.get((Integer)orderEquivalentInfo.get().getLeft()), equivalentDirection, collation.nullDirection);
        }
        return new SortExprDigest(projectExpression, collation.getDirection(), collation.nullDirection);
    }

    private boolean canPushDownSortExpressionInfos(List<SortExprDigest> sortExprDigests) {
        for (SortExprDigest info : sortExprDigests) {
            RexNode expr = info.getExpression();
            if (expr == null && StringUtils.isEmpty((CharSequence)info.getFieldName())) {
                return false;
            }
            if (info.isSimpleFieldReference() || !(expr instanceof RexLiteral) && !RexUtil.isConstant((RexNode)expr) && this.isSupportedSortScriptType(expr.getType().getSqlTypeName())) continue;
            return false;
        }
        return true;
    }

    private boolean isSupportedSortScriptType(SqlTypeName sqlTypeName) {
        return SqlTypeName.CHAR_TYPES.contains(sqlTypeName) || SqlTypeName.APPROX_TYPES.contains(sqlTypeName) || SqlTypeName.INT_TYPES.contains(sqlTypeName);
    }

    @Value.Immutable
    public static interface Config
    extends OpenSearchRuleConfig {
        public static final Config DEFAULT = ImmutableSortExprIndexScanRule.Config.builder().build().withOperandSupplier(b0 -> b0.operand(Sort.class).predicate(sort -> !sort.collation.getFieldCollations().isEmpty()).oneInput(b1 -> b1.operand(Project.class).predicate(Predicate.not(Project::containsOver)).oneInput(b2 -> b2.operand(AbstractCalciteIndexScan.class).predicate(AbstractCalciteIndexScan::noAggregatePushed).noInputs())));

        default public SortExprIndexScanRule toRule() {
            return new SortExprIndexScanRule(this);
        }
    }
}

