/*
 * Decompiled with CFR 0.152.
 */
package com.nvidia.cuvs.internal;

import com.nvidia.cuvs.BruteForceIndex;
import com.nvidia.cuvs.BruteForceIndexParams;
import com.nvidia.cuvs.BruteForceQuery;
import com.nvidia.cuvs.CuVSResources;
import com.nvidia.cuvs.Dataset;
import com.nvidia.cuvs.SearchResults;
import com.nvidia.cuvs.internal.BruteForceSearchResults;
import com.nvidia.cuvs.internal.CuVSResourcesImpl;
import com.nvidia.cuvs.internal.DatasetImpl;
import com.nvidia.cuvs.internal.common.LinkerHelper;
import com.nvidia.cuvs.internal.common.Util;
import com.nvidia.cuvs.internal.panama.cuvsBruteForceIndex;
import java.io.FileInputStream;
import java.io.FileOutputStream;
import java.io.InputStream;
import java.io.OutputStream;
import java.lang.foreign.Arena;
import java.lang.foreign.FunctionDescriptor;
import java.lang.foreign.Linker;
import java.lang.foreign.MemoryLayout;
import java.lang.foreign.MemorySegment;
import java.lang.foreign.SequenceLayout;
import java.lang.foreign.ValueLayout;
import java.lang.invoke.MethodHandle;
import java.nio.file.Files;
import java.nio.file.Path;
import java.nio.file.attribute.FileAttribute;
import java.util.BitSet;
import java.util.Objects;
import java.util.UUID;

public class BruteForceIndexImpl
implements BruteForceIndex {
    private static final MethodHandle indexMethodHandle = LinkerHelper.downcallHandle("build_brute_force_index", FunctionDescriptor.of(ValueLayout.ADDRESS, ValueLayout.ADDRESS, LinkerHelper.C_LONG, LinkerHelper.C_LONG, ValueLayout.ADDRESS, ValueLayout.ADDRESS, LinkerHelper.C_INT), new Linker.Option[0]);
    private static final MethodHandle searchMethodHandle = LinkerHelper.downcallHandle("search_brute_force_index", FunctionDescriptor.ofVoid(ValueLayout.ADDRESS, ValueLayout.ADDRESS, LinkerHelper.C_INT, LinkerHelper.C_LONG, LinkerHelper.C_INT, ValueLayout.ADDRESS, ValueLayout.ADDRESS, ValueLayout.ADDRESS, ValueLayout.ADDRESS, ValueLayout.ADDRESS, LinkerHelper.C_LONG), new Linker.Option[0]);
    private static final MethodHandle destroyIndexMethodHandle = LinkerHelper.downcallHandle("destroy_brute_force_index", FunctionDescriptor.ofVoid(ValueLayout.ADDRESS, ValueLayout.ADDRESS), new Linker.Option[0]);
    private static final MethodHandle serializeMethodHandle = LinkerHelper.downcallHandle("serialize_brute_force_index", FunctionDescriptor.ofVoid(ValueLayout.ADDRESS, ValueLayout.ADDRESS, ValueLayout.ADDRESS, ValueLayout.ADDRESS), new Linker.Option[0]);
    private static final MethodHandle deserializeMethodHandle = LinkerHelper.downcallHandle("deserialize_brute_force_index", FunctionDescriptor.ofVoid(ValueLayout.ADDRESS, ValueLayout.ADDRESS, ValueLayout.ADDRESS, ValueLayout.ADDRESS), new Linker.Option[0]);
    private final float[][] vectors;
    private final Dataset dataset;
    private final CuVSResourcesImpl resources;
    private final IndexReference bruteForceIndexReference;
    private final BruteForceIndexParams bruteForceIndexParams;
    private boolean destroyed;

    private BruteForceIndexImpl(float[][] vectors, Dataset dataset, CuVSResourcesImpl resources, BruteForceIndexParams bruteForceIndexParams) throws Throwable {
        this.vectors = vectors;
        this.dataset = dataset;
        this.resources = resources;
        this.bruteForceIndexParams = bruteForceIndexParams;
        this.bruteForceIndexReference = this.build();
    }

    private BruteForceIndexImpl(InputStream inputStream, CuVSResourcesImpl resources) throws Throwable {
        this.bruteForceIndexParams = null;
        this.vectors = null;
        this.dataset = null;
        this.resources = resources;
        this.bruteForceIndexReference = this.deserialize(inputStream);
    }

    private void checkNotDestroyed() {
        if (this.destroyed) {
            throw new IllegalStateException("destroyed");
        }
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    @Override
    public void destroyIndex() throws Throwable {
        this.checkNotDestroyed();
        try (Arena localArena = Arena.ofConfined();){
            MemorySegment returnValue = localArena.allocate(LinkerHelper.C_INT);
            destroyIndexMethodHandle.invokeExact(this.bruteForceIndexReference.getMemorySegment(), returnValue);
            Util.checkError(returnValue.get(LinkerHelper.C_INT, 0L), "destroyIndexMethodHandle");
        }
        finally {
            this.destroyed = true;
        }
        if (this.dataset != null) {
            this.dataset.close();
        }
    }

    private IndexReference build() throws Throwable {
        long rows;
        long l = rows = this.dataset != null ? (long)this.dataset.size() : (long)this.vectors.length;
        long cols = this.dataset != null ? (long)this.dataset.dimensions() : (long)(rows > 0L ? this.vectors[0].length : 0);
        MemorySegment dataSeg = this.dataset != null ? ((DatasetImpl)this.dataset).seg : Util.buildMemorySegment(this.resources.getArena(), this.vectors);
        try (Arena localArena = Arena.ofConfined();){
            MemorySegment returnValue = localArena.allocate(LinkerHelper.C_INT);
            MemorySegment indexSeg = indexMethodHandle.invokeExact(dataSeg, rows, cols, this.resources.getMemorySegment(), returnValue, this.bruteForceIndexParams.getNumWriterThreads());
            Util.checkError(returnValue.get(LinkerHelper.C_INT, 0L), "indexMethodHandle");
            IndexReference indexReference = new IndexReference(indexSeg);
            return indexReference;
        }
    }

    @Override
    public SearchResults search(BruteForceQuery cuvsQuery) throws Throwable {
        this.checkNotDestroyed();
        long numQueries = cuvsQuery.getQueryVectors().length;
        long numBlocks = (long)cuvsQuery.getTopK() * numQueries;
        int vectorDimension = numQueries > 0L ? cuvsQuery.getQueryVectors()[0].length : 0;
        SequenceLayout neighborsSequenceLayout = MemoryLayout.sequenceLayout(numBlocks, LinkerHelper.C_LONG);
        SequenceLayout distancesSequenceLayout = MemoryLayout.sequenceLayout(numBlocks, LinkerHelper.C_FLOAT);
        MemorySegment neighborsMemorySegment = this.resources.getArena().allocate(neighborsSequenceLayout);
        MemorySegment distancesMemorySegment = this.resources.getArena().allocate(distancesSequenceLayout);
        long prefilterDataLength = 0L;
        MemorySegment prefilterDataMemorySegment = MemorySegment.NULL;
        BitSet[] prefilters = cuvsQuery.getPrefilters();
        if (prefilters != null && prefilters.length > 0) {
            BitSet concatenatedFilters = Util.concatenate(prefilters, cuvsQuery.getNumDocs());
            long[] filters = concatenatedFilters.toLongArray();
            prefilterDataMemorySegment = Util.buildMemorySegment(this.resources.getArena(), filters);
            prefilterDataLength = cuvsQuery.getNumDocs() * prefilters.length;
        }
        MemorySegment querySeg = Util.buildMemorySegment(this.resources.getArena(), cuvsQuery.getQueryVectors());
        try (Arena localArena = Arena.ofConfined();){
            MemorySegment returnValue = localArena.allocate(LinkerHelper.C_INT);
            searchMethodHandle.invokeExact(this.bruteForceIndexReference.getMemorySegment(), querySeg, cuvsQuery.getTopK(), numQueries, vectorDimension, this.resources.getMemorySegment(), neighborsMemorySegment, distancesMemorySegment, returnValue, prefilterDataMemorySegment, prefilterDataLength);
            Util.checkError(returnValue.get(LinkerHelper.C_INT, 0L), "searchMethodHandle");
        }
        return new BruteForceSearchResults(neighborsSequenceLayout, distancesSequenceLayout, neighborsMemorySegment, distancesMemorySegment, cuvsQuery.getTopK(), cuvsQuery.getMapping(), numQueries);
    }

    @Override
    public void serialize(OutputStream outputStream) throws Throwable {
        Path p = Files.createTempFile(this.resources.tempDirectory(), UUID.randomUUID().toString(), ".bf", new FileAttribute[0]);
        this.serialize(outputStream, p);
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    @Override
    public void serialize(OutputStream outputStream, Path tempFile) throws Throwable {
        this.checkNotDestroyed();
        tempFile = tempFile.toAbsolutePath();
        MemorySegment pathSeg = Util.buildMemorySegment(this.resources.getArena(), tempFile.toString());
        try (Arena localArena = Arena.ofConfined();){
            MemorySegment returnValue = localArena.allocate(LinkerHelper.C_INT);
            serializeMethodHandle.invokeExact(this.resources.getMemorySegment(), this.bruteForceIndexReference.getMemorySegment(), returnValue, pathSeg);
            Util.checkError(returnValue.get(LinkerHelper.C_INT, 0L), "serializeMethodHandle");
            try (FileInputStream fileInputStream = new FileInputStream(tempFile.toFile());){
                fileInputStream.transferTo(outputStream);
            }
            finally {
                Files.deleteIfExists(tempFile);
            }
        }
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    private IndexReference deserialize(InputStream inputStream) throws Throwable {
        this.checkNotDestroyed();
        Path tmpIndexFile = Files.createTempFile(this.resources.tempDirectory(), UUID.randomUUID().toString(), ".bf", new FileAttribute[0]);
        tmpIndexFile = tmpIndexFile.toAbsolutePath();
        IndexReference indexReference = new IndexReference(this.resources);
        try (InputStream in = inputStream;
             FileOutputStream fileOutputStream = new FileOutputStream(tmpIndexFile.toFile());){
            in.transferTo(fileOutputStream);
            MemorySegment pathSeg = Util.buildMemorySegment(this.resources.getArena(), tmpIndexFile.toString());
            try (Arena localArena = Arena.ofConfined();){
                MemorySegment returnValue = localArena.allocate(LinkerHelper.C_INT);
                deserializeMethodHandle.invokeExact(this.resources.getMemorySegment(), indexReference.getMemorySegment(), returnValue, pathSeg);
                Util.checkError(returnValue.get(LinkerHelper.C_INT, 0L), "deserializeMethodHandle");
            }
        }
        finally {
            Files.deleteIfExists(tmpIndexFile);
        }
        return indexReference;
    }

    public static BruteForceIndex.Builder newBuilder(CuVSResources cuvsResources) {
        Objects.requireNonNull(cuvsResources);
        if (!(cuvsResources instanceof CuVSResourcesImpl)) {
            throw new IllegalArgumentException("Unsupported " + String.valueOf(cuvsResources));
        }
        return new Builder((CuVSResourcesImpl)cuvsResources);
    }

    protected static class IndexReference {
        private final MemorySegment memorySegment;

        protected IndexReference(CuVSResourcesImpl resources) {
            this.memorySegment = cuvsBruteForceIndex.allocate(resources.getArena());
        }

        protected IndexReference(MemorySegment indexMemorySegment) {
            this.memorySegment = indexMemorySegment;
        }

        protected MemorySegment getMemorySegment() {
            return this.memorySegment;
        }
    }

    public static class Builder
    implements BruteForceIndex.Builder {
        private float[][] vectors;
        private Dataset dataset;
        private final CuVSResourcesImpl cuvsResources;
        private BruteForceIndexParams bruteForceIndexParams;
        private InputStream inputStream;

        public Builder(CuVSResourcesImpl cuvsResources) {
            this.cuvsResources = cuvsResources;
        }

        @Override
        public Builder withIndexParams(BruteForceIndexParams bruteForceIndexParams) {
            this.bruteForceIndexParams = bruteForceIndexParams;
            return this;
        }

        @Override
        public Builder from(InputStream inputStream) {
            this.inputStream = inputStream;
            return this;
        }

        @Override
        public Builder withDataset(float[][] vectors) {
            this.vectors = vectors;
            return this;
        }

        @Override
        public Builder withDataset(Dataset dataset) {
            this.dataset = dataset;
            return this;
        }

        @Override
        public BruteForceIndexImpl build() throws Throwable {
            if (this.inputStream != null) {
                return new BruteForceIndexImpl(this.inputStream, this.cuvsResources);
            }
            return new BruteForceIndexImpl(this.vectors, this.dataset, this.cuvsResources, this.bruteForceIndexParams);
        }
    }
}

