/*
 * Decompiled with CFR 0.152.
 */
package com.azure.cosmos.implementation.query;

import com.azure.cosmos.BridgeInternal;
import com.azure.cosmos.CosmosDiagnostics;
import com.azure.cosmos.CosmosException;
import com.azure.cosmos.implementation.ClientSideRequestStatistics;
import com.azure.cosmos.implementation.DiagnosticsClientContext;
import com.azure.cosmos.implementation.Document;
import com.azure.cosmos.implementation.DocumentClientRetryPolicy;
import com.azure.cosmos.implementation.DocumentCollection;
import com.azure.cosmos.implementation.ImplementationBridgeHelpers;
import com.azure.cosmos.implementation.QueryMetrics;
import com.azure.cosmos.implementation.RequestChargeTracker;
import com.azure.cosmos.implementation.ResourceType;
import com.azure.cosmos.implementation.RxDocumentServiceRequest;
import com.azure.cosmos.implementation.Utils;
import com.azure.cosmos.implementation.feedranges.FeedRangeEpkImpl;
import com.azure.cosmos.implementation.query.DocumentProducer;
import com.azure.cosmos.implementation.query.HybridSearchDocumentProducer;
import com.azure.cosmos.implementation.query.IDocumentQueryClient;
import com.azure.cosmos.implementation.query.IDocumentQueryExecutionComponent;
import com.azure.cosmos.implementation.query.ParallelDocumentQueryExecutionContextBase;
import com.azure.cosmos.implementation.query.PipelinedDocumentQueryParams;
import com.azure.cosmos.implementation.query.QueryInfo;
import com.azure.cosmos.implementation.query.TriFunction;
import com.azure.cosmos.implementation.query.hybridsearch.FullTextQueryStatistics;
import com.azure.cosmos.implementation.query.hybridsearch.GlobalFullTextSearchQueryStatistics;
import com.azure.cosmos.implementation.query.hybridsearch.HybridSearchQueryInfo;
import com.azure.cosmos.implementation.query.hybridsearch.HybridSearchQueryResult;
import com.azure.cosmos.models.CosmosQueryRequestOptions;
import com.azure.cosmos.models.FeedResponse;
import com.azure.cosmos.models.ModelBridgeInternal;
import com.azure.cosmos.models.SqlQuerySpec;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.Comparator;
import java.util.HashMap;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.UUID;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ConcurrentMap;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.function.Function;
import java.util.function.Supplier;
import java.util.stream.Collectors;
import org.reactivestreams.Publisher;
import reactor.core.publisher.Flux;
import reactor.core.publisher.Mono;

public class HybridSearchDocumentQueryExecutionContext
extends ParallelDocumentQueryExecutionContextBase<Document> {
    private static final ImplementationBridgeHelpers.CosmosDiagnosticsHelper.CosmosDiagnosticsAccessor diagnosticsAccessor = ImplementationBridgeHelpers.CosmosDiagnosticsHelper.getCosmosDiagnosticsAccessor();
    private static final ImplementationBridgeHelpers.FeedResponseHelper.FeedResponseAccessor feedResponseAccessor = ImplementationBridgeHelpers.FeedResponseHelper.getFeedResponseAccessor();
    private static final String FORMATTABLE_TOTAL_DOCUMENT_COUNT = "{documentdb-formattablehybridsearchquery-totaldocumentcount}";
    private static final String FORMATTABLE_TOTAL_WORD_COUNT = "{documentdb-formattablehybridsearchquery-totalwordcount-%d}";
    private static final String FORMATTABLE_HIT_COUNTS_ARRAY = "{documentdb-formattablehybridsearchquery-hitcountsarray-%d}";
    private static final String FORMATTABLE_ORDER_BY = "{documentdb-formattableorderbyquery-filter}";
    private static final String TRUE = "true";
    private static final Integer RRF_CONSTANT = 60;
    protected IDocumentQueryClient client;
    private final HybridSearchQueryInfo hybridSearchQueryInfo;
    private final RequestChargeTracker tracker;
    private final ConcurrentMap<String, QueryMetrics> queryMetricMap;
    private final Collection<ClientSideRequestStatistics> clientSideRequestStatistics;
    private Flux<HybridSearchQueryResult<Document>> hybridObservable;
    private Mono<GlobalFullTextSearchQueryStatistics> aggregatedGlobalStatistics;

    protected HybridSearchDocumentQueryExecutionContext(DiagnosticsClientContext diagnosticsClientContext, IDocumentQueryClient client, ResourceType resourceTypeEnum, SqlQuerySpec query, CosmosQueryRequestOptions cosmosQueryRequestOptions, String resourceLink, String rewrittenQuery, UUID correlatedActivityId, AtomicBoolean isQueryCancelledOnTimeout, HybridSearchQueryInfo hybridSearchQueryInfo) {
        super(diagnosticsClientContext, client, resourceTypeEnum, Document.class, query, cosmosQueryRequestOptions, resourceLink, rewrittenQuery, correlatedActivityId, Boolean.FALSE, isQueryCancelledOnTimeout);
        this.hybridSearchQueryInfo = hybridSearchQueryInfo;
        this.client = client;
        this.tracker = new RequestChargeTracker();
        this.queryMetricMap = new ConcurrentHashMap<String, QueryMetrics>();
        this.clientSideRequestStatistics = ConcurrentHashMap.newKeySet();
    }

    public static Flux<IDocumentQueryExecutionComponent<Document>> createAsync(DiagnosticsClientContext diagnosticsClientContext, IDocumentQueryClient client, PipelinedDocumentQueryParams<Document> initParams, DocumentCollection collection) {
        HybridSearchQueryInfo hybridSearchQueryInfo = initParams.getHybridSearchQueryInfo();
        HybridSearchDocumentQueryExecutionContext context = new HybridSearchDocumentQueryExecutionContext(diagnosticsClientContext, client, initParams.getResourceTypeEnum(), initParams.getQuery(), initParams.getCosmosQueryRequestOptions(), initParams.getResourceLink(), null, initParams.getCorrelatedActivityId(), initParams.isQueryCancelledOnTimeout(), hybridSearchQueryInfo);
        context.setTop(initParams.getTop());
        try {
            context.initialize(initParams.getFeedRanges(), initParams.getAllFeedRanges(), initParams.getInitialPageSize(), collection);
            return Flux.just((Object)context);
        }
        catch (CosmosException dce) {
            return Flux.error((Throwable)((Object)dce));
        }
    }

    private void initialize(List<FeedRangeEpkImpl> targetFeedRanges, List<FeedRangeEpkImpl> allFeedRanges, int initialPageSize, DocumentCollection collection) {
        if (this.hybridSearchQueryInfo.getRequiresGlobalStatistics().booleanValue()) {
            HashMap<FeedRangeEpkImpl, String> partitionKeyRangeToContinuationToken = new HashMap<FeedRangeEpkImpl, String>();
            for (FeedRangeEpkImpl feedRangeEpk : allFeedRanges) {
                partitionKeyRangeToContinuationToken.put(feedRangeEpk, null);
            }
            super.initialize(collection, partitionKeyRangeToContinuationToken, initialPageSize, new SqlQuerySpec(this.hybridSearchQueryInfo.getGlobalStatisticsQuery()));
            this.aggregatedGlobalStatistics = Flux.fromIterable((Iterable)this.documentProducers).flatMap(producer -> producer.produceAsync().map(documentProducerFeedResponse -> {
                List results = documentProducerFeedResponse.pageResult.getResults();
                return new GlobalFullTextSearchQueryStatistics((Document)results.get(0));
            })).collectList().map(this::aggregateStatistics);
        }
        this.hybridObservable = this.hybridSearch(targetFeedRanges, initialPageSize, collection);
    }

    private Flux<HybridSearchQueryResult<Document>> hybridSearch(List<FeedRangeEpkImpl> targetFeedRanges, int initialPageSize, DocumentCollection collection) {
        Flux<QueryInfo> rewrittenQueryInfos = this.retrieveRewrittenQueryInfos(this.hybridSearchQueryInfo.getComponentQueryInfoList());
        Flux<Document> componentQueryResults = this.getComponentQueryResults(targetFeedRanges, initialPageSize, collection, rewrittenQueryInfos);
        Mono<List<HybridSearchQueryResult<Document>>> coalescedAndSortedResults = this.coalesceAndSortResults(componentQueryResults);
        Mono<List<List<ScoreTuple>>> componentScoresList = HybridSearchDocumentQueryExecutionContext.retrieveComponentScores(coalescedAndSortedResults);
        Mono<List<List<Integer>>> ranks = HybridSearchDocumentQueryExecutionContext.computeRanks(componentScoresList);
        return HybridSearchDocumentQueryExecutionContext.computeRRFScores(ranks, coalescedAndSortedResults);
    }

    protected HybridSearchDocumentProducer createDocumentProducer(String collectionRid, String continuationToken, int initialPageSize, CosmosQueryRequestOptions cosmosQueryRequestOptions, SqlQuerySpec querySpecForInit, Map<String, String> commonRequestHeaders, TriFunction<FeedRangeEpkImpl, String, Integer, RxDocumentServiceRequest> createRequestFunc, Function<RxDocumentServiceRequest, Mono<FeedResponse<Document>>> executeFunc, Supplier<DocumentClientRetryPolicy> createRetryPolicyFunc, FeedRangeEpkImpl feedRange, String collectionLink) {
        return new HybridSearchDocumentProducer(this.client, collectionRid, cosmosQueryRequestOptions, createRequestFunc, executeFunc, feedRange, collectionLink, createRetryPolicyFunc, Document.class, this.correlatedActivityId, initialPageSize, continuationToken, this.top, this.getOperationContextTextProvider());
    }

    @Override
    public Flux<FeedResponse<Document>> drainAsync(int maxPageSize) {
        return this.hybridObservable.transformDeferred((Function)new HybridSearchQueryResultToPageTransformer(this.tracker, maxPageSize, this.queryMetricMap, this.clientSideRequestStatistics));
    }

    @Override
    public Flux<FeedResponse<Document>> executeAsync() {
        return this.drainAsync(ModelBridgeInternal.getMaxItemCountFromQueryRequestOptions(this.cosmosQueryRequestOptions));
    }

    private static Flux<HybridSearchQueryResult<Document>> computeRRFScores(Mono<List<List<Integer>>> ranks, Mono<List<HybridSearchQueryResult<Document>>> coalescedAndSortedResults) {
        return ranks.zipWith(coalescedAndSortedResults).map(tuple -> {
            List ranksInternal = (List)tuple.getT1();
            List results = (List)tuple.getT2();
            for (int index = 0; index < results.size(); ++index) {
                double rrfScore = 0.0;
                for (List integers : ranksInternal) {
                    rrfScore += 1.0 / (double)(RRF_CONSTANT + (Integer)integers.get(index));
                }
                ((HybridSearchQueryResult)results.get(index)).setScore(rrfScore);
            }
            results.sort(Comparator.comparing(HybridSearchQueryResult::getScore, Comparator.reverseOrder()));
            return results;
        }).flatMapMany(Flux::fromIterable);
    }

    private static Mono<List<List<Integer>>> computeRanks(Mono<List<List<ScoreTuple>>> componentScoresList) {
        return componentScoresList.map(componentScores -> {
            int index;
            int componentIndex;
            ArrayList ranksInternal = new ArrayList();
            for (componentIndex = 0; componentIndex < componentScores.size(); ++componentIndex) {
                ArrayList<Integer> row = new ArrayList<Integer>();
                for (index = 0; index < ((List)componentScores.get(0)).size(); ++index) {
                    row.add(0);
                }
                ranksInternal.add(row);
            }
            for (componentIndex = 0; componentIndex < componentScores.size(); ++componentIndex) {
                int rank = 1;
                for (index = 0; index < ((List)componentScores.get(componentIndex)).size(); ++index) {
                    if (index > 0 && ((ScoreTuple)((List)componentScores.get(componentIndex)).get(index)).getScore() < ((ScoreTuple)((List)componentScores.get(componentIndex)).get(index - 1)).getScore()) {
                        ++rank;
                    }
                    int rankIndex = ((ScoreTuple)((List)componentScores.get(componentIndex)).get(index)).getIndex();
                    ((List)ranksInternal.get(componentIndex)).set(rankIndex, rank);
                }
            }
            return ranksInternal;
        });
    }

    private static Mono<List<List<ScoreTuple>>> retrieveComponentScores(Mono<List<HybridSearchQueryResult<Document>>> coalescedAndSortedResults) {
        return coalescedAndSortedResults.map(results -> {
            int i;
            ArrayList componentScoresInternal = new ArrayList();
            for (int i2 = 0; i2 < ((HybridSearchQueryResult)results.get(0)).getComponentScores().size(); ++i2) {
                componentScoresInternal.add(new ArrayList());
            }
            ArrayList<Double> undefinedComponentScores = new ArrayList<Double>();
            for (i = 0; i < componentScoresInternal.size(); ++i) {
                undefinedComponentScores.add(-999999.0);
            }
            for (i = 0; i < results.size(); ++i) {
                void var4_6;
                List<Double> list = ((HybridSearchQueryResult)results.get(i)).getComponentScores();
                if (list.isEmpty()) {
                    ArrayList<Double> arrayList = undefinedComponentScores;
                }
                for (int j = 0; j < var4_6.size(); ++j) {
                    ScoreTuple scoreTuple = new ScoreTuple((Double)var4_6.get(j), i);
                    ((List)componentScoresInternal.get(j)).add(scoreTuple);
                }
            }
            for (List list : componentScoresInternal) {
                list.sort(Comparator.comparing(ScoreTuple::getScore, Comparator.reverseOrder()));
            }
            return componentScoresInternal;
        });
    }

    private Flux<Document> getComponentQueryResults(List<FeedRangeEpkImpl> targetFeedRanges, int initialPageSize, DocumentCollection collection, Flux<QueryInfo> rewrittenQueryInfos) {
        return rewrittenQueryInfos.flatMap(queryInfo -> {
            HashMap<FeedRangeEpkImpl, String> partitionKeyRangeToContinuationToken = new HashMap<FeedRangeEpkImpl, String>();
            for (FeedRangeEpkImpl feedRangeEpk : targetFeedRanges) {
                partitionKeyRangeToContinuationToken.put(feedRangeEpk, null);
            }
            this.documentProducers = new ArrayList();
            super.initialize(collection, partitionKeyRangeToContinuationToken, initialPageSize, new SqlQuerySpec(queryInfo.getRewrittenQuery()));
            return Flux.fromIterable((Iterable)this.documentProducers).flatMap(DocumentProducer::produceAsync).flatMap(response -> Flux.fromIterable(response.pageResult.getResults()));
        });
    }

    private Flux<QueryInfo> retrieveRewrittenQueryInfos(List<QueryInfo> componentQueryInfos) {
        return this.aggregatedGlobalStatistics.hasElement().flatMapMany(globalStatistics -> {
            if (globalStatistics != null) {
                ArrayList<Mono> rewrittenQueryInfosInternal = new ArrayList<Mono>();
                for (QueryInfo queryInfo : componentQueryInfos) {
                    assert (queryInfo.hasOrderBy());
                    assert (queryInfo.hasNonStreamingOrderBy());
                    List rewrittenOrderByExpressionList = queryInfo.getOrderByExpressions().stream().map(orderByExpression -> this.formatComponentQuery((String)orderByExpression, componentQueryInfos.size())).collect(Collectors.toList());
                    Mono rewrittenOrderByExpression = Mono.zip(rewrittenOrderByExpressionList, results -> Arrays.stream(results).map(Object::toString).collect(Collectors.toList()));
                    Mono<String> rewrittenQuery = this.formatComponentQuery(queryInfo.getRewrittenQuery(), componentQueryInfos.size());
                    Mono newQueryInfo = Mono.zip((Mono)rewrittenOrderByExpression, rewrittenQuery).map(tuple -> {
                        QueryInfo newQueryInfoInternal = new QueryInfo(queryInfo.getPropertyBag());
                        newQueryInfoInternal.setOrderByExpressions((List)tuple.getT1());
                        newQueryInfoInternal.setRewrittenQuery((String)tuple.getT2());
                        return newQueryInfoInternal;
                    });
                    rewrittenQueryInfosInternal.add(newQueryInfo);
                }
                return Flux.concat(rewrittenQueryInfosInternal);
            }
            return Flux.fromIterable((Iterable)componentQueryInfos);
        });
    }

    private Mono<List<HybridSearchQueryResult<Document>>> coalesceAndSortResults(Flux<Document> componentQueryResults) {
        return componentQueryResults.collectList().map(results -> {
            LinkedHashMap uniqueDocuments = new LinkedHashMap();
            for (Document document : results) {
                HybridSearchQueryResult result = new HybridSearchQueryResult(document.toJson());
                String rid = result.getRid();
                uniqueDocuments.putIfAbsent(rid, result);
            }
            ArrayList coalescedResults = new ArrayList(uniqueDocuments.values());
            coalescedResults.sort(Comparator.comparing(HybridSearchQueryResult::getRid));
            return coalescedResults;
        });
    }

    private Mono<String> formatComponentQuery(String orderByExpression, int componentCount) {
        return this.aggregatedGlobalStatistics.map(statistics -> {
            String query = orderByExpression.replace(FORMATTABLE_ORDER_BY, TRUE).replace(FORMATTABLE_TOTAL_DOCUMENT_COUNT, Long.toString(statistics.getDocumentCount()));
            int statisticsIndex = 0;
            for (int componentIndex = 0; componentIndex < componentCount; ++componentIndex) {
                String totalWordCountPlaceHolder = String.format(FORMATTABLE_TOTAL_WORD_COUNT, componentIndex);
                String hitCountsArrayPlaceHolder = String.format(FORMATTABLE_HIT_COUNTS_ARRAY, componentIndex);
                if (!query.contains(totalWordCountPlaceHolder)) continue;
                FullTextQueryStatistics fullTextQueryStatistics = statistics.getFullTextQueryStatistics().get(statisticsIndex);
                query = query.replace(totalWordCountPlaceHolder, Long.toString(fullTextQueryStatistics.getTotalWordCount()));
                String hit_counts_array = "[" + fullTextQueryStatistics.getHitCounts().stream().map(Object::toString).collect(Collectors.joining(",")) + "]";
                query = query.replace(hitCountsArrayPlaceHolder, hit_counts_array);
                ++statisticsIndex;
            }
            return query;
        });
    }

    private GlobalFullTextSearchQueryStatistics aggregateStatistics(List<GlobalFullTextSearchQueryStatistics> globalFullTextSearchQueryStatistics) {
        GlobalFullTextSearchQueryStatistics aggregatedStats = new GlobalFullTextSearchQueryStatistics();
        aggregatedStats.setDocumentCount(0L);
        ArrayList<FullTextQueryStatistics> aggregateFullTextQueryStatistics = new ArrayList();
        for (GlobalFullTextSearchQueryStatistics statistics : globalFullTextSearchQueryStatistics) {
            aggregatedStats.setDocumentCount(aggregatedStats.getDocumentCount() + statistics.getDocumentCount());
            if (aggregateFullTextQueryStatistics.isEmpty()) {
                aggregateFullTextQueryStatistics = statistics.getFullTextQueryStatistics();
            } else {
                assert (statistics.getFullTextQueryStatistics().size() == aggregateFullTextQueryStatistics.size());
                for (int i = 0; i < statistics.getFullTextQueryStatistics().size(); ++i) {
                    assert (statistics.getFullTextQueryStatistics().get(i).getHitCounts().size() == ((FullTextQueryStatistics)aggregateFullTextQueryStatistics.get(i)).getHitCounts().size());
                    ((FullTextQueryStatistics)aggregateFullTextQueryStatistics.get(i)).setTotalWordCount(((FullTextQueryStatistics)aggregateFullTextQueryStatistics.get(i)).getTotalWordCount() + statistics.getFullTextQueryStatistics().get(i).getTotalWordCount());
                    for (int j = 0; j < statistics.getFullTextQueryStatistics().get(i).getHitCounts().size(); ++j) {
                        ((FullTextQueryStatistics)aggregateFullTextQueryStatistics.get(i)).getHitCounts().set(j, ((FullTextQueryStatistics)aggregateFullTextQueryStatistics.get(i)).getHitCounts().get(j) + statistics.getFullTextQueryStatistics().get(i).getHitCounts().get(j));
                    }
                }
            }
            aggregatedStats.setFullTextQueryStatistics(aggregateFullTextQueryStatistics);
        }
        return aggregatedStats;
    }

    private static class HybridSearchQueryResultToPageTransformer
    implements Function<Flux<HybridSearchQueryResult<Document>>, Flux<FeedResponse<Document>>> {
        private static final int DEFAULT_PAGE_SIZE = 100;
        private final RequestChargeTracker tracker;
        private final int maxPageSize;
        private final ConcurrentMap<String, QueryMetrics> queryMetricMap;
        private final Collection<ClientSideRequestStatistics> clientSideRequestStatistics;

        public HybridSearchQueryResultToPageTransformer(RequestChargeTracker tracker, int maxPageSize, ConcurrentMap<String, QueryMetrics> queryMetricsMap, Collection<ClientSideRequestStatistics> clientSideRequestStatistics) {
            this.tracker = tracker;
            this.maxPageSize = maxPageSize > 0 ? maxPageSize : 100;
            this.queryMetricMap = queryMetricsMap;
            this.clientSideRequestStatistics = clientSideRequestStatistics;
        }

        private static Map<String, String> headerResponse(double requestCharge) {
            return Utils.immutableMapOf("x-ms-request-charge", String.valueOf(requestCharge));
        }

        @Override
        public Flux<FeedResponse<Document>> apply(Flux<HybridSearchQueryResult<Document>> source) {
            return source.window(this.maxPageSize).map(Flux::collectList).flatMap(resultListObs -> resultListObs, 1).map(hybridSearchQueryResults -> {
                FeedResponse feedResponse = feedResponseAccessor.createFeedResponse(hybridSearchQueryResults, HybridSearchQueryResultToPageTransformer.headerResponse(this.tracker.getAndResetCharge()), (CosmosDiagnostics)null);
                if (!this.queryMetricMap.isEmpty()) {
                    for (Map.Entry entry : this.queryMetricMap.entrySet()) {
                        BridgeInternal.putQueryMetricsIntoMap(feedResponse, (String)entry.getKey(), (QueryMetrics)entry.getValue());
                    }
                }
                return feedResponse;
            }).map(feedOfHybridSearchQueryResults -> {
                ArrayList<Document> unwrappedResults = new ArrayList<Document>();
                for (HybridSearchQueryResult hybridSearchQueryResult : feedOfHybridSearchQueryResults.getResults()) {
                    unwrappedResults.add(hybridSearchQueryResult.getPayload());
                }
                FeedResponse feedResponse = BridgeInternal.createFeedResponseWithQueryMetrics(unwrappedResults, feedOfHybridSearchQueryResults.getResponseHeaders(), BridgeInternal.queryMetricsFromFeedResponse(feedOfHybridSearchQueryResults), ModelBridgeInternal.getQueryPlanDiagnosticsContext(feedOfHybridSearchQueryResults), false, false, feedOfHybridSearchQueryResults.getCosmosDiagnostics());
                diagnosticsAccessor.addClientSideDiagnosticsToFeed(feedResponse.getCosmosDiagnostics(), this.clientSideRequestStatistics);
                return feedResponse;
            }).switchIfEmpty((Publisher)Flux.defer(() -> {
                FeedResponse frp = BridgeInternal.createFeedResponseWithQueryMetrics(Utils.immutableListOf(), HybridSearchQueryResultToPageTransformer.headerResponse(this.tracker.getAndResetCharge()), this.queryMetricMap, null, false, false, null);
                diagnosticsAccessor.addClientSideDiagnosticsToFeed(frp.getCosmosDiagnostics(), this.clientSideRequestStatistics);
                return Flux.just(frp);
            }));
        }
    }

    public static class ScoreTuple {
        private final Double score;
        private final Integer index;

        public ScoreTuple(Double score, Integer index) {
            this.score = Objects.requireNonNull(score);
            this.index = Objects.requireNonNull(index);
        }

        public Integer getIndex() {
            return this.index;
        }

        public Double getScore() {
            return this.score;
        }
    }
}

