/*
 * Decompiled with CFR 0.152.
 */
package com.googlecode.clearnlp.component.srl;

import com.carrotsearch.hppc.ObjectIntOpenHashMap;
import com.googlecode.clearnlp.classification.model.StringModel;
import com.googlecode.clearnlp.classification.prediction.StringPrediction;
import com.googlecode.clearnlp.classification.train.StringTrainSpace;
import com.googlecode.clearnlp.classification.vector.StringFeatureVector;
import com.googlecode.clearnlp.component.AbstractStatisticalComponent;
import com.googlecode.clearnlp.dependency.DEPArc;
import com.googlecode.clearnlp.dependency.DEPNode;
import com.googlecode.clearnlp.dependency.DEPTree;
import com.googlecode.clearnlp.feature.xml.FtrToken;
import com.googlecode.clearnlp.feature.xml.JointFtrXml;
import com.googlecode.clearnlp.util.UTInput;
import com.googlecode.clearnlp.util.UTOutput;
import java.io.BufferedReader;
import java.io.InputStreamReader;
import java.io.PrintStream;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.regex.Matcher;
import java.util.zip.ZipEntry;
import java.util.zip.ZipInputStream;
import java.util.zip.ZipOutputStream;

public class CSenseClassifier
extends AbstractStatisticalComponent {
    private String ENTRY_CONFIGURATION;
    private String ENTRY_FEATURE;
    private String ENTRY_LEXICA;
    private String ENTRY_MODEL;
    protected final int LEXICA_SENSES = 0;
    protected final int LEXICA_LEMMAS = 1;
    protected Map<String, Set<String>> m_collect;
    protected Map<String, String> m_senses;
    protected ObjectIntOpenHashMap<String> m_lemmas;
    protected String[] g_senses;
    protected int i_pred;
    protected String s_key;

    public CSenseClassifier(JointFtrXml[] xmls, String key) {
        super(xmls);
        this.initKey(key);
        this.m_collect = new HashMap<String, Set<String>>();
    }

    public CSenseClassifier(JointFtrXml[] xmls, StringTrainSpace[] spaces, Object[] lexica, String key) {
        super(xmls, spaces, lexica);
        this.initKey(key);
    }

    public CSenseClassifier(JointFtrXml[] xmls, StringModel[] models, Object[] lexica, String key) {
        super(xmls, models, lexica);
        this.initKey(key);
    }

    public CSenseClassifier(ZipInputStream in, String key) {
        this.i_flag = (byte)2;
        this.initKey(key);
        this.loadModels(in);
    }

    @Override
    protected void initLexia(Object[] lexica) {
        this.m_senses = (Map)lexica[0];
        this.m_lemmas = (ObjectIntOpenHashMap)lexica[1];
    }

    private void initKey(String key) {
        this.ENTRY_CONFIGURATION = "sense_" + key + "_CONFIGURATION";
        this.ENTRY_FEATURE = "sense_" + key + "_FEATURE";
        this.ENTRY_LEXICA = "sense_" + key + "_LEXICA";
        this.ENTRY_MODEL = "sense_" + key + "_MODEL";
        this.s_key = key;
    }

    @Override
    public void loadModels(ZipInputStream zin) {
        int fLen = this.ENTRY_FEATURE.length();
        int mLen = this.ENTRY_MODEL.length();
        this.f_xmls = new JointFtrXml[1];
        this.s_models = null;
        try {
            ZipEntry zEntry;
            while ((zEntry = zin.getNextEntry()) != null) {
                String entry = zEntry.getName();
                if (entry.equals(this.ENTRY_CONFIGURATION)) {
                    this.loadDefaultConfiguration(zin);
                    continue;
                }
                if (entry.startsWith(this.ENTRY_FEATURE)) {
                    this.loadFeatureTemplates(zin, Integer.parseInt(entry.substring(fLen)));
                    continue;
                }
                if (entry.startsWith(this.ENTRY_MODEL)) {
                    this.loadStatisticalModels(zin, Integer.parseInt(entry.substring(mLen)));
                    continue;
                }
                if (!entry.equals(this.ENTRY_LEXICA)) continue;
                this.loadLexica(zin);
            }
        }
        catch (Exception e) {
            e.printStackTrace();
        }
    }

    private void loadLexica(ZipInputStream zin) throws Exception {
        BufferedReader fin = new BufferedReader(new InputStreamReader(zin));
        System.out.println("Loading lexica.");
        this.m_senses = UTInput.getStringMap(fin, " ");
        this.m_lemmas = UTInput.getStringIntOpenHashMap(fin, " ");
    }

    @Override
    public void saveModels(ZipOutputStream zout) {
        try {
            this.saveDefaultConfiguration(zout, this.ENTRY_CONFIGURATION);
            this.saveFeatureTemplates(zout, this.ENTRY_FEATURE);
            this.saveLexica(zout);
            this.saveStatisticalModels(zout, this.ENTRY_MODEL);
            zout.close();
        }
        catch (Exception e) {
            e.printStackTrace();
        }
    }

    private void saveLexica(ZipOutputStream zout) throws Exception {
        zout.putNextEntry(new ZipEntry(this.ENTRY_LEXICA));
        PrintStream fout = UTOutput.createPrintBufferedStream(zout);
        UTOutput.printMap(fout, this.m_senses, " ");
        fout.flush();
        UTOutput.printMap(fout, this.m_lemmas, " ");
        fout.flush();
        zout.closeEntry();
    }

    @Override
    public Object[] getLexica() {
        Map<String, String> mSenses = this.getSenseMap();
        Object[] lexica = new Object[]{mSenses, this.getLemmas(this.m_collect.keySet(), mSenses)};
        return lexica;
    }

    private Map<String, String> getSenseMap() {
        HashMap<String, String> map = new HashMap<String, String>();
        for (String lemma : this.m_collect.keySet()) {
            Set<String> set = this.m_collect.get(lemma);
            if (set.size() != 1) continue;
            map.put(lemma, new ArrayList<String>(set).get(0));
        }
        return map;
    }

    private ObjectIntOpenHashMap<String> getLemmas(Set<String> sLemmas, Map<String, String> mSenses) {
        ObjectIntOpenHashMap map = new ObjectIntOpenHashMap();
        int idx = 0;
        for (String lemma : sLemmas) {
            if (mSenses.containsKey(lemma)) continue;
            map.put((Object)lemma, idx++);
        }
        return map;
    }

    @Override
    public void countAccuracy(int[] counts) {
        int correct = 0;
        int total = 0;
        for (int i = 1; i < this.t_size; ++i) {
            DEPNode node = this.d_tree.get(i);
            String gRoleset = this.g_senses[i];
            if (gRoleset == null) continue;
            ++total;
            if (!gRoleset.equals(node.getFeat(this.s_key))) continue;
            ++correct;
        }
        counts[0] = counts[0] + total;
        counts[1] = counts[1] + correct;
    }

    @Override
    public void process(DEPTree tree) {
        this.init(tree);
        this.processAux();
    }

    protected void init(DEPTree tree) {
        this.d_tree = tree;
        this.t_size = tree.size();
        if (this.i_flag != 2) {
            this.g_senses = this.d_tree.getSenses(this.s_key);
        }
        tree.setDependents();
    }

    protected void processAux() {
        if (this.i_flag == 0) {
            this.addLexica();
        } else {
            this.classify();
        }
    }

    protected void addLexica() {
        this.i_pred = 1;
        while (this.i_pred < this.t_size) {
            String sense = this.g_senses[this.i_pred];
            if (sense != null) {
                String lemma = this.d_tree.get((int)this.i_pred).lemma;
                Set<String> set = this.m_collect.get(lemma);
                if (set == null) {
                    set = new HashSet<String>();
                    this.m_collect.put(lemma, set);
                }
                set.add(sense);
            }
            ++this.i_pred;
        }
    }

    protected void classify() {
        String cue = this.i_flag == 2 && this.s_key.equals("vn") ? "pb" : this.s_key;
        this.i_pred = 1;
        while (this.i_pred < this.t_size) {
            DEPNode pred = this.d_tree.get(this.i_pred);
            if (pred.getFeat(cue) != null) {
                String sense = this.m_senses.get(pred.lemma);
                if (sense == null) {
                    if (this.m_lemmas.containsKey((Object)pred.lemma)) {
                        sense = this.getLabel(this.m_lemmas.get((Object)pred.lemma));
                    } else if (this.s_key.equals("pb")) {
                        sense = pred.lemma + ".01";
                    } else if (this.s_key.equals("vn")) {
                        sense = "unknown";
                    }
                }
                pred.addFeat(this.s_key, sense);
            }
            ++this.i_pred;
        }
    }

    protected String getLabel(int modelId) {
        StringFeatureVector vector = this.getFeatureVector(this.f_xmls[0]);
        String label = null;
        if (this.i_flag == 1) {
            label = this.getGoldLabel();
            this.s_spaces[modelId].addInstance(label, vector);
        } else if (this.i_flag == 2 || this.i_flag == 4) {
            label = this.getAutoLabel(vector, modelId);
        }
        return label;
    }

    private String getGoldLabel() {
        return this.g_senses[this.i_pred];
    }

    private String getAutoLabel(StringFeatureVector vector, int modelId) {
        StringPrediction p = this.s_models[modelId].predictBest(vector);
        return p.label;
    }

    @Override
    protected String getField(FtrToken token) {
        DEPNode node = this.getNode(token);
        if (node == null) {
            return null;
        }
        if (token.isField("f")) {
            return node.form;
        }
        if (token.isField("m")) {
            return node.lemma;
        }
        if (token.isField("p")) {
            return node.pos;
        }
        if (token.isField("d")) {
            return node.getLabel();
        }
        if (token.isField("ldp")) {
            return this.getDependents(node.getLeftDependents(), "p");
        }
        if (token.isField("rdp")) {
            return this.getDependents(node.getRightDependents(), "p");
        }
        if (token.isField("ldd")) {
            return this.getDependents(node.getLeftDependents(), "d");
        }
        if (token.isField("rdd")) {
            return this.getDependents(node.getRightDependents(), "d");
        }
        Matcher m = JointFtrXml.P_FEAT.matcher(token.field);
        if (m.find()) {
            return node.getFeat(m.group(1));
        }
        return null;
    }

    private String getDependents(List<DEPNode> nodes, String type) {
        boolean isPos = type.equals("p");
        if (!nodes.isEmpty()) {
            StringBuilder build = new StringBuilder();
            for (DEPNode node : nodes) {
                build.append("_");
                if (isPos) {
                    build.append(node.pos);
                    continue;
                }
                build.append(node.getLabel());
            }
            return build.substring(1);
        }
        return null;
    }

    @Override
    protected String[] getFields(FtrToken token) {
        DEPNode node = this.getNode(token);
        if (node == null) {
            return null;
        }
        if (token.isField("ds")) {
            return this.getDeprelSet(node.getDependents());
        }
        return null;
    }

    private String[] getDeprelSet(List<DEPArc> deps) {
        if (deps.isEmpty()) {
            return null;
        }
        HashSet<String> set = new HashSet<String>();
        for (DEPArc arc : deps) {
            set.add(arc.getLabel());
        }
        String[] fields = new String[set.size()];
        set.toArray(fields);
        return fields;
    }

    private DEPNode getNode(FtrToken token) {
        DEPNode node = this.getNodeAux(token);
        if (node == null) {
            return null;
        }
        if (token.relation != null) {
            if (token.isRelation("h")) {
                node = node.getHead();
            } else if (token.isRelation("lmd")) {
                node = node.getLeftMostDependent();
            } else if (token.isRelation("rmd")) {
                node = node.getRightMostDependent();
            } else if (token.isRelation("lnd")) {
                node = node.getLeftNearestDependent();
            } else if (token.isRelation("rnd")) {
                node = node.getRightNearestDependent();
            }
        }
        return node;
    }

    private DEPNode getNodeAux(FtrToken token) {
        if (token.offset == 0) {
            return this.d_tree.get(this.i_pred);
        }
        int cIndex = this.i_pred + token.offset;
        if (0 < cIndex && cIndex < this.d_tree.size()) {
            return this.d_tree.get(cIndex);
        }
        return null;
    }
}

