package dev.argon.esexpr;

import org.eclipse.collections.api.factory.primitive.ByteLists;
import org.eclipse.collections.api.factory.primitive.IntLists;
import org.eclipse.collections.api.factory.primitive.LongLists;
import org.eclipse.collections.api.factory.primitive.ShortLists;
import org.jspecify.annotations.Nullable;

import java.io.EOFException;
import java.io.IOException;
import java.io.InputStream;
import java.math.BigInteger;
import java.nio.charset.StandardCharsets;
import java.util.*;
import java.util.stream.Stream;

/**
 * A reader for the ESExpr binary format.
 */
public class ESExprBinaryReader {
	/**
	 * Create a reader for the ESExpr binary format.
	 * @param symbolTable The symbol table used when parsing.
	 * @param is The stream.
	 */
	public ESExprBinaryReader(List<String> symbolTable, InputStream is) {
		this.symbolTable = new ArrayList<>(symbolTable);
		this.is = is;
	}

	/**
	 * Create a reader for the ESExpr binary format.
	 * @param is The stream.
	 */
	public ESExprBinaryReader(InputStream is) {
		this.symbolTable = new ArrayList<>();
		this.is = is;
	}

	private final List<String> symbolTable;
	private final InputStream is;
	private int nextByte = -1;

	/**
	 * Attempts to read an ESExpr from the stream.
	 * @return The ESExpr, or null if at the end of the stream.
	 * @throws IOException when an error occurs in the underlying stream.
	 * @throws SyntaxException when an expression cannot be read.
	 */
	public @Nullable ESExpr tryRead() throws IOException, SyntaxException {
		for(;;) {
			switch(readExprPlus()) {
				case ExprPlus.Expr(var expr) -> { return expr; }
				case ExprPlus.EndOfFile() -> { return null; }
				case ExprPlus.AppendedToStringTable() -> {}
				default -> throw new SyntaxException();
			};
		}
	}

	/**
	 * Read an ESExpr from the stream.
	 * @return The ESExpr
	 * @throws EOFException if at the end of the stream.
	 * @throws IOException when an error occurs in the underlying stream.
	 * @throws SyntaxException when an expression cannot be read.
	 */
	public ESExpr read() throws IOException, SyntaxException {
		var expr = tryRead();
		if(expr == null) {
			throw new EOFException();
		}
		return expr;
	}

	/**
	 * Reads all ESExpr values from the stream.
	 * @return A stream of ESExpr values.
	 */
	public Stream<ESExpr> readAll() {
		return Stream
			.generate(() -> {
				try {
					return tryRead();
				}
				catch(IOException | SyntaxException ex) {
					throw new RuntimeException(ex);
				}
			})
			.takeWhile(Objects::nonNull);
	}



	private int next() throws IOException {
		if(nextByte >= 0) {
			int res = nextByte;
			nextByte = -1;
			return res;
		}

		return is.read();
	}

	private @Nullable BinToken nextToken() throws IOException, SyntaxException {
		int b = next();
		if(b < 0) {
			return null;
		}

		BinToken.WithIntegerType type = switch((b & 0xE0)) {
			case 0x00 -> BinToken.WithIntegerType.CONSTRUCTOR;
			case 0x20 -> BinToken.WithIntegerType.INT;
			case 0x40 -> BinToken.WithIntegerType.NEG_INT;
			case 0x60 -> BinToken.WithIntegerType.STRING;
			case 0x80 -> BinToken.WithIntegerType.STRING_POOL_INDEX;
			case 0xA0 -> BinToken.WithIntegerType.ARRAY8;
			case 0xC0 -> BinToken.WithIntegerType.KEYWORD;
			default -> null;
		};

		if(type == null) {
			return switch(b) {
				case 0xE0 -> BinToken.Fixed.CONSTRUCTOR_END;
				case 0xE1 -> BinToken.Fixed.TRUE;
				case 0xE2 -> BinToken.Fixed.FALSE;
				case 0xE3 -> BinToken.Fixed.NULL0;
				case 0xE8 -> BinToken.Fixed.NULL1;
				case 0xE9 -> BinToken.Fixed.NULL2;
				case 0xEA -> BinToken.Fixed.NULLN;
				case 0xEC -> BinToken.Fixed.FLOAT16;
				case 0xE4 -> BinToken.Fixed.FLOAT32;
				case 0xE5 -> BinToken.Fixed.FLOAT64;
				case 0xE6 -> BinToken.Fixed.CONSTRUCTOR_START_STRING_TABLE;
				case 0xE7 -> BinToken.Fixed.CONSTRUCTOR_START_LIST;
				case 0xEB -> BinToken.Fixed.APPEND_STRING_TABLE;
				case 0xED -> BinToken.Fixed.ARRAY16;
				case 0xEE -> BinToken.Fixed.ARRAY32;
				case 0xEF -> BinToken.Fixed.ARRAY64;
				case 0xF0 -> BinToken.Fixed.ARRAY128;
				default -> throw new SyntaxException();
			};
		}
		else {
			BigInteger i = BigInteger.valueOf(b & 0x0F);
			if((b & 0x10) == 0x10) {
				i = readInt(i, 4);
			}

			return new BinToken.WithInteger(type, i);
		}
	}

	private BigInteger readInt(BigInteger acc, int bits) throws IOException {
		while(true) {
			int b = next();
			if(b < 0) {
				throw new EOFException();
			}

			acc = acc.or(BigInteger.valueOf(b & 0x7F).shiftLeft(bits));
			bits += 7;

			if((b & 0x80) == 0) {
				return acc;
			}
		}
	}

	private sealed interface ExprPlus {
		record Expr(ESExpr expr) implements ExprPlus {}
		record ConstructorEnd() implements ExprPlus {}
		record Keyword(String name) implements ExprPlus {}
		record AppendedToStringTable() implements ExprPlus {}
		record EndOfFile() implements ExprPlus {}
	}

	private ExprPlus readExprPlus() throws SyntaxException, IOException {
		return switch(nextToken()) {
			case null -> new ExprPlus.EndOfFile();
			
			case BinToken.WithInteger(var type, var value) -> switch(type) {
				case CONSTRUCTOR -> {
					var sym = symbolTable.get(value.intValueExact());
					yield new ExprPlus.Expr(readConstructor(sym));
				}
				case INT -> new ExprPlus.Expr(new ESExpr.Int(value));
				case NEG_INT -> new ExprPlus.Expr(new ESExpr.Int(value.add(BigInteger.ONE).negate()));

				// Should be safe to bypass next/peekNext here.
				case STRING -> {
					int len = value.intValueExact();
					byte[] b = new byte[len];
					if(is.readNBytes(b, 0, len) < len) {
						throw new EOFException();
					}

					yield new ExprPlus.Expr(new ESExpr.Str(new String(b, StandardCharsets.UTF_8)));
				}

				case STRING_POOL_INDEX -> {
					var sym = symbolTable.get(value.intValueExact());
					yield new ExprPlus.Expr(new ESExpr.Str(sym));
				}

				case ARRAY8 -> {
					int len = value.intValueExact();
					byte[] b = new byte[len];
					if(is.readNBytes(b, 0, len) < len) {
						throw new EOFException();
					}

					yield new ExprPlus.Expr(new ESExpr.Array8(ByteLists.immutable.of(b)));
				}

				case KEYWORD -> {
					var sym = symbolTable.get(value.intValueExact());
					yield new ExprPlus.Keyword(sym);
				}
			};

			case BinToken.Fixed fixed -> switch(fixed) {
				case CONSTRUCTOR_END -> new ExprPlus.ConstructorEnd();
				case TRUE -> new ExprPlus.Expr(new ESExpr.Bool(true));
				case FALSE -> new ExprPlus.Expr(new ESExpr.Bool(false));
				case NULL0 -> new ExprPlus.Expr(new ESExpr.Null(BigInteger.ZERO));
				case NULL1 -> new ExprPlus.Expr(new ESExpr.Null(BigInteger.ONE));
				case NULL2 -> new ExprPlus.Expr(new ESExpr.Null(BigInteger.valueOf(2)));
				case NULLN -> {
					var n = readInt(BigInteger.ZERO, 0);
					yield new ExprPlus.Expr(new ESExpr.Null(n.add(BigInteger.valueOf(3))));
				}

				case FLOAT16 -> {
					short bits = 0;
					for(int i = 0; i < 2; ++i) {
						int b = next();
						if(b < 0) {
							throw new EOFException();
						}

						bits |= (short)((b & 0xFF) << (i * 8));
					}

					yield new ExprPlus.Expr(new ESExpr.Float32(Float.intBitsToFloat(bits)));
				}
				case FLOAT32 -> {
					int bits = 0;
					for(int i = 0; i < 4; ++i) {
						int b = next();
						if(b < 0) {
							throw new EOFException();
						}

						bits |= (b & 0xFF) << (i * 8);
					}

					yield new ExprPlus.Expr(new ESExpr.Float32(Float.intBitsToFloat(bits)));
				}
				case FLOAT64 -> {
					long bits = 0;
					for(int i = 0; i < 8; ++i) {
						int b = next();
						if(b < 0) {
							throw new EOFException();
						}

						bits |= (long)(b & 0xFF) << (i * 8);
					}

					yield new ExprPlus.Expr(new ESExpr.Float64(Double.longBitsToDouble(bits)));
				}

				case CONSTRUCTOR_START_STRING_TABLE -> new ExprPlus.Expr(readConstructor(BinToken.StringTableName));
				case CONSTRUCTOR_START_LIST -> new ExprPlus.Expr(readConstructor(BinToken.ListName));
				case APPEND_STRING_TABLE -> {
					var newStringTable = read();
					if(newStringTable instanceof ESExpr.Str(var s)) {
						symbolTable.add(s);
					}
					else {
						StringTable newDecoded;
						try {
							newDecoded = StringTable.codec().decode(newStringTable);
						}
						catch(DecodeException ex) {
							throw new SyntaxException("Could not decode string table.", ex);
						}

						symbolTable.addAll(newDecoded.values());
					}

					yield new ExprPlus.AppendedToStringTable();
				}
				
				case ARRAY16 -> {
					var n = readInt(BigInteger.ZERO, 0);
					int len = n.intValueExact();
					short[] b = new short[len];
					for(int i = 0; i < len; ++i) {
						short value = 0;
						for (int j = 0; j < 2; ++j) {
							int byteVal = next();
							if (byteVal < 0) {
								throw new EOFException();
							}
							value |= (short)((byteVal & 0xFF) << (j * 8));
						}
						b[i] = value;
					}
					yield new ExprPlus.Expr(new ESExpr.Array16(ShortLists.immutable.of(b)));
				}
				case ARRAY32 -> {
					var n = readInt(BigInteger.ZERO, 0);
					int len = n.intValueExact();
					int[] b = new int[len];
					for (int i = 0; i < len; ++i) {
						int value = 0;
						for (int j = 0; j < 4; ++j) {
							int byteVal = next();
							if (byteVal < 0) {
								throw new EOFException();
							}
							value |= (byteVal & 0xFF) << (j * 8);
						}
						b[i] = value;
					}
					yield new ExprPlus.Expr(new ESExpr.Array32(IntLists.immutable.of(b)));
				}
				case ARRAY64 -> {
					var n = readInt(BigInteger.ZERO, 0);
					int len = n.intValueExact();
					long[] b = new long[len];
					for (int i = 0; i < len; ++i) {
						long value = 0;
						for (int j = 0; j < 8; ++j) {
							int byteVal = next();
							if (byteVal < 0) {
								throw new EOFException();
							}
							value |= (long) (byteVal & 0xFF) << (j * 8);
						}
						b[i] = value;
					}
					yield new ExprPlus.Expr(new ESExpr.Array64(LongLists.immutable.of(b)));
				}
				case ARRAY128 -> {
					var n = readInt(BigInteger.ZERO, 0);
					int len = n.multiply(BigInteger.TWO).intValueExact();
					long[] b = new long[len];
					for (int i = 0; i < len; ++i) {
						long value = 0;
						for (int j = 0; j < 8; ++j) {
							int byteVal = next();
							if (byteVal < 0) {
								throw new EOFException();
							}
							value |= (long) (byteVal & 0xFF) << (j * 8);
						}
						b[i] = value;
					}
					yield new ExprPlus.Expr(new ESExpr.Array128(LongLists.immutable.of(b)));
				}
			};
		};
	}

	private ESExpr readConstructor(String name) throws IOException, SyntaxException {
		var args = new ArrayList<ESExpr>();
		var kwargs = new HashMap<String, ESExpr>();

		body:
		while(true) {
			switch(readExprPlus()) {
				case ExprPlus.Expr(var expr) -> args.add(expr);
				case ExprPlus.ConstructorEnd() -> {
					break body;
				}
				case ExprPlus.Keyword(var kw) -> {
					var expr = read();
					kwargs.put(kw, expr);
				}
				case ExprPlus.AppendedToStringTable() -> {}
				case ExprPlus.EndOfFile() -> throw new EOFException();
			}
		}

		return new ESExpr.Constructor(name, args, kwargs);
	}

}
