/*
 * Decompiled with CFR 0.152.
 */
package org.jkiss.dbeaver.model.ai.openai;

import com.theokanning.openai.OpenAiHttpException;
import com.theokanning.openai.completion.CompletionChoice;
import com.theokanning.openai.completion.CompletionRequest;
import com.theokanning.openai.completion.chat.ChatCompletionChoice;
import com.theokanning.openai.completion.chat.ChatCompletionRequest;
import com.theokanning.openai.completion.chat.ChatMessage;
import java.time.Duration;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.StringJoiner;
import java.util.regex.Matcher;
import java.util.regex.Pattern;
import java.util.stream.Collectors;
import org.jkiss.code.NotNull;
import org.jkiss.code.Nullable;
import org.jkiss.dbeaver.DBException;
import org.jkiss.dbeaver.Log;
import org.jkiss.dbeaver.model.DBPDataSourceContainer;
import org.jkiss.dbeaver.model.ai.AIEngineSettings;
import org.jkiss.dbeaver.model.ai.AISettingsRegistry;
import org.jkiss.dbeaver.model.ai.completion.AbstractAICompletionEngine;
import org.jkiss.dbeaver.model.ai.completion.DAICompletionContext;
import org.jkiss.dbeaver.model.ai.completion.DAICompletionMessage;
import org.jkiss.dbeaver.model.ai.format.IAIFormatter;
import org.jkiss.dbeaver.model.ai.metadata.MetadataProcessor;
import org.jkiss.dbeaver.model.ai.openai.GPTModel;
import org.jkiss.dbeaver.model.ai.openai.service.AdaptedOpenAiService;
import org.jkiss.dbeaver.model.ai.openai.service.GPTCompletionAdapter;
import org.jkiss.dbeaver.model.exec.DBCExecutionContext;
import org.jkiss.dbeaver.model.runtime.DBRProgressMonitor;
import org.jkiss.dbeaver.model.struct.DBSObjectContainer;
import org.jkiss.dbeaver.runtime.DBWorkbench;
import org.jkiss.dbeaver.utils.RuntimeUtils;
import org.jkiss.utils.CommonUtils;
import retrofit2.HttpException;

public class OpenAICompletionEngine
extends AbstractAICompletionEngine<GPTCompletionAdapter, Object> {
    private static final Log log = Log.getLog(OpenAICompletionEngine.class);
    protected static final int MAX_REQUEST_ATTEMPTS = 3;
    private static final Map<String, GPTCompletionAdapter> clientInstances = new HashMap<String, GPTCompletionAdapter>();
    private static final Pattern sizeErrorPattern = Pattern.compile(".+context length is ([0-9]+) tokens.+([0-9]+) tokens.+");

    private static CompletionRequest buildSingleRequest(boolean chatMode, @NotNull List<DAICompletionMessage> messages, int maxTokens, Double temperature, String modelId) {
        return CompletionRequest.builder().prompt(OpenAICompletionEngine.buildSingleMessage(OpenAICompletionEngine.truncateMessages(chatMode, messages, maxTokens))).temperature(temperature).frequencyPenalty(Double.valueOf(0.0)).n(Integer.valueOf(1)).presencePenalty(Double.valueOf(0.0)).stop(List.of("#", ";")).model(modelId).build();
    }

    private static ChatCompletionRequest buildChatRequest(boolean chatMode, @NotNull List<DAICompletionMessage> messages, int maxTokens, Double temperature, String modelId) {
        return ChatCompletionRequest.builder().messages(OpenAICompletionEngine.truncateMessages(chatMode, messages, maxTokens).stream().map(m -> new ChatMessage(m.getRole().getId(), m.getContent())).toList()).temperature(temperature).frequencyPenalty(Double.valueOf(0.0)).presencePenalty(Double.valueOf(0.0)).n(Integer.valueOf(1)).model(modelId).build();
    }

    public static void resetServices() {
        clientInstances.clear();
    }

    @Override
    public String getEngineName() {
        return "OpenAI GPT";
    }

    public String getModelName() {
        return CommonUtils.toString((Object)this.getSettings().getProperties().get("gpt.model"), (String)GPTModel.GPT_TURBO16.getName());
    }

    @Override
    public boolean isValidConfiguration() {
        return !CommonUtils.isEmpty((String)this.acquireToken());
    }

    @Override
    public Map<String, GPTCompletionAdapter> getServiceMap() {
        return clientInstances;
    }

    @Override
    @Nullable
    protected String requestCompletion(@NotNull DBRProgressMonitor monitor, @NotNull DAICompletionContext context, @NotNull List<DAICompletionMessage> messages, @NotNull IAIFormatter formatter, boolean chatCompletion) throws DBException {
        DBCExecutionContext executionContext = context.getExecutionContext();
        DBSObjectContainer mainObject = this.getScopeObject(context, executionContext);
        DAICompletionMessage metadataMessage = MetadataProcessor.INSTANCE.createMetadataMessage(monitor, context, mainObject, formatter, this.getInstructions(chatCompletion), this.getMaxTokens() - 2000);
        ArrayList<DAICompletionMessage> mergedMessages = new ArrayList<DAICompletionMessage>();
        mergedMessages.add(metadataMessage);
        mergedMessages.addAll(messages);
        if (monitor.isCanceled()) {
            return "";
        }
        GPTCompletionAdapter service = this.getServiceInstance(executionContext);
        Object completionRequest = this.createCompletionRequest(chatCompletion, mergedMessages);
        String completionText = this.callCompletion(monitor, chatCompletion, (List<DAICompletionMessage>)mergedMessages, service, completionRequest);
        return this.processCompletion(mergedMessages, monitor, executionContext, mainObject, completionText, formatter, this.getModel().isChatAPI());
    }

    @Override
    protected int getMaxTokens() {
        return GPTModel.getByName(this.getModelName()).getMaxTokens();
    }

    protected GPTCompletionAdapter initGPTApiClientInstance() throws DBException {
        String token = this.acquireToken();
        if (CommonUtils.isEmpty((String)token)) {
            throw new DBException("Empty API token value");
        }
        return new AdaptedOpenAiService(token, Duration.ofSeconds(30L));
    }

    protected String acquireToken() {
        AIEngineSettings config = this.getSettings();
        Object token = config.getProperties().get("gpt.token");
        if (token != null) {
            return token.toString();
        }
        return DBWorkbench.getPlatform().getPreferenceStore().getString("gpt.token");
    }

    @Override
    @NotNull
    protected AIEngineSettings getSettings() {
        return AISettingsRegistry.getInstance().getSettings().getEngineConfiguration("openai");
    }

    @Override
    @Nullable
    protected String callCompletion(@NotNull DBRProgressMonitor monitor, boolean chatMode, @NotNull List<DAICompletionMessage> messages, @NotNull GPTCompletionAdapter service, @NotNull Object completionRequest) throws DBException {
        monitor.subTask("Request GPT completion");
        try {
            List<?> choices;
            if (CommonUtils.toBoolean((Object)this.getSettings().getProperties().get("gpt.log.query"))) {
                if (completionRequest instanceof ChatCompletionRequest) {
                    log.debug((Object)("Chat GPT request:\n" + ((ChatCompletionRequest)completionRequest).getMessages().stream().map(message -> "# " + message.getRole() + "\n" + message.getContent()).collect(Collectors.joining("\n"))));
                } else {
                    log.debug((Object)("GPT request:\n" + ((CompletionRequest)completionRequest).getPrompt()));
                }
            }
            if (monitor.isCanceled()) {
                return null;
            }
            int i = 0;
            while (true) {
                try {
                    choices = this.getCompletionChoices(service, completionRequest);
                }
                catch (Exception e) {
                    if (e instanceof HttpException && ((HttpException)e).code() == 429 || e instanceof OpenAiHttpException && e.getMessage().contains("This model's maximum")) {
                        if (e instanceof HttpException) {
                            RuntimeUtils.pause((int)1000);
                        } else {
                            Matcher matcher = sizeErrorPattern.matcher(e.getMessage());
                            if (!matcher.find()) {
                                throw e;
                            }
                            String numberStr = matcher.group(1);
                            int promptSize = CommonUtils.toInt((Object)numberStr);
                            if (promptSize >= this.getMaxTokens()) {
                                throw e;
                            }
                            completionRequest = this.createCompletionRequest(chatMode, messages, promptSize);
                        }
                        if (i >= 2) {
                            throw e;
                        }
                        if (e instanceof HttpException) {
                            log.debug((Object)("AI service failed. Retry (" + e.getMessage() + ")"));
                        }
                    } else {
                        throw e;
                    }
                    ++i;
                    continue;
                }
                break;
            }
            Object choice = choices.stream().findFirst().orElseThrow();
            String completionText = choice instanceof CompletionChoice ? ((CompletionChoice)choice).getText() : ((ChatCompletionChoice)choice).getMessage().getContent();
            if (CommonUtils.toBoolean((Object)this.getSettings().getProperties().get("gpt.log.query"))) {
                log.debug((Object)("GPT response:\n" + completionText));
            }
            String string = completionText;
            return string;
        }
        finally {
            monitor.done();
        }
    }

    @Override
    protected GPTCompletionAdapter getServiceInstance(@NotNull DBCExecutionContext executionContext) throws DBException {
        DBPDataSourceContainer container = executionContext.getDataSource().getContainer();
        GPTCompletionAdapter service = clientInstances.get(container.getId());
        if (service == null) {
            service = this.initGPTApiClientInstance();
            clientInstances.put(container.getId(), service);
        }
        return service;
    }

    @Override
    protected Object createCompletionRequest(boolean chatMode, @NotNull List<DAICompletionMessage> messages) {
        return this.createCompletionRequest(chatMode, messages, this.getMaxTokens());
    }

    @Override
    protected Object createCompletionRequest(boolean chatMode, @NotNull List<DAICompletionMessage> messages, int maxTokens) {
        Double temperature = CommonUtils.toDouble((Object)this.getSettings().getProperties().get("gpt.model.temperature"), (double)0.0);
        GPTModel model = this.getModel();
        if (model.isChatAPI()) {
            return OpenAICompletionEngine.buildChatRequest(chatMode, messages, maxTokens, temperature, model.getName());
        }
        return OpenAICompletionEngine.buildSingleRequest(chatMode, messages, maxTokens, temperature, model.getName());
    }

    @NotNull
    private GPTModel getModel() {
        String modelId = CommonUtils.toString((Object)this.getSettings().getProperties().get("gpt.model"), (String)"");
        return CommonUtils.isEmpty((String)modelId) ? GPTModel.GPT_TURBO16 : GPTModel.getByName(modelId);
    }

    private List<?> getCompletionChoices(GPTCompletionAdapter service, Object completionRequest) {
        if (completionRequest instanceof CompletionRequest) {
            return service.createCompletion((CompletionRequest)completionRequest).getChoices();
        }
        return service.createChatCompletion((ChatCompletionRequest)completionRequest).getChoices();
    }

    @NotNull
    private static String buildSingleMessage(@NotNull List<DAICompletionMessage> messages) {
        StringJoiner buffer = new StringJoiner("\n");
        for (DAICompletionMessage message : messages) {
            if (message.getRole() == DAICompletionMessage.Role.SYSTEM) {
                buffer.add("###");
                buffer.add(message.getContent().lines().map(line -> "#" + line).collect(Collectors.joining("\n")));
                buffer.add("###");
                continue;
            }
            buffer.add(message.getContent());
        }
        buffer.add("SELECT ");
        return buffer.toString();
    }
}

