TensorFlow: rendimiento lento al obtener gradientes en las entradas

Estoy construyendo un simple perceptrón multicapa con TensorFlow, y también necesito obtener los gradientes (o señal de error) de la pérdida en las entradas de la red neuronal.

Aquí está mi código, que funciona:

cost = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(self.network, self.y))
optimizer = tf.train.AdagradOptimizer(learning_rate=nn_learning_rate).minimize(cost)
...
for i in range(epochs):
    ....
    for batch in batches:
        ...
        sess.run(optimizer, feed_dict=feed_dict)
        grads_wrt_input = sess.run(tf.gradients(cost, self.x), feed_dict=feed_dict)[0]

(editado para incluir el bucle de entrenamiento)

Sin la última línea (grads_wrt_input...), esto funciona muy rápido en una máquina CUDA. Sin embargo,tf.gradients() Reduce el rendimiento en diez veces o más.

Recuerdo que las señales de error en los nodos se calculan como valores intermedios en el algoritmo de retropropagación, y lo hice con éxito utilizando la biblioteca Java DeepLearning4j. También tenía la impresión de que esto sería una ligera modificación en el gráfico de cálculo ya construido poroptimizer.

¿Cómo se puede hacer esto más rápido, o hay alguna otra forma de calcular los gradientes de la pérdida w.r.t. las entradas?

Respuestas a la pregunta(1)

Su respuesta a la pregunta