/*
 * Decompiled with CFR 0.152.
 */
package com.dbeaver.model.ai.rag.strategy;

import com.dbeaver.model.ai.RAGEmbeddedRecord;
import com.dbeaver.model.ai.RAGEmbeddingStorage;
import com.dbeaver.model.ai.RAGObjectKey;
import com.dbeaver.model.ai.RAGPrefixRelevantRecordFilter;
import com.dbeaver.model.ai.rag.AIVectorEmbedding;
import com.dbeaver.model.ai.rag.RAGUtils;
import com.dbeaver.model.ai.rag.strategy.SelectionStrategy;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.stream.Collectors;
import org.jkiss.code.NotNull;
import org.jkiss.dbeaver.DBException;
import org.jkiss.dbeaver.model.DBPDataSource;
import org.jkiss.dbeaver.model.ai.engine.AIDatabaseContext;
import org.jkiss.dbeaver.model.runtime.DBRProgressMonitor;
import org.jkiss.dbeaver.model.struct.DBSEntity;
import org.jkiss.dbeaver.model.struct.DBSEntityAssociation;
import org.jkiss.dbeaver.model.struct.DBSObject;
import org.jkiss.dbeaver.model.virtual.DBVUtils;

public class GeneralSelectionStrategy
implements SelectionStrategy {
    private static final int RETRIEVE_COUNT = 100;
    private final RAGEmbeddingStorage embeddingStorage;

    public GeneralSelectionStrategy(RAGEmbeddingStorage embeddingStorage) {
        this.embeddingStorage = embeddingStorage;
    }

    @Override
    public List<DBSObject> selectObjectsForScope(@NotNull DBRProgressMonitor monitor, @NotNull List<AIVectorEmbedding> queryEmbeddings, @NotNull AIDatabaseContext originalContext) throws DBException {
        DBPDataSource dataSource = originalContext.getExecutionContext().getDataSource();
        Map<RAGObjectKey, DBSEntity> entityKeyMap = RAGUtils.flattenContainers(monitor, List.of(originalContext.getScopeObject())).stream().collect(Collectors.toMap(obj -> RAGObjectKey.fromEntity((DBPDataSource)dataSource, (DBSObject)obj), obj -> obj));
        originalContext.getScopeObject().cacheStructure(monitor, 4);
        HashMap result = new HashMap();
        for (AIVectorEmbedding embedding : queryEmbeddings) {
            List relevantRecords = this.embeddingStorage.findRelevantRecords(new RAGPrefixRelevantRecordFilter(RAGObjectKey.fromEntity((DBPDataSource)dataSource, (DBSObject)originalContext.getScopeObject()), embedding.vector(), 100));
            int selectionCount = 0;
            Iterator iterator = relevantRecords.iterator();
            while (selectionCount < 5 && iterator.hasNext()) {
                RAGEmbeddedRecord record = (RAGEmbeddedRecord)iterator.next();
                DBSEntity dbsEntity = entityKeyMap.get(record.key());
                if (dbsEntity == null) continue;
                GeneralSelectionStrategy.extendWithReferences(monitor, dbsEntity).forEach(e -> {
                    DBSEntity dBSEntity = result.putIfAbsent(RAGObjectKey.fromEntity((DBPDataSource)dataSource, (DBSObject)e), e);
                });
                ++selectionCount;
            }
        }
        return new ArrayList<DBSObject>(result.values());
    }

    private static List<DBSEntity> extendWithReferences(@NotNull DBRProgressMonitor monitor, @NotNull DBSEntity entity) {
        ArrayList<DBSEntity> result = new ArrayList<DBSEntity>();
        DBVUtils.getAllReferences((DBRProgressMonitor)monitor, (DBSEntity)entity).stream().map(DBSEntityAssociation::getAssociatedEntity).filter(Objects::nonNull).forEach(result::add);
        result.add(entity);
        return result;
    }
}

