Implementando seq2seq com pesquisa de feixe

Agora estou implementandoseq2seq modelo baseado no código de exemplo quetensorflow fornece. E eu quero ter umadecodificador top-5 resultados para fazer um aprendizado por reforço.

No entanto, eles implementaram o modelo de tradução com decodificador de atenção, portanto, eu deveria implementar a pesquisa por feixe para obtertop-k resultados.

Há uma parte do código que agora implementa (esse código é adicionado aotranslate.py)

Referência porhttps://github.com/tensorflow/tensorflow/issues/654

with tf.Graph().as_default():
  beam_size = FLAGS.beam_size # Number of hypotheses in beam
  num_symbols = FLAGS.tar_vocab_size # Output vocabulary size
  embedding_size = 10
  num_steps = 5
  embedding = tf.zeros([num_symbols, embedding_size])
  output_projection = None

  log_beam_probs, beam_symbols, beam_path = [], [], []

  def beam_search(prev, i):
    if output_projection is not None:
      prev = tf.nn.xw_plus_b(prev, output_projection[0], output_projection[1])

    probs = tf.log(tf.nn.softmax(prev))

    if i > 1:
      probs = tf.reshape(probs + log_beam_probs[-1], [-1, beam_size * num_symbols])

    best_probs, indices = tf.nn.top_k(probs, beam_size)
    indices = tf.stop_gradient(tf.squeeze(tf.reshape(indices, [-1, 1])))
    best_probs = tf.stop_gradient(tf.reshape(best_probs, [-1, 1]))

    symbols = indices % num_symbols      # which word in vocabulary
    beam_parent = indices // num_symbols # which hypothesis it came from

    beam_symbols.append(symbols)
    beam_path.append(beam_parent)
    log_beam_probs.append(best_probs)

    return tf.nn.embedding_lookup(embedding, symbols)

  # Setting up graph.
  inputs = [tf.placeholder(tf.float32, shape=[None, num_symbols]) for i in range(num_steps)]

  for i in range(num_steps):
    beam_search(inputs[i], i+1)

  input_vals = tf.zeros([1, beam_size], dtype=tf.float32)

  input_feed = {inputs[i]: input_vals[i][:beam_size, :] for i in xrange(num_steps)}
  output_feed = beam_symbols + beam_path + log_beam_probs
  session = tf.InteractiveSession()
  outputs = session.run(output_feed, feed_dict=input_feed)

  print("Top_5 Sentences ")
  for predicted in enumerate(outputs[:5]):
    print(list(predicted))
    print("\n")

Na parte input_feed, há um erro:

ValueError: Shape (1, 12) must have rank 1

Existe algum problema no meu código para fazerpesquisa por feixe?

questionAnswers(1)

yourAnswerToTheQuestion