/*
 * Decompiled with CFR 0.152.
 */
package eu.clarussecure.proxy.protocol.plugins.pgsql.message.ssl;

import eu.clarussecure.proxy.protocol.plugins.pgsql.PgsqlConstants;
import eu.clarussecure.proxy.protocol.plugins.pgsql.PgsqlSession;
import eu.clarussecure.proxy.protocol.plugins.pgsql.message.sql.TransferMode;
import eu.clarussecure.proxy.protocol.plugins.pgsql.message.ssl.SessionMessageTransferMode;
import eu.clarussecure.proxy.protocol.plugins.pgsql.raw.handler.codec.PgsqlRawPartCodec;
import eu.clarussecure.proxy.protocol.plugins.tcp.TCPConstants;
import eu.clarussecure.proxy.protocol.plugins.tcp.ssl.SSLSessionInitializer;
import eu.clarussecure.proxy.spi.CString;
import io.netty.channel.Channel;
import io.netty.channel.ChannelHandler;
import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.ChannelPipeline;
import io.netty.util.concurrent.Future;
import io.netty.util.concurrent.GenericFutureListener;
import java.io.IOException;
import java.util.Collections;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.atomic.AtomicInteger;
import javax.net.ssl.SSLException;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class SessionInitializer {
    private static final Logger LOGGER = LoggerFactory.getLogger(SessionInitializer.class);
    private SSLSessionInitializer sslSessionInitializer = new SSLSessionInitializer();
    private AtomicBoolean sslRequestReceived = new AtomicBoolean(false);
    private AtomicInteger sslResponseReceived = new AtomicInteger(0);
    private AtomicBoolean sessionEncryptedOnFrontendSide = new AtomicBoolean(false);
    private AtomicBoolean sessionEncryptedOnBackendSide = new AtomicBoolean(false);

    public SessionMessageTransferMode<Void, Byte> processSSLRequest(ChannelHandlerContext ctx, int code) throws IOException {
        LOGGER.debug("SSL request code: {}", (Object)code);
        TransferMode transferMode = TransferMode.FORWARD;
        Byte response = null;
        LinkedHashMap<Byte, CString> errorDetails = null;
        this.sslRequestReceived.set(true);
        if (this.sslSessionInitializer.getClientMode() == SSLSessionInitializer.SSLMode.DISABLED) {
            LOGGER.trace("Reply to the frontend that SSL is required");
            transferMode = TransferMode.ERROR;
            errorDetails = new LinkedHashMap<Byte, CString>();
            errorDetails.put((byte)83, CString.valueOf((CharSequence)"FATAL"));
            errorDetails.put((byte)77, CString.valueOf((CharSequence)"SSL is disabled"));
            LOGGER.trace("SSL request is ignored (due to an error on the frontend side)");
        } else {
            LOGGER.trace("SSL is allowed or required on the frontend side");
            if (this.sslSessionInitializer.getServerMode() == SSLSessionInitializer.SSLMode.DISABLED) {
                this.addSSLHandlerOnFrontendSide(ctx);
                LOGGER.trace("Reply SSL to the frontend");
                transferMode = TransferMode.FORGET;
                response = 83;
                LOGGER.trace("SSL request is ignored (SSL is disabled on the backend side)");
            } else {
                LOGGER.trace("Forward the SSL request (SSL is allowed or required on the backend side)");
            }
        }
        SessionMessageTransferMode<Object, Object> mode = new SessionMessageTransferMode<Object, Object>(null, transferMode, response, (Map<Byte, CString>)errorDetails);
        LOGGER.debug("SSL request processed: transfer mode={}", mode);
        return mode;
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    public SessionMessageTransferMode<Byte, Void> processSSLResponse(ChannelHandlerContext ctx, byte code) throws IOException {
        int serverEndpoint2;
        LOGGER.debug("SSL response code: {}", (Object)code);
        TransferMode transferMode = TransferMode.FORWARD;
        Byte newCode = code;
        if (code == 83) {
            if (this.sslRequestReceived.get()) {
                if (this.sslSessionInitializer.getClientMode() == SSLSessionInitializer.SSLMode.DISABLED) {
                    LOGGER.trace("SSL is disabled on the frontend side");
                    LOGGER.trace("Modify SSL code to NO_SSL");
                    newCode = 78;
                } else {
                    int preferredServerEndpoint;
                    serverEndpoint2 = this.getServerEndPoint(ctx);
                    if (serverEndpoint2 != (preferredServerEndpoint = this.getPreferredServerEndPoint(ctx))) {
                        transferMode = TransferMode.FORGET;
                        newCode = null;
                    } else {
                        LOGGER.trace("Forward the SSL response (SSL was required by the frontend)");
                        this.addSSLHandlerOnFrontendSide(ctx);
                    }
                }
            } else {
                LOGGER.trace("SSL response is ignored (frontend did not request SSL)");
                transferMode = TransferMode.FORGET;
                newCode = null;
                this.removeSessionInitializationResponseHandler(ctx, false);
            }
            this.addSSLHandlerOnBackendSide(ctx);
        } else if (code == 78) {
            if (this.sslRequestReceived.get()) {
                int preferredServerEndpoint;
                serverEndpoint2 = this.getServerEndPoint(ctx);
                if (serverEndpoint2 != (preferredServerEndpoint = this.getPreferredServerEndPoint(ctx))) {
                    transferMode = TransferMode.FORGET;
                    newCode = null;
                } else {
                    LOGGER.trace("Forward the SSL response (SSL was required by the frontend)");
                }
            } else {
                LOGGER.trace("SSL response is ignored (frontend did not request SSL)");
                transferMode = TransferMode.FORGET;
                newCode = null;
                this.removeSessionInitializationResponseHandler(ctx, false);
            }
        }
        SessionInitializer serverEndpoint2 = this;
        synchronized (serverEndpoint2) {
            this.sslResponseReceived.incrementAndGet();
            this.notifyAll();
        }
        SessionMessageTransferMode<Byte, Void> mode = new SessionMessageTransferMode<Byte, Void>(newCode, transferMode);
        LOGGER.debug("SSL response processed: new code={}, transfer mode={}", (Object)newCode, mode);
        return mode;
    }

    private int getServerEndPoint(ChannelHandlerContext ctx) {
        Integer serverEndpointNumber = (Integer)ctx.channel().attr(TCPConstants.SERVER_ENDPOINT_NUMBER_KEY).get();
        if (serverEndpointNumber == null) {
            throw new NullPointerException(TCPConstants.SERVER_ENDPOINT_NUMBER_KEY.name() + " is not set");
        }
        PgsqlSession pgsqlSession = this.getPgsqlSession(ctx);
        if (serverEndpointNumber < 0 || serverEndpointNumber >= pgsqlSession.getServerSideChannels().size()) {
            throw new IndexOutOfBoundsException(String.format("invalid %s: value: %d, number of server endpoints: %d ", TCPConstants.SERVER_ENDPOINT_NUMBER_KEY.name(), serverEndpointNumber, pgsqlSession.getServerSideChannels().size()));
        }
        return serverEndpointNumber;
    }

    private int getPreferredServerEndPoint(ChannelHandlerContext ctx) {
        Integer preferredServerEndpoint = (Integer)ctx.channel().attr(TCPConstants.PREFERRED_SERVER_ENDPOINT_KEY).get();
        if (preferredServerEndpoint == null) {
            throw new NullPointerException(TCPConstants.PREFERRED_SERVER_ENDPOINT_KEY.name() + " is not set");
        }
        PgsqlSession pgsqlSession = this.getPgsqlSession(ctx);
        if (preferredServerEndpoint < 0 || preferredServerEndpoint >= pgsqlSession.getServerSideChannels().size()) {
            throw new IndexOutOfBoundsException(String.format("invalid %s: value: %d, number of server endpoints: %d ", TCPConstants.PREFERRED_SERVER_ENDPOINT_KEY.name(), preferredServerEndpoint, pgsqlSession.getServerSideChannels().size()));
        }
        return preferredServerEndpoint;
    }

    public SessionMessageTransferMode<Void, Void> processStartupMessage(ChannelHandlerContext ctx) throws IOException {
        LOGGER.debug("Start-up message");
        TransferMode transferMode = TransferMode.FORWARD;
        LinkedHashMap<Byte, CString> errorDetails = null;
        if (this.sslRequestReceived.get()) {
            LOGGER.trace("Session initialization completed");
            LOGGER.trace("Session {} on the frontend side", (Object)(this.sessionEncryptedOnFrontendSide.get() ? "encrypted with SSL" : "not encrypted"));
            LOGGER.trace("Session {} on the backend side", (Object)(this.sessionEncryptedOnBackendSide.get() ? "encrypted with SSL" : "not encrypted"));
        } else if (this.sslSessionInitializer.getClientMode() == SSLSessionInitializer.SSLMode.REQUIRED) {
            LOGGER.trace("Reply to the frontend that SSL is required");
            transferMode = TransferMode.ERROR;
            errorDetails = new LinkedHashMap<Byte, CString>();
            errorDetails.put((byte)83, CString.valueOf((CharSequence)"FATAL"));
            errorDetails.put((byte)77, CString.valueOf((CharSequence)"SSL is required"));
            LOGGER.trace("SSL request is ignored (due to an error on the frontend side)");
        } else if (this.sslSessionInitializer.getServerMode() == SSLSessionInitializer.SSLMode.REQUIRED) {
            LOGGER.trace("Handle SSL initialization with the backend");
            transferMode = TransferMode.ORCHESTRATE;
        } else {
            LOGGER.trace("Session initialization completed");
            LOGGER.trace("Session {} on the frontend side", (Object)(this.sessionEncryptedOnFrontendSide.get() ? "encrypted with SSL" : "not encrypted"));
            LOGGER.trace("Session {} on the backend side", (Object)(this.sessionEncryptedOnBackendSide.get() ? "encrypted with SSL" : "not encrypted"));
        }
        this.removeSessionInitializationRequestHandler(ctx);
        if (transferMode != TransferMode.ORCHESTRATE) {
            this.removeSessionInitializationResponseHandler(ctx, true);
            this.skipSSLResponse(ctx);
        }
        SessionMessageTransferMode mode = new SessionMessageTransferMode(null, transferMode, (Map<Byte, CString>)errorDetails);
        LOGGER.debug("Start-up message processed: transfer mode={}", mode);
        return mode;
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    public void waitForResponses(ChannelHandlerContext ctx) throws IOException {
        PgsqlSession pgsqlSession = this.getPgsqlSession(ctx);
        SessionInitializer sessionInitializer = this;
        synchronized (sessionInitializer) {
            while (this.sslResponseReceived.get() < pgsqlSession.getServerSideChannels().size()) {
                try {
                    this.wait();
                }
                catch (InterruptedException e) {
                    throw new IOException(e);
                }
            }
        }
    }

    private void addSSLHandlerOnFrontendSide(ChannelHandlerContext ctx) throws IOException {
        Future handshakeFuture = this.sslSessionInitializer.addSSLHandlerOnClientSide(ctx, this.getPgsqlSession(ctx).getClientSideChannel().pipeline());
        handshakeFuture.addListener((GenericFutureListener)new GenericFutureListener<Future<? super Channel>>(){

            public void operationComplete(Future<? super Channel> future) throws Exception {
                SessionInitializer.this.sessionEncryptedOnFrontendSide.set(true);
                LOGGER.trace("SSL handshake for frontend side completed");
            }
        });
    }

    private void addSSLHandlerOnBackendSide(ChannelHandlerContext ctx) throws SSLException {
        Future handshakeFuture = this.sslSessionInitializer.addSSLHandlerOnServerSide(ctx);
        handshakeFuture.addListener((GenericFutureListener)new GenericFutureListener<Future<? super Channel>>(){

            public void operationComplete(Future<? super Channel> future) throws Exception {
                SessionInitializer.this.sessionEncryptedOnBackendSide.set(true);
                LOGGER.trace("SSL handshake for backend side completed");
            }
        });
    }

    private void removeSessionInitializationRequestHandler(ChannelHandlerContext ctx) {
        ChannelPipeline pipeline = ctx.pipeline();
        ChannelHandler handler = pipeline.get("SessionInitializationRequestHandler");
        if (handler != null) {
            pipeline.remove(handler);
        }
    }

    private void removeSessionInitializationResponseHandler(ChannelHandlerContext ctx, boolean all) {
        List<Channel> serverSideChannels;
        if (all) {
            serverSideChannels = this.getPgsqlSession(ctx).getServerSideChannels();
        } else {
            int serverEndPoint = this.getServerEndPoint(ctx);
            Channel serverSideChannel = this.getPgsqlSession(ctx).getServerSideChannel(serverEndPoint);
            serverSideChannels = Collections.singletonList(serverSideChannel);
        }
        for (Channel serverSideChannel : serverSideChannels) {
            ChannelPipeline pipeline = serverSideChannel.pipeline();
            ChannelHandler handler = pipeline.get("SessionInitializationResponseHandler");
            if (handler == null) continue;
            pipeline.remove(handler);
        }
    }

    private void skipSSLResponse(ChannelHandlerContext ctx) {
        for (Channel serverSideChannel : this.getPgsqlSession(ctx).getServerSideChannels()) {
            ChannelPipeline pipeline = serverSideChannel.pipeline();
            PgsqlRawPartCodec codec = (PgsqlRawPartCodec)pipeline.get("PgsqlPartCodec");
            codec.skipFirstMessages();
        }
    }

    public SessionMessageTransferMode<Void, Void> processCancelRequest(ChannelHandlerContext ctx, int code, int processId, int secretKey) throws IOException {
        LOGGER.debug("Cancel request: code={}, process ID={}, secret key={}", new Object[]{code, processId, secretKey});
        TransferMode transferMode = TransferMode.FORWARD;
        Map<Byte, CString> errorDetails = null;
        LOGGER.trace("Forward the cancel request");
        SessionMessageTransferMode mode = new SessionMessageTransferMode(null, transferMode, errorDetails);
        LOGGER.debug("Cancel request processed: transfer mode={}", mode);
        return mode;
    }

    private PgsqlSession getPgsqlSession(ChannelHandlerContext ctx) {
        PgsqlSession pgsqlSession = (PgsqlSession)((Object)ctx.channel().attr(PgsqlConstants.SESSION_KEY).get());
        return pgsqlSession;
    }
}

