/*
 * Decompiled with CFR 0.152.
 */
package com.nvidia.cuvs.internal;

import com.nvidia.cuvs.BruteForceIndex;
import com.nvidia.cuvs.BruteForceIndexParams;
import com.nvidia.cuvs.BruteForceQuery;
import com.nvidia.cuvs.CuVSMatrix;
import com.nvidia.cuvs.CuVSResources;
import com.nvidia.cuvs.SearchResults;
import com.nvidia.cuvs.internal.BruteForceSearchResults;
import com.nvidia.cuvs.internal.CuVSMatrixBaseImpl;
import com.nvidia.cuvs.internal.common.LinkerHelper;
import com.nvidia.cuvs.internal.common.Util;
import com.nvidia.cuvs.internal.panama.cuvsFilter;
import com.nvidia.cuvs.internal.panama.headers_h;
import java.io.InputStream;
import java.io.OutputStream;
import java.lang.foreign.Arena;
import java.lang.foreign.MemoryLayout;
import java.lang.foreign.MemorySegment;
import java.lang.foreign.SequenceLayout;
import java.nio.file.Files;
import java.nio.file.OpenOption;
import java.nio.file.Path;
import java.nio.file.attribute.FileAttribute;
import java.util.BitSet;
import java.util.Objects;
import java.util.UUID;

public class BruteForceIndexImpl
implements BruteForceIndex {
    private final CuVSResources resources;
    private final IndexReference bruteForceIndexReference;
    private boolean destroyed;

    private BruteForceIndexImpl(CuVSMatrix dataset, CuVSResources resources, BruteForceIndexParams bruteForceIndexParams) throws Exception {
        Objects.requireNonNull(dataset);
        try (CuVSMatrix cuVSMatrix = dataset;){
            this.resources = resources;
            assert (dataset instanceof CuVSMatrixBaseImpl);
            this.bruteForceIndexReference = this.build((CuVSMatrixBaseImpl)dataset, bruteForceIndexParams);
        }
    }

    private BruteForceIndexImpl(InputStream inputStream, CuVSResources resources) throws Throwable {
        this.resources = resources;
        this.bruteForceIndexReference = this.deserialize(inputStream);
    }

    private void checkNotDestroyed() {
        if (this.destroyed) {
            throw new IllegalStateException("destroyed");
        }
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    @Override
    public void destroyIndex() {
        this.checkNotDestroyed();
        try {
            int returnValue = headers_h.cuvsBruteForceIndexDestroy(this.bruteForceIndexReference.indexPtr);
            Util.checkCuVSError(returnValue, "cuvsBruteForceIndexDestroy");
            if (this.bruteForceIndexReference.datasetBytes > 0L) {
                try (CuVSResources.ScopedAccess resourcesAccessor = this.resources.access();){
                    Util.checkCuVSError(headers_h.cuvsRMMFree(resourcesAccessor.handle(), this.bruteForceIndexReference.datasetPtr, this.bruteForceIndexReference.datasetBytes), "cuvsRMMFree");
                }
            }
            if (this.bruteForceIndexReference.tensorDataArena != null) {
                this.bruteForceIndexReference.tensorDataArena.close();
            }
        }
        finally {
            this.destroyed = true;
        }
    }

    private IndexReference build(CuVSMatrixBaseImpl dataset, BruteForceIndexParams bruteForceIndexParams) {
        long rows = dataset.size();
        long cols = dataset.columns();
        MemorySegment datasetMemSegment = dataset.memorySegment();
        headers_h.omp_set_num_threads(bruteForceIndexParams.getNumWriterThreads());
        long datasetBytes = LinkerHelper.C_FLOAT_BYTE_SIZE * rows * cols;
        MemorySegment index = BruteForceIndexImpl.createBruteForceIndex();
        try {
            CuVSResources.ScopedAccess resourcesAccessor = this.resources.access();
            try {
                long cuvsResources = resourcesAccessor.handle();
                MemorySegment datasetMemorySegmentP = Util.allocateRMMSegment(cuvsResources, datasetBytes);
                Util.cudaMemcpy(datasetMemorySegmentP, datasetMemSegment, datasetBytes, Util.CudaMemcpyKind.INFER_DIRECTION);
                long[] datasetShape = new long[]{rows, cols};
                Arena tensorDataArena = Arena.ofShared();
                MemorySegment datasetTensor = Util.prepareTensor(tensorDataArena, datasetMemorySegmentP, datasetShape, 2, 32, 2, 1);
                int returnValue = headers_h.cuvsStreamSync(cuvsResources);
                Util.checkCuVSError(returnValue, "cuvsStreamSync");
                returnValue = headers_h.cuvsBruteForceBuild(cuvsResources, datasetTensor, 0, 0.0f, index);
                Util.checkCuVSError(returnValue, "cuvsBruteForceBuild");
                returnValue = headers_h.cuvsStreamSync(cuvsResources);
                Util.checkCuVSError(returnValue, "cuvsStreamSync");
                IndexReference indexReference = new IndexReference(datasetMemorySegmentP, datasetBytes, tensorDataArena, index);
                if (resourcesAccessor != null) {
                    resourcesAccessor.close();
                }
                return indexReference;
            }
            catch (Throwable throwable) {
                if (resourcesAccessor != null) {
                    try {
                        resourcesAccessor.close();
                    }
                    catch (Throwable throwable2) {
                        throwable.addSuppressed(throwable2);
                    }
                }
                throw throwable;
            }
        }
        finally {
            headers_h.omp_set_num_threads(1);
        }
    }

    @Override
    public SearchResults search(BruteForceQuery cuvsQuery) throws Throwable {
        try (Arena localArena = Arena.ofConfined();){
            long prefilterDataLength;
            MemorySegment prefilterDataMemorySegment;
            this.checkNotDestroyed();
            long numQueries = cuvsQuery.getQueryVectors().length;
            long numBlocks = (long)cuvsQuery.getTopK() * numQueries;
            int vectorDimension = numQueries > 0L ? cuvsQuery.getQueryVectors()[0].length : 0;
            SequenceLayout neighborsSequenceLayout = MemoryLayout.sequenceLayout(numBlocks, LinkerHelper.C_LONG);
            SequenceLayout distancesSequenceLayout = MemoryLayout.sequenceLayout(numBlocks, LinkerHelper.C_FLOAT);
            MemorySegment neighborsMemorySegment = localArena.allocate(neighborsSequenceLayout);
            MemorySegment distancesMemorySegment = localArena.allocate(distancesSequenceLayout);
            BitSet[] prefilters = cuvsQuery.getPrefilters();
            if (prefilters != null && prefilters.length > 0) {
                BitSet concatenatedFilters = Util.concatenate(prefilters, cuvsQuery.getNumDocs());
                long[] filters = concatenatedFilters.toLongArray();
                prefilterDataMemorySegment = Util.buildMemorySegment(localArena, filters);
                prefilterDataLength = (long)cuvsQuery.getNumDocs() * (long)prefilters.length;
            } else {
                prefilterDataLength = 0L;
                prefilterDataMemorySegment = MemorySegment.NULL;
            }
            MemorySegment querySeg = Util.buildMemorySegment(localArena, cuvsQuery.getQueryVectors());
            int topk = cuvsQuery.getTopK();
            try (CuVSResources.ScopedAccess resourcesAccessor = this.resources.access();){
                long cuvsResources = resourcesAccessor.handle();
                long queriesBytes = LinkerHelper.C_FLOAT_BYTE_SIZE * numQueries * (long)vectorDimension;
                long neighborsBytes = LinkerHelper.C_LONG_BYTE_SIZE * numQueries * (long)topk;
                long distanceBytes = LinkerHelper.C_FLOAT_BYTE_SIZE * numQueries * (long)topk;
                long prefilterBytes = 0L;
                MemorySegment queriesDP = Util.allocateRMMSegment(cuvsResources, queriesBytes);
                MemorySegment neighborsDP = Util.allocateRMMSegment(cuvsResources, neighborsBytes);
                MemorySegment distancesDP = Util.allocateRMMSegment(cuvsResources, distanceBytes);
                MemorySegment prefilterDP = MemorySegment.NULL;
                Util.cudaMemcpy(queriesDP, querySeg, queriesBytes, Util.CudaMemcpyKind.INFER_DIRECTION);
                long[] queriesShape = new long[]{numQueries, vectorDimension};
                MemorySegment queriesTensor = Util.prepareTensor(localArena, queriesDP, queriesShape, 2, 32, 2, 1);
                long[] neighborsShape = new long[]{numQueries, topk};
                MemorySegment neighborsTensor = Util.prepareTensor(localArena, neighborsDP, neighborsShape, 0, 64, 2, 1);
                long[] distancesShape = new long[]{numQueries, topk};
                MemorySegment distancesTensor = Util.prepareTensor(localArena, distancesDP, distancesShape, 2, 32, 2, 1);
                MemorySegment prefilter = cuvsFilter.allocate(localArena);
                if (prefilterDataMemorySegment == MemorySegment.NULL) {
                    cuvsFilter.type(prefilter, 0);
                    cuvsFilter.addr(prefilter, 0L);
                } else {
                    long[] prefilterShape = new long[]{(prefilterDataLength + 31L) / 32L};
                    long prefilterLen = prefilterShape[0];
                    prefilterBytes = LinkerHelper.C_INT_BYTE_SIZE * prefilterLen;
                    prefilterDP = Util.allocateRMMSegment(cuvsResources, prefilterBytes);
                    Util.cudaMemcpy(prefilterDP, prefilterDataMemorySegment, prefilterBytes, Util.CudaMemcpyKind.HOST_TO_DEVICE);
                    MemorySegment prefilterTensor = Util.prepareTensor(localArena, prefilterDP, prefilterShape, 1, 32, 2, 1);
                    cuvsFilter.type(prefilter, 2);
                    cuvsFilter.addr(prefilter, prefilterTensor.address());
                }
                int returnValue = headers_h.cuvsStreamSync(cuvsResources);
                Util.checkCuVSError(returnValue, "cuvsStreamSync");
                returnValue = headers_h.cuvsBruteForceSearch(cuvsResources, this.bruteForceIndexReference.indexPtr, queriesTensor, neighborsTensor, distancesTensor, prefilter);
                Util.checkCuVSError(returnValue, "cuvsBruteForceSearch");
                returnValue = headers_h.cuvsStreamSync(cuvsResources);
                Util.checkCuVSError(returnValue, "cuvsStreamSync");
                Util.cudaMemcpy(neighborsMemorySegment, neighborsDP, neighborsBytes, Util.CudaMemcpyKind.INFER_DIRECTION);
                Util.cudaMemcpy(distancesMemorySegment, distancesDP, distanceBytes, Util.CudaMemcpyKind.INFER_DIRECTION);
                returnValue = headers_h.cuvsRMMFree(cuvsResources, neighborsDP, neighborsBytes);
                Util.checkCuVSError(returnValue, "cuvsRMMFree");
                returnValue = headers_h.cuvsRMMFree(cuvsResources, distancesDP, distanceBytes);
                Util.checkCuVSError(returnValue, "cuvsRMMFree");
                returnValue = headers_h.cuvsRMMFree(cuvsResources, queriesDP, queriesBytes);
                Util.checkCuVSError(returnValue, "cuvsRMMFree");
                if (prefilterBytes > 0L) {
                    returnValue = headers_h.cuvsRMMFree(cuvsResources, prefilterDP, prefilterBytes);
                    Util.checkCuVSError(returnValue, "cuvsRMMFree");
                }
            }
            SearchResults searchResults = BruteForceSearchResults.create(neighborsSequenceLayout, distancesSequenceLayout, neighborsMemorySegment, distancesMemorySegment, cuvsQuery.getTopK(), cuvsQuery.getMapping(), numQueries);
            return searchResults;
        }
    }

    @Override
    public void serialize(OutputStream outputStream) throws Throwable {
        Path path = Files.createTempFile(this.resources.tempDirectory(), UUID.randomUUID().toString(), ".bf", new FileAttribute[0]);
        this.serialize(outputStream, path);
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    @Override
    public void serialize(OutputStream outputStream, Path tempFile) throws Throwable {
        this.checkNotDestroyed();
        Path tempFilePath = tempFile.toAbsolutePath();
        try (Arena localArena = Arena.ofConfined();
             CuVSResources.ScopedAccess resourcesAccessor = this.resources.access();){
            int returnValue = headers_h.cuvsBruteForceSerialize(resourcesAccessor.handle(), localArena.allocateFrom(tempFilePath.toString()), this.bruteForceIndexReference.indexPtr);
            Util.checkCuVSError(returnValue, "cuvsBruteForceSerialize");
        }
        try (InputStream inputStream = Files.newInputStream(tempFilePath, new OpenOption[0]);){
            inputStream.transferTo(outputStream);
        }
        finally {
            Files.deleteIfExists(tempFile);
        }
    }

    private static MemorySegment createBruteForceIndex() {
        try (Arena localArena = Arena.ofConfined();){
            MemorySegment indexPtrPtr = localArena.allocate(headers_h.cuvsBruteForceIndex_t);
            int returnValue = headers_h.cuvsBruteForceIndexCreate(indexPtrPtr);
            Util.checkCuVSError(returnValue, "cuvsBruteForceIndexCreate");
            MemorySegment memorySegment = indexPtrPtr.get(headers_h.cuvsBruteForceIndex_t, 0L);
            return memorySegment;
        }
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    private IndexReference deserialize(InputStream inputStream) throws Throwable {
        this.checkNotDestroyed();
        Path tmpIndexFile = Files.createTempFile(this.resources.tempDirectory(), UUID.randomUUID().toString(), ".bf", new FileAttribute[0]).toAbsolutePath();
        IndexReference indexReference = new IndexReference(BruteForceIndexImpl.createBruteForceIndex());
        try (InputStream inputStream2 = inputStream;
             OutputStream outputStream = Files.newOutputStream(tmpIndexFile, new OpenOption[0]);
             Arena arena = Arena.ofConfined();
             CuVSResources.ScopedAccess resourcesAccessor = this.resources.access();){
            inputStream.transferTo(outputStream);
            int returnValue = headers_h.cuvsBruteForceDeserialize(resourcesAccessor.handle(), arena.allocateFrom(tmpIndexFile.toString()), indexReference.indexPtr);
            Util.checkCuVSError(returnValue, "cuvsBruteForceDeserialize");
        }
        finally {
            Files.deleteIfExists(tmpIndexFile);
        }
        return indexReference;
    }

    public static BruteForceIndex.Builder newBuilder(CuVSResources cuvsResources) {
        return new Builder(Objects.requireNonNull(cuvsResources));
    }

    private static class IndexReference {
        private final MemorySegment datasetPtr;
        private final long datasetBytes;
        private final Arena tensorDataArena;
        private final MemorySegment indexPtr;

        private IndexReference(MemorySegment datasetPtr, long datasetBytes, Arena tensorDataArena, MemorySegment indexPtr) {
            this.datasetPtr = datasetPtr;
            this.datasetBytes = datasetBytes;
            this.tensorDataArena = tensorDataArena;
            this.indexPtr = indexPtr;
        }

        private IndexReference(MemorySegment indexPtr) {
            this.datasetPtr = MemorySegment.NULL;
            this.datasetBytes = 0L;
            this.tensorDataArena = null;
            this.indexPtr = indexPtr;
        }
    }

    public static class Builder
    implements BruteForceIndex.Builder {
        private CuVSMatrix dataset;
        private final CuVSResources cuvsResources;
        private BruteForceIndexParams bruteForceIndexParams;
        private InputStream inputStream;

        public Builder(CuVSResources cuvsResources) {
            this.cuvsResources = cuvsResources;
        }

        @Override
        public Builder withIndexParams(BruteForceIndexParams bruteForceIndexParams) {
            this.bruteForceIndexParams = bruteForceIndexParams;
            return this;
        }

        @Override
        public Builder from(InputStream inputStream) {
            this.inputStream = inputStream;
            return this;
        }

        @Override
        public Builder withDataset(float[][] vectors) {
            this.dataset = CuVSMatrix.ofArray(vectors);
            return this;
        }

        @Override
        public Builder withDataset(CuVSMatrix dataset) {
            this.dataset = dataset;
            return this;
        }

        @Override
        public BruteForceIndexImpl build() throws Throwable {
            if (this.inputStream != null) {
                return new BruteForceIndexImpl(this.inputStream, this.cuvsResources);
            }
            return new BruteForceIndexImpl(this.dataset, this.cuvsResources, this.bruteForceIndexParams);
        }
    }
}

