TensorFlow: desempenho lento ao obter gradientes nas entradas
Estou construindo um perceptron simples de várias camadas com o TensorFlow e também preciso obter os gradientes (ou sinal de erro) da perda nas entradas da rede neural.
Aqui está o meu código, que funciona:
cost = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(self.network, self.y))
optimizer = tf.train.AdagradOptimizer(learning_rate=nn_learning_rate).minimize(cost)
...
for i in range(epochs):
....
for batch in batches:
...
sess.run(optimizer, feed_dict=feed_dict)
grads_wrt_input = sess.run(tf.gradients(cost, self.x), feed_dict=feed_dict)[0]
(editado para incluir o loop de treinamento)
Sem a última linha (grads_wrt_input...
), isso é executado muito rápido em uma máquina CUDA. Contudo,tf.gradients()
reduz bastante o desempenho em dez vezes ou mais.
Lembro que os sinais de erro nos nós são computados como valores intermediários no algoritmo de retropropagação e fiz isso com sucesso usando a biblioteca Java DeepLearning4j. Também fiquei com a impressão de que isso seria uma ligeira modificação no gráfico de computação já construído poroptimizer
.
Como isso pode ser feito mais rápido, ou existe outra maneira de calcular os gradientes da perda w.r.t. as entradas?