¿Cómo usar el Iterador de API de conjunto de datos de tensorflow como entrada de una red neuronal (recurrente)?

Cuando utilizo el Iterador de API de conjunto de datos de tensorflow, mi objetivo es definir un RNN que funcione en el iteradorget_next() tensores como su entrada (ver(1) en el código)

Sin embargo, simplemente definiendo eldynamic_rnn conget_next() como su entrada da como resultado un error:ValueError: Initializer for variable rnn/basic_lstm_cell/kernel/ is from inside a control-flow construct, such as a loop or conditional. When creating a variable inside a loop or conditional, use a lambda as the initializer.

Ahora sé que una solución es simplemente crear un marcador de posición paranext_batch y entonceseval() el tensor (porque no puede pasar el tensor en sí) y pasarlo usandofeed_dict (verX y(2) en el código) Sin embargo, si lo entiendo correctamente, esta no es una solución eficiente, ya que primero evaluamos y luego reinicializamos el tensor.

¿Hay alguna manera de:

Definir eldynamic_rnn directamente encima de la salida del iterador;

o:

De alguna manera pasa directamente el existenteget_next() tensor al marcador de posición que es la entrada dedynamic_rnn?

Ejemplo de trabajo completo; el(1) versión es lo que me gustaría trabajar pero no lo hace, mientras(2) es la solución que funciona.

import tensorflow as tf

from tensorflow.contrib.rnn import BasicLSTMCell
from tensorflow.python.data import Iterator

data = [ [[1], [2], [3]], [[4], [5], [6]], [[1], [2], [3]] ]
dataset = tf.data.Dataset.from_tensor_slices(data)
dataset = dataset.batch(2)
iterator = Iterator.from_structure(dataset.output_types,
                                   dataset.output_shapes)
next_batch = iterator.get_next()
iterator_init = iterator.make_initializer(dataset)

# (2):
X = tf.placeholder(tf.float32, shape=(None, 3, 1))

cell = BasicLSTMCell(num_units=8)

# (1):
# outputs, states = lstm_outputs, lstm_states = tf.nn.dynamic_rnn(cell, next_batch, dtype=tf.float32)

# (2):
outputs, states = lstm_outputs, lstm_states = tf.nn.dynamic_rnn(cell, X, dtype=tf.float32)

init = tf.global_variables_initializer()
with tf.Session() as sess:
    sess.run(init)
    sess.run(iterator_init)

    # (1):
    # o, s = sess.run([outputs, states])
    # o, s = sess.run([outputs, states])

    # (2):
    o, s = sess.run([outputs, states], feed_dict={X: next_batch.eval()})
    o, s = sess.run([outputs, states], feed_dict={X: next_batch.eval()})

(Usando tensorflow 1.4.0, Python 3.6.)

Muchas gracias :)

Respuestas a la pregunta(1)

Su respuesta a la pregunta