/*
 * Decompiled with CFR 0.152.
 */
package org.apache.avro.ipc;

import java.io.EOFException;
import java.io.IOException;
import java.io.UnsupportedEncodingException;
import java.net.SocketAddress;
import java.nio.ByteBuffer;
import java.nio.channels.SocketChannel;
import java.util.ArrayList;
import java.util.List;
import javax.security.sasl.SaslClient;
import javax.security.sasl.SaslException;
import javax.security.sasl.SaslServer;
import org.apache.avro.Protocol;
import org.apache.avro.ipc.Transceiver;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class SaslSocketTransceiver
extends Transceiver {
    private static final Logger LOG = LoggerFactory.getLogger(SaslSocketTransceiver.class);
    private static final ByteBuffer EMPTY = ByteBuffer.allocate(0);
    private SaslParticipant sasl;
    private SocketChannel channel;
    private boolean dataIsWrapped;
    private boolean saslResponsePiggybacked;
    private Protocol remote;
    private ByteBuffer readHeader = ByteBuffer.allocate(4);
    private ByteBuffer writeHeader = ByteBuffer.allocate(4);
    private ByteBuffer zeroHeader = ByteBuffer.allocate(4).putInt(0);

    public SaslSocketTransceiver(SocketAddress address) throws IOException {
        this(address, new AnonymousClient());
    }

    public SaslSocketTransceiver(SocketAddress address, SaslClient saslClient) throws IOException {
        this.sasl = new SaslParticipant(saslClient);
        this.channel = SocketChannel.open(address);
        this.channel.socket().setTcpNoDelay(true);
        LOG.debug("open to {}", (Object)this.getRemoteName());
        this.open(true);
    }

    public SaslSocketTransceiver(SocketChannel channel, SaslServer saslServer) throws IOException {
        this.sasl = new SaslParticipant(saslServer);
        this.channel = channel;
        LOG.debug("open from {}", (Object)this.getRemoteName());
        this.open(false);
    }

    @Override
    public boolean isConnected() {
        return this.remote != null;
    }

    @Override
    public void setRemote(Protocol remote) {
        this.remote = remote;
    }

    @Override
    public Protocol getRemote() {
        return this.remote;
    }

    @Override
    public String getRemoteName() {
        return this.channel.socket().getRemoteSocketAddress().toString();
    }

    @Override
    public synchronized List<ByteBuffer> transceive(List<ByteBuffer> request) throws IOException {
        if (this.saslResponsePiggybacked) {
            this.saslResponsePiggybacked = false;
            Status status = this.readStatus();
            ByteBuffer frame = this.readFrame();
            switch (status) {
                case COMPLETE: {
                    break;
                }
                case FAIL: {
                    throw new SaslException("Fail: " + this.toString(frame));
                }
                default: {
                    throw new IOException("Unexpected SASL status: " + (Object)((Object)status));
                }
            }
        }
        return super.transceive(request);
    }

    private void open(boolean isClient) throws IOException {
        LOG.debug("beginning SASL negotiation");
        if (isClient) {
            ByteBuffer response = EMPTY;
            if (this.sasl.client.hasInitialResponse()) {
                response = ByteBuffer.wrap(this.sasl.evaluate(response.array()));
            }
            this.write(Status.START, this.sasl.getMechanismName(), response);
            if (this.sasl.isComplete()) {
                this.saslResponsePiggybacked = true;
            }
        }
        block8: while (!this.sasl.isComplete()) {
            Status status = this.readStatus();
            ByteBuffer frame = this.readFrame();
            switch (status) {
                case START: {
                    String mechanism = this.toString(frame);
                    frame = this.readFrame();
                    if (!mechanism.equalsIgnoreCase(this.sasl.getMechanismName())) {
                        this.write(Status.FAIL, "Wrong mechanism: " + mechanism);
                        throw new SaslException("Wrong mechanism: " + mechanism);
                    }
                }
                case CONTINUE: {
                    byte[] response;
                    try {
                        response = this.sasl.evaluate(frame.array());
                        status = this.sasl.isComplete() ? Status.COMPLETE : Status.CONTINUE;
                    }
                    catch (SaslException e) {
                        response = e.toString().getBytes("UTF-8");
                        status = Status.FAIL;
                    }
                    this.write(status, response != null ? ByteBuffer.wrap(response) : EMPTY);
                    continue block8;
                }
                case COMPLETE: {
                    this.sasl.evaluate(frame.array());
                    if (this.sasl.isComplete()) continue block8;
                    throw new SaslException("Expected completion!");
                }
                case FAIL: {
                    throw new SaslException("Fail: " + this.toString(frame));
                }
            }
            throw new IOException("Unexpected SASL status: " + (Object)((Object)status));
        }
        LOG.debug("SASL opened");
        String qop = (String)this.sasl.getNegotiatedProperty("javax.security.sasl.qop");
        LOG.debug("QOP = {}", (Object)qop);
        this.dataIsWrapped = qop != null && !qop.equalsIgnoreCase("auth");
    }

    private String toString(ByteBuffer buffer) throws IOException {
        try {
            return new String(buffer.array(), "UTF-8");
        }
        catch (UnsupportedEncodingException e) {
            throw new IOException(e.toString(), e);
        }
    }

    @Override
    public synchronized List<ByteBuffer> readBuffers() throws IOException {
        ArrayList<ByteBuffer> buffers = new ArrayList<ByteBuffer>();
        ByteBuffer buffer;
        while ((buffer = this.readFrameAndUnwrap()).remaining() != 0) {
            buffers.add(buffer);
        }
        return buffers;
    }

    private Status readStatus() throws IOException {
        ByteBuffer buffer = ByteBuffer.allocate(1);
        this.read(buffer);
        byte status = buffer.get();
        if (status > Status.values().length) {
            throw new IOException("Unexpected SASL status byte: " + status);
        }
        return Status.values()[status];
    }

    private ByteBuffer readFrameAndUnwrap() throws IOException {
        ByteBuffer frame = this.readFrame();
        if (!this.dataIsWrapped) {
            return frame;
        }
        ByteBuffer unwrapped = ByteBuffer.wrap(this.sasl.unwrap(frame.array()));
        LOG.debug("unwrapped data of length: {}", (Object)unwrapped.remaining());
        return unwrapped;
    }

    private ByteBuffer readFrame() throws IOException {
        this.read(this.readHeader);
        ByteBuffer buffer = ByteBuffer.allocate(this.readHeader.getInt());
        LOG.debug("about to read: {} bytes", (Object)buffer.capacity());
        this.read(buffer);
        return buffer;
    }

    private void read(ByteBuffer buffer) throws IOException {
        buffer.clear();
        while (buffer.hasRemaining()) {
            if (this.channel.read(buffer) != -1) continue;
            throw new EOFException();
        }
        buffer.flip();
    }

    @Override
    public synchronized void writeBuffers(List<ByteBuffer> buffers) throws IOException {
        if (buffers == null) {
            return;
        }
        ArrayList<ByteBuffer> writes = new ArrayList<ByteBuffer>(buffers.size() * 2 + 1);
        int currentLength = 0;
        ByteBuffer currentHeader = this.writeHeader;
        for (ByteBuffer buffer : buffers) {
            if (buffer.remaining() == 0) continue;
            if (this.dataIsWrapped) {
                LOG.debug("wrapping data of length: {}", (Object)buffer.remaining());
                buffer = ByteBuffer.wrap(this.sasl.wrap(buffer.array(), buffer.position(), buffer.remaining()));
            }
            int length = buffer.remaining();
            if (!this.dataIsWrapped && currentLength + length <= 8192) {
                if (currentLength == 0) {
                    writes.add(currentHeader);
                }
                currentHeader.clear();
                currentHeader.putInt(currentLength += length);
                LOG.debug("adding {} to write, total now {}", (Object)length, (Object)currentLength);
            } else {
                currentLength = length;
                currentHeader = ByteBuffer.allocate(4).putInt(length);
                writes.add(currentHeader);
                LOG.debug("planning write of {}", (Object)length);
            }
            currentHeader.flip();
            writes.add(buffer);
        }
        this.zeroHeader.flip();
        writes.add(this.zeroHeader);
        this.writeFully(writes.toArray(new ByteBuffer[writes.size()]));
    }

    private void write(Status status, String prefix, ByteBuffer response) throws IOException {
        LOG.debug("write status: {} {}", (Object)status, (Object)prefix);
        this.write(status, prefix);
        this.write(response);
    }

    private void write(Status status, String response) throws IOException {
        this.write(status, ByteBuffer.wrap(response.getBytes("UTF-8")));
    }

    private void write(Status status, ByteBuffer response) throws IOException {
        LOG.debug("write status: {}", (Object)status);
        ByteBuffer statusBuffer = ByteBuffer.allocate(1);
        statusBuffer.clear();
        statusBuffer.put((byte)status.ordinal()).flip();
        this.writeFully(statusBuffer);
        this.write(response);
    }

    private void write(ByteBuffer response) throws IOException {
        LOG.debug("writing: {}", (Object)response.remaining());
        this.writeHeader.clear();
        this.writeHeader.putInt(response.remaining()).flip();
        this.writeFully(this.writeHeader, response);
    }

    private void writeFully(ByteBuffer ... buffers) throws IOException {
        int length = buffers.length;
        int start = 0;
        block0: while (true) {
            this.channel.write(buffers, start, length - start);
            do {
                if (buffers[start].remaining() != 0) continue block0;
            } while (++start != length);
            break;
        }
    }

    @Override
    public void close() throws IOException {
        if (this.channel.isOpen()) {
            LOG.info("closing to " + this.getRemoteName());
            this.channel.close();
        }
        this.sasl.dispose();
    }

    private static class AnonymousClient
    implements SaslClient {
        private AnonymousClient() {
        }

        @Override
        public String getMechanismName() {
            return "ANONYMOUS";
        }

        @Override
        public boolean hasInitialResponse() {
            return true;
        }

        @Override
        public byte[] evaluateChallenge(byte[] challenge) throws SaslException {
            try {
                return System.getProperty("user.name").getBytes("UTF-8");
            }
            catch (IOException e) {
                throw new SaslException(e.toString());
            }
        }

        @Override
        public boolean isComplete() {
            return true;
        }

        @Override
        public byte[] unwrap(byte[] incoming, int offset, int len) {
            throw new UnsupportedOperationException();
        }

        @Override
        public byte[] wrap(byte[] outgoing, int offset, int len) {
            throw new UnsupportedOperationException();
        }

        @Override
        public Object getNegotiatedProperty(String propName) {
            return null;
        }

        @Override
        public void dispose() {
        }
    }

    private static class SaslParticipant {
        public SaslServer server;
        public SaslClient client;

        public SaslParticipant(SaslServer server) {
            this.server = server;
        }

        public SaslParticipant(SaslClient client) {
            this.client = client;
        }

        public String getMechanismName() {
            if (this.client != null) {
                return this.client.getMechanismName();
            }
            return this.server.getMechanismName();
        }

        public boolean isComplete() {
            if (this.client != null) {
                return this.client.isComplete();
            }
            return this.server.isComplete();
        }

        public void dispose() throws SaslException {
            if (this.client != null) {
                this.client.dispose();
            } else {
                this.server.dispose();
            }
        }

        public byte[] unwrap(byte[] buf) throws SaslException {
            if (this.client != null) {
                return this.client.unwrap(buf, 0, buf.length);
            }
            return this.server.unwrap(buf, 0, buf.length);
        }

        public byte[] wrap(byte[] buf) throws SaslException {
            return this.wrap(buf, 0, buf.length);
        }

        public byte[] wrap(byte[] buf, int start, int len) throws SaslException {
            if (this.client != null) {
                return this.client.wrap(buf, start, len);
            }
            return this.server.wrap(buf, start, len);
        }

        public Object getNegotiatedProperty(String propName) {
            if (this.client != null) {
                return this.client.getNegotiatedProperty(propName);
            }
            return this.server.getNegotiatedProperty(propName);
        }

        public byte[] evaluate(byte[] buf) throws SaslException {
            if (this.client != null) {
                return this.client.evaluateChallenge(buf);
            }
            return this.server.evaluateResponse(buf);
        }
    }

    private static enum Status {
        START,
        CONTINUE,
        FAIL,
        COMPLETE;

    }
}

