/*
 * Decompiled with CFR 0.152.
 */
package org.apache.iotdb.library.match;

import java.nio.charset.Charset;
import java.util.Arrays;
import java.util.List;
import java.util.Map;
import org.apache.iotdb.library.match.model.DTWMatchResult;
import org.apache.iotdb.library.match.model.DTWState;
import org.apache.iotdb.udf.api.State;
import org.apache.iotdb.udf.api.UDAF;
import org.apache.iotdb.udf.api.customizer.config.UDAFConfigurations;
import org.apache.iotdb.udf.api.customizer.parameter.UDFParameterValidator;
import org.apache.iotdb.udf.api.customizer.parameter.UDFParameters;
import org.apache.iotdb.udf.api.type.Type;
import org.apache.iotdb.udf.api.utils.ResultValue;
import org.apache.tsfile.block.column.Column;
import org.apache.tsfile.utils.Binary;
import org.apache.tsfile.utils.BitMap;

public class UDAFDTWMatch
implements UDAF {
    private Double[] pattern;
    private float threshold;
    private DTWState state;

    @Override
    public void beforeStart(UDFParameters udfParameters, UDAFConfigurations udafConfigurations) {
        udafConfigurations.setOutputDataType(Type.TEXT);
        Map<String, String> attributes = udfParameters.getAttributes();
        this.threshold = Float.parseFloat(attributes.get("threshold"));
        this.pattern = (Double[])Arrays.stream(attributes.get("pattern").split(",")).map(Double::valueOf).toArray(Double[]::new);
        if (this.state != null) {
            this.state.setSize(this.pattern.length);
        }
    }

    @Override
    public State createState() {
        this.state = this.pattern != null ? new DTWState(this.pattern.length) : new DTWState();
        return this.state;
    }

    @Override
    public void addInput(State state, Column[] columns, BitMap bitMap) {
        DTWState DTWState2 = (DTWState)state;
        int count = columns[0].getPositionCount();
        for (int i = 0; i < count; ++i) {
            float dtw;
            if (bitMap != null && !bitMap.isMarked(i) || columns[1].isNull(i)) continue;
            long timestamp = columns[1].getLong(i);
            double value = this.getValue(columns[0], i);
            DTWState2.updateBuffer(timestamp, value);
            if (DTWState2.getValueBuffer().length != this.pattern.length || !((dtw = this.calculateDTW(DTWState2.getValueBuffer(), this.pattern)) <= this.threshold)) continue;
            ((DTWState)state).addMatchResult(new DTWMatchResult(dtw, DTWState2.getFirstTime(), DTWState2.getLastTime()));
        }
    }

    private double getValue(Column column, int i) {
        switch (column.getDataType()) {
            case INT32: {
                return column.getInt(i);
            }
            case INT64: {
                return column.getLong(i);
            }
            case FLOAT: {
                return column.getFloat(i);
            }
            case DOUBLE: {
                return column.getDouble(i);
            }
            case BOOLEAN: {
                return column.getBoolean(i) ? 1.0 : 0.0;
            }
        }
        throw new RuntimeException(String.format("Unsupported datatype %s", new Object[]{column.getDataType()}));
    }

    private float calculateDTW(Double[] series1, Double[] series2) {
        int j;
        int i;
        int n = series1.length;
        double[][] dtw = new double[n][n];
        for (i = 0; i < n; ++i) {
            for (j = 0; j < n; ++j) {
                dtw[i][j] = Double.POSITIVE_INFINITY;
            }
        }
        dtw[0][0] = 0.0;
        for (i = 1; i < n; ++i) {
            for (j = 1; j < n; ++j) {
                double cost = Math.abs(series1[i] - series2[j]);
                dtw[i][j] = cost + Math.min(Math.min(dtw[i - 1][j], dtw[i][j - 1]), dtw[i - 1][j - 1]);
            }
        }
        return (float)dtw[n - 1][n - 1];
    }

    @Override
    public void combineState(State state, State state1) {
        DTWState dtwState = (DTWState)state;
        DTWState newDTWState = (DTWState)state1;
        Long[] times = newDTWState.getTimeBuffer();
        Double[] values = newDTWState.getValueBuffer();
        for (int i = 0; i < times.length; ++i) {
            float dtw;
            if (times[i] <= dtwState.getFirstTime()) continue;
            dtwState.updateBuffer(times[i], values[i]);
            if (dtwState.getValueBuffer().length != this.pattern.length || !((dtw = this.calculateDTW(dtwState.getValueBuffer(), this.pattern)) <= this.threshold)) continue;
            dtwState.addMatchResult(new DTWMatchResult(dtw, dtwState.getFirstTime(), dtwState.getLastTime()));
        }
    }

    public List<DTWMatchResult> calcMatch(List<Long> times, List<Double> values, Double[] pattern, float threshold) {
        this.pattern = pattern;
        this.threshold = threshold;
        DTWState dtwState = (DTWState)this.createState();
        dtwState.reset();
        for (int i = 0; i < times.size(); ++i) {
            float dtw;
            dtwState.updateBuffer(times.get(i), values.get(i));
            if (dtwState.getValueBuffer().length != pattern.length || !((dtw = this.calculateDTW(dtwState.getValueBuffer(), pattern)) <= threshold)) continue;
            dtwState.addMatchResult(new DTWMatchResult(dtw, dtwState.getFirstTime(), dtwState.getLastTime()));
        }
        return dtwState.getMatchResults();
    }

    @Override
    public void outputFinal(State state, ResultValue resultValue) {
        DTWState DTWState2 = (DTWState)state;
        List<DTWMatchResult> matchResults = DTWState2.getMatchResults();
        if (!matchResults.isEmpty()) {
            resultValue.setBinary(new Binary(matchResults.toString(), Charset.defaultCharset()));
        } else {
            resultValue.setNull();
        }
    }

    @Override
    public void removeState(State state, State removed) {
    }

    @Override
    public void validate(UDFParameterValidator validator) {
        validator.validateInputSeriesNumber(1).validateInputSeriesDataType(0, Type.INT32, Type.INT64, Type.FLOAT, Type.DOUBLE, Type.BOOLEAN).validateRequiredAttribute("pattern").validateRequiredAttribute("threshold");
    }
}

