Elegante Methode zum Auswählen eines Elements pro Zeile in Tensorflow

Gegeben..

a MatrixA von Form[m, n] ein TensorI von Form[m]

Ich möchte eine Liste bekommenJ von Elementen ausA woJ[i] = A[i, I[i]].

Das ist,I enthält den Index des Elements, das aus jeder Zeile in @ ausgewählt werden soA.

Kontext: Ich habe bereits dasargmax(A, 1) und jetzt will ich auch dasmax. Ich weiß, dass ich nur @ verwenden kareduce_max. Und nachdem ich ein bisschen rumprobiert hatte, kam ich auch auf folgendes:

J = tf.gather_nd(A,
    tf.transpose(tf.pack([tf.to_int64(tf.range(A.get_shape()[0])), I])))

Bei dem dieto_int64 wird benötigt, da range nur @ erzeuint32 undargmax produziert nurint64.

Keiner der beiden kommt mir besonders elegant vor. Man hat Laufzeit-Overhead (wahrscheinlich über Faktorn) und der andere hat einen unbekannten kognitiven Overheadfaktor. Vermisse ich hier etwas?

Antworten auf die Frage(4)

Ihre Antwort auf die Frage