Inyección de vectores word2vec previamente entrenados en TensorFlow seq2seq

Estaba tratando de inyectar vectores word2vec preentrenados en el modelo existente tensorflow seq2seq.

Siguiendoesta respuesta, Produje el siguiente código. Pero no parece mejorar el rendimiento como debería, aunque se actualizan los valores en la variable.

Según tengo entendido, el error podría deberse al hecho de que EmbeddingWrapper o embedding_attention_decoder crean incrustaciones independientemente del orden de vocabulario.

¿Cuál sería la mejor manera de cargar vectores preentrenados en el modelo de tensorflow?

SOURCE_EMBEDDING_KEY = "embedding_attention_seq2seq/RNN/EmbeddingWrapper/embedding"
TARGET_EMBEDDING_KEY = "embedding_attention_seq2seq/embedding_attention_decoder/embedding"


def inject_pretrained_word2vec(session, word2vec_path, input_size, dict_dir, source_vocab_size, target_vocab_size):
  word2vec_model = word2vec.load(word2vec_path, encoding="latin-1")
  print("w2v model created!")
  session.run(tf.initialize_all_variables())

  assign_w2v_pretrained_vectors(session, word2vec_model, SOURCE_EMBEDDING_KEY, source_vocab_path, source_vocab_size)
  assign_w2v_pretrained_vectors(session, word2vec_model, TARGET_EMBEDDING_KEY, target_vocab_path, target_vocab_size)


def assign_w2v_pretrained_vectors(session, word2vec_model, embedding_key, vocab_path, vocab_size):
  vectors_variable = [v for v in tf.trainable_variables() if embedding_key in v.name]
  if len(vectors_variable) != 1:
      print("Word vector variable not found or too many. key: " + embedding_key)
      print("Existing embedding trainable variables:")
      print([v.name for v in tf.trainable_variables() if "embedding" in v.name])
      sys.exit(1)

  vectors_variable = vectors_variable[0]
  vectors = vectors_variable.eval()

  with gfile.GFile(vocab_path, mode="r") as vocab_file:
      counter = 0
      while counter < vocab_size:
          vocab_w = vocab_file.readline().replace("\n", "")
          # for each word in vocabulary check if w2v vector exist and inject.
          # otherwise dont change the value.
          if word2vec_model.__contains__(vocab_w):
              w2w_word_vector = word2vec_model.get_vector(vocab_w)
              vectors[counter] = w2w_word_vector
          counter += 1

  session.run([vectors_variable.initializer],
            {vectors_variable.initializer.inputs[1]: vectors})

Respuestas a la pregunta(1)

Su respuesta a la pregunta