/*
 * Decompiled with CFR 0.152.
 */
package com.nvidia.cuvs.internal.common;

import com.nvidia.cuvs.GPUInfo;
import com.nvidia.cuvs.internal.common.LinkerHelper;
import com.nvidia.cuvs.internal.panama.gpuInfo;
import java.lang.foreign.Arena;
import java.lang.foreign.FunctionDescriptor;
import java.lang.foreign.Linker;
import java.lang.foreign.MemoryLayout;
import java.lang.foreign.MemorySegment;
import java.lang.foreign.SequenceLayout;
import java.lang.foreign.ValueLayout;
import java.lang.invoke.MethodHandle;
import java.lang.invoke.VarHandle;
import java.util.ArrayList;
import java.util.BitSet;
import java.util.List;

public class Util {
    public static final int CUVS_SUCCESS = 1;
    private static final MethodHandle getGpuInfoMethodHandle = LinkerHelper.downcallHandle("get_gpu_info", FunctionDescriptor.ofVoid(ValueLayout.ADDRESS, ValueLayout.ADDRESS, ValueLayout.ADDRESS), new Linker.Option[0]);
    private static final MethodHandle getLastErrorTextMethodHandle = LinkerHelper.downcallHandle("cuvsGetLastErrorText", FunctionDescriptor.of(ValueLayout.ADDRESS, new MemoryLayout[0]), new Linker.Option[0]);
    static final long MAX_ERROR_TEXT = 1000000L;

    private Util() {
    }

    public static void checkError(int value, String caller) {
        if (value != 1) {
            String errorMsg = Util.getLastErrorText();
            throw new RuntimeException(caller + " returned " + value + "[" + errorMsg + "]");
        }
    }

    static String getLastErrorText() {
        try {
            MemorySegment seg = getLastErrorTextMethodHandle.invokeExact();
            if (seg.equals(MemorySegment.NULL)) {
                return "no last error text";
            }
            return seg.reinterpret(1000000L).getString(0L);
        }
        catch (Throwable t) {
            throw new RuntimeException(t);
        }
    }

    public static List<GPUInfo> compatibleGPUs() throws Throwable {
        return Util.compatibleGPUs(7.0, 8192);
    }

    public static List<GPUInfo> compatibleGPUs(double minComputeCapability, int minDeviceMemoryMB) throws Throwable {
        ArrayList<GPUInfo> compatibleGPUs = new ArrayList<GPUInfo>();
        double minDeviceMemoryB = Math.pow(2.0, 20.0) * (double)minDeviceMemoryMB;
        for (GPUInfo gpuInfo2 : Util.availableGPUs()) {
            if (!((double)gpuInfo2.computeCapability() >= minComputeCapability) || !((double)gpuInfo2.totalMemory() >= minDeviceMemoryB)) continue;
            compatibleGPUs.add(gpuInfo2);
        }
        return compatibleGPUs;
    }

    public static List<GPUInfo> availableGPUs() throws Throwable {
        ArrayList<GPUInfo> results = new ArrayList<GPUInfo>();
        try (Arena localArena = Arena.ofConfined();){
            MemorySegment returnValueMemorySegment = localArena.allocate(LinkerHelper.C_INT);
            MemorySegment numGpuMemorySegment = localArena.allocate(LinkerHelper.C_INT);
            MemorySegment GpuInfoArrayMemorySegment = gpuInfo.allocateArray(1024L, localArena);
            getGpuInfoMethodHandle.invokeExact(returnValueMemorySegment, numGpuMemorySegment, GpuInfoArrayMemorySegment);
            int numGPUs = numGpuMemorySegment.get(ValueLayout.JAVA_INT, 0L);
            SequenceLayout ml = MemoryLayout.sequenceLayout(numGPUs, gpuInfo.layout());
            for (int i = 0; i < numGPUs; ++i) {
                VarHandle gpuIdVarHandle = ml.varHandle(MemoryLayout.PathElement.sequenceElement(i), MemoryLayout.PathElement.groupElement("gpu_id"));
                VarHandle freeMemoryVarHandle = ml.varHandle(MemoryLayout.PathElement.sequenceElement(i), MemoryLayout.PathElement.groupElement("free_memory"));
                VarHandle totalMemoryVarHandle = ml.varHandle(MemoryLayout.PathElement.sequenceElement(i), MemoryLayout.PathElement.groupElement("total_memory"));
                VarHandle ComputeCapabilityVarHandle = ml.varHandle(MemoryLayout.PathElement.sequenceElement(i), MemoryLayout.PathElement.groupElement("compute_capability"));
                StringBuilder gpuName = new StringBuilder();
                char b = '\u0001';
                int p = 0;
                while (b != '\u0000') {
                    VarHandle gpuNameVarHandle = ml.varHandle(MemoryLayout.PathElement.sequenceElement(i), MemoryLayout.PathElement.groupElement("name"), MemoryLayout.PathElement.sequenceElement(p++));
                    b = (char)gpuNameVarHandle.get(GpuInfoArrayMemorySegment, 0L);
                    gpuName.append(b);
                }
                results.add(new GPUInfo(gpuIdVarHandle.get(GpuInfoArrayMemorySegment, 0L), gpuName.toString().trim(), freeMemoryVarHandle.get(GpuInfoArrayMemorySegment, 0L), totalMemoryVarHandle.get(GpuInfoArrayMemorySegment, 0L), ComputeCapabilityVarHandle.get(GpuInfoArrayMemorySegment, 0L)));
            }
            ArrayList<GPUInfo> arrayList = results;
            return arrayList;
        }
    }

    public static MemorySegment buildMemorySegment(Arena arena, String str) {
        StringBuilder sb = new StringBuilder(str).append('\u0000');
        SequenceLayout stringMemoryLayout = MemoryLayout.sequenceLayout(sb.length(), LinkerHelper.C_CHAR);
        MemorySegment stringMemorySegment = arena.allocate(stringMemoryLayout);
        for (int i = 0; i < sb.length(); ++i) {
            VarHandle varHandle = stringMemoryLayout.varHandle(MemoryLayout.PathElement.sequenceElement(i));
            varHandle.set(stringMemorySegment, 0L, (byte)sb.charAt(i));
        }
        return stringMemorySegment;
    }

    public static MemorySegment buildMemorySegment(Arena arena, long[] data) {
        int cells = data.length;
        SequenceLayout dataMemoryLayout = MemoryLayout.sequenceLayout(cells, LinkerHelper.C_LONG);
        MemorySegment dataMemorySegment = arena.allocate(dataMemoryLayout);
        MemorySegment.copy(data, 0, dataMemorySegment, LinkerHelper.C_LONG, 0L, cells);
        return dataMemorySegment;
    }

    public static MemorySegment buildMemorySegment(Arena arena, byte[] data) {
        int cells = data.length;
        SequenceLayout dataMemoryLayout = MemoryLayout.sequenceLayout(cells, LinkerHelper.C_CHAR);
        MemorySegment dataMemorySegment = arena.allocate(dataMemoryLayout);
        MemorySegment.copy(data, 0, dataMemorySegment, LinkerHelper.C_CHAR, 0L, cells);
        return dataMemorySegment;
    }

    public static MemorySegment buildMemorySegment(Arena arena, float[][] data) {
        long rows = data.length;
        long cols = rows > 0L ? (long)data[0].length : 0L;
        SequenceLayout dataMemoryLayout = MemoryLayout.sequenceLayout(rows * cols, LinkerHelper.C_FLOAT);
        MemorySegment dataMemorySegment = arena.allocate(dataMemoryLayout);
        int r = 0;
        while ((long)r < rows) {
            MemorySegment.copy(data[r], 0, dataMemorySegment, LinkerHelper.C_FLOAT, (long)r * cols * LinkerHelper.C_FLOAT.byteSize(), (int)cols);
            ++r;
        }
        return dataMemorySegment;
    }

    public static BitSet concatenate(BitSet[] arr, int maxSizeOfEachBitSet) {
        BitSet ret = new BitSet(maxSizeOfEachBitSet * arr.length);
        for (int i = 0; i < arr.length; ++i) {
            BitSet b = arr[i];
            if (b == null || b.length() == 0) {
                ret.set(i * maxSizeOfEachBitSet, (i + 1) * maxSizeOfEachBitSet);
                continue;
            }
            for (int j = 0; j < maxSizeOfEachBitSet; ++j) {
                ret.set(i * maxSizeOfEachBitSet + j, b.get(j));
            }
        }
        return ret;
    }
}

