/*
 * Decompiled with CFR 0.152.
 */
package com.dbeaver.model.ai.engine.aws;

import com.dbeaver.model.ai.engine.aws.AwsBedrockProperties;
import com.dbeaver.model.ai.engine.aws.AwsModels;
import com.google.gson.Gson;
import com.google.gson.GsonBuilder;
import com.google.gson.Strictness;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Optional;
import java.util.Set;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.atomic.AtomicReference;
import java.util.stream.Collectors;
import org.jkiss.code.NotNull;
import org.jkiss.code.Nullable;
import org.jkiss.dbeaver.DBException;
import org.jkiss.dbeaver.model.ai.AIMessage;
import org.jkiss.dbeaver.model.ai.AIMessageType;
import org.jkiss.dbeaver.model.ai.engine.AIEngineProperties;
import org.jkiss.dbeaver.model.ai.engine.AIEngineRequest;
import org.jkiss.dbeaver.model.ai.engine.AIEngineResponse;
import org.jkiss.dbeaver.model.ai.engine.AIEngineResponseChunk;
import org.jkiss.dbeaver.model.ai.engine.AIEngineResponseConsumer;
import org.jkiss.dbeaver.model.ai.engine.AIFunctionCall;
import org.jkiss.dbeaver.model.ai.engine.AIModel;
import org.jkiss.dbeaver.model.ai.engine.BaseCompletionEngine;
import org.jkiss.dbeaver.model.ai.registry.AIFunctionDescriptor;
import org.jkiss.dbeaver.model.ai.utils.DisposableLazyValue;
import org.jkiss.dbeaver.model.runtime.DBRProgressMonitor;
import software.amazon.awssdk.auth.credentials.AwsBasicCredentials;
import software.amazon.awssdk.auth.credentials.AwsCredentials;
import software.amazon.awssdk.auth.credentials.AwsCredentialsProvider;
import software.amazon.awssdk.auth.credentials.StaticCredentialsProvider;
import software.amazon.awssdk.core.document.Document;
import software.amazon.awssdk.services.bedrock.BedrockAsyncClient;
import software.amazon.awssdk.services.bedrock.BedrockAsyncClientBuilder;
import software.amazon.awssdk.services.bedrock.model.InferenceType;
import software.amazon.awssdk.services.bedrock.model.ListFoundationModelsRequest;
import software.amazon.awssdk.services.bedrock.model.ListFoundationModelsResponse;
import software.amazon.awssdk.services.bedrock.model.ListInferenceProfilesRequest;
import software.amazon.awssdk.services.bedrock.model.ListInferenceProfilesResponse;
import software.amazon.awssdk.services.bedrock.model.ModelModality;
import software.amazon.awssdk.services.bedrockruntime.BedrockRuntimeAsyncClient;
import software.amazon.awssdk.services.bedrockruntime.BedrockRuntimeAsyncClientBuilder;
import software.amazon.awssdk.services.bedrockruntime.model.BedrockRuntimeException;
import software.amazon.awssdk.services.bedrockruntime.model.ContentBlock;
import software.amazon.awssdk.services.bedrockruntime.model.ContentBlockDelta;
import software.amazon.awssdk.services.bedrockruntime.model.ContentBlockStart;
import software.amazon.awssdk.services.bedrockruntime.model.ConversationRole;
import software.amazon.awssdk.services.bedrockruntime.model.ConverseRequest;
import software.amazon.awssdk.services.bedrockruntime.model.ConverseResponse;
import software.amazon.awssdk.services.bedrockruntime.model.ConverseStreamRequest;
import software.amazon.awssdk.services.bedrockruntime.model.ConverseStreamResponseHandler;
import software.amazon.awssdk.services.bedrockruntime.model.InferenceConfiguration;
import software.amazon.awssdk.services.bedrockruntime.model.Message;
import software.amazon.awssdk.services.bedrockruntime.model.ModelNotReadyException;
import software.amazon.awssdk.services.bedrockruntime.model.SystemContentBlock;
import software.amazon.awssdk.services.bedrockruntime.model.Tool;
import software.amazon.awssdk.services.bedrockruntime.model.ToolConfiguration;
import software.amazon.awssdk.services.bedrockruntime.model.ToolInputSchema;
import software.amazon.awssdk.services.bedrockruntime.model.ToolSpecification;

public class AwsBedrockEngine
extends BaseCompletionEngine<AwsBedrockProperties> {
    private static final Gson GSON = new GsonBuilder().setStrictness(Strictness.LENIENT).serializeNulls().create();
    private final DisposableLazyValue<BedrockAsyncClient, DBException> metadataClient = new DisposableLazyValue<BedrockAsyncClient, DBException>(){

        protected void onDispose(@NotNull BedrockAsyncClient disposedValue) {
            disposedValue.close();
        }

        @NotNull
        protected BedrockAsyncClient initialize() {
            return AwsBedrockEngine.this.createMetadataBedrockAsyncClient();
        }
    };
    private final DisposableLazyValue<BedrockRuntimeAsyncClient, DBException> runtimeClient = new DisposableLazyValue<BedrockRuntimeAsyncClient, DBException>(){

        protected void onDispose(@NotNull BedrockRuntimeAsyncClient disposedValue) {
            disposedValue.close();
        }

        @NotNull
        protected BedrockRuntimeAsyncClient initialize() {
            return AwsBedrockEngine.this.createRuntimeBedrockClient();
        }
    };

    public AwsBedrockEngine(@NotNull AwsBedrockProperties properties) {
        super((AIEngineProperties)properties);
    }

    @NotNull
    public List<AIModel> getModels(@NotNull DBRProgressMonitor monitor) throws DBException {
        if (((AwsBedrockProperties)this.properties).isShowInferences()) {
            return this.listInferenceProfiles();
        }
        return this.listFoundationModels();
    }

    @NotNull
    public AIEngineResponse requestCompletion(@NotNull DBRProgressMonitor monitor, @NotNull AIEngineRequest request) throws DBException {
        List<Message> messages = this.mapMessages(request);
        SystemContentBlock systemPrompt = this.findSystemPrompt(request);
        ConverseRequest.Builder converseRequestBuilder = ConverseRequest.builder().inferenceConfig(this.buildInference()).modelId(((AwsBedrockProperties)this.properties).getModel()).messages(messages).system(new SystemContentBlock[]{systemPrompt});
        if (!request.getFunctions().isEmpty()) {
            ToolConfiguration tools = this.mapTools(request);
            converseRequestBuilder.toolConfig(tools);
        }
        ConverseRequest converseRequest = (ConverseRequest)converseRequestBuilder.build();
        try {
            CompletableFuture responseFuture = ((BedrockRuntimeAsyncClient)this.runtimeClient.getInstance()).converse(converseRequest);
            ConverseResponse response = (ConverseResponse)responseFuture.get();
            String modelResponseVal = response.stopReasonAsString();
            if (modelResponseVal.compareTo("tool_use") == 0) {
                Message message = response.output().message();
                for (ContentBlock contentBlock : message.content()) {
                    if (contentBlock.type().equals((Object)ContentBlock.Type.TOOL_USE) && contentBlock.toolUse() != null) {
                        return new AIEngineResponse(new AIFunctionCall(contentBlock.toolUse().name(), contentBlock.toolUse().input() == null ? Map.of() : this.convertMap(contentBlock.toolUse().input().asMap())));
                    }
                    if (contentBlock.text() != null) continue;
                    throw new DBException("Unexpected content block type in tool use response: " + String.valueOf(contentBlock.type()));
                }
                throw new DBException("Expected tool use content block, but none found.");
            }
            ArrayList<String> completions = new ArrayList<String>();
            Message message = response.output().message();
            for (ContentBlock contentBlock : message.content()) {
                if (contentBlock.text() == null) continue;
                completions.add(contentBlock.text());
            }
            return new AIEngineResponse(AIMessageType.ASSISTANT, completions);
        }
        catch (ModelNotReadyException ex) {
            throw new DBException("Model is not ready: " + ex.getMessage(), (Throwable)ex);
        }
        catch (BedrockRuntimeException ex) {
            throw new DBException("Failed to converse with Bedrock model: " + ex.getMessage(), (Throwable)ex);
        }
        catch (Exception e) {
            throw new DBException("Error during Bedrock converse request", (Throwable)e);
        }
    }

    public void requestCompletionStream(@NotNull DBRProgressMonitor monitor, @NotNull AIEngineRequest request, @NotNull AIEngineResponseConsumer listener) throws DBException {
        List<Message> messages = this.mapMessages(request);
        SystemContentBlock systemPrompt = this.findSystemPrompt(request);
        ConverseStreamRequest.Builder builder = ConverseStreamRequest.builder().inferenceConfig(this.buildInference()).modelId(((AwsBedrockProperties)this.properties).getModel()).messages(messages).system(new SystemContentBlock[]{systemPrompt});
        if (!request.getFunctions().isEmpty()) {
            ToolConfiguration tools = this.mapTools(request);
            builder.toolConfig(tools);
        }
        ConverseStreamRequest converseRequest = (ConverseStreamRequest)builder.build();
        try {
            AtomicReference toolName = new AtomicReference();
            ConverseStreamResponseHandler.Builder subscriber = (ConverseStreamResponseHandler.Builder)((ConverseStreamResponseHandler.Builder)ConverseStreamResponseHandler.builder().subscriber(ConverseStreamResponseHandler.Visitor.builder().onContentBlockStart(start -> {
                ContentBlockStart start1 = start.start();
                toolName.set(start1.toolUse().name());
            }).onContentBlockDelta(delta -> {
                ContentBlockDelta deltaBlock = delta.delta();
                if (deltaBlock.toolUse() != null) {
                    listener.nextChunk(new AIEngineResponseChunk(new AIFunctionCall((String)toolName.get(), (Map)GSON.fromJson(deltaBlock.toolUse().input() == null ? "" : deltaBlock.toolUse().input(), Map.class))));
                } else if (deltaBlock.text() != null) {
                    listener.nextChunk(new AIEngineResponseChunk(List.of(deltaBlock.text())));
                }
            }).build()).onError(arg_0 -> ((AIEngineResponseConsumer)listener).error(arg_0))).onComplete(() -> ((AIEngineResponseConsumer)listener).close());
            subscriber.build();
            ((BedrockRuntimeAsyncClient)this.runtimeClient.getInstance()).converseStream(converseRequest, subscriber.build()).get();
        }
        catch (ModelNotReadyException ex) {
            throw new DBException("Model is not ready: " + ex.getMessage(), (Throwable)ex);
        }
        catch (BedrockRuntimeException ex) {
            throw new DBException("Failed to converse with Bedrock model: " + ex.getMessage(), (Throwable)ex);
        }
        catch (InterruptedException | ExecutionException e) {
            throw new DBException("Error during Bedrock converse stream request", (Throwable)e);
        }
    }

    public int getContextWindowSize(@NotNull DBRProgressMonitor monitor) {
        return ((AwsBedrockProperties)this.properties).getContextWindowSize();
    }

    public void close() throws DBException {
        this.metadataClient.dispose();
        this.runtimeClient.dispose();
    }

    @NotNull
    private Map<String, Object> convertMap(@NotNull Map<String, Document> map) {
        HashMap<String, Object> result = new HashMap<String, Object>();
        for (Map.Entry<String, Document> entry : map.entrySet()) {
            String key = entry.getKey();
            Document value = entry.getValue();
            result.put(key, value.unwrap());
        }
        return result;
    }

    @NotNull
    private List<AIModel> listInferenceProfiles() throws DBException {
        BedrockAsyncClient instance = (BedrockAsyncClient)this.metadataClient.getInstance();
        CompletableFuture listModelsFuture = instance.listInferenceProfiles((ListInferenceProfilesRequest)ListInferenceProfilesRequest.builder().build());
        try {
            ListInferenceProfilesResponse response = (ListInferenceProfilesResponse)listModelsFuture.get();
            return response.inferenceProfileSummaries().stream().map(inference -> new AIModel(inference.inferenceProfileId(), AwsModels.getContextWindowSize(), AwsModels.mapFeaturesForInferences(inference, instance))).toList();
        }
        catch (Exception e) {
            throw new DBException("Error fetching models from AWS Bedrock", (Throwable)e);
        }
    }

    @NotNull
    private List<AIModel> listFoundationModels() throws DBException {
        CompletableFuture listModelsFuture = ((BedrockAsyncClient)this.metadataClient.getInstance()).listFoundationModels((ListFoundationModelsRequest)ListFoundationModelsRequest.builder().byOutputModality(ModelModality.TEXT).byInferenceType(InferenceType.ON_DEMAND).build());
        try {
            ListFoundationModelsResponse response = (ListFoundationModelsResponse)listModelsFuture.get();
            return response.modelSummaries().stream().map(model -> new AIModel(model.modelId(), AwsModels.getContextWindowSize(), AwsModels.mapFeaturesForModels(model))).toList();
        }
        catch (Exception e) {
            throw new DBException("Error fetching models from AWS Bedrock", (Throwable)e);
        }
    }

    @NotNull
    private InferenceConfiguration buildInference() {
        return (InferenceConfiguration)InferenceConfiguration.builder().temperature(Float.valueOf((float)((AwsBedrockProperties)this.properties).getTemperature())).maxTokens(((AwsBedrockProperties)this.properties).getContextWindowSize()).build();
    }

    @Nullable
    private SystemContentBlock findSystemPrompt(@NotNull AIEngineRequest request) {
        Optional<AIMessage> message = request.getMessages().stream().filter(m -> AIMessageType.SYSTEM.equals((Object)m.getRole())).findFirst();
        return message.map(aiMessage -> (SystemContentBlock)SystemContentBlock.builder().text(aiMessage.getContent()).build()).orElse(null);
    }

    @NotNull
    private ToolConfiguration mapTools(@NotNull AIEngineRequest request) {
        Set toolList = request.getFunctions().stream().map(it -> {
            HashMap<String, Document> propertiesMap = new HashMap<String, Document>();
            AIFunctionDescriptor.Parameter[] parameters = it.getParameters();
            ArrayList<Document> requiredParameters = new ArrayList<Document>();
            AIFunctionDescriptor.Parameter[] parameterArray = parameters;
            int n = parameters.length;
            int n2 = 0;
            while (n2 < n) {
                AIFunctionDescriptor.Parameter parameter = parameterArray[n2];
                requiredParameters.add(Document.fromString((String)parameter.getName()));
                HashMap<String, Document> parameterMap = new HashMap<String, Document>();
                parameterMap.put("type", Document.fromString((String)parameter.getType()));
                if (parameter.getDescription() != null) {
                    parameterMap.put("description", Document.fromString((String)parameter.getDescription()));
                }
                propertiesMap.put(parameter.getName(), Document.fromMap(parameterMap));
                ++n2;
            }
            HashMap<String, Document> rootMap = new HashMap<String, Document>();
            rootMap.put("type", Document.fromString((String)"object"));
            rootMap.put("properties", Document.fromMap(propertiesMap));
            rootMap.put("required", Document.fromList(requiredParameters));
            Document document = Document.fromMap(rootMap);
            return (ToolSpecification)ToolSpecification.builder().name(it.getId()).description(it.getDescription()).inputSchema((ToolInputSchema)ToolInputSchema.builder().json(document).build()).build();
        }).map(it -> (Tool)Tool.builder().toolSpec(it).build()).collect(Collectors.toSet());
        return (ToolConfiguration)ToolConfiguration.builder().tools(toolList).build();
    }

    @NotNull
    private List<Message> mapMessages(@NotNull AIEngineRequest request) {
        return request.getMessages().stream().map(m -> {
            switch (m.getRole()) {
                case USER: {
                    return (Message)Message.builder().content(new ContentBlock[]{ContentBlock.fromText((String)m.getContent())}).role(ConversationRole.USER).build();
                }
                case ASSISTANT: {
                    return (Message)Message.builder().content(new ContentBlock[]{ContentBlock.fromText((String)m.getContent())}).role(ConversationRole.ASSISTANT).build();
                }
            }
            return null;
        }).filter(Objects::nonNull).toList();
    }

    @NotNull
    private BedrockAsyncClient createMetadataBedrockAsyncClient() {
        BedrockAsyncClientBuilder builder = BedrockAsyncClient.builder();
        this.processAuthenticationAndRegion(builder);
        return (BedrockAsyncClient)builder.build();
    }

    @NotNull
    private BedrockRuntimeAsyncClient createRuntimeBedrockClient() {
        BedrockRuntimeAsyncClientBuilder builder = BedrockRuntimeAsyncClient.builder();
        this.processAuthenticationAndRegion(builder);
        return (BedrockRuntimeAsyncClient)builder.build();
    }

    private void processAuthenticationAndRegion(@NotNull BedrockRuntimeAsyncClientBuilder builder) {
        String accessKeyId = ((AwsBedrockProperties)this.properties).getAccessKey();
        String secretAccessKey = ((AwsBedrockProperties)this.properties).getSecretAccessKey();
        builder.credentialsProvider((AwsCredentialsProvider)StaticCredentialsProvider.create((AwsCredentials)AwsBasicCredentials.create((String)accessKeyId, (String)secretAccessKey)));
        builder.region(((AwsBedrockProperties)this.properties).getAwsRegion());
    }

    private void processAuthenticationAndRegion(@NotNull BedrockAsyncClientBuilder builder) {
        String accessKeyId = ((AwsBedrockProperties)this.properties).getAccessKey();
        String secretAccessKey = ((AwsBedrockProperties)this.properties).getSecretAccessKey();
        builder.credentialsProvider((AwsCredentialsProvider)StaticCredentialsProvider.create((AwsCredentials)AwsBasicCredentials.create((String)accessKeyId, (String)secretAccessKey)));
        builder.region(((AwsBedrockProperties)this.properties).getAwsRegion());
    }
}

