/*
 * Decompiled with CFR 0.152.
 */
package com.nvidia.cuvs.internal;

import com.nvidia.cuvs.CuVSResources;
import com.nvidia.cuvs.HnswIndex;
import com.nvidia.cuvs.HnswIndexParams;
import com.nvidia.cuvs.HnswQuery;
import com.nvidia.cuvs.HnswSearchParams;
import com.nvidia.cuvs.SearchResults;
import com.nvidia.cuvs.internal.CuVSResourcesImpl;
import com.nvidia.cuvs.internal.HnswSearchResults;
import com.nvidia.cuvs.internal.common.LinkerHelper;
import com.nvidia.cuvs.internal.common.Util;
import com.nvidia.cuvs.internal.panama.cuvsHnswIndex;
import com.nvidia.cuvs.internal.panama.cuvsHnswIndexParams;
import com.nvidia.cuvs.internal.panama.cuvsHnswSearchParams;
import java.io.InputStream;
import java.lang.foreign.Arena;
import java.lang.foreign.FunctionDescriptor;
import java.lang.foreign.Linker;
import java.lang.foreign.MemoryLayout;
import java.lang.foreign.MemorySegment;
import java.lang.foreign.SequenceLayout;
import java.lang.foreign.ValueLayout;
import java.lang.invoke.MethodHandle;
import java.util.Objects;

public class HnswIndexImpl
implements HnswIndex {
    private static final MethodHandle deserializeHnswIndexMethodHandle = LinkerHelper.downcallHandle("deserialize_hnsw_index", FunctionDescriptor.of(ValueLayout.ADDRESS, ValueLayout.ADDRESS, ValueLayout.ADDRESS, ValueLayout.ADDRESS, ValueLayout.ADDRESS, LinkerHelper.C_INT), new Linker.Option[0]);
    private static final MethodHandle searchHnswIndexMethodHandle = LinkerHelper.downcallHandle("search_hnsw_index", FunctionDescriptor.ofVoid(ValueLayout.ADDRESS, ValueLayout.ADDRESS, ValueLayout.ADDRESS, ValueLayout.ADDRESS, ValueLayout.ADDRESS, ValueLayout.ADDRESS, ValueLayout.ADDRESS, LinkerHelper.C_INT, LinkerHelper.C_INT, LinkerHelper.C_LONG), new Linker.Option[0]);
    private static final MethodHandle destroyHnswIndexMethodHandle = LinkerHelper.downcallHandle("destroy_hnsw_index", FunctionDescriptor.ofVoid(ValueLayout.ADDRESS, ValueLayout.ADDRESS), new Linker.Option[0]);
    private final CuVSResourcesImpl resources;
    private final HnswIndexParams hnswIndexParams;
    private final IndexReference hnswIndexReference;

    private HnswIndexImpl(InputStream inputStream, CuVSResourcesImpl resources, HnswIndexParams hnswIndexParams) throws Throwable {
        this.hnswIndexParams = hnswIndexParams;
        this.resources = resources;
        this.hnswIndexReference = this.deserialize(inputStream);
    }

    @Override
    public void destroyIndex() throws Throwable {
        try (Arena localArena = Arena.ofConfined();){
            MemorySegment returnValue = localArena.allocate(LinkerHelper.C_INT);
            destroyHnswIndexMethodHandle.invokeExact(this.hnswIndexReference.getMemorySegment(), returnValue);
            Util.checkError(returnValue.get(LinkerHelper.C_INT, 0L), "destroyHnswIndexMethodHandle");
        }
    }

    @Override
    public SearchResults search(HnswQuery query) throws Throwable {
        long numQueries = query.getQueryVectors().length;
        long numBlocks = (long)query.getTopK() * numQueries;
        int vectorDimension = numQueries > 0L ? query.getQueryVectors()[0].length : 0;
        SequenceLayout neighborsSequenceLayout = MemoryLayout.sequenceLayout(numBlocks, LinkerHelper.C_LONG);
        SequenceLayout distancesSequenceLayout = MemoryLayout.sequenceLayout(numBlocks, LinkerHelper.C_FLOAT);
        MemorySegment neighborsMemorySegment = this.resources.getArena().allocate(neighborsSequenceLayout);
        MemorySegment distancesMemorySegment = this.resources.getArena().allocate(distancesSequenceLayout);
        MemorySegment querySeg = Util.buildMemorySegment(this.resources.getArena(), query.getQueryVectors());
        try (Arena localArena = Arena.ofConfined();){
            MemorySegment returnValue = localArena.allocate(LinkerHelper.C_INT);
            searchHnswIndexMethodHandle.invokeExact(this.resources.getMemorySegment(), this.hnswIndexReference.getMemorySegment(), this.segmentFromSearchParams(query.getHnswSearchParams()), returnValue, neighborsMemorySegment, distancesMemorySegment, querySeg, query.getTopK(), vectorDimension, numQueries);
            Util.checkError(returnValue.get(LinkerHelper.C_INT, 0L), "searchHnswIndexMethodHandle");
        }
        return new HnswSearchResults(neighborsSequenceLayout, distancesSequenceLayout, neighborsMemorySegment, distancesMemorySegment, query.getTopK(), query.getMapping(), numQueries);
    }

    private IndexReference deserialize(InputStream inputStream) throws Throwable {
        return this.deserialize(inputStream, 1024);
    }

    /*
     * Exception decompiling
     */
    private IndexReference deserialize(InputStream inputStream, int bufferLength) throws Throwable {
        /*
         * This method has failed to decompile.  When submitting a bug report, please provide this stack trace, and (if you hold appropriate legal rights) the relevant class file.
         * 
         * org.benf.cfr.reader.util.ConfusedCFRException: Started 3 blocks at once
         *     at org.benf.cfr.reader.bytecode.analysis.opgraph.Op04StructuredStatement.getStartingBlocks(Op04StructuredStatement.java:412)
         *     at org.benf.cfr.reader.bytecode.analysis.opgraph.Op04StructuredStatement.buildNestedBlocks(Op04StructuredStatement.java:487)
         *     at org.benf.cfr.reader.bytecode.analysis.opgraph.Op03SimpleStatement.createInitialStructuredBlock(Op03SimpleStatement.java:736)
         *     at org.benf.cfr.reader.bytecode.CodeAnalyser.getAnalysisInner(CodeAnalyser.java:850)
         *     at org.benf.cfr.reader.bytecode.CodeAnalyser.getAnalysisOrWrapFail(CodeAnalyser.java:278)
         *     at org.benf.cfr.reader.bytecode.CodeAnalyser.getAnalysis(CodeAnalyser.java:201)
         *     at org.benf.cfr.reader.entities.attributes.AttributeCode.analyse(AttributeCode.java:94)
         *     at org.benf.cfr.reader.entities.Method.analyse(Method.java:531)
         *     at org.benf.cfr.reader.entities.ClassFile.analyseMid(ClassFile.java:1055)
         *     at org.benf.cfr.reader.entities.ClassFile.analyseTop(ClassFile.java:942)
         *     at org.benf.cfr.reader.Driver.doJarVersionTypes(Driver.java:257)
         *     at org.benf.cfr.reader.Driver.doJar(Driver.java:139)
         *     at org.benf.cfr.reader.CfrDriverImpl.analyse(CfrDriverImpl.java:76)
         *     at org.benf.cfr.reader.Main.main(Main.java:54)
         */
        throw new IllegalStateException("Decompilation failed");
    }

    private MemorySegment segmentFromIndexParams(HnswIndexParams params2) {
        MemorySegment seg = cuvsHnswIndexParams.allocate(this.resources.getArena());
        cuvsHnswIndexParams.ef_construction(seg, params2.getEfConstruction());
        cuvsHnswIndexParams.num_threads(seg, params2.getNumThreads());
        return seg;
    }

    private MemorySegment segmentFromSearchParams(HnswSearchParams params2) {
        MemorySegment seg = cuvsHnswSearchParams.allocate(this.resources.getArena());
        cuvsHnswSearchParams.ef(seg, params2.ef());
        cuvsHnswSearchParams.num_threads(seg, params2.numThreads());
        return seg;
    }

    public static HnswIndex.Builder newBuilder(CuVSResources cuvsResources) {
        Objects.requireNonNull(cuvsResources);
        if (!(cuvsResources instanceof CuVSResourcesImpl)) {
            throw new IllegalArgumentException("Unsupported " + String.valueOf(cuvsResources));
        }
        return new Builder((CuVSResourcesImpl)cuvsResources);
    }

    protected static class IndexReference {
        private final MemorySegment memorySegment;

        protected IndexReference(CuVSResourcesImpl resources) {
            this.memorySegment = cuvsHnswIndex.allocate(resources.getArena());
        }

        protected IndexReference(MemorySegment indexMemorySegment) {
            this.memorySegment = indexMemorySegment;
        }

        protected MemorySegment getMemorySegment() {
            return this.memorySegment;
        }
    }

    public static class Builder
    implements HnswIndex.Builder {
        private final CuVSResourcesImpl cuvsResources;
        private InputStream inputStream;
        private HnswIndexParams hnswIndexParams;

        public Builder(CuVSResourcesImpl cuvsResources) {
            this.cuvsResources = cuvsResources;
        }

        @Override
        public Builder from(InputStream inputStream) {
            this.inputStream = inputStream;
            return this;
        }

        @Override
        public Builder withIndexParams(HnswIndexParams hnswIndexParameters) {
            this.hnswIndexParams = hnswIndexParameters;
            return this;
        }

        @Override
        public HnswIndexImpl build() throws Throwable {
            return new HnswIndexImpl(this.inputStream, this.cuvsResources, this.hnswIndexParams);
        }
    }
}

