/*
 * Decompiled with CFR 0.152.
 */
package com.dbeaver.model.ai.rag.strategy;

import com.dbeaver.model.ai.RAGEmbeddingStorage;
import com.dbeaver.model.ai.RAGObjectKey;
import com.dbeaver.model.ai.RAGRelevantRecordFilter;
import com.dbeaver.model.ai.rag.AIVectorEmbedding;
import com.dbeaver.model.ai.rag.RAGUtils;
import com.dbeaver.model.ai.rag.strategy.SelectionStrategy;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;
import org.jkiss.code.NotNull;
import org.jkiss.dbeaver.DBException;
import org.jkiss.dbeaver.model.DBPDataSource;
import org.jkiss.dbeaver.model.ai.engine.AIDatabaseContext;
import org.jkiss.dbeaver.model.runtime.DBRProgressMonitor;
import org.jkiss.dbeaver.model.struct.DBSEntity;
import org.jkiss.dbeaver.model.struct.DBSObject;

public class CustomScopeSelectionStrategy
implements SelectionStrategy {
    @NotNull
    private final RAGEmbeddingStorage embeddingStorage;

    public CustomScopeSelectionStrategy(@NotNull RAGEmbeddingStorage embeddingStorage) {
        this.embeddingStorage = embeddingStorage;
    }

    @Override
    public List<DBSObject> selectObjectsForScope(@NotNull DBRProgressMonitor monitor, @NotNull List<AIVectorEmbedding> queryEmbeddings, @NotNull AIDatabaseContext originalContext) throws DBException {
        List customEntities = originalContext.getCustomEntities();
        if (customEntities == null || customEntities.isEmpty()) {
            return Collections.emptyList();
        }
        DBPDataSource dataSource = originalContext.getExecutionContext().getDataSource();
        Map<RAGObjectKey, DBSEntity> entityKeyMap = RAGUtils.flattenContainers(monitor, customEntities).stream().collect(Collectors.toMap(it -> RAGObjectKey.fromEntity((DBPDataSource)dataSource, (DBSObject)it), obj -> obj));
        ArrayList<DBSObject> result = new ArrayList<DBSObject>();
        for (AIVectorEmbedding embedding : queryEmbeddings) {
            this.embeddingStorage.findRelevantRecords(new RAGRelevantRecordFilter(entityKeyMap.keySet(), embedding.vector(), 5)).forEach(r -> {
                boolean bl = result.add((DBSObject)entityKeyMap.get(r.key()));
            });
        }
        return result;
    }
}

