/*
 * Decompiled with CFR 0.152.
 */
package org.jkiss.dbeaver.model.stm;

import java.util.ArrayList;
import java.util.Collection;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.LinkedList;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.stream.Collectors;
import java.util.stream.Stream;
import org.antlr.v4.runtime.atn.ATN;
import org.antlr.v4.runtime.atn.ATNState;
import org.antlr.v4.runtime.atn.AtomTransition;
import org.antlr.v4.runtime.atn.RangeTransition;
import org.antlr.v4.runtime.atn.SetTransition;
import org.antlr.v4.runtime.atn.Transition;
import org.antlr.v4.runtime.misc.IntSet;
import org.antlr.v4.runtime.misc.Interval;
import org.antlr.v4.runtime.misc.IntervalSet;
import org.antlr.v4.runtime.tree.RuleNode;
import org.antlr.v4.runtime.tree.TerminalNode;
import org.jkiss.code.NotNull;
import org.jkiss.code.Nullable;
import org.jkiss.dbeaver.model.impl.sql.BasicSQLDialect;
import org.jkiss.dbeaver.model.lsm.sql.impl.syntax.SQLStandardParser;
import org.jkiss.dbeaver.model.stm.STMTreeNode;
import org.jkiss.dbeaver.model.stm.STMTreeRuleNode;
import org.jkiss.dbeaver.model.stm.STMTreeTermNode;
import org.jkiss.dbeaver.utils.ListNode;
import org.jkiss.utils.Pair;

public class LSMInspections {
    @NotNull
    private static final Set<String> knownReservedWords = new HashSet<String>(BasicSQLDialect.INSTANCE.getReservedWords());
    @NotNull
    private static final Set<Integer> reachabilityTestRules = Set.of(Integer.valueOf(43), Integer.valueOf(85), Integer.valueOf(21), Integer.valueOf(34));
    @NotNull
    private static final Set<Integer> knownReservedWordsExcludeRules = Stream.of(reachabilityTestRules, Set.of(Integer.valueOf(271), Integer.valueOf(269), Integer.valueOf(268), Integer.valueOf(261), Integer.valueOf(93), Integer.valueOf(270))).flatMap(Collection::stream).collect(Collectors.toUnmodifiableSet());

    @NotNull
    private static Pair<STMTreeNode, Boolean> findChildBeforeOrAtPosition(@NotNull STMTreeNode node, int position) {
        STMTreeNode nodeBefore = null;
        Interval nodeBeforeRange = null;
        int i = 0;
        while (i < node.getChildCount()) {
            STMTreeNode cn = node.getStmChild(i);
            Interval range = cn.getRealInterval();
            if (range.a <= position && range.b >= position) {
                return Pair.of((Object)cn, (Object)true);
            }
            if (range.a >= position || nodeBeforeRange != null && nodeBeforeRange.a >= range.a) break;
            nodeBefore = cn;
            nodeBeforeRange = range;
            ++i;
        }
        return Pair.of(nodeBefore, (Object)false);
    }

    @NotNull
    public static SyntaxInspectionResult prepareOffquerySyntaxInspection() {
        ATN atn = SQLStandardParser._ATN;
        ListNode emptyStack = ListNode.of(null);
        ATNState initialState = (ATNState)atn.states.get(atn.ruleToStartState[0].stateNumber);
        return LSMInspections.inspectAbstractSyntaxAtState((ListNode<Integer>)emptyStack, initialState);
    }

    /*
     * Handled impossible loop by adding 'first' condition
     * WARNING - void declaration
     * Enabled aggressive block sorting
     */
    @Nullable
    public static SyntaxInspectionResult prepareAbstractSyntaxInspection(@NotNull STMTreeNode root, int position) {
        ATNState initialState;
        STMTreeNode subroot = root;
        ATN atn = SQLStandardParser._ATN;
        Interval range = subroot.getRealInterval();
        if (position < range.a) {
            return LSMInspections.prepareOffquerySyntaxInspection();
        }
        Pair<STMTreeNode, Boolean> p = Pair.of((Object)subroot, (Object)true);
        boolean bl = true;
        do {
            if (!bl || (bl = false) || !true) {
                subroot = (STMTreeNode)p.getFirst();
                p = LSMInspections.findChildBeforeOrAtPosition(subroot, position);
            }
            Object object = p.getFirst();
            if (!(object instanceof STMTreeTermNode)) continue;
            STMTreeTermNode sTMTreeTermNode = (STMTreeTermNode)object;
            STMTreeTermNode cfr_ignored_1 = (STMTreeTermNode)object;
            break;
        } while (p.getFirst() != null);
        if (p.getSecond() == null) {
            initialState = (ATNState)atn.states.get(subroot.getAtnState());
            return LSMInspections.inspectAbstractSyntaxAtTreeState(subroot, initialState);
        }
        if (((Boolean)p.getSecond()).booleanValue()) {
            initialState = (ATNState)atn.states.get(((STMTreeNode)p.getFirst()).getAtnState());
            return LSMInspections.inspectAbstractSyntaxAtTreeState(subroot, initialState);
        }
        STMTreeNode node = (STMTreeNode)p.getFirst();
        STMTreeNode sTMTreeNode = node;
        if (sTMTreeNode instanceof STMTreeTermNode) {
            void tn;
            STMTreeTermNode sTMTreeTermNode = (STMTreeTermNode)sTMTreeNode;
            STMTreeTermNode cfr_ignored_2 = (STMTreeTermNode)sTMTreeNode;
            initialState = ((ATNState)atn.states.get((int)tn.getAtnState())).getTransitions()[0].target;
            return LSMInspections.inspectAbstractSyntaxAtTreeState(subroot, initialState);
        }
        STMTreeNode sTMTreeNode2 = node;
        if (sTMTreeNode2 instanceof STMTreeRuleNode) {
            void rn;
            STMTreeRuleNode sTMTreeRuleNode = (STMTreeRuleNode)sTMTreeNode2;
            STMTreeRuleNode cfr_ignored_3 = (STMTreeRuleNode)sTMTreeNode2;
            initialState = atn.ruleToStopState[rn.getRuleContext().getRuleIndex()];
            return LSMInspections.inspectAbstractSyntaxAtTreeState(subroot, initialState);
        }
        STMTreeTermNode tn = LSMInspections.findLastTerm(root);
        if (tn == null) throw new IllegalStateException("TODO/WTF");
        subroot = tn;
        initialState = ((ATNState)atn.states.get((int)tn.getAtnState())).getTransitions()[0].target;
        return LSMInspections.inspectAbstractSyntaxAtTreeState(subroot, initialState);
    }

    /*
     * WARNING - void declaration
     */
    @Nullable
    private static STMTreeTermNode findLastTerm(@NotNull STMTreeNode root) {
        ListNode stack = ListNode.of((Object)root);
        while (ListNode.hasAny((ListNode)stack)) {
            STMTreeNode node = (STMTreeNode)stack.data;
            stack = stack.next;
            STMTreeNode sTMTreeNode = node;
            if (sTMTreeNode instanceof STMTreeTermNode) {
                void term;
                STMTreeTermNode sTMTreeTermNode = (STMTreeTermNode)sTMTreeNode;
                STMTreeTermNode cfr_ignored_0 = (STMTreeTermNode)sTMTreeNode;
                return term;
            }
            int i = 0;
            while (i < node.getChildCount()) {
                stack = ListNode.push((ListNode)stack, (Object)node.getStmChild(i));
                ++i;
            }
        }
        return null;
    }

    /*
     * WARNING - void declaration
     */
    @NotNull
    public static List<STMTreeTermNode> prepareTerms(@NotNull STMTreeNode root) {
        ArrayList<STMTreeTermNode> terms = new ArrayList<STMTreeTermNode>();
        ListNode stack = ListNode.of((Object)root);
        while (ListNode.hasAny((ListNode)stack)) {
            STMTreeNode node = (STMTreeNode)stack.data;
            stack = stack.next;
            STMTreeNode sTMTreeNode = node;
            if (sTMTreeNode instanceof STMTreeTermNode) {
                void term;
                STMTreeTermNode cfr_ignored_0 = (STMTreeTermNode)sTMTreeNode;
                STMTreeTermNode cfr_ignored_1 = (STMTreeTermNode)sTMTreeNode;
                terms.add((STMTreeTermNode)term);
                continue;
            }
            int i = node.getChildCount() - 1;
            while (i >= 0) {
                stack = ListNode.push((ListNode)stack, (Object)node.getStmChild(i));
                --i;
            }
        }
        return terms;
    }

    /*
     * WARNING - void declaration
     * Enabled aggressive block sorting
     */
    @Nullable
    private static SyntaxInspectionResult inspectAbstractSyntaxAtTreeState(@NotNull STMTreeNode node, @NotNull ATNState initialState) {
        void var6_6;
        Iterator iterator;
        STMTreeNode sTMTreeNode;
        ListNode stack = ListNode.of(null);
        LinkedList<void> path = new LinkedList<void>();
        STMTreeNode sTMTreeNode2 = sTMTreeNode = node instanceof TerminalNode ? node.getStmParent() : node;
        while ((iterator = var6_6) instanceof RuleNode) {
            void rn;
            RuleNode cfr_ignored_0 = (RuleNode)iterator;
            RuleNode cfr_ignored_1 = (RuleNode)iterator;
            path.addFirst(rn);
            STMTreeNode sTMTreeNode3 = var6_6.getStmParent();
        }
        for (RuleNode ruleNode : path) {
            stack = ListNode.push((ListNode)stack, (Object)ruleNode.getRuleContext().getRuleIndex());
        }
        int atnStateIndex = node.getAtnState();
        if (atnStateIndex < 0) {
            return null;
        }
        return LSMInspections.inspectAbstractSyntaxAtState((ListNode<Integer>)stack, initialState);
    }

    @NotNull
    private static SyntaxInspectionResult inspectAbstractSyntaxAtState(@NotNull ListNode<Integer> stack, @NotNull ATNState initialState) {
        HashSet<String> predictedWords = new HashSet<String>();
        HashSet<Integer> predictedTokenIds = new HashSet<Integer>();
        HashMap<Integer, Boolean> reachabilityTests = new HashMap<Integer, Boolean>(reachabilityTestRules.size());
        reachabilityTestRules.forEach(n -> {
            Boolean bl = reachabilityTests.put((Integer)n, false);
        });
        Collection<Transition> tt = LSMInspections.collectFollowingTerms(stack, initialState, knownReservedWordsExcludeRules, reachabilityTests);
        IntervalSet transitionTokens = LSMInspections.getTransitionTokens(tt);
        for (Interval interval : transitionTokens.getIntervals()) {
            int a = interval.a;
            int b = interval.b;
            int v = a;
            while (v <= b) {
                String word = SQLStandardParser.VOCABULARY.getDisplayName(v);
                if (word != null && knownReservedWords.contains(word)) {
                    predictedTokenIds.add(v);
                    predictedWords.add(word);
                }
                ++v;
            }
        }
        return new SyntaxInspectionResult(predictedTokenIds, predictedWords, reachabilityTests, (Boolean)reachabilityTests.get(43), (Boolean)reachabilityTests.get(85), (Boolean)reachabilityTests.get(21));
    }

    @NotNull
    private static IntervalSet getTransitionTokens(@NotNull Collection<Transition> transitions) {
        IntervalSet tokens = new IntervalSet(new int[0]);
        for (Transition transition : transitions) {
            switch (transition.getSerializationType()) {
                case 5: {
                    tokens.add(((AtomTransition)transition).label);
                    break;
                }
                case 2: {
                    RangeTransition t = (RangeTransition)transition;
                    tokens.add(t.from, t.to);
                    break;
                }
                case 7: {
                    tokens.addAll((IntSet)((SetTransition)transition).set);
                    break;
                }
                case 8: 
                case 9: {
                    break;
                }
                default: {
                    throw new UnsupportedOperationException("Unrecognized ATN transition type.");
                }
            }
        }
        return tokens;
    }

    @NotNull
    private static Collection<Transition> collectFollowingTerms(@NotNull ListNode<Integer> stateStack, @NotNull ATNState initialState, Set<Integer> exceptRules, @NotNull Map<Integer, Boolean> reachabilityTest) {
        HashSet<ATNState> visited = new HashSet<ATNState>();
        HashSet<Transition> results = new HashSet<Transition>();
        LinkedList<Pair> q = new LinkedList<Pair>();
        q.addLast(Pair.of((Object)initialState, stateStack));
        while (q.size() > 0) {
            Pair pair = (Pair)q.removeLast();
            ATNState state = (ATNState)pair.getFirst();
            ListNode stack = (ListNode)pair.getSecond();
            Transition[] transitionArray = state.getTransitions();
            int n = transitionArray.length;
            int n2 = 0;
            while (n2 < n) {
                Transition transition = transitionArray[n2];
                block0 : switch (transition.getSerializationType()) {
                    case 2: 
                    case 5: 
                    case 7: 
                    case 8: 
                    case 9: {
                        results.add(transition);
                        break;
                    }
                    case 1: 
                    case 3: 
                    case 4: 
                    case 6: 
                    case 10: {
                        ListNode transitionStack;
                        switch (state.getStateType()) {
                            case 7: {
                                if (stack == null || stack.data == null || stack.next == null || stack.next.data == null || transition.target.ruleIndex != (Integer)stack.next.data) break block0;
                                transitionStack = stack.next;
                                break;
                            }
                            case 2: {
                                reachabilityTest.computeIfPresent(state.ruleIndex, (k, v) -> true);
                                if (exceptRules.contains(state.ruleIndex)) break block0;
                                transitionStack = ListNode.push((ListNode)stack, (Object)state.ruleIndex);
                                break;
                            }
                            default: {
                                transitionStack = stack;
                            }
                        }
                        if (!visited.add(transition.target)) break;
                        q.addLast(Pair.of((Object)transition.target, (Object)transitionStack));
                        break;
                    }
                    default: {
                        throw new UnsupportedOperationException("Unrecognized ATN transition type.");
                    }
                }
                ++n2;
            }
        }
        return results;
    }

    public static class SyntaxInspectionResult {
        @NotNull
        public final Set<Integer> predictedTokensIds;
        @NotNull
        public final Set<String> predictedWords;
        @NotNull
        private final Map<Integer, Boolean> reachabilityTests;
        public final boolean expectingTableReference;
        public final boolean expectingColumnReference;
        public final boolean expectingIdentifier;

        public SyntaxInspectionResult(@NotNull Set<Integer> predictedTokenIds, @NotNull Set<String> predictedWords, @NotNull Map<Integer, Boolean> reachabilityTests, boolean expectingTableReference, boolean expectingColumnReference, boolean expectingIdentifier) {
            this.predictedTokensIds = predictedTokenIds;
            this.predictedWords = predictedWords;
            this.reachabilityTests = reachabilityTests;
            this.expectingTableReference = expectingTableReference;
            this.expectingColumnReference = expectingColumnReference;
            this.expectingIdentifier = expectingIdentifier;
        }

        @NotNull
        public Map<String, Boolean> getReachabilityByName() {
            return this.reachabilityTests.entrySet().stream().collect(Collectors.toMap(e -> SQLStandardParser.ruleNames[(Integer)e.getKey()], Map.Entry::getValue));
        }
    }
}

