/*
 * Decompiled with CFR 0.152.
 */
package com.dbeaver.model.ai.impl;

import com.dbeaver.model.ai.AIAssistantPro;
import com.dbeaver.model.ai.AIChatResponseConsumer;
import com.dbeaver.model.ai.audio.AIAudioStream;
import com.dbeaver.model.ai.audio.AIEngineAudio;
import com.dbeaver.model.ai.audio.AITranscriptResult;
import com.dbeaver.model.ai.impl.AIEngineRequestFactoryPro;
import com.dbeaver.model.ai.rag.AIEmbeddingGenerator;
import com.dbeaver.model.ai.rag.AIScopeSelectorFactory;
import com.dbeaver.model.qm.QMService;
import com.dbeaver.model.qm.QMServiceProvider;
import java.util.ArrayList;
import java.util.Collection;
import java.util.List;
import java.util.concurrent.atomic.AtomicInteger;
import org.jkiss.code.NotNull;
import org.jkiss.code.Nullable;
import org.jkiss.dbeaver.DBException;
import org.jkiss.dbeaver.Log;
import org.jkiss.dbeaver.model.ai.AIAssistantResponse;
import org.jkiss.dbeaver.model.ai.AIFunctionContext;
import org.jkiss.dbeaver.model.ai.AIFunctionResult;
import org.jkiss.dbeaver.model.ai.AIMessage;
import org.jkiss.dbeaver.model.ai.AIPromptGenerator;
import org.jkiss.dbeaver.model.ai.engine.AIDatabaseContext;
import org.jkiss.dbeaver.model.ai.engine.AIEngine;
import org.jkiss.dbeaver.model.ai.engine.AIEngineRequest;
import org.jkiss.dbeaver.model.ai.engine.AIEngineResponse;
import org.jkiss.dbeaver.model.ai.engine.AIEngineResponseChunk;
import org.jkiss.dbeaver.model.ai.engine.AIEngineResponseConsumer;
import org.jkiss.dbeaver.model.ai.engine.AIFunctionCall;
import org.jkiss.dbeaver.model.ai.impl.AIAssistantImpl;
import org.jkiss.dbeaver.model.ai.impl.AIDatabaseSnapshotService;
import org.jkiss.dbeaver.model.ai.impl.AIEngineRequestFactory;
import org.jkiss.dbeaver.model.ai.impl.DummyTokenCounter;
import org.jkiss.dbeaver.model.ai.impl.TokenCounter;
import org.jkiss.dbeaver.model.ai.registry.AIEngineDescriptor;
import org.jkiss.dbeaver.model.ai.registry.AISettingsManager;
import org.jkiss.dbeaver.model.app.DBPWorkspace;
import org.jkiss.dbeaver.model.logical.DBSLogicalDataSource;
import org.jkiss.dbeaver.model.runtime.DBRProgressMonitor;
import org.jkiss.dbeaver.runtime.DBWorkbench;
import org.jkiss.utils.CommonUtils;

public class AIAssistantProImpl
extends AIAssistantImpl
implements AIAssistantPro {
    private static final Log log = Log.getLog(AIAssistantProImpl.class);
    private final AIScopeSelectorFactory scopeSelectorFactory;

    public AIAssistantProImpl(@NotNull DBPWorkspace workspace) {
        super(workspace);
        this.scopeSelectorFactory = new AIScopeSelectorFactory(workspace);
    }

    protected AIEngineRequestFactory createRequestFactory() {
        return new AIEngineRequestFactoryPro(new AIDatabaseSnapshotService(), (TokenCounter)new DummyTokenCounter());
    }

    @NotNull
    public AIAssistantResponse generateText(@NotNull DBRProgressMonitor monitor, @Nullable AIDatabaseContext context, @NotNull AIPromptGenerator systemGenerator, @NotNull List<AIMessage> messages) throws DBException {
        AIAssistantResponse result = super.generateText(monitor, this.selectRelevantContext(monitor, context, messages), systemGenerator, messages);
        DBPWorkspace dBPWorkspace = this.workspace;
        if (dBPWorkspace instanceof QMServiceProvider) {
            QMServiceProvider qsp = (QMServiceProvider)dBPWorkspace;
            if (result.isText()) {
                QMService qmService = qsp.getActiveQueryManagerService();
                String qmSid = qsp.getQueryManagerSessionId();
                if (qmService != null && qmSid != null && context != null) {
                    DBSLogicalDataSource dataSource = context.getDataSource();
                    String text = result.getText();
                    qmService.saveSmartCompletionHistory(qmSid, dataSource.getDataSourceContainer().getId(), dataSource.getName(), dataSource.getCurrentCatalog(), dataSource.getCurrentSchema(), text, text);
                }
            }
        }
        return result;
    }

    @Override
    public void generateTextStream(@NotNull DBRProgressMonitor monitor, @Nullable AIDatabaseContext context, @NotNull AIPromptGenerator systemGenerator, @NotNull List<AIMessage> messages, @NotNull AIChatResponseConsumer chatListener) throws DBException {
        this.checkAiEnablement();
        try {
            AIEngineDescriptor engineDescriptor = this.getEngineDescriptor();
            AIEngine engine = engineDescriptor.createEngineInstance();
            AIFunctionContext functionContext = new AIFunctionContext(monitor, this.selectRelevantContext(monitor, context, messages), systemGenerator, messages);
            AtomicInteger callCounter = new AtomicInteger(0);
            this.executeEngineStreamRequest(monitor, functionContext, systemGenerator, messages, chatListener, engine, engineDescriptor, callCounter);
        }
        catch (Exception e) {
            if (e instanceof DBException) {
                DBException dbe = (DBException)((Object)e);
                throw dbe;
            }
            throw new DBException("Error requesting completion stream", (Throwable)e);
        }
    }

    private void executeEngineStreamRequest(final @NotNull DBRProgressMonitor monitor, final @NotNull AIFunctionContext functionContext, final @NotNull AIPromptGenerator systemGenerator, final @NotNull List<AIMessage> messages, final @NotNull AIChatResponseConsumer chatListener, final @NotNull AIEngine<?> engine, final @NotNull AIEngineDescriptor engineDescriptor, final @NotNull AtomicInteger callCounter) throws DBException {
        AIEngineRequest request = this.requestFactory.build(monitor, engine, engineDescriptor, systemGenerator, functionContext.getContext(), messages);
        final boolean loggingEnabled = this.isLoggingEnabled();
        if (loggingEnabled) {
            log.debug((Object)("AI chat request:\n" + CommonUtils.addTextIndent((String)request.getMessages().toString(), (String)"\t")));
        }
        AIEngineResponseConsumer listener = new AIEngineResponseConsumer(){
            private boolean closed = false;

            public void nextChunk(@NotNull AIEngineResponseChunk chunk) {
                block8: {
                    if (chunk.getFunctionCall() != null) {
                        try {
                            if (callCounter.incrementAndGet() > 5) {
                                chatListener.error(new DBException("Too many AI function calls (5)"));
                                chatListener.close();
                                return;
                            }
                            AIFunctionCall functionCall = chunk.getFunctionCall();
                            functionContext.addFunctionCall(functionCall);
                            AIFunctionResult result = AIAssistantProImpl.this.callFunction(functionContext, functionCall);
                            chatListener.nextFunctionCall(functionCall, result);
                            if (result.getType() == AIFunctionResult.FunctionType.INFORMATION) {
                                String payload = CommonUtils.toString((Object)result.getValue());
                                ArrayList<AIMessage> newMessages = new ArrayList<AIMessage>(messages);
                                newMessages.add(AIMessage.userMessage((String)payload));
                                AIAssistantProImpl.this.executeEngineStreamRequest(monitor, functionContext, systemGenerator, newMessages, chatListener, engine, engineDescriptor, callCounter);
                                break block8;
                            }
                            String functionResult = CommonUtils.toString((Object)result.getValue());
                            chatListener.nextMessageChunk(functionResult);
                            chatListener.close();
                        }
                        catch (Exception e) {
                            chatListener.error(e);
                        }
                    } else {
                        List choices = chunk.getChoices();
                        if (!CommonUtils.isEmpty((Collection)choices)) {
                            if (loggingEnabled) {
                                System.err.print(choices);
                            }
                            chatListener.nextMessageChunk((String)choices.getFirst());
                        }
                    }
                }
            }

            public void error(@NotNull Throwable throwable) {
                if (loggingEnabled) {
                    throwable.printStackTrace(System.err);
                }
                chatListener.error(throwable);
            }

            public void close() {
                if (this.closed) {
                    return;
                }
                chatListener.close();
                try {
                    engine.close();
                }
                catch (Exception e) {
                    log.error((Object)e);
                }
                this.closed = true;
            }
        };
        boolean[] isTruncated = new boolean[]{request.wasPromptTruncated()};
        AIAssistantProImpl.callWithRetry(() -> {
            if (isTruncated[0]) {
                blArray[0] = false;
                chatListener.warning("Context description was truncated. To reduce context size you [could specify custom database scope](https://dbeaver.com/docs/dbeaver/AI-chat/#defining-the-scope).");
            }
            if (AIAssistantProImpl.useStreamMode()) {
                engine.requestCompletionStream(monitor, request, listener);
            } else {
                AIEngineResponse response = engine.requestCompletion(monitor, request);
                if (response.getFunctionCall() != null) {
                    listener.nextChunk(new AIEngineResponseChunk(response.getFunctionCall()));
                } else if (response.getVariants() != null) {
                    listener.nextChunk(new AIEngineResponseChunk(response.getVariants()));
                } else {
                    listener.error((Throwable)new DBException("Empty response"));
                }
                listener.close();
            }
            return null;
        });
    }

    @Override
    public boolean supportTranscription() throws DBException {
        return this.isEngineSupports(AIEngineAudio.class);
    }

    @Override
    @NotNull
    public AITranscriptResult createSpeechTranscription(@NotNull AIAudioStream audio) throws DBException {
        Throwable throwable = null;
        Object var3_4 = null;
        try (AIEngine engine = this.createEngine();){
            if (engine instanceof AIEngineAudio) {
                AIEngineAudio transcriber = (AIEngineAudio)engine;
                return transcriber.createSpeechTranscription(audio);
            }
        }
        catch (Throwable throwable2) {
            if (throwable == null) {
                throwable = throwable2;
            } else if (throwable != throwable2) {
                throwable.addSuppressed(throwable2);
            }
            throw throwable;
        }
        throw new DBException("Transcription is not supported");
    }

    @Nullable
    private AIDatabaseContext selectRelevantContext(@NotNull DBRProgressMonitor monitor, @Nullable AIDatabaseContext originalContext, @NotNull List<AIMessage> messages) {
        if (!AIAssistantProImpl.isRagEnabled() || originalContext == null) {
            return originalContext;
        }
        try {
            Throwable throwable = null;
            Object var5_7 = null;
            try (AIEngine engine = this.createEngine();){
                if (engine instanceof AIEmbeddingGenerator) {
                    AIEmbeddingGenerator generator = (AIEmbeddingGenerator)engine;
                    return this.scopeSelectorFactory.getScopeSelector().select(monitor, generator, originalContext, messages);
                }
            }
            catch (Throwable throwable2) {
                if (throwable == null) {
                    throwable = throwable2;
                } else if (throwable != throwable2) {
                    throwable.addSuppressed(throwable2);
                }
                throw throwable;
            }
        }
        catch (Exception e) {
            log.error((Object)"Error selecting relevant context", (Throwable)e);
        }
        return originalContext;
    }

    private static boolean useStreamMode() {
        return DBWorkbench.getPlatform().getPreferenceStore().getBoolean("ai.useStreamMode");
    }

    private static boolean isRagEnabled() {
        return (Boolean)AISettingsManager.getInstance().getSettings().getProperty("ai.rag.enabled", (Object)false);
    }
}

