/*
 * Decompiled with CFR 0.152.
 */
package com.dbeaver.model.ai.rag;

import com.dbeaver.model.ai.RAGEmbeddedRecord;
import com.dbeaver.model.ai.RAGEmbeddingStorage;
import com.dbeaver.model.ai.RAGObjectKey;
import com.dbeaver.model.ai.rag.AIEmbeddingGenerator;
import com.dbeaver.model.ai.rag.AIVectorEmbedding;
import com.dbeaver.model.ai.rag.RAGObjectDescriber;
import com.dbeaver.model.ai.rag.RAGObjectDescription;
import java.util.ArrayList;
import java.util.Collection;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.ExecutorService;
import java.util.stream.Collectors;
import org.jkiss.code.NotNull;
import org.jkiss.dbeaver.DBException;
import org.jkiss.dbeaver.Log;
import org.jkiss.dbeaver.model.DBPDataSource;
import org.jkiss.dbeaver.model.ai.impl.DummyTokenCounter;
import org.jkiss.dbeaver.model.ai.impl.TokenCounter;
import org.jkiss.dbeaver.model.exec.DBCExecutionContext;
import org.jkiss.dbeaver.model.runtime.DBRProgressMonitor;
import org.jkiss.dbeaver.model.struct.DBSEntity;
import org.jkiss.dbeaver.model.struct.DBSObject;

public class RAGIndexBuilder {
    private static final Log log = Log.getLog(RAGIndexBuilder.class);
    private static final int BATCH_SIZE = 200;
    @NotNull
    private final RAGEmbeddingStorage embeddingStorage;
    @NotNull
    private final RAGObjectDescriber objectDescriber;
    private final TokenCounter tokenCounter = new DummyTokenCounter();
    private final ExecutorService executor;

    public RAGIndexBuilder(@NotNull RAGEmbeddingStorage embeddingStorage, @NotNull RAGObjectDescriber objectDescriber, @NotNull ExecutorService executor) {
        this.embeddingStorage = embeddingStorage;
        this.objectDescriber = objectDescriber;
        this.executor = executor;
    }

    public CompletableFuture<Void> updateIndexAsyncFor(@NotNull DBRProgressMonitor monitor, @NotNull AIEmbeddingGenerator embeddingGenerator, @NotNull List<DBSEntity> dbsObjects, @NotNull DBCExecutionContext ctx) {
        if (dbsObjects.isEmpty()) {
            return CompletableFuture.completedFuture(null);
        }
        List<List<DBSEntity>> batchedObjects = this.batchObjects(dbsObjects, 200);
        ArrayList<CompletableFuture<Void>> futures = new ArrayList<CompletableFuture<Void>>();
        for (List<DBSEntity> batch : batchedObjects) {
            futures.add(CompletableFuture.runAsync(() -> {
                try {
                    this.updateIndexFor(monitor, embeddingGenerator, batch, ctx);
                }
                catch (DBException e) {
                    throw new RuntimeException(e);
                }
            }, this.executor));
        }
        return CompletableFuture.allOf(futures.toArray(new CompletableFuture[0]));
    }

    private void updateIndexFor(@NotNull DBRProgressMonitor monitor, @NotNull AIEmbeddingGenerator embeddingGenerator, @NotNull Collection<DBSEntity> dbsObjects, @NotNull DBCExecutionContext ctx) throws DBException {
        List<RAGObjectDescription> objectsToIndex = this.findObjectsToIndex(monitor, embeddingGenerator.getEmbeddingModelName(), dbsObjects, ctx);
        if (objectsToIndex.isEmpty()) {
            return;
        }
        if (embeddingGenerator.isLoggingEnabled()) {
            log.debug((Object)("Found " + objectsToIndex.size() + " objects to index"));
        }
        List<AIVectorEmbedding> vectorEmbeddings = this.computeEmbeddings(monitor, embeddingGenerator, objectsToIndex);
        List<RAGEmbeddedRecord> recordsToSave = RAGIndexBuilder.buildEmbeddedRecords(embeddingGenerator, objectsToIndex, vectorEmbeddings);
        this.embeddingStorage.save(recordsToSave);
    }

    @NotNull
    private static List<RAGEmbeddedRecord> buildEmbeddedRecords(@NotNull AIEmbeddingGenerator embeddingGenerator, List<RAGObjectDescription> objectsToIndex, List<AIVectorEmbedding> vectorEmbeddings) {
        ArrayList<RAGEmbeddedRecord> recordsToSave = new ArrayList<RAGEmbeddedRecord>();
        int i = 0;
        while (i < objectsToIndex.size()) {
            RAGObjectDescription description = objectsToIndex.get(i);
            AIVectorEmbedding vectorEmbedding = vectorEmbeddings.get(i);
            RAGEmbeddedRecord record = new RAGEmbeddedRecord(description.key(), description.checksum(), embeddingGenerator.getEmbeddingModelName(), vectorEmbedding.vector());
            recordsToSave.add(record);
            ++i;
        }
        return recordsToSave;
    }

    private List<RAGObjectDescription> findObjectsToIndex(@NotNull DBRProgressMonitor monitor, @NotNull String embeddingModelName, @NotNull Collection<DBSEntity> dbsObjects, @NotNull DBCExecutionContext ctx) throws DBException {
        Map<RAGObjectKey, DBSEntity> entitiesByKey = dbsObjects.stream().collect(Collectors.toMap(it -> RAGObjectKey.fromEntity((DBPDataSource)ctx.getDataSource(), (DBSObject)it), it -> it));
        Map<RAGObjectKey, RAGEmbeddedRecord> storedVectors = this.embeddingStorage.findByKeys(entitiesByKey.keySet()).stream().collect(Collectors.toMap(RAGEmbeddedRecord::key, it -> it));
        ArrayList<RAGObjectDescription> objectsToUpdate = new ArrayList<RAGObjectDescription>();
        for (Map.Entry<RAGObjectKey, DBSEntity> entry : entitiesByKey.entrySet()) {
            RAGObjectDescription description = this.objectDescriber.describe(monitor, entry.getValue(), ctx);
            boolean checksumMatches = Optional.ofNullable(storedVectors.get(description.key())).map(r -> r.checksum() == description.checksum()).orElse(false);
            boolean modelMatches = Optional.ofNullable(storedVectors.get(description.key())).map(r -> embeddingModelName.equals(r.modelName())).orElse(false);
            if (modelMatches && checksumMatches) continue;
            objectsToUpdate.add(description);
        }
        return objectsToUpdate;
    }

    private List<AIVectorEmbedding> computeEmbeddings(@NotNull DBRProgressMonitor monitor, @NotNull AIEmbeddingGenerator embeddingGenerator, @NotNull List<RAGObjectDescription> objectsToIndex) throws DBException {
        ArrayList<AIVectorEmbedding> result = new ArrayList<AIVectorEmbedding>();
        int currentContextTokens = 0;
        ArrayList<String> textsBatch = new ArrayList<String>();
        for (RAGObjectDescription toIndex : objectsToIndex) {
            String text = this.tokenCounter.truncateToTokenLimit(toIndex.schema(), embeddingGenerator.getEmbeddingContextLimit());
            int objectTokens = this.tokenCounter.count(text);
            if (currentContextTokens + objectTokens > embeddingGenerator.getEmbeddingContextLimit() && !textsBatch.isEmpty()) {
                result.addAll(embeddingGenerator.embedTexts(monitor, textsBatch));
                textsBatch.clear();
                currentContextTokens = 0;
            }
            textsBatch.add(text);
            currentContextTokens += objectTokens;
        }
        if (!textsBatch.isEmpty()) {
            result.addAll(embeddingGenerator.embedTexts(monitor, textsBatch));
        }
        return result;
    }

    private List<List<DBSEntity>> batchObjects(@NotNull List<DBSEntity> dbsObjects, int batchSize) {
        ArrayList<List<DBSEntity>> batches = new ArrayList<List<DBSEntity>>();
        int total = dbsObjects.size();
        int i = 0;
        while (i < total) {
            int toIndex = Math.min(i + batchSize, total);
            batches.add(new ArrayList<DBSEntity>(dbsObjects.subList(i, toIndex)));
            i += batchSize;
        }
        return batches;
    }
}

