TensorFlow: rendimiento lento al obtener gradientes en las entradas
Estoy construyendo un simple perceptrón multicapa con TensorFlow, y también necesito obtener los gradientes (o señal de error) de la pérdida en las entradas de la red neuronal.
Aquí está mi código, que funciona:
cost = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(self.network, self.y))
optimizer = tf.train.AdagradOptimizer(learning_rate=nn_learning_rate).minimize(cost)
...
for i in range(epochs):
....
for batch in batches:
...
sess.run(optimizer, feed_dict=feed_dict)
grads_wrt_input = sess.run(tf.gradients(cost, self.x), feed_dict=feed_dict)[0]
(editado para incluir el bucle de entrenamiento)
Sin la última línea (grads_wrt_input...
), esto funciona muy rápido en una máquina CUDA. Sin embargo,tf.gradients()
Reduce el rendimiento en diez veces o más.
Recuerdo que las señales de error en los nodos se calculan como valores intermedios en el algoritmo de retropropagación, y lo hice con éxito utilizando la biblioteca Java DeepLearning4j. También tenía la impresión de que esto sería una ligera modificación en el gráfico de cálculo ya construido poroptimizer
.
¿Cómo se puede hacer esto más rápido, o hay alguna otra forma de calcular los gradientes de la pérdida w.r.t. las entradas?