TensorFlow: desempenho lento ao obter gradientes nas entradas

Estou construindo um perceptron simples de várias camadas com o TensorFlow e também preciso obter os gradientes (ou sinal de erro) da perda nas entradas da rede neural.

Aqui está o meu código, que funciona:

cost = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(self.network, self.y))
optimizer = tf.train.AdagradOptimizer(learning_rate=nn_learning_rate).minimize(cost)
...
for i in range(epochs):
    ....
    for batch in batches:
        ...
        sess.run(optimizer, feed_dict=feed_dict)
        grads_wrt_input = sess.run(tf.gradients(cost, self.x), feed_dict=feed_dict)[0]

(editado para incluir o loop de treinamento)

Sem a última linha (grads_wrt_input...), isso é executado muito rápido em uma máquina CUDA. Contudo,tf.gradients() reduz bastante o desempenho em dez vezes ou mais.

Lembro que os sinais de erro nos nós são computados como valores intermediários no algoritmo de retropropagação e fiz isso com sucesso usando a biblioteca Java DeepLearning4j. Também fiquei com a impressão de que isso seria uma ligeira modificação no gráfico de computação já construído poroptimizer.

Como isso pode ser feito mais rápido, ou existe outra maneira de calcular os gradientes da perda w.r.t. as entradas?

questionAnswers(1)

yourAnswerToTheQuestion