Реализация seq2seq с поиском луча

Я сейчас реализуюseq2seq модель на основе примера кода, которыйtensorflow обеспечивает. И я хочу получитьтоп-5 декодер выходы, чтобы сделать обучение подкрепления.

Тем не менее, они реализовали модель трансляции с помощью декодера внимания, поэтому я должен реализовать лучевой поиск для получениятоп-K Результаты.

Есть часть кода, которая сейчас реализуется (этот код добавляется вtranslate.py).

Ссылка поhttps://github.com/tensorflow/tensorflow/issues/654

with tf.Graph().as_default():
  beam_size = FLAGS.beam_size # Number of hypotheses in beam
  num_symbols = FLAGS.tar_vocab_size # Output vocabulary size
  embedding_size = 10
  num_steps = 5
  embedding = tf.zeros([num_symbols, embedding_size])
  output_projection = None

  log_beam_probs, beam_symbols, beam_path = [], [], []

  def beam_search(prev, i):
    if output_projection is not None:
      prev = tf.nn.xw_plus_b(prev, output_projection[0], output_projection[1])

    probs = tf.log(tf.nn.softmax(prev))

    if i > 1:
      probs = tf.reshape(probs + log_beam_probs[-1], [-1, beam_size * num_symbols])

    best_probs, indices = tf.nn.top_k(probs, beam_size)
    indices = tf.stop_gradient(tf.squeeze(tf.reshape(indices, [-1, 1])))
    best_probs = tf.stop_gradient(tf.reshape(best_probs, [-1, 1]))

    symbols = indices % num_symbols      # which word in vocabulary
    beam_parent = indices // num_symbols # which hypothesis it came from

    beam_symbols.append(symbols)
    beam_path.append(beam_parent)
    log_beam_probs.append(best_probs)

    return tf.nn.embedding_lookup(embedding, symbols)

  # Setting up graph.
  inputs = [tf.placeholder(tf.float32, shape=[None, num_symbols]) for i in range(num_steps)]

  for i in range(num_steps):
    beam_search(inputs[i], i+1)

  input_vals = tf.zeros([1, beam_size], dtype=tf.float32)

  input_feed = {inputs[i]: input_vals[i][:beam_size, :] for i in xrange(num_steps)}
  output_feed = beam_symbols + beam_path + log_beam_probs
  session = tf.InteractiveSession()
  outputs = session.run(output_feed, feed_dict=input_feed)

  print("Top_5 Sentences ")
  for predicted in enumerate(outputs[:5]):
    print(list(predicted))
    print("\n")

В части input_feed есть ошибка:

ValueError: Shape (1, 12) must have rank 1

Есть ли проблема в моем коде, чтобы сделатьЛуч-поиск?

Ответы на вопрос(1)

Ваш ответ на вопрос