/*
 * Decompiled with CFR 0.152.
 */
package org.apache.sysds.runtime.instructions.fed;

import java.net.InetAddress;
import java.net.InetSocketAddress;
import java.net.MalformedURLException;
import java.net.URL;
import java.net.UnknownHostException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.concurrent.Future;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.TimeoutException;
import org.apache.commons.lang3.tuple.ImmutablePair;
import org.apache.commons.lang3.tuple.Pair;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.apache.sysds.common.Types;
import org.apache.sysds.conf.ConfigurationManager;
import org.apache.sysds.runtime.DMLRuntimeException;
import org.apache.sysds.runtime.controlprogram.caching.CacheableData;
import org.apache.sysds.runtime.controlprogram.caching.FrameObject;
import org.apache.sysds.runtime.controlprogram.context.ExecutionContext;
import org.apache.sysds.runtime.controlprogram.federated.FederatedData;
import org.apache.sysds.runtime.controlprogram.federated.FederatedRange;
import org.apache.sysds.runtime.controlprogram.federated.FederatedResponse;
import org.apache.sysds.runtime.controlprogram.federated.FederationMap;
import org.apache.sysds.runtime.controlprogram.federated.FederationUtils;
import org.apache.sysds.runtime.instructions.InstructionUtils;
import org.apache.sysds.runtime.instructions.cp.CPOperand;
import org.apache.sysds.runtime.instructions.cp.Data;
import org.apache.sysds.runtime.instructions.cp.ListObject;
import org.apache.sysds.runtime.instructions.cp.ScalarObject;
import org.apache.sysds.runtime.instructions.cp.StringObject;
import org.apache.sysds.runtime.instructions.fed.FEDInstruction;
import org.apache.sysds.runtime.lineage.LineageItem;
import org.apache.sysds.runtime.lineage.LineageTraceable;
import org.apache.sysds.runtime.meta.DataCharacteristics;

public class InitFEDInstruction
extends FEDInstruction
implements LineageTraceable {
    private static final Log LOG = LogFactory.getLog((String)InitFEDInstruction.class.getName());
    public static final String FED_MATRIX_IDENTIFIER = "matrix";
    public static final String FED_FRAME_IDENTIFIER = "frame";
    private CPOperand _type;
    private CPOperand _addresses;
    private CPOperand _ranges;
    private CPOperand _output;

    public InitFEDInstruction(CPOperand type, CPOperand addresses, CPOperand ranges, CPOperand out, String opcode, String instr) {
        super(FEDInstruction.FEDType.Init, opcode, instr);
        this._type = type;
        this._addresses = addresses;
        this._ranges = ranges;
        this._output = out;
    }

    public static InitFEDInstruction parseInstruction(String str) {
        String[] parts = InstructionUtils.getInstructionPartsWithValueType(str);
        if (parts.length != 5) {
            throw new DMLRuntimeException("Invalid number of operands in federated instruction: " + str);
        }
        String opcode = parts[0];
        CPOperand type = new CPOperand(parts[1]);
        CPOperand addresses = new CPOperand(parts[2]);
        CPOperand ranges = new CPOperand(parts[3]);
        CPOperand out = new CPOperand(parts[4]);
        return new InitFEDInstruction(type, addresses, ranges, out, opcode, str);
    }

    @Override
    public void processInstruction(ExecutionContext ec) {
        Types.DataType fedDataType;
        String type = ec.getScalarInput(this._type).getStringValue();
        ListObject addresses = ec.getListObject(this._addresses.getName());
        ListObject ranges = ec.getListObject(this._ranges.getName());
        ArrayList<Pair<FederatedRange, FederatedData>> feds = new ArrayList<Pair<FederatedRange, FederatedData>>();
        if (addresses.getLength() * 2 != ranges.getLength()) {
            throw new DMLRuntimeException("Federated read needs twice the amount of addresses as ranges (begin and end): addresses=" + addresses.getLength() + " ranges=" + ranges.getLength());
        }
        if (type.equalsIgnoreCase(FED_MATRIX_IDENTIFIER)) {
            fedDataType = Types.DataType.MATRIX;
        } else if (type.equalsIgnoreCase(FED_FRAME_IDENTIFIER)) {
            fedDataType = Types.DataType.FRAME;
        } else {
            throw new DMLRuntimeException("type \"" + type + "\" non valid federated type");
        }
        long[] usedDims = new long[]{0L, 0L};
        for (int i = 0; i < addresses.getLength(); ++i) {
            Data addressData = addresses.getData().get(i);
            if (addressData instanceof StringObject) {
                String[] parsedValues = InitFEDInstruction.parseURL(((StringObject)addressData).getStringValue());
                String host = parsedValues[0];
                int port = Integer.parseInt(parsedValues[1]);
                String filePath = parsedValues[2];
                List<Data> rangesData = ranges.getData();
                Data beginData = rangesData.get(i * 2);
                Data endData = rangesData.get(i * 2 + 1);
                if (beginData.getDataType() != Types.DataType.LIST || endData.getDataType() != Types.DataType.LIST) {
                    throw new DMLRuntimeException("Federated read ranges (lower, upper) have to be lists of dimensions");
                }
                List<Data> beginDimsData = ((ListObject)beginData).getData();
                List<Data> endDimsData = ((ListObject)endData).getData();
                long[] beginDims = new long[beginDimsData.size()];
                long[] endDims = new long[beginDims.length];
                for (int d = 0; d < beginDims.length; ++d) {
                    beginDims[d] = ((ScalarObject)beginDimsData.get(d)).getLongValue();
                    endDims[d] = ((ScalarObject)endDimsData.get(d)).getLongValue();
                }
                usedDims[0] = Math.max(usedDims[0], endDims[0]);
                usedDims[1] = Math.max(usedDims[1], endDims[1]);
                try {
                    FederatedData federatedData = new FederatedData(fedDataType, new InetSocketAddress(InetAddress.getByName(host), port), filePath);
                    feds.add((Pair<FederatedRange, FederatedData>)new ImmutablePair((Object)new FederatedRange(beginDims, endDims), (Object)federatedData));
                    continue;
                }
                catch (UnknownHostException e) {
                    throw new DMLRuntimeException("federated host was unknown: " + host);
                }
            }
            throw new DMLRuntimeException("federated instruction only takes strings as addresses");
        }
        if (type.equalsIgnoreCase(FED_MATRIX_IDENTIFIER)) {
            CacheableData<?> output = ec.getCacheableData(this._output);
            output.getDataCharacteristics().setRows(usedDims[0]).setCols(usedDims[1]);
            InitFEDInstruction.federateMatrix(output, feds);
        } else if (type.equalsIgnoreCase(FED_FRAME_IDENTIFIER)) {
            if (usedDims[1] > Integer.MAX_VALUE) {
                throw new DMLRuntimeException("federated Frame can not have more than max int columns, because the schema can only be max int length");
            }
            FrameObject output = ec.getFrameObject(this._output);
            output.getDataCharacteristics().setRows(usedDims[0]).setCols(usedDims[1]);
            InitFEDInstruction.federateFrame(output, feds);
        } else {
            throw new DMLRuntimeException("type \"" + type + "\" non valid federated type");
        }
    }

    public static String[] parseURL(String input) {
        try {
            String filePath;
            URL address = new URL("http://" + input);
            String host = address.getHost();
            if (host.length() == 0) {
                throw new IllegalArgumentException("Missing Host name for federated address");
            }
            String ipRegex = "^(?:(?:25[0-5]|2[0-4][0-9]|[01]?[0-9][0-9]?)\\.){3}(?:25[0-5]|2[0-4][0-9]|[01]?[0-9][0-9]?)$";
            if (host.matches("^\\d+\\.\\d+\\.\\d+\\.\\d+$") && !host.matches(ipRegex)) {
                throw new IllegalArgumentException("Input Host address looks like an IP address but is outside range");
            }
            int port = address.getPort();
            if (port == -1) {
                port = 4040;
            }
            if ((filePath = address.getPath()).length() <= 1) {
                throw new IllegalArgumentException("Missing File path for federated address");
            }
            filePath = filePath.substring(1);
            if (address.getQuery() != null) {
                throw new IllegalArgumentException("Query is not supported");
            }
            if (address.getRef() != null) {
                throw new IllegalArgumentException("Reference is not supported");
            }
            return new String[]{host, String.valueOf(port), filePath};
        }
        catch (MalformedURLException e) {
            throw new IllegalArgumentException("federated address `" + input + "` does not fit required URL pattern of \"host:port/directory\"", e);
        }
    }

    public static void federateMatrix(CacheableData<?> output, List<Pair<FederatedRange, FederatedData>> workers) {
        ArrayList<Pair<FederatedRange, FederatedData>> fedMapping = new ArrayList<Pair<FederatedRange, FederatedData>>();
        for (Pair<FederatedRange, FederatedData> e : workers) {
            fedMapping.add(e);
        }
        ArrayList<ImmutablePair> idResponses = new ArrayList<ImmutablePair>();
        long id = FederationUtils.getNextFedDataID();
        boolean rowPartitioned = true;
        boolean colPartitioned = true;
        for (Pair pair : fedMapping) {
            FederatedRange federatedRange = (FederatedRange)pair.getKey();
            FederatedData value = (FederatedData)pair.getValue();
            if (!value.isInitialized()) {
                long[] beginDims = federatedRange.getBeginDims();
                long[] endDims = federatedRange.getEndDims();
                long[] dims = output.getDataCharacteristics().getDims();
                for (int i = 0; i < dims.length; ++i) {
                    dims[i] = endDims[i] - beginDims[i];
                }
                idResponses.add(new ImmutablePair((Object)value, value.initFederatedData(id)));
            }
            rowPartitioned &= federatedRange.getSize(1) == output.getNumColumns();
            colPartitioned &= federatedRange.getSize(0) == output.getNumRows();
        }
        try {
            int timeout = ConfigurationManager.getDMLConfig().getIntValue("sysds.federated.initialization.timeout");
            if (LOG.isDebugEnabled()) {
                LOG.debug((Object)("Federated Initialization with timeout: " + timeout));
            }
            for (Pair pair : idResponses) {
                FederatedResponse re = (FederatedResponse)((Future)pair.getRight()).get(timeout, TimeUnit.SECONDS);
                DataCharacteristics dc = (DataCharacteristics)re.getData()[1];
                if (dc.getRows() <= output.getNumRows() && dc.getCols() <= output.getNumColumns()) continue;
                throw new DMLRuntimeException("Invalid federated meta data: " + output.getDataCharacteristics() + " vs federated response: " + dc);
            }
        }
        catch (TimeoutException e) {
            throw new DMLRuntimeException("Federated Initialization timeout exceeded", e);
        }
        catch (Exception e) {
            throw new DMLRuntimeException("Federation initialization failed", e);
        }
        output.getDataCharacteristics().setNonZeros(-1L);
        output.getDataCharacteristics().setBlocksize(ConfigurationManager.getBlocksize());
        output.setFedMapping(new FederationMap(id, fedMapping));
        output.getFedMapping().setType(rowPartitioned && colPartitioned ? FederationMap.FType.FULL : (rowPartitioned ? FederationMap.FType.ROW : (colPartitioned ? FederationMap.FType.COL : FederationMap.FType.OTHER)));
        if (LOG.isDebugEnabled()) {
            LOG.debug((Object)("Fed map Inited:" + output.getFedMapping()));
        }
    }

    public static void federateFrame(FrameObject output, List<Pair<FederatedRange, FederatedData>> workers) {
        ArrayList<Pair<FederatedRange, FederatedData>> fedMapping = new ArrayList<Pair<FederatedRange, FederatedData>>();
        for (Pair<FederatedRange, FederatedData> e : workers) {
            fedMapping.add(e);
        }
        ArrayList<ImmutablePair> idResponses = new ArrayList<ImmutablePair>();
        long id = FederationUtils.getNextFedDataID();
        boolean rowPartitioned = true;
        boolean colPartitioned = true;
        for (Pair pair : fedMapping) {
            FederatedRange federatedRange = (FederatedRange)pair.getKey();
            FederatedData value = (FederatedData)pair.getValue();
            if (!value.isInitialized()) {
                long[] beginDims = federatedRange.getBeginDims();
                long[] endDims = federatedRange.getEndDims();
                long[] dims = output.getDataCharacteristics().getDims();
                for (int i = 0; i < dims.length; ++i) {
                    dims[i] = endDims[i] - beginDims[i];
                }
                idResponses.add(new ImmutablePair((Object)value, (Object)new ImmutablePair((Object)((int)beginDims[1]), value.initFederatedData(id))));
            }
            rowPartitioned &= federatedRange.getSize(1) == output.getNumColumns();
            colPartitioned &= federatedRange.getSize(0) == output.getNumRows();
        }
        Types.ValueType[] schema = new Types.ValueType[(int)output.getNumColumns()];
        Arrays.fill((Object[])schema, (Object)Types.ValueType.UNKNOWN);
        try {
            for (Pair pair : idResponses) {
                FederatedData fedData = (FederatedData)pair.getLeft();
                FederatedResponse response = (FederatedResponse)((Future)((Pair)pair.getRight()).getRight()).get();
                int startCol = (Integer)((Pair)pair.getRight()).getLeft();
                InitFEDInstruction.handleFedFrameResponse(schema, fedData, response, startCol);
                DataCharacteristics dc = (DataCharacteristics)response.getData()[2];
                if (dc.getRows() <= output.getNumRows() && dc.getCols() <= output.getNumColumns()) continue;
                throw new DMLRuntimeException("Invalid federated meta data: " + output.getDataCharacteristics() + " vs federated response: " + dc);
            }
        }
        catch (Exception exception) {
            throw new DMLRuntimeException("Federation initialization failed", exception);
        }
        output.getDataCharacteristics().setNonZeros(output.getNumColumns() * output.getNumRows());
        output.setSchema(schema);
        output.setFedMapping(new FederationMap(id, fedMapping));
        output.getFedMapping().setType(rowPartitioned && colPartitioned ? FederationMap.FType.FULL : (rowPartitioned ? FederationMap.FType.ROW : (colPartitioned ? FederationMap.FType.COL : FederationMap.FType.OTHER)));
        if (LOG.isDebugEnabled()) {
            LOG.debug((Object)("Fed map Inited: " + output.getFedMapping()));
        }
    }

    private static void handleFedFrameResponse(Types.ValueType[] schema, FederatedData federatedData, FederatedResponse response, int startColumn) {
        try {
            Object[] data = response.getData();
            federatedData.setVarID((Long)data[0]);
            Types.ValueType[] range_schema = (Types.ValueType[])data[1];
            for (int i = 0; i < range_schema.length; ++i) {
                Types.ValueType vType = range_schema[i];
                int schema_index = startColumn + i;
                if (schema[schema_index] != Types.ValueType.UNKNOWN && schema[schema_index] != vType) {
                    throw new DMLRuntimeException("federated Frame schemas mismatch");
                }
                schema[schema_index] = vType;
            }
        }
        catch (Exception e) {
            throw new DMLRuntimeException("Exception in frame response from federated worker.", e);
        }
    }

    @Override
    public Pair<String, LineageItem> getLineageItem(ExecutionContext ec) {
        String type = ec.getScalarInput(this._type).getStringValue();
        ListObject addresses = ec.getListObject(this._addresses.getName());
        ListObject ranges = ec.getListObject(this._ranges.getName());
        LineageItem[] liInputs = new LineageItem[addresses.getLength()];
        for (int i = 0; i < addresses.getLength(); ++i) {
            Data addressData = addresses.getData().get(i);
            if (!(addressData instanceof StringObject)) {
                throw new DMLRuntimeException("federated instruction only takes strings as addresses");
            }
            String address = ((StringObject)addressData).getStringValue();
            List<Data> rangesData = ranges.getData();
            List<Data> beginDimsData = ((ListObject)rangesData.get(i * 2)).getData();
            List<Data> endDimsData = ((ListObject)rangesData.get(i * 2 + 1)).getData();
            String rl = ((ScalarObject)beginDimsData.get(0)).getStringValue();
            String cl = ((ScalarObject)beginDimsData.get(1)).getStringValue();
            String ru = ((ScalarObject)endDimsData.get(0)).getStringValue();
            String cu = ((ScalarObject)endDimsData.get(1)).getStringValue();
            String data = InstructionUtils.concatOperands(type, address, rl, cl, ru, cu);
            liInputs[i] = new LineageItem(data);
        }
        return Pair.of((Object)this._output.getName(), (Object)new LineageItem(this.getOpcode(), liInputs));
    }
}

