/*
 * Decompiled with CFR 0.152.
 */
package org.jkiss.dbeaver.model.ai.gpt3;

import com.theokanning.openai.OpenAiService;
import com.theokanning.openai.completion.CompletionChoice;
import com.theokanning.openai.completion.CompletionRequest;
import java.time.Duration;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import org.jkiss.code.NotNull;
import org.jkiss.code.Nullable;
import org.jkiss.dbeaver.DBException;
import org.jkiss.dbeaver.Log;
import org.jkiss.dbeaver.model.DBPDataSourceContainer;
import org.jkiss.dbeaver.model.DBUtils;
import org.jkiss.dbeaver.model.ai.AIEngineSettings;
import org.jkiss.dbeaver.model.ai.AISettings;
import org.jkiss.dbeaver.model.ai.completion.DAICompletionEngine;
import org.jkiss.dbeaver.model.ai.completion.DAICompletionRequest;
import org.jkiss.dbeaver.model.ai.completion.DAICompletionResponse;
import org.jkiss.dbeaver.model.ai.completion.DAICompletionScope;
import org.jkiss.dbeaver.model.ai.gpt3.GPTModel;
import org.jkiss.dbeaver.model.exec.DBCExecutionContext;
import org.jkiss.dbeaver.model.exec.DBCExecutionContextDefaults;
import org.jkiss.dbeaver.model.logical.DBSLogicalDataSource;
import org.jkiss.dbeaver.model.navigator.DBNUtils;
import org.jkiss.dbeaver.model.preferences.DBPPreferenceStore;
import org.jkiss.dbeaver.model.runtime.DBRProgressMonitor;
import org.jkiss.dbeaver.model.struct.DBSEntity;
import org.jkiss.dbeaver.model.struct.DBSEntityAttribute;
import org.jkiss.dbeaver.model.struct.DBSObject;
import org.jkiss.dbeaver.model.struct.DBSObjectContainer;
import org.jkiss.dbeaver.model.struct.rdb.DBSSchema;
import org.jkiss.dbeaver.model.struct.rdb.DBSTablePartition;
import org.jkiss.dbeaver.runtime.DBWorkbench;
import org.jkiss.dbeaver.utils.RuntimeUtils;
import org.jkiss.utils.CommonUtils;
import retrofit2.HttpException;

public class GPTCompletionEngine
implements DAICompletionEngine {
    private static final Log log = Log.getLog(GPTCompletionEngine.class);
    private static final int MAX_REQUEST_ATTEMPTS = 3;
    private static final Map<String, OpenAiService> clientInstances = new HashMap<String, OpenAiService>();
    private static final int GPT_MODEL_MAX_TOKENS = 2048;
    private static final int MAX_PROMPT_LENGTH = 7500;
    private static final boolean SUPPORTS_ATTRS = true;

    @Override
    public String getEngineName() {
        return "GPT-3";
    }

    @Override
    public String getModelName() {
        return DBWorkbench.getPlatform().getPreferenceStore().getString("gpt.model");
    }

    @Override
    @NotNull
    public List<DAICompletionResponse> performQueryCompletion(@NotNull DBRProgressMonitor monitor, @Nullable DBSLogicalDataSource dataSource, @NotNull DBCExecutionContext executionContext, @NotNull DAICompletionRequest completionRequest, boolean returnOnlyCompletion, int maxResults) throws DBException {
        String result = this.requestCompletion(completionRequest, monitor, executionContext);
        DAICompletionResponse response = this.createCompletionResponse(dataSource, executionContext, result);
        return Collections.singletonList(response);
    }

    @Override
    public boolean isValidConfiguration() {
        return !CommonUtils.isEmpty((String)GPTCompletionEngine.acquireToken());
    }

    @NotNull
    protected DAICompletionResponse createCompletionResponse(DBSLogicalDataSource dataSource, DBCExecutionContext executionContext, String result) {
        DAICompletionResponse response = new DAICompletionResponse();
        response.setResultCompletion(result);
        return response;
    }

    private static OpenAiService initGPTApiClientInstance() throws DBException {
        String token = GPTCompletionEngine.acquireToken();
        if (CommonUtils.isEmpty((String)token)) {
            throw new DBException("Empty API token value");
        }
        return new OpenAiService(token, Duration.ofSeconds(30L));
    }

    private static String acquireToken() {
        AIEngineSettings openAiConfig = AISettings.getSettings().getEngineConfiguration("openai");
        Object token = openAiConfig.getProperties().get("gpt.token");
        if (token != null) {
            return token.toString();
        }
        return DBWorkbench.getPlatform().getPreferenceStore().getString("gpt.token");
    }

    private String requestCompletion(@NotNull DAICompletionRequest request, @NotNull DBRProgressMonitor monitor, @NotNull DBCExecutionContext executionContext) throws DBException {
        DBPDataSourceContainer container;
        OpenAiService service;
        DAICompletionScope scope = request.getScope();
        DBSObjectContainer mainObject = null;
        DBCExecutionContextDefaults contextDefaults = executionContext.getContextDefaults();
        if (contextDefaults != null) {
            switch (scope) {
                case CURRENT_SCHEMA: {
                    if (contextDefaults.getDefaultSchema() != null) {
                        mainObject = contextDefaults.getDefaultSchema();
                        break;
                    }
                    mainObject = contextDefaults.getDefaultCatalog();
                    break;
                }
                case CURRENT_DATABASE: {
                    mainObject = contextDefaults.getDefaultCatalog();
                    break;
                }
            }
        }
        if (mainObject == null) {
            mainObject = (DBSObjectContainer)executionContext.getDataSource();
        }
        if ((service = clientInstances.get((container = executionContext.getDataSource().getContainer()).getId())) == null) {
            service = GPTCompletionEngine.initGPTApiClientInstance();
            clientInstances.put(container.getId(), service);
        }
        String modifiedRequest = this.addDBMetadataToRequest(monitor, request, executionContext, mainObject);
        if (monitor.isCanceled()) {
            return "";
        }
        CompletionRequest completionRequest = GPTCompletionEngine.createCompletionRequest(modifiedRequest);
        monitor.subTask("Request GPT completion");
        try {
            List choices;
            if (DBWorkbench.getPlatform().getPreferenceStore().getBoolean("gpt.log.query")) {
                log.debug((Object)("GPT request:\n" + completionRequest.getPrompt()));
            }
            if (monitor.isCanceled()) {
                return null;
            }
            int i = 0;
            while (true) {
                try {
                    choices = service.createCompletion(completionRequest).getChoices();
                }
                catch (Exception e) {
                    if (e instanceof HttpException && ((HttpException)e).code() == 429) {
                        RuntimeUtils.pause((int)1000);
                        if (i >= 2) {
                            throw e;
                        }
                    } else {
                        throw e;
                    }
                    log.debug((Object)("AI service failed. Retry (" + e.getMessage() + ")"));
                    ++i;
                    continue;
                }
                break;
            }
            Optional choice = choices.stream().findFirst();
            Object completionText = ((CompletionChoice)choice.orElseThrow()).getText();
            if (CommonUtils.isEmpty((String)completionText)) {
                return null;
            }
            completionText = "SELECT " + ((String)completionText).trim() + ";";
            completionText = this.postProcessGeneratedQuery(monitor, mainObject, executionContext, (String)completionText);
            if (DBWorkbench.getPlatform().getPreferenceStore().getBoolean("ai.completion.includeSourceTextInQuery")) {
                String[] lines;
                String[] stringArray = lines = request.getPromptText().split("\n");
                int n = lines.length;
                int n2 = 0;
                while (n2 < n) {
                    String line = stringArray[n2];
                    if (!CommonUtils.isEmpty((String)line)) {
                        completionText = "-- " + line.trim() + "\n" + (String)completionText;
                    }
                    ++n2;
                }
            }
            String string = ((String)completionText).trim();
            return string;
        }
        finally {
            monitor.done();
        }
    }

    private static DBPPreferenceStore getPreferenceStore() {
        return DBWorkbench.getPlatform().getPreferenceStore();
    }

    private static CompletionRequest createCompletionRequest(@NotNull String request) throws DBException {
        GPTModel model;
        int maxTokens = 2048;
        Double temperature = GPTCompletionEngine.getPreferenceStore().getDouble("gpt.model.temperature");
        String modelId = GPTCompletionEngine.getPreferenceStore().getString("gpt.model");
        GPTModel gPTModel = model = CommonUtils.isEmpty((String)modelId) ? null : GPTModel.getByName(modelId);
        if (model == null) {
            model = GPTModel.TEXT_DAVINCI02;
        }
        CompletionRequest.CompletionRequestBuilder builder = CompletionRequest.builder().prompt(request);
        return builder.temperature(temperature).maxTokens(Integer.valueOf(maxTokens)).frequencyPenalty(Double.valueOf(0.0)).presencePenalty(Double.valueOf(0.0)).stop(List.of("#", ";")).model(modelId).build();
    }

    public static void resetServices() {
        clientInstances.clear();
    }

    protected String addDBMetadataToRequest(DBRProgressMonitor monitor, DAICompletionRequest request, DBCExecutionContext executionContext, DBSObjectContainer mainObject) throws DBException {
        DBSSchema defaultSchema;
        if (mainObject == null || mainObject.getDataSource() == null || CommonUtils.isEmptyTrimmed((String)request.getPromptText())) {
            throw new DBException("Invalid completion request");
        }
        StringBuilder additionalMetadata = new StringBuilder();
        additionalMetadata.append("### ").append(mainObject.getDataSource().getSQLDialect().getDialectName()).append(" SQL tables, with their properties:\n#\n");
        String tail = "";
        if (executionContext != null && executionContext.getContextDefaults() != null && (defaultSchema = executionContext.getContextDefaults().getDefaultSchema()) != null) {
            tail = String.valueOf(tail) + "#\n# Current schema is " + defaultSchema.getName() + "\n";
        }
        int maxRequestLength = 7500 - additionalMetadata.length() - tail.length() - 20;
        if (request.getScope() != DAICompletionScope.CUSTOM) {
            additionalMetadata.append(this.generateObjectDescription(monitor, request, (DBSObject)mainObject, maxRequestLength));
        } else {
            for (DBSEntity entity : request.getCustomEntities()) {
                additionalMetadata.append(this.generateObjectDescription(monitor, request, (DBSObject)entity, maxRequestLength));
            }
        }
        String promptText = request.getPromptText().trim();
        promptText = this.postProcessPrompt(monitor, mainObject, executionContext, promptText);
        additionalMetadata.append(tail).append("#\n###").append(promptText).append("\nSELECT");
        return additionalMetadata.toString();
    }

    private String generateObjectDescription(@NotNull DBRProgressMonitor monitor, @NotNull DAICompletionRequest request, @NotNull DBSObject object, int maxRequestLength) throws DBException {
        if (DBNUtils.getNodeByObject((DBRProgressMonitor)monitor, (DBSObject)object, (boolean)false) == null) {
            return "";
        }
        StringBuilder description = new StringBuilder();
        if (object instanceof DBSEntity) {
            description.append("# ").append(DBUtils.getQuotedIdentifier((DBSObject)object));
            description.append("(");
            boolean firstAttr = this.addPromptAttributes(monitor, (DBSEntity)object, description, true);
            this.addPromptExtra(monitor, (DBSEntity)object, description, firstAttr);
            description.append(");\n");
        } else if (object instanceof DBSObjectContainer) {
            monitor.subTask("Load cache of " + object.getName());
            ((DBSObjectContainer)object).cacheStructure(monitor, 3);
            for (DBSObject child : ((DBSObjectContainer)object).getChildren(monitor)) {
                if (DBUtils.isSystemObject((Object)child) || DBUtils.isHiddenObject((Object)child) || child instanceof DBSTablePartition) continue;
                String childText = this.generateObjectDescription(monitor, request, child, maxRequestLength);
                if (description.length() + childText.length() > maxRequestLength) {
                    log.debug((Object)("Trim GPT metadata prompt  at table '" + child.getName() + "' - too long request"));
                    break;
                }
                description.append(childText);
            }
        }
        return description.toString();
    }

    protected boolean addPromptAttributes(DBRProgressMonitor monitor, DBSEntity entity, StringBuilder prompt, boolean firstAttr) throws DBException {
        List attributes = entity.getAttributes(monitor);
        if (attributes != null) {
            for (DBSEntityAttribute attribute : attributes) {
                if (DBUtils.isHiddenObject((Object)attribute)) continue;
                if (!firstAttr) {
                    prompt.append(",");
                }
                firstAttr = false;
                prompt.append(attribute.getName());
            }
        }
        return firstAttr;
    }

    protected void addPromptExtra(DBRProgressMonitor monitor, DBSEntity object, StringBuilder description, boolean firstAttr) throws DBException {
    }

    protected String postProcessPrompt(DBRProgressMonitor monitor, DBSObjectContainer mainObject, DBCExecutionContext executionContext, String promptText) {
        return promptText;
    }

    protected String postProcessGeneratedQuery(DBRProgressMonitor monitor, DBSObjectContainer mainObject, DBCExecutionContext executionContext, String completionText) {
        return completionText;
    }
}

