/*
 * Decompiled with CFR 0.152.
 */
package com.dbeaver.model.ai.rag;

import com.dbeaver.model.ai.RAGDatasourceKey;
import com.dbeaver.model.ai.RAGEmbeddedRecord;
import com.dbeaver.model.ai.RAGEmbeddingStorage;
import com.dbeaver.model.ai.RAGObjectKey;
import com.dbeaver.model.ai.RAGRelevantRecordFilter;
import com.dbeaver.model.ai.rag.AIEmbeddingGenerator;
import com.dbeaver.model.ai.rag.AIScopeSelector;
import com.dbeaver.model.ai.rag.AIVectorEmbedding;
import java.nio.charset.StandardCharsets;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Optional;
import java.util.Set;
import java.util.stream.Collectors;
import java.util.zip.CRC32;
import org.jkiss.code.NotNull;
import org.jkiss.code.Nullable;
import org.jkiss.dbeaver.DBException;
import org.jkiss.dbeaver.Log;
import org.jkiss.dbeaver.model.DBPDataSource;
import org.jkiss.dbeaver.model.DBPNamedObject;
import org.jkiss.dbeaver.model.DBUtils;
import org.jkiss.dbeaver.model.ai.AIDatabaseScope;
import org.jkiss.dbeaver.model.ai.AIMessage;
import org.jkiss.dbeaver.model.ai.AIMessageType;
import org.jkiss.dbeaver.model.ai.AISchemaGenerationOptions;
import org.jkiss.dbeaver.model.ai.engine.AIDatabaseContext;
import org.jkiss.dbeaver.model.ai.impl.AISchemaGeneratorImpl;
import org.jkiss.dbeaver.model.exec.DBCExecutionContext;
import org.jkiss.dbeaver.model.runtime.DBRProgressMonitor;
import org.jkiss.dbeaver.model.struct.DBSEntity;
import org.jkiss.dbeaver.model.struct.DBSEntityAssociation;
import org.jkiss.dbeaver.model.struct.DBSObject;
import org.jkiss.dbeaver.model.struct.DBSObjectContainer;
import org.jkiss.dbeaver.model.virtual.DBVUtils;
import org.jkiss.dbeaver.registry.DataSourceDescriptor;

public class RAGScopeSelector
implements AIScopeSelector {
    private static final Log log = Log.getLog(RAGScopeSelector.class);
    private static final AISchemaGenerationOptions OPTIONS = AISchemaGenerationOptions.builder().withSendColumnTypes(false).withSendObjectComment(true).withSendConstraints(false).withSendForeignKeys(false).withMaxDbSnapshotTokens(Integer.MAX_VALUE).build();
    @NotNull
    private final RAGEmbeddingStorage embeddingStorage;
    @NotNull
    private final AISchemaGeneratorImpl schemaGenerator;

    public RAGScopeSelector(@NotNull RAGEmbeddingStorage embeddingStorage, @NotNull AISchemaGeneratorImpl schemaGenerator) {
        this.embeddingStorage = embeddingStorage;
        this.schemaGenerator = schemaGenerator;
    }

    @Override
    @Nullable
    public AIDatabaseContext select(@NotNull DBRProgressMonitor monitor, @NotNull AIEmbeddingGenerator embeddingGenerator, @Nullable AIDatabaseContext originalContext, @NotNull List<AIMessage> messages) throws DBException {
        if (originalContext == null) {
            return null;
        }
        return new AIDatabaseContext.Builder(originalContext.getDataSource()).setScope(AIDatabaseScope.CUSTOM).setCustomEntities(this.doSelection(monitor, embeddingGenerator, originalContext, messages)).setExecutionContext(originalContext.getExecutionContext()).build();
    }

    @NotNull
    private List<DBSObject> doSelection(@NotNull DBRProgressMonitor monitor, @NotNull AIEmbeddingGenerator embeddingGenerator, @NotNull AIDatabaseContext originalContext, @NotNull List<AIMessage> messages) throws DBException {
        List<DBSObject> dbsObjects;
        boolean loggingEnabled = embeddingGenerator.isLoggingEnabled();
        List from = originalContext.getScope() == AIDatabaseScope.CUSTOM ? Objects.requireNonNull(originalContext.getCustomEntities()) : List.of(originalContext.getScopeObject());
        List<String> userMessages = messages.stream().filter(it -> it.getRole() == AIMessageType.USER).map(AIMessage::getContent).toList();
        List<DBSEntity> scopeObjects = RAGScopeSelector.flattenContainers(monitor, from);
        if (loggingEnabled) {
            log.debug((Object)("Updating embeddings for " + scopeObjects.size() + " objects"));
        }
        this.updateEmbeddings(monitor, embeddingGenerator, originalContext.getExecutionContext(), scopeObjects);
        Map<RAGObjectKey, DBSEntity> objectsByKeys = scopeObjects.stream().collect(Collectors.toMap(RAGScopeSelector::makeObjectKey, it -> it));
        if (originalContext.getScope() == AIDatabaseScope.CUSTOM) {
            if (loggingEnabled) {
                log.debug((Object)"Selecting relevant objects from custom scope");
            }
            dbsObjects = this.selectRelevantObjectsFromCustomScope(monitor, embeddingGenerator, originalContext, scopeObjects, userMessages, objectsByKeys);
        } else {
            if (loggingEnabled) {
                log.debug((Object)("Selecting relevant objects from " + String.valueOf(originalContext.getScopeObject())));
            }
            dbsObjects = this.selectRelevantObjectsFromContainer(monitor, embeddingGenerator, originalContext, userMessages, objectsByKeys);
        }
        if (loggingEnabled) {
            log.debug((Object)("Selected " + dbsObjects.size() + " objects (" + dbsObjects.stream().map(o -> DBUtils.getFullQualifiedName((DBPDataSource)originalContext.getExecutionContext().getDataSource(), (DBPNamedObject[])new DBPNamedObject[]{o})).collect(Collectors.joining(",")) + ") relevant objects"));
        }
        return dbsObjects;
    }

    @NotNull
    private List<DBSObject> selectRelevantObjectsFromCustomScope(@NotNull DBRProgressMonitor monitor, @NotNull AIEmbeddingGenerator embeddingGenerator, @NotNull AIDatabaseContext originalContext, @NotNull List<DBSEntity> scopeObjects, @NotNull List<String> userMessages, @NotNull Map<RAGObjectKey, DBSEntity> objectsByKeys) throws DBException {
        return this.selectFromEntities(monitor, embeddingGenerator, originalContext.getExecutionContext(), scopeObjects, userMessages).stream().map(objectsByKeys::get).filter(Objects::nonNull).map(it -> it).toList();
    }

    @NotNull
    private List<DBSObject> selectRelevantObjectsFromContainer(@NotNull DBRProgressMonitor monitor, @NotNull AIEmbeddingGenerator embeddingGenerator, @NotNull AIDatabaseContext originalContext, @NotNull List<String> userMessages, @NotNull Map<RAGObjectKey, DBSEntity> objectsByKeys) throws DBException {
        return this.selectFromEntities(monitor, embeddingGenerator, originalContext.getExecutionContext(), objectsByKeys.values(), userMessages).stream().map(objectsByKeys::get).filter(Objects::nonNull).flatMap(it -> {
            ArrayList<DBSEntity> entities = new ArrayList<DBSEntity>();
            entities.add((DBSEntity)it);
            DBVUtils.getAllReferences((DBRProgressMonitor)monitor, (DBSEntity)it).stream().map(DBSEntityAssociation::getAssociatedEntity).filter(Objects::nonNull).forEach(entities::add);
            return entities.stream();
        }).map(it -> it).distinct().toList();
    }

    @NotNull
    private Set<RAGObjectKey> selectFromEntities(@NotNull DBRProgressMonitor monitor, @NotNull AIEmbeddingGenerator embeddingGenerator, @NotNull DBCExecutionContext context, @NotNull Collection<DBSEntity> from, @NotNull List<String> queries) throws DBException {
        List<AIVectorEmbedding> queryEmbedding = embeddingGenerator.embedTexts(monitor, queries);
        HashSet<RAGObjectKey> relevantKeys = new HashSet<RAGObjectKey>();
        for (AIVectorEmbedding embedding : queryEmbedding) {
            this.embeddingStorage.findRelevantRecords(new RAGRelevantRecordFilter(context.getDataSource().getContainer().getProject().getId(), context.getDataSource().getContainer().getId(), from.stream().map(RAGScopeSelector::getObjectName).toList(), embedding.vector(), 10)).forEach(r -> {
                boolean bl = relevantKeys.add(r.key());
            });
        }
        return relevantKeys;
    }

    private void updateEmbeddings(@NotNull DBRProgressMonitor monitor, @NotNull AIEmbeddingGenerator embeddingGenerator, @NotNull DBCExecutionContext context, @NotNull Collection<DBSEntity> scopeObjects) throws DBException {
        List<RAGObjectKey> objectKeys = scopeObjects.stream().map(RAGScopeSelector::getObjectName).map(it -> new RAGObjectKey(new RAGDatasourceKey(context.getDataSource().getContainer().getProject().getId(), context.getDataSource().getContainer().getId()), it)).toList();
        Map<RAGObjectKey, RAGEmbeddedRecord> storedVectors = this.embeddingStorage.findByKeys(objectKeys).stream().collect(Collectors.toMap(RAGEmbeddedRecord::key, it -> it));
        ArrayList<Map.Entry<RAGObjectKey, String>> objectsToUpdate = new ArrayList<Map.Entry<RAGObjectKey, String>>();
        for (DBSEntity entity : scopeObjects) {
            String schema = this.schemaGenerator.generateSchema(monitor, entity, context, OPTIONS, false);
            long checksum = this.computeChecksum(schema);
            RAGObjectKey objectKey = RAGScopeSelector.makeObjectKey(entity);
            boolean checksumMatches = Optional.ofNullable(storedVectors.get(objectKey)).map(r -> r.checksum() == checksum).orElse(false);
            boolean modelMatches = Optional.ofNullable(storedVectors.get(objectKey)).map(r -> embeddingGenerator.getEmbeddingModelName().equals(r.modelName())).orElse(false);
            if (modelMatches && checksumMatches) continue;
            objectsToUpdate.add(Map.entry(objectKey, schema));
        }
        if (!objectsToUpdate.isEmpty()) {
            List<String> texts = objectsToUpdate.stream().map(Map.Entry::getValue).toList();
            List<AIVectorEmbedding> embeddings = embeddingGenerator.embedTexts(monitor, texts);
            ArrayList<RAGEmbeddedRecord> newRecords = new ArrayList<RAGEmbeddedRecord>(objectsToUpdate.size());
            int i = 0;
            while (i < objectsToUpdate.size()) {
                Map.Entry entry = (Map.Entry)objectsToUpdate.get(i);
                float[] embedding = embeddings.get(i).vector();
                long checksum = this.computeChecksum((String)entry.getValue());
                newRecords.add(new RAGEmbeddedRecord((RAGObjectKey)entry.getKey(), checksum, embeddingGenerator.getEmbeddingModelName(), embedding));
                ++i;
            }
            this.embeddingStorage.save(newRecords);
        }
    }

    @NotNull
    private static List<DBSEntity> flattenContainers(@NotNull DBRProgressMonitor monitor, @NotNull Collection<DBSObject> objects) throws DBException {
        ArrayList<DBSEntity> scopeObjects = new ArrayList<DBSEntity>();
        for (DBSObject dbsObject : objects) {
            for (DBSObject object : RAGScopeSelector.flattenContainer(monitor, dbsObject)) {
                if (!(object instanceof DBSEntity)) continue;
                DBSEntity dbsEntity = (DBSEntity)object;
                scopeObjects.add(dbsEntity);
            }
        }
        return scopeObjects;
    }

    @NotNull
    private static List<DBSObject> flattenContainer(@NotNull DBRProgressMonitor monitor, @NotNull DBSObject object) throws DBException {
        ArrayList<DBSObject> result = new ArrayList<DBSObject>();
        if (object instanceof DBSObjectContainer) {
            DBSObjectContainer container = (DBSObjectContainer)object;
            container.cacheStructure(monitor, 1);
            Collection children = container.getChildren(monitor);
            if (children == null) {
                return result;
            }
            for (DBSObject dbsObject : children) {
                result.addAll(RAGScopeSelector.flattenContainer(monitor, dbsObject));
            }
        } else {
            result.add(object);
        }
        return result;
    }

    @NotNull
    private static String getObjectName(@NotNull DBSObject object) {
        ArrayList<String> names = new ArrayList<String>();
        DBSObject currentObject = object;
        while (currentObject != null && !(currentObject instanceof DataSourceDescriptor)) {
            names.add(currentObject.getName());
            currentObject = currentObject.getParentObject();
        }
        Collections.reverse(names);
        return String.join((CharSequence)".", names);
    }

    @NotNull
    private static RAGObjectKey makeObjectKey(@NotNull DBSEntity object) {
        return new RAGObjectKey(new RAGDatasourceKey(object.getDataSource().getContainer().getProject().getId(), object.getDataSource().getContainer().getId()), RAGScopeSelector.getObjectName((DBSObject)object));
    }

    private long computeChecksum(@NotNull String text) {
        CRC32 crc = new CRC32();
        crc.update(text.getBytes(StandardCharsets.UTF_8));
        return crc.getValue();
    }
}

