/*
 * Decompiled with CFR 0.152.
 */
package com.dbeaver.model.ai.rag.storage.pg;

import com.dbeaver.model.ai.RAGDatasourceKey;
import com.dbeaver.model.ai.RAGEmbeddedRecord;
import com.dbeaver.model.ai.RAGObjectKey;
import com.dbeaver.model.ai.RAGPrefixRelevantRecordFilter;
import com.dbeaver.model.ai.RAGRelevantRecordFilter;
import com.dbeaver.model.ai.rag.storage.RAGAbstractEmbeddingStorage;
import com.dbeaver.model.ai.rag.storage.RAGDatabase;
import java.sql.Connection;
import java.sql.PreparedStatement;
import java.sql.ResultSet;
import java.sql.SQLException;
import java.sql.Timestamp;
import java.time.Instant;
import java.util.ArrayList;
import java.util.Collection;
import java.util.List;
import org.jkiss.code.NotNull;
import org.jkiss.dbeaver.DBException;

public class RAGPostgresEmbeddingStorage
extends RAGAbstractEmbeddingStorage {
    public RAGPostgresEmbeddingStorage(@NotNull RAGDatabase ragDatabase) {
        super(ragDatabase);
    }

    @Override
    public void save(@NotNull Collection<RAGEmbeddedRecord> records) throws DBException {
        String sql = "INSERT INTO {table_prefix}ai_object_index (object_key, embedding, checksum, model_name)\nVALUES (?, ?::vector, ?, ?)\nON CONFLICT(object_key) DO UPDATE SET\n    embedding=excluded.embedding,\n    checksum=excluded.checksum,\n    model_name=excluded.model_name\n";
        try {
            Throwable throwable = null;
            Object var4_6 = null;
            try (Connection connection = this.getConnection();){
                Throwable throwable2 = null;
                Object var7_11 = null;
                try (PreparedStatement preparedStatement = connection.prepareStatement(sql);){
                    connection.setAutoCommit(false);
                    int count = 0;
                    for (RAGEmbeddedRecord record : records) {
                        preparedStatement.setString(1, record.key().asString());
                        preparedStatement.setString(2, RAGPostgresEmbeddingStorage.toPgVectorLiteral(record.embedding()));
                        preparedStatement.setLong(3, record.checksum());
                        preparedStatement.setString(4, record.modelName());
                        preparedStatement.addBatch();
                        if (++count % 1000 != 0) continue;
                        preparedStatement.executeBatch();
                    }
                    preparedStatement.executeBatch();
                    connection.commit();
                }
                catch (Throwable throwable3) {
                    if (throwable2 == null) {
                        throwable2 = throwable3;
                    } else if (throwable2 != throwable3) {
                        throwable2.addSuppressed(throwable3);
                    }
                    throw throwable2;
                }
            }
            catch (Throwable throwable4) {
                if (throwable == null) {
                    throwable = throwable4;
                } else if (throwable != throwable4) {
                    throwable.addSuppressed(throwable4);
                }
                throw throwable;
            }
        }
        catch (SQLException e) {
            throw new DBException("Error saving RAG embedded records", (Throwable)e);
        }
    }

    @Override
    @NotNull
    protected List<RAGEmbeddedRecord> findRelevantRecords0(@NotNull RAGRelevantRecordFilter filter) throws DBException {
        if (filter.objectKeys().isEmpty()) {
            return List.of();
        }
        Object[] objectKeys = (String[])filter.objectKeys().stream().map(RAGObjectKey::asString).toArray(String[]::new);
        try {
            Throwable throwable = null;
            Object var4_6 = null;
            try (Connection connection = this.getConnection();){
                PreparedStatement preparedStatement = connection.prepareStatement("SELECT object_key, checksum, model_name, embedding\nFROM {table_prefix}ai_object_index\nWHERE object_key = ANY(?)\nORDER BY embedding <-> ?::vector\nLIMIT ?\n");
                preparedStatement.setArray(1, connection.createArrayOf("text", objectKeys));
                preparedStatement.setString(2, RAGPostgresEmbeddingStorage.toPgVectorLiteral(filter.query()));
                preparedStatement.setInt(3, filter.topK());
                ResultSet resultSet = preparedStatement.executeQuery();
                return this.extractRecords(resultSet);
            }
            catch (Throwable throwable2) {
                if (throwable == null) {
                    throwable = throwable2;
                } else if (throwable != throwable2) {
                    throwable.addSuppressed(throwable2);
                }
                throw throwable;
            }
        }
        catch (SQLException e) {
            throw new DBException("Error finding relevant records", (Throwable)e);
        }
    }

    @Override
    @NotNull
    protected List<RAGEmbeddedRecord> findRelevantRecordsByPrefix0(@NotNull RAGPrefixRelevantRecordFilter filter) throws DBException {
        try {
            Throwable throwable = null;
            Object var3_5 = null;
            try (Connection connection = this.getConnection();){
                PreparedStatement preparedStatement = connection.prepareStatement("SELECT object_key, checksum, model_name, embedding\nFROM {table_prefix}ai_object_index\nWHERE object_key LIKE ?\nORDER BY embedding <-> ?::vector\nLIMIT ?\n");
                preparedStatement.setString(1, filter.objectKey().asString() + "/%");
                preparedStatement.setString(2, RAGPostgresEmbeddingStorage.toPgVectorLiteral(filter.query()));
                preparedStatement.setInt(3, filter.topK());
                ResultSet resultSet = preparedStatement.executeQuery();
                return this.extractRecords(resultSet);
            }
            catch (Throwable throwable2) {
                if (throwable == null) {
                    throwable = throwable2;
                } else if (throwable != throwable2) {
                    throwable.addSuppressed(throwable2);
                }
                throw throwable;
            }
        }
        catch (SQLException e) {
            throw new DBException("Error finding relevant records by prefix", (Throwable)e);
        }
    }

    @Override
    @NotNull
    public List<RAGDatasourceKey> findStaleDatasources(@NotNull Instant olderThan) throws DBException {
        try {
            Throwable throwable = null;
            Object var3_5 = null;
            try (Connection connection = this.getConnection();){
                PreparedStatement preparedStatement = connection.prepareStatement("SELECT datasource_key\nFROM {table_prefix}ai_datasource_stats\nWHERE last_query < ?\n");
                preparedStatement.setTimestamp(1, Timestamp.from(olderThan));
                ResultSet resultSet = preparedStatement.executeQuery();
                ArrayList<RAGDatasourceKey> results = new ArrayList<RAGDatasourceKey>();
                while (resultSet.next()) {
                    results.add(RAGDatasourceKey.fromString((String)resultSet.getString("datasource_key")));
                }
                return results;
            }
            catch (Throwable throwable2) {
                if (throwable == null) {
                    throwable = throwable2;
                } else if (throwable != throwable2) {
                    throwable.addSuppressed(throwable2);
                }
                throw throwable;
            }
        }
        catch (SQLException e) {
            throw new DBException("Error querying stale datasources", (Throwable)e);
        }
    }

    /*
     * Enabled force condition propagation
     * Lifted jumps to return sites
     */
    @Override
    protected void updateDatasourceStats(@NotNull RAGDatasourceKey datasourceKey) throws DBException {
        try {
            Throwable throwable = null;
            Object var3_5 = null;
            try {
                Connection connection = this.getConnection();
                try {
                    try (PreparedStatement stmt = connection.prepareStatement("INSERT INTO {table_prefix}ai_datasource_stats (datasource_key, last_query)\n   VALUES (?, NOW())\n   ON CONFLICT (datasource_key) DO UPDATE\n     SET last_query = EXCLUDED.last_query\n");){
                        stmt.setString(1, datasourceKey.asString());
                        stmt.execute();
                    }
                    if (connection == null) return;
                }
                catch (Throwable throwable2) {
                    if (throwable == null) {
                        throwable = throwable2;
                    } else if (throwable != throwable2) {
                        throwable.addSuppressed(throwable2);
                    }
                    if (connection == null) throw throwable;
                    connection.close();
                    throw throwable;
                }
                connection.close();
                return;
            }
            catch (Throwable throwable3) {
                if (throwable == null) {
                    throwable = throwable3;
                    throw throwable;
                } else {
                    if (throwable == throwable3) throw throwable;
                    throwable.addSuppressed(throwable3);
                }
                throw throwable;
            }
        }
        catch (SQLException e) {
            throw new DBException("Error updating datasource stats", (Throwable)e);
        }
    }

    @NotNull
    private static String toPgVectorLiteral(@NotNull float[] v) {
        float[] normalizedVector = RAGPostgresEmbeddingStorage.normalizeEmbedding(v);
        StringBuilder sb = new StringBuilder(normalizedVector.length * 8);
        sb.append('[');
        int i = 0;
        while (i < normalizedVector.length) {
            if (!Float.isFinite(normalizedVector[i])) {
                throw new IllegalArgumentException("Non-finite value at " + i);
            }
            if (i > 0) {
                sb.append(',');
            }
            sb.append(normalizedVector[i]);
            ++i;
        }
        sb.append(']');
        return sb.toString();
    }

    @NotNull
    private static float[] fromPgVectorLiteral(@NotNull String s) {
        if (s.length() < 2 || s.charAt(0) != '[' || s.charAt(s.length() - 1) != ']') {
            throw new IllegalArgumentException("Invalid vector literal: " + s);
        }
        String[] parts = s.substring(1, s.length() - 1).split(",");
        float[] result = new float[parts.length];
        int i = 0;
        while (i < parts.length) {
            result[i] = Float.parseFloat(parts[i]);
            ++i;
        }
        return result;
    }

    @Override
    @NotNull
    protected RAGEmbeddedRecord extractRecord(@NotNull ResultSet resultSet) throws SQLException {
        return new RAGEmbeddedRecord(RAGObjectKey.fromString((String)resultSet.getString("object_key")), resultSet.getLong("checksum"), resultSet.getString("model_name"), RAGPostgresEmbeddingStorage.fromPgVectorLiteral(resultSet.getString("embedding")));
    }
}

