Red neuronal con propagación inversa no convergente

Básicamente estoy tratando de implementarbackpropogation en una red. Sé que el algoritmo de propagación hacia atrás está codificado, pero primero estoy intentando que sea funcional.

Funciona para un conjunto de entradas y salidas, pero más allá de un conjunto de capacitación, la red converge en una solución, mientras que la otra salida converge en 0.5.

Es decir, la salida para una prueba es:[0.9969527919933012, 0.003043774988797313]

[0.5000438200377985, 0.49995612243030635]

Network.java

private ArrayList<ArrayList<ArrayList<Double>>> weights;
private ArrayList<ArrayList<Double>> nodes;

private final double LEARNING_RATE = -0.25;
private final double DEFAULT_NODE_VALUE = 0.0;

private double momentum = 1.0;

public Network() {
    weights = new ArrayList<ArrayList<ArrayList<Double>>>();
    nodes = new ArrayList<ArrayList<Double>>();
}

/**
 * This method is used to add a layer with {@link n} nodes to the network.
 * @param n number of nodes for the layer
 */
public void addLayer(int n) {
    nodes.add(new ArrayList<Double>());
    for (int i = 0;i < n;i++)
        nodes.get(nodes.size()-1).add(DEFAULT_NODE_VALUE);
}

/**
 * This method generates the weights used to link layers together.
 */
public void createWeights() {
    // there are only weights between layers, so we have one less weight layer than node layer
    for (int i = 0;i < nodes.size()-1;i++) {
        weights.add(new ArrayList<ArrayList<Double>>());

        // for each node above the weight
        for (int j = 0;j < nodes.get(i).size();j++) {
            weights.get(i).add(new ArrayList<Double>());

            // for each node below the weight
            for (int k = 0;k < nodes.get(i+1).size();k++)
                weights.get(i).get(j).add(Math.random()*2-1);
        }
    }
}

/**
 * Utilizes the differentiated sigmoid function to change weights in the network
 * @param out   The desired output pattern for the network
 */
private void propogateBackward(double[] out) {
    /*
     * Error calculation using squared error formula and the sigmoid derivative
     * 
     * Output Node : dk = Ok(1-Ok)(Ok-Tk)
     * Hidden Node : dj = Oj(1-Oj)SummationkEK(dkWjk)
     * 
     * k is an output node
     * j is a hidden node
     * 
     * dw = LEARNING_RATE*d*outputOfpreviousLayer(not weighted)
     * W = dW + W
     */

    // update the last layer of weights first because it is a special case

    double dkW = 0;

    for (int i = 0;i < nodes.get(nodes.size()-1).size();i++) {

        double outputK = nodes.get(nodes.size()-1).get(i);
        double deltaK = outputK*(1-outputK)*(outputK-out[i]);

        for (int j = 0;j < nodes.get(nodes.size()-2).size();j++) {
            weights.get(1).get(j).set(i, weights.get(1).get(j).get(i) + LEARNING_RATE*deltaK*nodes.get(nodes.size()-2).get(j) );
            dkW += deltaK*weights.get(1).get(j).get(i);
        }
    }

    for (int i = 0;i < nodes.get(nodes.size()-2).size();i++) {

        //Hidden Node : dj = Oj(1-Oj)SummationkEK(dkWjk)
        double outputJ = nodes.get(1).get(i);
        double deltaJ = outputJ*(1-outputJ)*dkW*LEARNING_RATE;

        for (int j = 0;j < nodes.get(0).size();j++) {
            weights.get(0).get(j).set(i, weights.get(0).get(j).get(i) + deltaJ*nodes.get(0).get(j) );
        }


    }

}

/**
 * Propogates an array of input values through the network
 * @param in    an array of inputs
 */
private void propogateForward(double[] in) {
    // pass the weights to the input layer
    for (int i = 0;i < in.length;i++)
        nodes.get(0).set(i, in[i]);

    // propagate through the rest of the network
    // for each layer after the first layer
    for (int i = 1;i < nodes.size();i++)

        // for each node in the layer
        for (int j = 0;j < nodes.get(i).size();j++) {

            // for each node in the previous layer
            for (int k = 0;k < nodes.get(i-1).size();k++)

                // add to the node the weighted output from k to j
                nodes.get(i).set(j, nodes.get(i).get(j)+weightedNode(i-1, k, j));

            // once the node has received all of its inputs we can apply the activation function
            nodes.get(i).set(j, activation(nodes.get(i).get(j)));

        }   
}

/**
 * This method returns the activation value of an input
 * @param   in the total input of a node
 * @return  the sigmoid function at the input
 */
private double activation(double in) {
    return 1/(1+Math.pow(Math.E,-in));
}

/**
 * Weighted output for a node.
 * @param layer the layer which the transmitting node is on
 * @param node  the index of the transmitting node
 * @param previousNode  the index of the receiving node
 * @return  the output of the transmitting node times the weight between the two nodes
 */
private double weightedNode(int layer, int node, int nextNode) {
    return nodes.get(layer).get(node)*weights.get(layer).get(node).get(nextNode);
}

/**
 * This method resets all of the nodes to their default value
 */
private void resetNodes() {
    for (int i = 0;i < nodes.size();i++)
        for (int j = 0;j < nodes.get(i).size();j++)
            nodes.get(i).set(j, DEFAULT_NODE_VALUE);
}

/**
 * Teach the network correct responses for certain input values.
 * @param in    an array of input values
 * @param out   an array of desired output values
 * @param n     number of iterations to perform
 */
public void train(double[] in, double[] out, int n) {
    for (int i = 0;i < n;i++) {
        propogateForward(in);
        propogateBackward(out);
        resetNodes();
    }
}

public void getResult(double[] in) {
    propogateForward(in);
    System.out.println(nodes.get(2));
    resetNodes();
}

SnapSolve.java

public SnapSolve() {

    Network net = new Network();
    net.addLayer(2);
    net.addLayer(4);
    net.addLayer(2);
    net.createWeights();

    double[] l = {0, 1};
    double[] p = {1, 0};

    double[] n = {1, 0};
    double[] r = {0, 1};

    for(int i = 0;i < 100000;i++) {
        net.train(l, p, 1);
        net.train(n, r, 1);
    }

    net.getResult(l);
    net.getResult(n);

}

public static void main(String[] args) {
    new SnapSolve();
}

Respuestas a la pregunta(1)

Su respuesta a la pregunta