/*
 * Decompiled with CFR 0.152.
 */
package com.nvidia.cuvs.internal;

import com.nvidia.cuvs.CuVSHostMatrix;
import com.nvidia.cuvs.CuVSMatrix;
import com.nvidia.cuvs.RowView;
import com.nvidia.cuvs.internal.CuVSMatrixBaseImpl;
import com.nvidia.cuvs.internal.common.LinkerHelper;
import java.lang.foreign.MemoryLayout;
import java.lang.foreign.MemorySegment;
import java.lang.foreign.SequenceLayout;
import java.lang.foreign.ValueLayout;
import java.lang.invoke.VarHandle;

public class CuVSHostMatrixImpl
extends CuVSMatrixBaseImpl
implements CuVSHostMatrix {
    private final ValueLayout valueLayout;
    protected final VarHandle accessor$vh;

    public CuVSHostMatrixImpl(MemorySegment memorySegment, long size, long columns, CuVSMatrix.DataType dataType) {
        this(memorySegment, size, columns, dataType, CuVSHostMatrixImpl.valueLayoutFromType(dataType), MemoryLayout.sequenceLayout(size * columns, CuVSHostMatrixImpl.valueLayoutFromType(dataType)).withByteAlignment(32L));
    }

    protected CuVSHostMatrixImpl(MemorySegment memorySegment, long size, long columns, CuVSMatrix.DataType dataType, ValueLayout valueLayout, MemoryLayout sequenceLayout) {
        super(memorySegment, dataType, size, columns);
        this.accessor$vh = sequenceLayout.varHandle(MemoryLayout.PathElement.sequenceElement());
        this.valueLayout = valueLayout;
    }

    protected static ValueLayout valueLayoutFromType(CuVSMatrix.DataType dataType) {
        return switch (dataType) {
            default -> throw new MatchException(null, null);
            case CuVSMatrix.DataType.FLOAT -> LinkerHelper.C_FLOAT;
            case CuVSMatrix.DataType.INT -> LinkerHelper.C_INT;
            case CuVSMatrix.DataType.BYTE -> LinkerHelper.C_CHAR;
        };
    }

    protected static SequenceLayout sequenceLayoutFromType(long size, long columns, CuVSMatrix.DataType dataType) {
        return MemoryLayout.sequenceLayout(size * columns, CuVSHostMatrixImpl.valueLayoutFromType(dataType)).withByteAlignment(32L);
    }

    @Override
    public RowView getRow(long nodeIndex) {
        long valueByteSize = this.valueLayout.byteSize();
        return new SliceRowView(this.memorySegment.asSlice(nodeIndex * this.columns * valueByteSize, this.columns * valueByteSize), this.columns, this.valueLayout, this.dataType, valueByteSize);
    }

    @Override
    public void toArray(int[][] array) {
        assert (this.dataType == CuVSMatrix.DataType.INT);
        assert ((long)array.length >= this.size) : "Input array is not large enough";
        assert (array.length == 0 || (long)array[0].length >= this.columns) : "Input array is not large enough";
        long valueByteSize = this.valueLayout.byteSize();
        int r = 0;
        while ((long)r < this.size) {
            MemorySegment.copy(this.memorySegment, this.valueLayout, (long)r * this.columns * valueByteSize, array[r], 0, (int)this.columns);
            ++r;
        }
    }

    @Override
    public void toArray(float[][] array) {
        assert (this.dataType == CuVSMatrix.DataType.FLOAT);
        assert ((long)array.length >= this.size) : "Input array is not large enough";
        assert (array.length == 0 || (long)array[0].length >= this.columns) : "Input array is not large enough";
        long valueByteSize = this.valueLayout.byteSize();
        int r = 0;
        while ((long)r < this.size) {
            MemorySegment.copy(this.memorySegment, this.valueLayout, (long)r * this.columns * valueByteSize, array[r], 0, (int)this.columns);
            ++r;
        }
    }

    @Override
    public void toArray(byte[][] array) {
        assert (this.dataType == CuVSMatrix.DataType.BYTE);
        assert ((long)array.length >= this.size) : "Input array is not large enough";
        assert (array.length == 0 || (long)array[0].length >= this.columns) : "Input array is not large enough";
        long valueByteSize = this.valueLayout.byteSize();
        int r = 0;
        while ((long)r < this.size) {
            MemorySegment.copy(this.memorySegment, this.valueLayout, (long)r * this.columns * valueByteSize, array[r], 0, (int)this.columns);
            ++r;
        }
    }

    @Override
    public void close() {
    }

    @Override
    public int get(int row, int col) {
        return this.accessor$vh.get(this.memorySegment, 0L, (long)row * this.columns + (long)col);
    }

    public ValueLayout valueLayout() {
        return this.valueLayout;
    }

    private static class SliceRowView
    implements RowView {
        private final MemorySegment memorySegment;
        private final long size;
        private final ValueLayout valueLayout;
        private final CuVSMatrix.DataType dataType;
        private final long valueByteSize;

        SliceRowView(MemorySegment slice, long size, ValueLayout valueLayout, CuVSMatrix.DataType dataType, long valueByteSize) {
            this.memorySegment = slice;
            this.size = size;
            this.valueLayout = valueLayout;
            this.dataType = dataType;
            this.valueByteSize = valueByteSize;
        }

        @Override
        public long size() {
            return this.size;
        }

        @Override
        public float getAsFloat(long index) {
            assert (this.dataType == CuVSMatrix.DataType.FLOAT);
            return this.memorySegment.get((ValueLayout.OfFloat)this.valueLayout, index * this.valueByteSize);
        }

        @Override
        public byte getAsByte(long index) {
            assert (this.dataType == CuVSMatrix.DataType.BYTE);
            return this.memorySegment.get((ValueLayout.OfByte)this.valueLayout, index * this.valueByteSize);
        }

        @Override
        public int getAsInt(long index) {
            assert (this.dataType == CuVSMatrix.DataType.INT);
            return this.memorySegment.get((ValueLayout.OfInt)this.valueLayout, index * this.valueByteSize);
        }

        @Override
        public void toArray(int[] array) {
            assert ((long)array.length >= this.size) : "Input array is not large enough";
            assert (this.dataType == CuVSMatrix.DataType.INT);
            MemorySegment.copy(this.memorySegment, this.valueLayout, 0L, array, 0, (int)this.size);
        }

        @Override
        public void toArray(float[] array) {
            assert ((long)array.length >= this.size) : "Input array is not large enough";
            assert (this.dataType == CuVSMatrix.DataType.FLOAT);
            MemorySegment.copy(this.memorySegment, this.valueLayout, 0L, array, 0, (int)this.size);
        }

        @Override
        public void toArray(byte[] array) {
            assert ((long)array.length >= this.size) : "Input array is not large enough";
            assert (this.dataType == CuVSMatrix.DataType.BYTE);
            MemorySegment.copy(this.memorySegment, this.valueLayout, 0L, array, 0, (int)this.size);
        }
    }
}

