/*
 * Decompiled with CFR 0.152.
 */
package com.nvidia.cuvs.internal;

import com.nvidia.cuvs.CagraIndexParams;
import com.nvidia.cuvs.CagraSearchParams;
import com.nvidia.cuvs.CuVSMatrix;
import com.nvidia.cuvs.CuVSResources;
import com.nvidia.cuvs.SearchResults;
import com.nvidia.cuvs.TieredIndex;
import com.nvidia.cuvs.TieredIndexParams;
import com.nvidia.cuvs.TieredIndexQuery;
import com.nvidia.cuvs.internal.CuVSMatrixBaseImpl;
import com.nvidia.cuvs.internal.CuVSParamsHelper;
import com.nvidia.cuvs.internal.TieredSearchResultsImpl;
import com.nvidia.cuvs.internal.common.CloseableHandle;
import com.nvidia.cuvs.internal.common.LinkerHelper;
import com.nvidia.cuvs.internal.common.Util;
import com.nvidia.cuvs.internal.panama.cuvsCagraIndexParams;
import com.nvidia.cuvs.internal.panama.cuvsCagraSearchParams;
import com.nvidia.cuvs.internal.panama.cuvsFilter;
import com.nvidia.cuvs.internal.panama.cuvsTieredIndexParams;
import com.nvidia.cuvs.internal.panama.headers_h;
import java.io.InputStream;
import java.lang.foreign.Arena;
import java.lang.foreign.MemoryLayout;
import java.lang.foreign.MemorySegment;
import java.lang.foreign.SequenceLayout;
import java.util.BitSet;
import java.util.Objects;

public class TieredIndexImpl
implements TieredIndex {
    private final CuVSMatrix dataset;
    private final CuVSResources resources;
    private final TieredIndexParams tieredIndexParameters;
    private final IndexReference tieredIndexReference;
    private boolean destroyed;

    private TieredIndexImpl(TieredIndexParams indexParameters, CuVSMatrix dataset, CuVSResources resources) {
        this.tieredIndexParameters = indexParameters;
        this.dataset = dataset;
        this.resources = resources;
        this.tieredIndexReference = this.build();
        this.destroyed = false;
    }

    private TieredIndexImpl(InputStream inputStream, CuVSResources resources) {
        throw new UnsupportedOperationException("Deserialization of TieredIndex is not yet supported");
    }

    private void checkNotDestroyed() {
        if (this.destroyed) {
            throw new IllegalStateException("destroyed");
        }
    }

    @Override
    public void destroyIndex() {
        this.checkNotDestroyed();
        try {
            int returnValue = headers_h.cuvsTieredIndexDestroy(this.tieredIndexReference.getMemorySegment());
            Util.checkCuVSError(returnValue, "cuvsTieredIndexDestroy");
            if (this.dataset != null) {
                this.dataset.close();
            }
        }
        finally {
            this.destroyed = true;
        }
    }

    /*
     * Exception decompiling
     */
    private IndexReference build() {
        /*
         * This method has failed to decompile.  When submitting a bug report, please provide this stack trace, and (if you hold appropriate legal rights) the relevant class file.
         * 
         * org.benf.cfr.reader.util.ConfusedCFRException: Started 2 blocks at once
         *     at org.benf.cfr.reader.bytecode.analysis.opgraph.Op04StructuredStatement.getStartingBlocks(Op04StructuredStatement.java:412)
         *     at org.benf.cfr.reader.bytecode.analysis.opgraph.Op04StructuredStatement.buildNestedBlocks(Op04StructuredStatement.java:487)
         *     at org.benf.cfr.reader.bytecode.analysis.opgraph.Op03SimpleStatement.createInitialStructuredBlock(Op03SimpleStatement.java:736)
         *     at org.benf.cfr.reader.bytecode.CodeAnalyser.getAnalysisInner(CodeAnalyser.java:850)
         *     at org.benf.cfr.reader.bytecode.CodeAnalyser.getAnalysisOrWrapFail(CodeAnalyser.java:278)
         *     at org.benf.cfr.reader.bytecode.CodeAnalyser.getAnalysis(CodeAnalyser.java:201)
         *     at org.benf.cfr.reader.entities.attributes.AttributeCode.analyse(AttributeCode.java:94)
         *     at org.benf.cfr.reader.entities.Method.analyse(Method.java:531)
         *     at org.benf.cfr.reader.entities.ClassFile.analyseMid(ClassFile.java:1055)
         *     at org.benf.cfr.reader.entities.ClassFile.analyseTop(ClassFile.java:942)
         *     at org.benf.cfr.reader.Driver.doJarVersionTypes(Driver.java:257)
         *     at org.benf.cfr.reader.Driver.doJar(Driver.java:139)
         *     at org.benf.cfr.reader.CfrDriverImpl.analyse(CfrDriverImpl.java:76)
         *     at org.benf.cfr.reader.Main.main(Main.java:54)
         */
        throw new IllegalStateException("Decompilation failed");
    }

    @Override
    public SearchResults search(TieredIndexQuery query) throws Throwable {
        try (Arena localArena = Arena.ofConfined();){
            this.checkNotDestroyed();
            int topK = query.getMapping() != null ? Math.min(query.getMapping().size(), query.getTopK()) : query.getTopK();
            long numQueries = query.getQueryVectors().length;
            long numBlocks = (long)topK * numQueries;
            int vectorDimension = numQueries > 0L ? query.getQueryVectors()[0].length : 0;
            SequenceLayout neighborsLayout = MemoryLayout.sequenceLayout(numBlocks, LinkerHelper.C_LONG);
            SequenceLayout distancesLayout = MemoryLayout.sequenceLayout(numBlocks, LinkerHelper.C_FLOAT);
            MemorySegment neighborsSeg = localArena.allocate(neighborsLayout);
            MemorySegment distancesSeg = localArena.allocate(distancesLayout);
            MemorySegment hostQueriesSeg = Util.buildMemorySegment(localArena, query.getQueryVectors());
            CuVSResources.ScopedAccess resourceAccess = this.resources.access();
            try {
                MemorySegment prefilterDP;
                long prefilterBytes;
                long cuvsRes = resourceAccess.handle();
                long queriesBytes = LinkerHelper.C_FLOAT_BYTE_SIZE * numQueries * (long)vectorDimension;
                long neighborsBytes = LinkerHelper.C_LONG_BYTE_SIZE * numQueries * (long)topK;
                long distancesBytes = LinkerHelper.C_FLOAT_BYTE_SIZE * numQueries * (long)topK;
                MemorySegment queriesDP = Util.allocateRMMSegment(cuvsRes, queriesBytes);
                MemorySegment neighborsDP = Util.allocateRMMSegment(cuvsRes, neighborsBytes);
                MemorySegment distancesDP = Util.allocateRMMSegment(cuvsRes, distancesBytes);
                int returnValue = headers_h.cudaMemcpy(queriesDP, hostQueriesSeg, queriesBytes, headers_h.cudaMemcpyHostToDevice());
                Util.checkCudaError(returnValue, "cudaMemcpy");
                long[] queriesShape = new long[]{numQueries, vectorDimension};
                MemorySegment queriesTensor = Util.prepareTensor(localArena, queriesDP, queriesShape, headers_h.kDLFloat(), 32, headers_h.kDLCUDA(), 1);
                long[] neighborsShape = new long[]{numQueries, topK};
                MemorySegment neighborsTensor = Util.prepareTensor(localArena, neighborsDP, neighborsShape, headers_h.kDLInt(), 64, headers_h.kDLCUDA(), 1);
                long[] distancesShape = new long[]{numQueries, topK};
                MemorySegment distancesTensor = Util.prepareTensor(localArena, distancesDP, distancesShape, headers_h.kDLFloat(), 32, headers_h.kDLCUDA(), 1);
                returnValue = headers_h.cuvsStreamSync(cuvsRes);
                Util.checkCuVSError(returnValue, "cuvsStreamSync");
                MemorySegment prefilter = cuvsFilter.allocate(localArena);
                if (query.getPrefilter() != null) {
                    BitSet[] prefilters = new BitSet[]{query.getPrefilter()};
                    BitSet concatenatedFilters = Util.concatenate(prefilters, (int)query.getNumDocs());
                    long[] filters = concatenatedFilters.toLongArray();
                    MemorySegment hostPrefilterSeg = Util.buildMemorySegment(localArena, filters);
                    long prefilterDataLength = query.getNumDocs() * (long)prefilters.length;
                    long[] prefilterShape = new long[]{(prefilterDataLength + 31L) / 32L};
                    long prefilterLen = prefilterShape[0];
                    prefilterBytes = LinkerHelper.C_INT_BYTE_SIZE * prefilterLen;
                    prefilterDP = Util.allocateRMMSegment(cuvsRes, prefilterBytes);
                    Util.checkCudaError(headers_h.cudaMemcpy(prefilterDP, hostPrefilterSeg, prefilterBytes, headers_h.cudaMemcpyHostToDevice()), "cudaMemcpy");
                    MemorySegment prefilterTensor = Util.prepareTensor(localArena, prefilterDP, prefilterShape, headers_h.kDLUInt(), 32, headers_h.kDLCUDA(), 1);
                    cuvsFilter.type(prefilter, 1);
                    cuvsFilter.addr(prefilter, prefilterTensor.address());
                } else {
                    prefilterDP = MemorySegment.NULL;
                    prefilterBytes = 0L;
                    cuvsFilter.type(prefilter, 0);
                    cuvsFilter.addr(prefilter, 0L);
                }
                returnValue = headers_h.cuvsTieredIndexSearch(cuvsRes, this.segmentFromSearchParams(query.getCagraSearchParameters(), localArena), this.tieredIndexReference.getMemorySegment(), queriesTensor, neighborsTensor, distancesTensor, prefilter);
                Util.checkCuVSError(returnValue, "cuvsTieredIndexSearch");
                returnValue = headers_h.cudaMemcpy(neighborsSeg, neighborsDP, neighborsBytes, headers_h.cudaMemcpyDeviceToHost());
                Util.checkCudaError(returnValue, "cudaMemcpy");
                returnValue = headers_h.cudaMemcpy(distancesSeg, distancesDP, distancesBytes, headers_h.cudaMemcpyDeviceToHost());
                Util.checkCudaError(returnValue, "cudaMemcpy");
                returnValue = headers_h.cuvsRMMFree(cuvsRes, queriesDP, queriesBytes);
                Util.checkCuVSError(returnValue, "cuvsRMMFree");
                returnValue = headers_h.cuvsRMMFree(cuvsRes, neighborsDP, neighborsBytes);
                Util.checkCuVSError(returnValue, "cuvsRMMFree");
                returnValue = headers_h.cuvsRMMFree(cuvsRes, distancesDP, distancesBytes);
                Util.checkCuVSError(returnValue, "cuvsRMMFree");
                if (prefilterDP != MemorySegment.NULL) {
                    returnValue = headers_h.cuvsRMMFree(cuvsRes, prefilterDP, prefilterBytes);
                    Util.checkCuVSError(returnValue, "cuvsRMMFree");
                }
                TieredSearchResultsImpl tieredSearchResultsImpl = TieredSearchResultsImpl.create(neighborsLayout, distancesLayout, neighborsSeg, distancesSeg, topK, query.getMapping(), numQueries);
                if (resourceAccess != null) {
                    resourceAccess.close();
                }
                return tieredSearchResultsImpl;
            }
            catch (Throwable throwable) {
                if (resourceAccess != null) {
                    try {
                        resourceAccess.close();
                    }
                    catch (Throwable throwable2) {
                        throwable.addSuppressed(throwable2);
                    }
                }
                throw throwable;
            }
        }
    }

    @Override
    public ExtendBuilder extend() {
        this.checkNotDestroyed();
        return new ExtendBuilder(this);
    }

    private void performExtend(CuVSMatrix extendDataset) {
        try (Arena localArena = Arena.ofConfined();){
            assert (extendDataset != null);
            long rows = extendDataset.size();
            long cols = extendDataset.columns();
            MemorySegment hostDataSeg = ((CuVSMatrixBaseImpl)extendDataset).memorySegment();
            try (CuVSResources.ScopedAccess resourceAccess = this.resources.access();){
                long cuvsRes = resourceAccess.handle();
                long dataSize = LinkerHelper.C_FLOAT_BYTE_SIZE * rows * cols;
                MemorySegment datasetDP = Util.allocateRMMSegment(cuvsRes, dataSize);
                Util.checkCudaError(headers_h.cudaMemcpy(datasetDP, hostDataSeg, dataSize, headers_h.cudaMemcpyHostToDevice()), "cudaMemcpy");
                long[] datasetShape = new long[]{rows, cols};
                MemorySegment datasetTensor = Util.prepareTensor(localArena, datasetDP, datasetShape, headers_h.kDLFloat(), 32, headers_h.kDLCUDA(), 1);
                Util.checkCuVSError(headers_h.cuvsStreamSync(cuvsRes), "cuvsStreamSync");
                Util.checkCuVSError(headers_h.cuvsTieredIndexExtend(cuvsRes, datasetTensor, this.tieredIndexReference.getMemorySegment()), "cuvsTieredIndexExtend");
                Util.checkCuVSError(headers_h.cuvsRMMFree(cuvsRes, datasetDP, dataSize), "cuvsRMMFree");
            }
        }
    }

    private static CloseableHandle segmentFromIndexParams(Arena arena, TieredIndexParams params) {
        int metric;
        CloseableHandle paramsHandle = CuVSParamsHelper.createTieredIndexParams();
        MemorySegment seg = paramsHandle.handle();
        if (params.getCagraParams() != null) {
            metric = params.getCagraParams().getCuvsDistanceType().value;
        } else {
            metric = switch (params.getMetric()) {
                case TieredIndexParams.Metric.L2 -> 0;
                case TieredIndexParams.Metric.INNER_PRODUCT -> 1;
                default -> throw new IllegalArgumentException("Unsupported metric: " + String.valueOf((Object)params.getMetric()));
            };
        }
        cuvsTieredIndexParams.metric(seg, metric);
        int algo = 0;
        cuvsTieredIndexParams.algo(seg, algo);
        cuvsTieredIndexParams.min_ann_rows(seg, params.getMinAnnRows());
        cuvsTieredIndexParams.create_ann_index_on_extend(seg, params.isCreateAnnIndexOnExtend());
        CagraIndexParams cagraParams = params.getCagraParams();
        if (cagraParams != null) {
            MemorySegment cagraParamsSeg = cuvsCagraIndexParams.allocate(arena);
            cuvsCagraIndexParams.intermediate_graph_degree(cagraParamsSeg, cagraParams.getIntermediateGraphDegree());
            cuvsCagraIndexParams.graph_degree(cagraParamsSeg, cagraParams.getGraphDegree());
            cuvsCagraIndexParams.build_algo(cagraParamsSeg, cagraParams.getCagraGraphBuildAlgo().value);
            cuvsCagraIndexParams.nn_descent_niter(cagraParamsSeg, cagraParams.getNNDescentNumIterations());
            cuvsCagraIndexParams.metric(cagraParamsSeg, metric);
            cuvsTieredIndexParams.cagra_params(seg, cagraParamsSeg);
        }
        cuvsTieredIndexParams.ivf_flat_params(seg, MemorySegment.NULL);
        cuvsTieredIndexParams.ivf_pq_params(seg, MemorySegment.NULL);
        return paramsHandle;
    }

    private MemorySegment segmentFromSearchParams(CagraSearchParams params, Arena arena) {
        MemorySegment seg = cuvsCagraSearchParams.allocate(arena);
        cuvsCagraSearchParams.max_queries(seg, params.getMaxQueries());
        cuvsCagraSearchParams.itopk_size(seg, params.getITopKSize());
        cuvsCagraSearchParams.max_iterations(seg, params.getMaxIterations());
        if (params.getCagraSearchAlgo() != null) {
            cuvsCagraSearchParams.algo(seg, params.getCagraSearchAlgo().value);
        }
        cuvsCagraSearchParams.team_size(seg, params.getTeamSize());
        cuvsCagraSearchParams.search_width(seg, params.getSearchWidth());
        cuvsCagraSearchParams.min_iterations(seg, params.getMinIterations());
        cuvsCagraSearchParams.thread_block_size(seg, params.getThreadBlockSize());
        if (params.getHashMapMode() != null) {
            cuvsCagraSearchParams.hashmap_mode(seg, params.getHashMapMode().value);
        }
        cuvsCagraSearchParams.hashmap_max_fill_rate(seg, params.getHashMapMaxFillRate());
        cuvsCagraSearchParams.num_random_samplings(seg, params.getNumRandomSamplings());
        cuvsCagraSearchParams.rand_xor_mask(seg, params.getRandXORMask());
        return seg;
    }

    @Override
    public CuVSResources getCuVSResources() {
        return this.resources;
    }

    @Override
    public TieredIndex.TieredIndexType getIndexType() {
        TieredIndex.TieredIndexType indexType = TieredIndex.TieredIndexType.CAGRA;
        return indexType;
    }

    public static TieredIndex.Builder newBuilder(CuVSResources cuvsResources) {
        Objects.requireNonNull(cuvsResources);
        return new Builder(cuvsResources);
    }

    public static class IndexReference {
        private final MemorySegment memorySegment;

        protected IndexReference(MemorySegment indexMemorySegment) {
            this.memorySegment = indexMemorySegment;
        }

        protected MemorySegment getMemorySegment() {
            return this.memorySegment;
        }
    }

    public static class ExtendBuilder
    implements TieredIndex.ExtendBuilder {
        private final TieredIndexImpl index;
        private CuVSMatrix dataset;

        private ExtendBuilder(TieredIndexImpl index) {
            this.index = index;
        }

        @Override
        public ExtendBuilder withDataset(float[][] vectors) {
            this.dataset = CuVSMatrix.ofArray(vectors);
            return this;
        }

        @Override
        public ExtendBuilder withDataset(CuVSMatrix dataset) {
            this.dataset = dataset;
            return this;
        }

        @Override
        public void execute() {
            if (this.dataset == null) {
                throw new IllegalArgumentException("Must provide a dataset");
            }
            this.index.performExtend(this.dataset);
        }
    }

    public static class Builder
    implements TieredIndex.Builder {
        private final CuVSResources resources;
        private CuVSMatrix dataset;
        private TieredIndexParams params;
        private TieredIndex.TieredIndexType indexType = TieredIndex.TieredIndexType.CAGRA;
        private InputStream inputStream;

        private Builder(CuVSResources resources) {
            this.resources = resources;
        }

        @Override
        public Builder from(InputStream inputStream) {
            this.inputStream = inputStream;
            return this;
        }

        @Override
        public Builder withDataset(float[][] vectors) {
            if (this.dataset != null) {
                throw new IllegalArgumentException("An input dataset can only be specified once");
            }
            if (vectors == null || vectors.length == 0 || vectors[0].length == 0) {
                throw new IllegalArgumentException("The input vectors cannot be null or empty");
            }
            this.dataset = CuVSMatrix.ofArray(vectors);
            return this;
        }

        @Override
        public Builder withDataset(CuVSMatrix dataset) {
            if (this.dataset != null) {
                throw new IllegalArgumentException("An input dataset can only be specified once");
            }
            if (dataset == null) {
                throw new IllegalArgumentException("An input dataset cannot be null");
            }
            this.dataset = dataset;
            return this;
        }

        @Override
        public Builder withIndexParams(TieredIndexParams params) {
            this.params = params;
            return this;
        }

        @Override
        public Builder withIndexType(TieredIndex.TieredIndexType indexType) {
            this.indexType = indexType;
            return this;
        }

        @Override
        public TieredIndex build() throws Throwable {
            if (this.inputStream != null) {
                return new TieredIndexImpl(this.inputStream, this.resources);
            }
            if (this.dataset == null) {
                throw new IllegalArgumentException("Must provide a dataset");
            }
            if (this.params == null) {
                throw new IllegalStateException("Index parameters must be provided");
            }
            return new TieredIndexImpl(this.params, this.dataset, this.resources);
        }
    }
}

