Tensorflow while_loop para treinamento
No meu problema, preciso executar o GD com 1 exemplo de dados em cada etapa do treinamento. É um problema conhecido que o session.run () tem sobrecarga e, portanto, é muito longo para treinar o modelo. Na tentativa de evitar sobrecarga, tentei usar while_loop e treinar o modelo em todos os dados com uma chamada run (). Mas a abordagem não funciona e o train_op não executa nem mesmo os. Abaixo um exemplo simples do que estou fazendo:
data = [k*1. for k in range(10)]
tf.reset_default_graph()
i = tf.Variable(0, name='loop_i')
q_x = tf.FIFOQueue(100000, tf.float32)
q_y = tf.FIFOQueue(100000, tf.float32)
x = q_x.dequeue()
y = q_y.dequeue()
w = tf.Variable(0.)
b = tf.Variable(0.)
loss = (tf.add(tf.mul(x, w), b) - y)**2
gs = tf.Variable(0)
train_op = tf.train.GradientDescentOptimizer(0.05).minimize(loss, global_step=gs)
s = tf.Session()
s.run(tf.initialize_all_variables())
def cond(i):
return i < 10
def body(i):
return tf.tuple([tf.add(i, 1)], control_inputs=[train_op])
loop = tf.while_loop(cond, body, [i])
for _ in range(1):
s.run(q_x.enqueue_many((data, )))
s.run(q_y.enqueue_many((data, )))
s.run(loop)
s.close()
O que estou fazendo de errado? Ou existe outra solução desse problema com despesas gerais muito caras?
Obrigado!