Probleme beim Implementieren eines XOR-Gatters mit neuronalen Netzen in Tensorflow
Ich möchte ein triviales neuronales Netzwerk erstellen, es sollte nur das XOR-Gatter implementieren. Ich verwende die TensorFlow-Bibliothek in Python. Für ein XOR-Gatter ist die vollständige Wahrheitstabelle die einzige Daten, mit der ich trainiere. Sollte das ausreichen? Eine Überoptimierung ist das, was ich sehr schnell erwarten werde. Problem mit dem Code ist, dass das Gewichte und voreingenommen nicht updaten. Irgendwie gibt es mir immer noch 100% Genauigkeit mit Null für die Vorspannungen und Gewichte.
x = tf.placeholder("float", [None, 2])
W = tf.Variable(tf.zeros([2,2]))
b = tf.Variable(tf.zeros([2]))
y = tf.nn.softmax(tf.matmul(x,W) + b)
y_ = tf.placeholder("float", [None,1])
print "Done init"
cross_entropy = -tf.reduce_sum(y_*tf.log(y))
train_step = tf.train.GradientDescentOptimizer(0.75).minimize(cross_entropy)
print "Done loading vars"
init = tf.initialize_all_variables()
print "Done: Initializing variables"
sess = tf.Session()
sess.run(init)
print "Done: Session started"
xTrain = np.array([[0, 0], [0, 1], [1, 0], [1, 1]])
yTrain = np.array([[1], [0], [0], [0]])
acc=0.0
while acc<0.85:
for i in range(500):
sess.run(train_step, feed_dict={x: xTrain, y_: yTrain})
print b.eval(sess)
print W.eval(sess)
print "Done training"
correct_prediction = tf.equal(tf.argmax(y,1), tf.argmax(y_,1))
accuracy = tf.reduce_mean(tf.cast(correct_prediction, "float"))
print "Result:"
acc= sess.run(accuracy, feed_dict={x: xTrain, y_: yTrain})
print acc
B0 = b.eval(sess)[0]
B1 = b.eval(sess)[1]
W00 = W.eval(sess)[0][0]
W01 = W.eval(sess)[0][1]
W10 = W.eval(sess)[1][0]
W11 = W.eval(sess)[1][1]
for A,B in product([0,1],[0,1]):
top = W00*A + W01*A + B0
bottom = W10*B + W11*B + B1
print "A:",A," B:",B
# print "Top",top," Bottom: ", bottom
print "Sum:",top+bottom
Ich verfolge das Tutorial vonhttp: //tensorflow.org/tutorials/mnist/beginners/index.md#softmax_regression und in der letzten for-Schleife drucke ich die Ergebnisse aus der Matrix (wie im Link beschrieben).
Kann jemand auf meinen Fehler hinweisen und was sollte ich tun, um ihn zu beheben?