/*
 * Decompiled with CFR 0.152.
 */
package org.opensearch.knn.index.query.exactsearch;

import java.io.IOException;
import org.apache.lucene.search.DocIdSetIterator;
import org.opensearch.common.Nullable;
import org.opensearch.knn.index.SpaceType;
import org.opensearch.knn.index.query.SegmentLevelQuantizationInfo;
import org.opensearch.knn.index.query.SegmentLevelQuantizationUtil;
import org.opensearch.knn.index.query.exactsearch.ExactKNNIterator;
import org.opensearch.knn.index.vectorvalues.KNNFloatVectorValues;
import org.opensearch.knn.plugin.script.KNNScoringUtil;
import org.opensearch.knn.quantization.models.quantizationParams.QuantizationParams;
import org.opensearch.knn.quantization.models.quantizationParams.ScalarQuantizationParams;

class VectorIdsExactKNNIterator
implements ExactKNNIterator {
    protected final DocIdSetIterator filterIdsIterator;
    protected final float[] queryVector;
    private final byte[] quantizedQueryVector;
    protected final KNNFloatVectorValues knnFloatVectorValues;
    protected final SpaceType spaceType;
    protected float currentScore = Float.NEGATIVE_INFINITY;
    protected int docId;
    private final SegmentLevelQuantizationInfo segmentLevelQuantizationInfo;

    public VectorIdsExactKNNIterator(@Nullable DocIdSetIterator filterIdsIterator, float[] queryVector, KNNFloatVectorValues knnFloatVectorValues, SpaceType spaceType) throws IOException {
        this(filterIdsIterator, queryVector, knnFloatVectorValues, spaceType, null, null);
    }

    public VectorIdsExactKNNIterator(float[] queryVector, KNNFloatVectorValues knnFloatVectorValues, SpaceType spaceType) throws IOException {
        this(null, queryVector, knnFloatVectorValues, spaceType, null, null);
    }

    public VectorIdsExactKNNIterator(@Nullable DocIdSetIterator filterIdsIterator, float[] queryVector, KNNFloatVectorValues knnFloatVectorValues, SpaceType spaceType, byte[] quantizedQueryVector, SegmentLevelQuantizationInfo segmentLevelQuantizationInfo) throws IOException {
        this.filterIdsIterator = filterIdsIterator;
        this.queryVector = queryVector;
        this.knnFloatVectorValues = knnFloatVectorValues;
        this.spaceType = spaceType;
        this.docId = this.getNextDocId();
        this.quantizedQueryVector = quantizedQueryVector;
        this.segmentLevelQuantizationInfo = segmentLevelQuantizationInfo;
    }

    @Override
    public int nextDoc() throws IOException {
        if (this.docId == Integer.MAX_VALUE) {
            return Integer.MAX_VALUE;
        }
        this.currentScore = this.computeScore();
        int currentDocId = this.docId;
        this.docId = this.getNextDocId();
        return currentDocId;
    }

    @Override
    public float score() {
        return this.currentScore;
    }

    protected float computeScore() throws IOException {
        float[] vector = this.knnFloatVectorValues.getVector();
        if (this.segmentLevelQuantizationInfo == null) {
            return this.spaceType.getKnnVectorSimilarityFunction().compare(this.queryVector, vector);
        }
        byte[] quantizedVector = SegmentLevelQuantizationUtil.quantizeVector(vector, this.segmentLevelQuantizationInfo);
        if (this.quantizedQueryVector == null) {
            return this.scoreWithADC(this.queryVector, quantizedVector, this.spaceType);
        }
        return SpaceType.HAMMING.getKnnVectorSimilarityFunction().compare(this.quantizedQueryVector, quantizedVector);
    }

    protected int getNextDocId() throws IOException {
        if (this.filterIdsIterator == null) {
            return this.knnFloatVectorValues.nextDoc();
        }
        int nextDocID = this.filterIdsIterator.nextDoc();
        if (nextDocID != Integer.MAX_VALUE) {
            this.knnFloatVectorValues.advance(nextDocID);
        }
        return nextDocID;
    }

    protected boolean shouldScoreWithADC(SegmentLevelQuantizationInfo segmentLevelQuantizationInfo) {
        if (segmentLevelQuantizationInfo == null) {
            return false;
        }
        QuantizationParams quantizationParams = segmentLevelQuantizationInfo.getQuantizationParams();
        if (quantizationParams instanceof ScalarQuantizationParams) {
            ScalarQuantizationParams scalarQuantizationParams = (ScalarQuantizationParams)quantizationParams;
            return scalarQuantizationParams.isEnableADC();
        }
        return false;
    }

    protected float scoreWithADC(float[] queryVector, byte[] documentVector, SpaceType spaceType) {
        if (spaceType.equals((Object)SpaceType.L2)) {
            return SpaceType.L2.scoreTranslation(KNNScoringUtil.l2SquaredADC(queryVector, documentVector));
        }
        if (spaceType.equals((Object)SpaceType.INNER_PRODUCT)) {
            return SpaceType.INNER_PRODUCT.scoreTranslation(-1.0f * KNNScoringUtil.innerProductADC(queryVector, documentVector));
        }
        if (spaceType.equals((Object)SpaceType.COSINESIMIL)) {
            return SpaceType.COSINESIMIL.scoreTranslation(1.0f - KNNScoringUtil.innerProductADC(queryVector, documentVector));
        }
        throw new UnsupportedOperationException("Space type " + spaceType.getValue() + " is not supported for ADC");
    }
}

