/*
 * Decompiled with CFR 0.152.
 */
package com.insightful.miner;

import com.insightful.cnkjava.CNKProc;
import com.insightful.cnkjava.CNKProcCallback;
import com.insightful.cnkjava.CNKProcNNet;
import com.insightful.cnkjava.CNKProcNNetPredict;
import com.insightful.miner.EngineMessageHandler;
import com.insightful.miner.EngineNode;
import com.insightful.miner.MinerApp;
import com.insightful.miner.NeuralNetworkExecViewer;
import com.insightful.miner.PredictEngineNode;
import com.insightful.miner.XMLTree;
import com.insightful.miner.XTMetaData;
import com.insightful.miner.XTProps;
import java.util.Hashtable;
import java.util.Vector;
import org.w3c.dom.Element;
import org.w3c.dom.Node;
import org.w3c.dom.NodeList;

class RegressionNeuralNetworkEngineNode
extends EngineNode
implements CNKProcCallback {
    public static final String ACCURACY_ATTRIBUTE_TAG = "attribute";
    public static final String EPOCHS_ATTRIBUTE_TAG = "epochs";
    public static final String LEARNING_RATE_ATTRIBUTE_TAG = "learningRate";
    public static final String MOMENTUM_ATTRIBUTE_TAG = "momentum";
    public static final String WEIGHT_DECAY_ATTRIBUTE_TAG = "weightDecay";
    public static final String NUMBER_HIDDEN_LAYERS_ATTRIBUTE_TAG = "hiddenLayers";
    public static final String TRAINING_METHOD_ATTRIBUTE_TAG = "trainingMethod";
    public static final String NUMBER_NODES_HIDDEN_LAYER_ATTRIBUTE_TAG = "numNodesHiddenLayer";
    public static final String ONLINE_METHOD_ATTRIBUTE_TAG = "online";
    public static final String ONLINE_FAST_METHOD_ATTRIBUTE_TAG = "onlineFast";
    public static final String BATCH_METHOD_ATTRIBUTE_TAG = "batch";
    public static final String BATCH_FAST_METHOD_ATTRIBUTE_TAG = "batchFast";
    public static final String RPROP_METHOD_ATTRIBUTE_TAG = "rprop";
    public static final String QUICK_PROP_METHOD_ATTRIBUTE_TAG = "quickProp";
    public static final String DELTA_BAR_DELTA_METHOD_ATTRIBUTE_TAG = "deltaBarDelta";
    public static final String CONJUGATE_GRADIENT_METHOD_ATTRIBUTE_TAG = "conjugateGradient";
    public static final String VMETRIC_METHOD_ATTRIBUTE_TAG = "vmetric";
    public static final String POWELL_METHOD_ATTRIBUTE_TAG = "powell";
    public static final String EVOLUTIONARY_METHOD_ATTRIBUTE_TAG = "evolutionary";
    public static final String SHOW_EXEC_VIEWER_ATTRIBUTE_TAG = "showExecViewer";
    public static final String VALIDATION_PERCENT_ATTRIBUTE_TAG = "validationPercent";
    public static final String INIT_WITH_PREV_WEIGHTS_ATTRIBUTE_TAG = "initWithPrevWeights";
    public static final String INIT_WEIGHTS_FROM_FILE_TAG = "initWithWeightsFromFile";
    public static final String FINAL_MODEL_OPTION_TAG = "finalModelOption";
    public static final int NNET_FINAL_MODEL_BEST = 1;
    public static final int NNET_FINAL_MODEL_LAST = 2;
    private Vector mseTrainVector = null;
    private Vector mseValidateVector = null;
    private Vector epochVector = null;
    private int m_lastEpoch = 0;
    private String m_modelString = null;

    public boolean hasCNKProc() {
        return false;
    }

    public boolean hasDataCacheProc() {
        return true;
    }

    public void procExtractResults(CNKProc proc) throws Exception {
        if (proc instanceof CNKProcNNet) {
            this.m_modelString = ((CNKProcNNet)proc).getModel();
            XMLTree fitted = XMLTree.readFromString(this.m_modelString);
            XMLTree errorW = this.getNodeCache("NNerrorWeightCache");
            Element el = (Element)fitted.getXML().getElementsByTagName("NeuralNetwork").item(0);
            NodeList delNodes = el.getElementsByTagName("ErrorWeights");
            for (int i = 0; i < delNodes.getLength(); ++i) {
                el.removeChild(delNodes.item(i));
            }
            Element erWt = fitted.getDocument().createElement("ErrorWeights");
            if (errorW != null && !new XMLTree().toString().equals(errorW.toString())) {
                Node errorWt = fitted.getDocument().importNode(errorW.getXML(), true);
                erWt.appendChild(errorWt);
            }
            el.appendChild(erWt);
            this.setNodeCache("model", fitted);
        }
    }

    public String getDepVar(XTProps props, XTMetaData md) {
        return PredictEngineNode.getFirstDependentVar(props, md, XTMetaData.CONTINUOUS_TYPE_ATTRIBUTE_TAG);
    }

    protected boolean getIsTypeRegression() {
        return true;
    }

    public Vector getOutputSpecs() {
        XTProps props = this.getNodeProperties();
        XTMetaData md = this.getInputMetaData(0);
        String depVar = this.getDepVar(props, md);
        Vector indepVars = PredictEngineNode.getIndependentVars(props, md);
        Vector outputSpecs = PredictEngineNode.getOutputSpecs(md, props, depVar, indepVars);
        return outputSpecs;
    }

    public String getDataDictionaryAsString() throws Exception {
        XTProps props = this.getNodeProperties();
        XTMetaData md = this.getInputMetaData(0);
        Vector varNames = PredictEngineNode.getIndependentVars(props, md);
        String depCol = this.getDepVar(props, md);
        Vector allNames = varNames;
        allNames.add(depCol);
        XTMetaData mdModel = md.selectiveClone(allNames);
        for (int i = 0; i < varNames.size(); ++i) {
            mdModel.setDataFieldRole((String)varNames.get(i), "independent");
        }
        mdModel.setDataFieldRole(depCol, "dependent");
        return mdModel.writeToString();
    }

    public boolean executeDataCacheProc() throws Exception {
        XTProps props = this.getNodeProperties();
        XTMetaData md = this.getInputMetaData(0);
        String depVar = this.getDepVar(props, md);
        Vector indepVars = PredictEngineNode.getIndependentVars(props, md);
        if (indepVars.size() <= 0) {
            throw new Exception("can't create neural network: no independent variable(s)");
        }
        if (depVar == null) {
            throw new Exception("can't create neural network: no dependent variable");
        }
        boolean ok = true;
        String origText = (String)EngineMessageHandler.sendMessageToApp("getStatusText", new Object[0]);
        String str = origText + ": Building model...";
        EngineMessageHandler.sendMessageToApp("setStatusText", new Object[]{str});
        CNKProcNNet proc1 = new CNKProcNNet();
        if (props.getBoolean(INIT_WITH_PREV_WEIGHTS_ATTRIBUTE_TAG, false)) {
            String modelFile = props.getValue(INIT_WEIGHTS_FROM_FILE_TAG, "");
            if (modelFile.length() == 0) {
                modelFile = this.getModelFileName();
            }
            if (modelFile != null && modelFile.length() > 0) {
                proc1.setConfigData(modelFile);
            } else {
                proc1.setModel(this.getDataDictionaryAsString());
            }
        } else {
            proc1.setModel(this.getDataDictionaryAsString());
        }
        proc1.setOptParameters(props.getInt(EPOCHS_ATTRIBUTE_TAG, Integer.parseInt("50")), props.getDouble(ACCURACY_ATTRIBUTE_TAG, Double.parseDouble("0.00001")), props.getInt(NUMBER_HIDDEN_LAYERS_ATTRIBUTE_TAG, Integer.parseInt("1")), props.getInt(NUMBER_NODES_HIDDEN_LAYER_ATTRIBUTE_TAG, Integer.parseInt("10")), props.getValue(TRAINING_METHOD_ATTRIBUTE_TAG, BATCH_METHOD_ATTRIBUTE_TAG), props.getDouble(LEARNING_RATE_ATTRIBUTE_TAG, Double.parseDouble("0.001")), props.getDouble(MOMENTUM_ATTRIBUTE_TAG, Double.parseDouble("0.0")), props.getDouble(WEIGHT_DECAY_ATTRIBUTE_TAG, Double.parseDouble("1.0")), props.getInt(FINAL_MODEL_OPTION_TAG, 1));
        proc1.setValidationPercent(props.getInt(VALIDATION_PERCENT_ATTRIBUTE_TAG, Integer.parseInt("10")));
        int seed = (int)(this.getRandomSeed() & 0xFFFFL);
        proc1.setSeed(seed);
        boolean showExecViewer = props.getBoolean(SHOW_EXEC_VIEWER_ATTRIBUTE_TAG, true);
        if (!MinerApp.isInteractive() || MinerApp.isSolarisOS()) {
            showExecViewer = false;
        }
        if (showExecViewer) {
            String classname = "com.insightful.miner.NeuralNetworkExecViewer";
            this.createExecViewer(classname);
            Boolean b = new Boolean(this.getIsTypeRegression());
            this.sendMessageToExecViewer("setIsRegression", new Object[]{b});
            Hashtable<String, XMLTree> nodeData = new Hashtable<String, XMLTree>();
            nodeData.put(NeuralNetworkExecViewer.XTPROPS_TAG, this.getNodeProperties());
            nodeData.put(NeuralNetworkExecViewer.METADATA_TAG, this.getInputMetaData(0));
            this.sendMessageToExecViewer("setNeuralNetProps", new Object[]{nodeData});
        }
        this.m_lastEpoch = 0;
        proc1.setCallback(this);
        proc1.setProgressNotification(5);
        this.printlnVerbose("neural network: creating model");
        ok = this.getNetworkManager().executeCNKProc(this.getNodeID(), proc1);
        proc1.setCallback(null);
        if (showExecViewer) {
            this.sendMessageToExecViewer("loadFinalWeights", new Object[]{this.m_modelString});
            this.closeExecViewer();
        }
        this.procDelete(proc1);
        if (!ok || this.m_modelString == null) {
            return false;
        }
        Vector outputSpecs = this.getOutputSpecs();
        PredictEngineNode.isConflictingIO(outputSpecs, this);
        if (outputSpecs.size() < 1) {
            return ok;
        }
        str = origText + ": Predicting...";
        EngineMessageHandler.sendMessageToApp("setStatusText", new Object[]{str});
        this.printlnVerbose("neural network: creating CNKProcNNetPredict");
        CNKProcNNetPredict proc2 = new CNKProcNNetPredict();
        this.printlnVerbose("neural network: CNKProcNNetPredict.setModel");
        proc2.setModel(this.m_modelString);
        PredictEngineNode.defineOutputsFromSpecs(proc2, outputSpecs);
        this.printlnVerbose("neural network: predicting from training data.");
        ok = this.getNetworkManager().executeCNKProc(this.getNodeID(), proc2);
        this.procDelete(proc2);
        this.m_modelString = null;
        EngineMessageHandler.sendMessageToApp("setStatusText", new Object[]{origText});
        return ok;
    }

    public XTMetaData calculateOutputMetaData(int outputNum) {
        if (outputNum == 0) {
            return PredictEngineNode.calculateOutputMetaDataFromOutputSpecs(this.getOutputSpecs());
        }
        return null;
    }

    public void createWeightErrorCacheFile(XMLTree tree) {
        this.setNodeCache("NNerrorWeightCache", tree);
    }

    public Object getWeightErrorCacheFile() {
        return this.getNodeCache("NNerrorWeightCache");
    }

    public void doCallback(CNKProc proc) {
        boolean userStopRequested;
        Object obj;
        if (proc == null || !(proc instanceof CNKProcNNet)) {
            return;
        }
        CNKProcNNet nnetProc = (CNKProcNNet)proc;
        XTProps oldProps = null;
        boolean showExecViewer = this.getNodeProperties().getBoolean(SHOW_EXEC_VIEWER_ATTRIBUTE_TAG, true);
        if (!MinerApp.isInteractive() || MinerApp.isSolarisOS()) {
            showExecViewer = false;
        }
        if (!showExecViewer) {
            if (this.mseTrainVector == null) {
                this.mseTrainVector = new Vector();
            }
            if (this.mseValidateVector == null) {
                this.mseValidateVector = new Vector();
            }
            if (this.epochVector == null) {
                this.epochVector = new Vector();
            }
            this.mseTrainVector.add(Double.toString(nnetProc.getEpochDev()));
            this.mseValidateVector.add(Double.toString(nnetProc.getEpochValidateDev()));
            this.epochVector.add(Integer.toString(nnetProc.getEpoch()));
        }
        if (nnetProc.isSearchComplete() && !showExecViewer) {
            this.createCacheFile(this.epochVector, this.mseTrainVector, this.mseValidateVector, null);
            this.epochVector.removeAllElements();
            this.mseTrainVector.removeAllElements();
            this.mseValidateVector.removeAllElements();
            this.epochVector = null;
            this.mseTrainVector = null;
            this.mseValidateVector = null;
        }
        if (!showExecViewer) {
            return;
        }
        int curEpoch = nnetProc.getEpoch();
        if (curEpoch > this.m_lastEpoch) {
            boolean weightsRequested;
            this.m_lastEpoch = curEpoch;
            Hashtable<String, Object> h = new Hashtable<String, Object>();
            h.put(NeuralNetworkExecViewer.EPOCH_TAG, Integer.toString(curEpoch));
            h.put(NeuralNetworkExecViewer.DEV_TAG, Double.toString(nnetProc.getEpochDev()));
            h.put(NeuralNetworkExecViewer.VALIDATEDEV_TAG, Double.toString(nnetProc.getEpochValidateDev()));
            h.put(NeuralNetworkExecViewer.BEST_WEIGHTS_EPOCH_NUM_TAG, Integer.toString(nnetProc.getEpochNumberOfBestWeights()));
            h.put(NeuralNetworkExecViewer.SEARCH_COMPLETE_TAG, nnetProc.isSearchComplete());
            obj = this.sendMessageToExecViewer("isWeightsRequested", null);
            if (obj != null && obj instanceof Boolean && (weightsRequested = ((Boolean)obj).booleanValue())) {
                h.put(NeuralNetworkExecViewer.CURRENT_WEIGHTS_TAG, nnetProc.getCurrentWeights());
            }
            this.sendMessageToExecViewer("appendData", new Object[]{h});
            h = null;
        }
        if ((obj = this.sendMessageToExecViewer("isUserStopRequested", null)) != null && obj instanceof Boolean && (userStopRequested = ((Boolean)obj).booleanValue())) {
            this.sendMessageToExecViewer("updateLabels", null);
            boolean userFailRequested = false;
            oldProps = this.getNodeProperties();
            int opt = oldProps.getInt(FINAL_MODEL_OPTION_TAG, 1);
            String val = "";
            obj = this.sendMessageToExecViewer("isUserFailRequested", null);
            if (obj != null && obj instanceof Boolean && (userFailRequested = ((Boolean)obj).booleanValue())) {
                nnetProc.setError("User stopped network");
            }
            if (!userFailRequested && (obj = this.sendMessageToExecViewer("getTerminateOption", null)) != null && obj instanceof Integer && (opt = ((Integer)obj).intValue()) == 3) {
                obj = this.sendMessageToExecViewer("getRetrieveWeightsFile", null);
                if (obj != null && obj instanceof String) {
                    val = (String)obj;
                } else {
                    opt = 1;
                }
            }
            nnetProc.setUserStop(opt, val);
            return;
        }
        obj = this.sendMessageToExecViewer("isUserPauseRequested", null);
        boolean userPauseRequested = false;
        if (obj != null && obj instanceof Boolean && (userPauseRequested = ((Boolean)obj).booleanValue())) {
            try {
                oldProps = new XTProps(this.getNodeProperties());
            }
            catch (Exception e) {
                // empty catch block
            }
            this.sendMessageToExecViewer("updateLabels", null);
            this.execPause();
        }
        if (!userPauseRequested) {
            return;
        }
        obj = this.sendMessageToExecViewer("getSaveWeightsFile", null);
        if (obj != null && obj instanceof String) {
            String file = (String)obj;
            nnetProc.getCurrentConfig(file);
        }
        if ((obj = this.sendMessageToExecViewer("getSaveBestWeightsFile", null)) != null && obj instanceof String) {
            String file = (String)obj;
            nnetProc.getConfigWithBestWeights(file);
        }
        if ((obj = this.sendMessageToExecViewer("getRetrieveWeightsFile", null)) != null && obj instanceof String) {
            String file = (String)obj;
            nnetProc.setConfigData(file);
        }
        if (oldProps != null) {
            String s;
            double d;
            XTProps newProps = this.getNodeProperties();
            int n = newProps.getInt(EPOCHS_ATTRIBUTE_TAG, 0);
            if (n > 0 && n != oldProps.getInt(EPOCHS_ATTRIBUTE_TAG, 0)) {
                nnetProc.setMaxEpochs(n);
            }
            if ((d = newProps.getDouble(ACCURACY_ATTRIBUTE_TAG, -1.0)) >= 0.0 && d != oldProps.getDouble(ACCURACY_ATTRIBUTE_TAG, -1.0)) {
                nnetProc.setConvergenceTol(d);
            }
            if ((s = newProps.getValue(TRAINING_METHOD_ATTRIBUTE_TAG, "")).length() > 0 && s != oldProps.getValue(TRAINING_METHOD_ATTRIBUTE_TAG, "")) {
                nnetProc.setTrainingMethod(s);
            }
            if ((d = newProps.getDouble(LEARNING_RATE_ATTRIBUTE_TAG, 0.0)) > 0.0 && d <= 1.0 && d != oldProps.getDouble(LEARNING_RATE_ATTRIBUTE_TAG, 0.0)) {
                nnetProc.setLearningRate(d);
            }
            if ((d = newProps.getDouble(MOMENTUM_ATTRIBUTE_TAG, -1.0)) >= 0.0 && d <= 1.0 && d != oldProps.getDouble(MOMENTUM_ATTRIBUTE_TAG, -1.0)) {
                nnetProc.setMomentum(d);
            }
            if ((d = newProps.getDouble(WEIGHT_DECAY_ATTRIBUTE_TAG, 0.0)) > 0.0 && d <= 1.0 && d != oldProps.getDouble(WEIGHT_DECAY_ATTRIBUTE_TAG, 0.0)) {
                nnetProc.setWeightDecay(d);
            }
            if ((obj = this.sendMessageToExecViewer("isJitterbWeightsRequested", null)) != null && obj instanceof Boolean && ((Boolean)obj).booleanValue()) {
                nnetProc.setJitterWeights(75);
            }
        }
    }

    public void createCacheFile(Vector epochs, Vector trainingSet, Vector testSet, Vector saveArrows) {
        String errorString = "<Property name=\"ErrorValues\">";
        for (int j = 0; j < epochs.size(); ++j) {
            errorString = errorString + "<Property name=\"" + (String)epochs.elementAt(j) + "\" value=\"" + (String)trainingSet.elementAt(j) + " " + (String)testSet.elementAt(j) + "\"/>\n";
        }
        errorString = errorString + "</Property>";
        String arrowString = "";
        if (saveArrows != null) {
            arrowString = "<Property name=\"ArrowValues\" value=\"";
            for (int j = 0; j < saveArrows.size(); ++j) {
                arrowString = arrowString + Double.toString((Double)saveArrows.elementAt(j));
                arrowString = arrowString + " ";
            }
            arrowString = arrowString + "\"/>\n";
        }
        String finalString = "<XTProps>\n" + arrowString + "\n" + errorString + "\n" + "</XTProps>\n";
        this.createWeightErrorCacheFile(XMLTree.readFromString(finalString));
    }
}

