TensorFlow: recuerde el estado LSTM para el próximo lote (LSTM con estado)

Dado un modelo LSTM entrenado, quiero realizar inferencias para pasos de tiempo únicos, es decir,seq_length = 1 en el ejemplo a continuación. Después de cada paso de tiempo, los estados LSTM internos (memoria y ocultos) deben recordarse para el próximo 'lote'. Para el comienzo de la inferencia, los estados internos de LSTMinit_c, init_h se calculan dada la entrada. Estos se almacenan en unLSTMStateTuple objeto que se pasa al LSTM. Durante el entrenamiento, este estado se actualiza cada paso de tiempo. Sin embargo, por inferencia, quiero elstate para guardarse entre lotes, es decir, los estados iniciales solo deben calcularse desde el principio y luego los estados LSTM deben guardarse después de cada 'lote' (n = 1).

Encontré esta pregunta relacionada con StackOverflow:Tensorflow, ¿la mejor manera de guardar el estado en RNN?. Sin embargo, esto solo funciona sistate_is_tuple=False, pero TensorFlow pronto dejará de utilizar este comportamiento (consulternn_cell.py) Keras parece tener un buen envoltorio para hacercon estado LSTM posibles pero no sé la mejor manera de lograr esto en TensorFlow. Este problema en el TensorFlow GitHub también está relacionado con mi pregunta:https://github.com/tensorflow/tensorflow/issues/2838

¿Alguien buenas sugerencias para construir un modelo LSTM con estado?

inputs  = tf.placeholder(tf.float32, shape=[None, seq_length, 84, 84], name="inputs")
targets = tf.placeholder(tf.float32, shape=[None, seq_length], name="targets")

num_lstm_layers = 2

with tf.variable_scope("LSTM") as scope:

    lstm_cell  = tf.nn.rnn_cell.LSTMCell(512, initializer=initializer, state_is_tuple=True)
    self.lstm  = tf.nn.rnn_cell.MultiRNNCell([lstm_cell] * num_lstm_layers, state_is_tuple=True)

    init_c = # compute initial LSTM memory state using contents in placeholder 'inputs'
    init_h = # compute initial LSTM hidden state using contents in placeholder 'inputs'
    self.state = [tf.nn.rnn_cell.LSTMStateTuple(init_c, init_h)] * num_lstm_layers

    outputs = []

    for step in range(seq_length):

        if step != 0:
            scope.reuse_variables()

        # CNN features, as input for LSTM
        x_t = # ... 

        # LSTM step through time
        output, self.state = self.lstm(x_t, self.state)
        outputs.append(output)

Respuestas a la pregunta(2)

Su respuesta a la pregunta