/*
 * Decompiled with CFR 0.152.
 */
package org.eclipse.milo.opcua.stack.server.transport.uasc;

import io.netty.buffer.ByteBuf;
import io.netty.buffer.CompositeByteBuf;
import io.netty.channel.ChannelHandler;
import io.netty.channel.ChannelHandlerContext;
import io.netty.handler.codec.ByteToMessageDecoder;
import io.netty.util.AttributeKey;
import io.netty.util.ReferenceCountUtil;
import io.netty.util.ReferenceCounted;
import io.netty.util.Timeout;
import java.io.IOException;
import java.security.KeyPair;
import java.security.cert.X509Certificate;
import java.util.ArrayList;
import java.util.List;
import java.util.Objects;
import java.util.Optional;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicReference;
import org.eclipse.milo.opcua.stack.core.Stack;
import org.eclipse.milo.opcua.stack.core.UaException;
import org.eclipse.milo.opcua.stack.core.UaSerializationException;
import org.eclipse.milo.opcua.stack.core.channel.ChannelSecurity;
import org.eclipse.milo.opcua.stack.core.channel.ChunkDecoder;
import org.eclipse.milo.opcua.stack.core.channel.ChunkEncoder;
import org.eclipse.milo.opcua.stack.core.channel.ExceptionHandler;
import org.eclipse.milo.opcua.stack.core.channel.MessageAbortedException;
import org.eclipse.milo.opcua.stack.core.channel.SecureChannel;
import org.eclipse.milo.opcua.stack.core.channel.SerializationQueue;
import org.eclipse.milo.opcua.stack.core.channel.ServerSecureChannel;
import org.eclipse.milo.opcua.stack.core.channel.headers.AsymmetricSecurityHeader;
import org.eclipse.milo.opcua.stack.core.channel.headers.HeaderDecoder;
import org.eclipse.milo.opcua.stack.core.channel.messages.ErrorMessage;
import org.eclipse.milo.opcua.stack.core.channel.messages.MessageType;
import org.eclipse.milo.opcua.stack.core.security.CertificateManager;
import org.eclipse.milo.opcua.stack.core.security.CertificateValidator;
import org.eclipse.milo.opcua.stack.core.security.SecurityPolicy;
import org.eclipse.milo.opcua.stack.core.serialization.UaMessage;
import org.eclipse.milo.opcua.stack.core.transport.TransportProfile;
import org.eclipse.milo.opcua.stack.core.types.builtin.ByteString;
import org.eclipse.milo.opcua.stack.core.types.builtin.DateTime;
import org.eclipse.milo.opcua.stack.core.types.builtin.StatusCode;
import org.eclipse.milo.opcua.stack.core.types.builtin.unsigned.Unsigned;
import org.eclipse.milo.opcua.stack.core.types.enumerated.SecurityTokenRequestType;
import org.eclipse.milo.opcua.stack.core.types.structured.ChannelSecurityToken;
import org.eclipse.milo.opcua.stack.core.types.structured.EndpointDescription;
import org.eclipse.milo.opcua.stack.core.types.structured.OpenSecureChannelRequest;
import org.eclipse.milo.opcua.stack.core.types.structured.OpenSecureChannelResponse;
import org.eclipse.milo.opcua.stack.core.types.structured.ResponseHeader;
import org.eclipse.milo.opcua.stack.core.util.BufferUtil;
import org.eclipse.milo.opcua.stack.core.util.EndpointUtil;
import org.eclipse.milo.opcua.stack.core.util.NonceUtil;
import org.eclipse.milo.opcua.stack.server.UaStackServer;
import org.eclipse.milo.opcua.stack.server.transport.uasc.UascServerHelloHandler;
import org.eclipse.milo.opcua.stack.server.transport.uasc.UascServerSymmetricHandler;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class UascServerAsymmetricHandler
extends ByteToMessageDecoder
implements HeaderDecoder {
    static final AttributeKey<EndpointDescription> ENDPOINT_KEY = AttributeKey.valueOf((String)"endpoint");
    private static final long SecureChannelLifetimeMin = 3600000L;
    private static final long SecureChannelLifetimeMax = 86400000L;
    private final Logger logger = LoggerFactory.getLogger(((Object)((Object)this)).getClass());
    private ServerSecureChannel secureChannel;
    private Timeout secureChannelTimeout;
    private boolean symmetricHandlerAdded = false;
    private List<ByteBuf> chunkBuffers = new ArrayList<ByteBuf>();
    private final AtomicReference<AsymmetricSecurityHeader> headerRef = new AtomicReference();
    private final int maxArrayLength;
    private final int maxStringLength;
    private final int maxChunkCount;
    private final int maxChunkSize;
    private final UaStackServer stackServer;
    private final TransportProfile transportProfile;
    private final SerializationQueue serializationQueue;

    UascServerAsymmetricHandler(UaStackServer stackServer, TransportProfile transportProfile, SerializationQueue serializationQueue) {
        this.stackServer = stackServer;
        this.transportProfile = transportProfile;
        this.serializationQueue = serializationQueue;
        this.maxArrayLength = stackServer.getConfig().getEncodingLimits().getMaxArrayLength();
        this.maxStringLength = stackServer.getConfig().getEncodingLimits().getMaxStringLength();
        this.maxChunkCount = serializationQueue.getParameters().getLocalMaxChunkCount();
        this.maxChunkSize = serializationQueue.getParameters().getLocalReceiveBufferSize();
    }

    protected void decode(ChannelHandlerContext ctx, ByteBuf buffer, List<Object> out) throws Exception {
        if (buffer.readableBytes() >= 8) {
            int messageLength = this.getMessageLength(buffer, this.maxChunkSize);
            if (buffer.readableBytes() >= messageLength) {
                MessageType messageType = MessageType.fromMediumInt((int)buffer.getMediumLE(buffer.readerIndex()));
                switch (messageType) {
                    case OpenSecureChannel: {
                        this.onOpenSecureChannel(ctx, buffer.readSlice(messageLength));
                        break;
                    }
                    case CloseSecureChannel: {
                        this.logger.debug("Received CloseSecureChannelRequest");
                        buffer.skipBytes(messageLength);
                        if (this.secureChannelTimeout != null) {
                            this.secureChannelTimeout.cancel();
                            this.secureChannelTimeout = null;
                        }
                        ctx.close();
                        break;
                    }
                    default: {
                        throw new UaException(2155741184L, "unexpected MessageType: " + messageType);
                    }
                }
            }
        }
    }

    private void onOpenSecureChannel(final ChannelHandlerContext ctx, ByteBuf buffer) throws UaException {
        buffer.skipBytes(3);
        char chunkType = (char)buffer.readByte();
        if (chunkType == 'A') {
            this.chunkBuffers.forEach(ReferenceCounted::release);
            this.chunkBuffers.clear();
            this.headerRef.set(null);
        } else {
            int chunkSize;
            buffer.skipBytes(4);
            final long secureChannelId = buffer.readUnsignedIntLE();
            AsymmetricSecurityHeader header = AsymmetricSecurityHeader.decode((ByteBuf)buffer, (int)this.maxArrayLength, (int)this.maxStringLength);
            if (!this.headerRef.compareAndSet(null, header) && !header.equals((Object)this.headerRef.get())) {
                throw new UaException(2148728832L, "subsequent AsymmetricSecurityHeader did not match");
            }
            if (secureChannelId != 0L) {
                if (this.secureChannel == null) {
                    throw new UaException(2155806720L, "unknown secure channel id: " + secureChannelId);
                }
                if (secureChannelId != this.secureChannel.getChannelId()) {
                    throw new UaException(2155806720L, "unknown secure channel id: " + secureChannelId);
                }
            }
            if (this.secureChannel == null) {
                this.secureChannel = new ServerSecureChannel();
                this.secureChannel.setChannelId(this.stackServer.getNextChannelId());
                String securityPolicyUri = header.getSecurityPolicyUri();
                SecurityPolicy securityPolicy = SecurityPolicy.fromUri((String)securityPolicyUri);
                this.secureChannel.setSecurityPolicy(securityPolicy);
                if (securityPolicy != SecurityPolicy.None) {
                    this.secureChannel.setRemoteCertificate(header.getSenderCertificate().bytesOrEmpty());
                    CertificateValidator certificateValidator = this.stackServer.getConfig().getCertificateValidator();
                    certificateValidator.validate(this.secureChannel.getRemoteCertificate());
                    certificateValidator.verifyTrustChain(this.secureChannel.getRemoteCertificateChain());
                    CertificateManager certificateManager = this.stackServer.getConfig().getCertificateManager();
                    Optional localCertificateChain = certificateManager.getCertificateChain(header.getReceiverThumbprint());
                    Optional keyPair = certificateManager.getKeyPair(header.getReceiverThumbprint());
                    if (localCertificateChain.isPresent() && keyPair.isPresent()) {
                        X509Certificate[] chain = (X509Certificate[])localCertificateChain.get();
                        this.secureChannel.setLocalCertificate(chain[0]);
                        this.secureChannel.setLocalCertificateChain(chain);
                        this.secureChannel.setKeyPair((KeyPair)keyPair.get());
                    } else {
                        throw new UaException(2148728832L, "no certificate for provided thumbprint");
                    }
                }
            }
            if ((chunkSize = buffer.readerIndex(0).readableBytes()) > this.maxChunkSize) {
                throw new UaException(0x80800000L, String.format("max chunk size exceeded (%s)", this.maxChunkSize));
            }
            this.chunkBuffers.add(buffer.retain());
            if (this.maxChunkCount > 0 && this.chunkBuffers.size() > this.maxChunkCount) {
                throw new UaException(0x80800000L, String.format("max chunk count exceeded (%s)", this.maxChunkCount));
            }
            if (chunkType == 'F') {
                final List<ByteBuf> buffersToDecode = this.chunkBuffers;
                this.chunkBuffers = new ArrayList<ByteBuf>();
                this.headerRef.set(null);
                this.serializationQueue.decode((binaryDecoder, chunkDecoder) -> chunkDecoder.decodeAsymmetric((SecureChannel)this.secureChannel, buffersToDecode, new ChunkDecoder.Callback(){

                    public void onDecodingError(UaException ex) {
                        UascServerAsymmetricHandler.this.logger.error("Error decoding asymmetric message: {}", (Object)ex.getMessage(), (Object)ex);
                        ctx.close();
                    }

                    public void onMessageAborted(MessageAbortedException ex) {
                        UascServerAsymmetricHandler.this.logger.warn("Asymmetric message aborted. error={} reason={}", (Object)ex.getStatusCode(), (Object)ex.getMessage());
                    }

                    /*
                     * WARNING - Removed try catching itself - possible behaviour change.
                     */
                    public void onMessageDecoded(ByteBuf message, long requestId) {
                        try {
                            OpenSecureChannelRequest request = (OpenSecureChannelRequest)binaryDecoder.setBuffer(message).readMessage(null);
                            UascServerAsymmetricHandler.this.logger.debug("Received OpenSecureChannelRequest ({}, id={}).", (Object)request.getRequestType(), (Object)secureChannelId);
                            UascServerAsymmetricHandler.this.sendOpenSecureChannelResponse(ctx, requestId, request);
                        }
                        catch (Throwable t) {
                            UascServerAsymmetricHandler.this.logger.error("Error decoding OpenSecureChannelRequest", t);
                            ctx.close();
                        }
                        finally {
                            message.release();
                            buffersToDecode.clear();
                        }
                    }
                }));
            }
        }
    }

    private void sendOpenSecureChannelResponse(final ChannelHandlerContext ctx, long requestId, OpenSecureChannelRequest request) {
        this.serializationQueue.encode((binaryEncoder, chunkEncoder) -> {
            ByteBuf messageBuffer = BufferUtil.pooledBuffer();
            try {
                OpenSecureChannelResponse response = this.openSecureChannel(ctx, request);
                binaryEncoder.setBuffer(messageBuffer);
                binaryEncoder.writeMessage(null, (UaMessage)response);
                this.checkMessageSize(messageBuffer);
                chunkEncoder.encodeAsymmetric((SecureChannel)this.secureChannel, requestId, messageBuffer, MessageType.OpenSecureChannel, new ChunkEncoder.Callback(){

                    public void onEncodingError(UaException ex) {
                        UascServerAsymmetricHandler.this.logger.error("Error encoding OpenSecureChannelResponse: {}", (Object)ex.getMessage(), (Object)ex);
                        ctx.fireExceptionCaught((Throwable)ex);
                    }

                    public void onMessageEncoded(List<ByteBuf> messageChunks, long requestId) {
                        if (!UascServerAsymmetricHandler.this.symmetricHandlerAdded) {
                            UascServerSymmetricHandler symmetricHandler = new UascServerSymmetricHandler(UascServerAsymmetricHandler.this.stackServer, UascServerAsymmetricHandler.this.serializationQueue, UascServerAsymmetricHandler.this.secureChannel);
                            ctx.pipeline().addBefore(ctx.name(), null, (ChannelHandler)symmetricHandler);
                            UascServerAsymmetricHandler.this.symmetricHandlerAdded = true;
                        }
                        CompositeByteBuf chunkComposite = BufferUtil.compositeBuffer();
                        for (ByteBuf chunk : messageChunks) {
                            chunkComposite.addComponent(chunk);
                            chunkComposite.writerIndex(chunkComposite.writerIndex() + chunk.readableBytes());
                        }
                        ctx.writeAndFlush((Object)chunkComposite, ctx.voidPromise());
                        UascServerAsymmetricHandler.this.logger.debug("Sent OpenSecureChannelResponse.");
                    }
                });
            }
            catch (UaException e) {
                this.logger.error("Error installing security token: {}", (Object)e.getStatusCode(), (Object)e);
                ctx.close();
            }
            catch (UaSerializationException e) {
                ctx.fireExceptionCaught((Throwable)e);
            }
            finally {
                messageBuffer.release();
            }
        });
    }

    private OpenSecureChannelResponse openSecureChannel(ChannelHandlerContext ctx, OpenSecureChannelRequest request) throws UaException {
        ChannelSecurity oldSecrets;
        SecurityTokenRequestType requestType = request.getRequestType();
        if (requestType == SecurityTokenRequestType.Issue) {
            this.secureChannel.setMessageSecurityMode(request.getSecurityMode());
            String endpointUrl = (String)ctx.channel().attr(UascServerHelloHandler.ENDPOINT_URL_KEY).get();
            EndpointDescription endpoint = this.stackServer.getEndpointDescriptions().stream().filter(e -> {
                boolean transportMatch = Objects.equals(e.getTransportProfileUri(), this.transportProfile.getUri());
                boolean pathMatch = Objects.equals(EndpointUtil.getPath((String)e.getEndpointUrl()), EndpointUtil.getPath((String)endpointUrl));
                boolean securityPolicyMatch = Objects.equals(e.getSecurityPolicyUri(), this.secureChannel.getSecurityPolicy().getUri());
                boolean securityModeMatch = Objects.equals(e.getSecurityMode(), request.getSecurityMode());
                return transportMatch && pathMatch && securityPolicyMatch && securityModeMatch;
            }).findFirst().orElseThrow(() -> {
                String message = String.format("no matching endpoint found: transportProfile=%s, endpointUrl=%s, securityPolicy=%s, securityMode=%s", this.transportProfile, endpointUrl, this.secureChannel.getSecurityPolicy(), request.getSecurityMode());
                return new UaException(2148728832L, message);
            });
            ctx.channel().attr(ENDPOINT_KEY).set((Object)endpoint);
        }
        if (requestType == SecurityTokenRequestType.Renew && this.secureChannel.getMessageSecurityMode() != request.getSecurityMode()) {
            throw new UaException(2148728832L, "secure channel renewal requested a different MessageSecurityMode.");
        }
        long channelLifetime = request.getRequestedLifetime().longValue();
        channelLifetime = Math.min(86400000L, channelLifetime);
        channelLifetime = Math.max(3600000L, channelLifetime);
        ChannelSecurityToken newToken = new ChannelSecurityToken(Unsigned.uint((long)this.secureChannel.getChannelId()), Unsigned.uint((long)this.stackServer.getNextTokenId()), DateTime.now(), Unsigned.uint((long)channelLifetime));
        ChannelSecurity.SecurityKeys newKeys = null;
        if (this.secureChannel.isSymmetricSigningEnabled()) {
            ByteString remoteNonce = request.getClientNonce();
            if (remoteNonce == null || remoteNonce.isNull()) {
                throw new UaException(2148728832L, "remote nonce must be non-null");
            }
            if (remoteNonce.length() < NonceUtil.getNonceLength((SecurityPolicy)this.secureChannel.getSecurityPolicy())) {
                String message = String.format("remote nonce length must be at least %d bytes", NonceUtil.getNonceLength((SecurityPolicy)this.secureChannel.getSecurityPolicy()));
                throw new UaException(2148728832L, message);
            }
            ByteString localNonce = NonceUtil.generateNonce((SecurityPolicy)this.secureChannel.getSecurityPolicy());
            this.secureChannel.setLocalNonce(localNonce);
            this.secureChannel.setRemoteNonce(remoteNonce);
            newKeys = ChannelSecurity.generateKeyPair((SecureChannel)this.secureChannel, (ByteString)this.secureChannel.getRemoteNonce(), (ByteString)this.secureChannel.getLocalNonce());
        }
        ChannelSecurity.SecurityKeys oldKeys = (oldSecrets = this.secureChannel.getChannelSecurity()) != null ? oldSecrets.getCurrentKeys() : null;
        ChannelSecurityToken oldToken = oldSecrets != null ? oldSecrets.getCurrentToken() : null;
        ChannelSecurity newSecrets = new ChannelSecurity(newKeys, newToken, oldKeys, oldToken);
        this.secureChannel.setChannelSecurity(newSecrets);
        if (this.secureChannelTimeout == null || this.secureChannelTimeout.cancel()) {
            long lifetime = channelLifetime;
            this.secureChannelTimeout = Stack.sharedWheelTimer().newTimeout(timeout -> {
                this.logger.debug("SecureChannel renewal timed out after {}ms. id={}, channel={}", new Object[]{lifetime, this.secureChannel.getChannelId(), ctx.channel()});
                ctx.close();
            }, channelLifetime, TimeUnit.MILLISECONDS);
        }
        ResponseHeader responseHeader = new ResponseHeader(DateTime.now(), request.getRequestHeader().getRequestHandle(), StatusCode.GOOD, null, null, null);
        return new OpenSecureChannelResponse(responseHeader, Unsigned.uint((long)0L), newToken, this.secureChannel.getLocalNonce());
    }

    private void checkMessageSize(ByteBuf messageBuffer) throws UaSerializationException {
        int messageSize = messageBuffer.readableBytes();
        int remoteMaxMessageSize = this.serializationQueue.getParameters().getRemoteMaxMessageSize();
        if (remoteMaxMessageSize > 0 && messageSize > remoteMaxMessageSize) {
            throw new UaSerializationException(2159607808L, "response exceeds remote max message size: " + messageSize + " > " + remoteMaxMessageSize);
        }
    }

    public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) throws Exception {
        this.chunkBuffers.forEach(ReferenceCountUtil::safeRelease);
        this.chunkBuffers.clear();
        if (cause instanceof IOException) {
            ctx.close();
            this.logger.debug("[remote={}] IOException caught; channel closed");
        } else {
            ErrorMessage errorMessage = ExceptionHandler.sendErrorMessage((ChannelHandlerContext)ctx, (Throwable)cause);
            if (cause instanceof UaException) {
                this.logger.debug("[remote={}] UaException caught; sent {}", new Object[]{ctx.channel().remoteAddress(), errorMessage, cause});
            } else {
                this.logger.error("[remote={}] Exception caught; sent {}", new Object[]{ctx.channel().remoteAddress(), errorMessage, cause});
            }
        }
    }
}

