/*
 * Decompiled with CFR 0.152.
 */
package com.nvidia.cuvs.internal.common;

import com.nvidia.cuvs.GPUInfo;
import com.nvidia.cuvs.internal.common.LinkerHelper;
import com.nvidia.cuvs.internal.panama.DLDataType;
import com.nvidia.cuvs.internal.panama.DLDevice;
import com.nvidia.cuvs.internal.panama.DLManagedTensor;
import com.nvidia.cuvs.internal.panama.DLTensor;
import com.nvidia.cuvs.internal.panama.cudaDeviceProp;
import com.nvidia.cuvs.internal.panama.headers_h;
import java.lang.foreign.Arena;
import java.lang.foreign.MemoryLayout;
import java.lang.foreign.MemorySegment;
import java.lang.foreign.SequenceLayout;
import java.lang.invoke.VarHandle;
import java.util.ArrayList;
import java.util.BitSet;
import java.util.List;

public class Util {
    public static final int CUVS_SUCCESS = 1;
    public static final int CUDA_SUCCESS = 0;
    static final long MAX_ERROR_TEXT = 1000000L;

    private Util() {
    }

    public static void checkCuVSError(int value, String caller) {
        if (value != 1) {
            String errorMsg = Util.getLastErrorText();
            throw new RuntimeException(caller + " returned " + value + "[" + errorMsg + "]");
        }
    }

    public static void checkCudaError(int value, String caller) {
        if (value != 0) {
            throw new RuntimeException(caller + " returned " + value);
        }
    }

    public static void cudaMemcpy(MemorySegment dest, MemorySegment src, long numBytes, CudaMemcpyKind kind) {
        int returnValue = headers_h.cudaMemcpy(dest, src, numBytes, kind.kind);
        Util.checkCudaError(returnValue, "cudaMemcpy");
    }

    public static void cudaMemcpy(MemorySegment dest, MemorySegment src, long numBytes) {
        Util.cudaMemcpy(dest, src, numBytes, CudaMemcpyKind.INFER_DIRECTION);
    }

    static String getLastErrorText() {
        try {
            MemorySegment seg = headers_h.cuvsGetLastErrorText.makeInvoker(new MemoryLayout[0]).apply(new Object[0]);
            if (seg.equals(MemorySegment.NULL)) {
                return "no last error text";
            }
            return seg.reinterpret(1000000L).getString(0L);
        }
        catch (Throwable t) {
            throw new RuntimeException(t);
        }
    }

    public static List<GPUInfo> compatibleGPUs() throws Throwable {
        return Util.compatibleGPUs(7.0, 8192);
    }

    public static List<GPUInfo> compatibleGPUs(double minComputeCapability, int minDeviceMemoryMB) throws Throwable {
        ArrayList<GPUInfo> compatibleGPUs = new ArrayList<GPUInfo>();
        double minDeviceMemoryB = Math.pow(2.0, 20.0) * (double)minDeviceMemoryMB;
        for (GPUInfo gpuInfo : Util.availableGPUs()) {
            if (!((double)gpuInfo.computeCapability() >= minComputeCapability) || !((double)gpuInfo.totalMemory() >= minDeviceMemoryB)) continue;
            compatibleGPUs.add(gpuInfo);
        }
        return compatibleGPUs;
    }

    public static List<GPUInfo> availableGPUs() throws Throwable {
        try (Arena localArena = Arena.ofConfined();){
            MemorySegment numGpus = localArena.allocate(LinkerHelper.C_INT);
            int returnValue = headers_h.cudaGetDeviceCount(numGpus);
            Util.checkCudaError(returnValue, "cudaGetDeviceCount");
            int numGpuCount = numGpus.get(LinkerHelper.C_INT, 0L);
            ArrayList<GPUInfo> gpuInfoArr = new ArrayList<GPUInfo>();
            MemorySegment free = localArena.allocate(headers_h.size_t);
            MemorySegment total = localArena.allocate(headers_h.size_t);
            MemorySegment deviceProp = cudaDeviceProp.allocate(localArena);
            for (int i = 0; i < numGpuCount; ++i) {
                returnValue = headers_h.cudaSetDevice(i);
                Util.checkCudaError(returnValue, "cudaSetDevice");
                returnValue = headers_h.cudaGetDeviceProperties_v2(deviceProp, i);
                Util.checkCudaError(returnValue, "cudaGetDeviceProperties_v2");
                returnValue = headers_h.cudaMemGetInfo(free, total);
                Util.checkCudaError(returnValue, "cudaMemGetInfo");
                float computeCapability = Float.parseFloat(cudaDeviceProp.major(deviceProp) + "." + cudaDeviceProp.minor(deviceProp));
                GPUInfo gpuInfo = new GPUInfo(i, cudaDeviceProp.name(deviceProp).getString(0L), free.get(LinkerHelper.C_LONG, 0L), total.get(LinkerHelper.C_LONG, 0L), computeCapability);
                gpuInfoArr.add(gpuInfo);
            }
            ArrayList<GPUInfo> arrayList = gpuInfoArr;
            return arrayList;
        }
    }

    public static MemorySegment buildMemorySegment(Arena arena, String str) {
        StringBuilder sb = new StringBuilder(str).append('\u0000');
        SequenceLayout stringMemoryLayout = MemoryLayout.sequenceLayout(sb.length(), LinkerHelper.C_CHAR);
        MemorySegment stringMemorySegment = arena.allocate(stringMemoryLayout);
        for (int i = 0; i < sb.length(); ++i) {
            VarHandle varHandle = stringMemoryLayout.varHandle(MemoryLayout.PathElement.sequenceElement(i));
            varHandle.set(stringMemorySegment, 0L, (byte)sb.charAt(i));
        }
        return stringMemorySegment;
    }

    public static MemorySegment buildMemorySegment(Arena arena, long[] data) {
        int cells = data.length;
        SequenceLayout dataMemoryLayout = MemoryLayout.sequenceLayout(cells, LinkerHelper.C_LONG);
        MemorySegment dataMemorySegment = arena.allocate(dataMemoryLayout);
        MemorySegment.copy(data, 0, dataMemorySegment, LinkerHelper.C_LONG, 0L, cells);
        return dataMemorySegment;
    }

    public static MemorySegment buildMemorySegment(Arena arena, byte[] data) {
        int cells = data.length;
        SequenceLayout dataMemoryLayout = MemoryLayout.sequenceLayout(cells, LinkerHelper.C_CHAR);
        MemorySegment dataMemorySegment = arena.allocate(dataMemoryLayout);
        MemorySegment.copy(data, 0, dataMemorySegment, LinkerHelper.C_CHAR, 0L, cells);
        return dataMemorySegment;
    }

    public static MemorySegment buildMemorySegment(Arena arena, float[][] data) {
        long rows = data.length;
        long cols = rows > 0L ? (long)data[0].length : 0L;
        SequenceLayout dataMemoryLayout = MemoryLayout.sequenceLayout(rows * cols, LinkerHelper.C_FLOAT);
        MemorySegment dataMemorySegment = arena.allocate(dataMemoryLayout);
        Util.copy(dataMemorySegment, data);
        return dataMemorySegment;
    }

    public static void copy(MemorySegment memorySegment, float[][] data) {
        int rows = data.length;
        int cols = rows > 0 ? data[0].length : 0;
        for (int r = 0; r < rows; ++r) {
            MemorySegment.copy(data[r], 0, memorySegment, LinkerHelper.C_FLOAT, (long)(r * cols) * LinkerHelper.C_FLOAT.byteSize(), cols);
        }
    }

    public static void copy(MemorySegment memorySegment, int[][] data) {
        int rows = data.length;
        int cols = rows > 0 ? data[0].length : 0;
        for (int r = 0; r < rows; ++r) {
            MemorySegment.copy(data[r], 0, memorySegment, LinkerHelper.C_INT, (long)(r * cols) * LinkerHelper.C_INT.byteSize(), cols);
        }
    }

    public static void copy(MemorySegment memorySegment, byte[][] data) {
        int rows = data.length;
        int cols = rows > 0 ? data[0].length : 0;
        for (int r = 0; r < rows; ++r) {
            MemorySegment.copy(data[r], 0, memorySegment, LinkerHelper.C_CHAR, (long)(r * cols) * LinkerHelper.C_CHAR.byteSize(), cols);
        }
    }

    public static BitSet concatenate(BitSet[] arr, int maxSizeOfEachBitSet) {
        BitSet ret = new BitSet(maxSizeOfEachBitSet * arr.length);
        for (int i = 0; i < arr.length; ++i) {
            BitSet b = arr[i];
            if (b == null || b.length() == 0) {
                ret.set(i * maxSizeOfEachBitSet, (i + 1) * maxSizeOfEachBitSet);
                continue;
            }
            for (int j = 0; j < maxSizeOfEachBitSet; ++j) {
                ret.set(i * maxSizeOfEachBitSet + j, b.get(j));
            }
        }
        return ret;
    }

    public static MemorySegment prepareTensor(Arena arena, MemorySegment data, long[] shape, int code, int bits, int deviceType, int lanes) {
        MemorySegment tensor = DLManagedTensor.allocate(arena);
        MemorySegment dlTensor = DLTensor.allocate(arena);
        DLTensor.data(dlTensor, data);
        MemorySegment dlDevice = DLDevice.allocate(arena);
        DLDevice.device_type(dlDevice, deviceType);
        DLTensor.device(dlTensor, dlDevice);
        int ndim = shape.length;
        DLTensor.ndim(dlTensor, ndim);
        MemorySegment dtype = DLDataType.allocate(arena);
        DLDataType.code(dtype, (byte)code);
        DLDataType.bits(dtype, (byte)bits);
        DLDataType.lanes(dtype, (short)lanes);
        DLTensor.dtype(dlTensor, dtype);
        DLTensor.shape(dlTensor, Util.buildMemorySegment(arena, shape));
        DLTensor.strides(dlTensor, MemorySegment.NULL);
        DLManagedTensor.dl_tensor(tensor, dlTensor);
        return tensor;
    }

    public static MemorySegment allocateRMMSegment(long resourceHandle, long datasetBytes) {
        try (Arena localArena = Arena.ofConfined();){
            MemorySegment datasetMemorySegment = localArena.allocate(LinkerHelper.C_POINTER);
            int returnValue = headers_h.cuvsRMMAlloc(resourceHandle, datasetMemorySegment, datasetBytes);
            Util.checkCuVSError(returnValue, "cuvsRMMAlloc");
            MemorySegment memorySegment = datasetMemorySegment.get(LinkerHelper.C_POINTER, 0L);
            return memorySegment;
        }
    }

    public static enum CudaMemcpyKind {
        HOST_TO_HOST(0),
        HOST_TO_DEVICE(1),
        DEVICE_TO_HOST(2),
        DEVICE_TO_DEVICE(3),
        INFER_DIRECTION(4);

        public final int kind;

        private CudaMemcpyKind(int k) {
            this.kind = k;
        }
    }
}

