Atualizar valores de uma variável de matriz no fluxo tensor, indexação avançada

Eu gostaria de criar uma função que, para cada linha de um dado dado X, aplique a função softmax apenas a algumas classes amostradas, digamos 2, do total de K classes. Em python simples, o código se parece com isso:

def softy(X,W, num_samples):
    N = X.shape[0]
    K = W.shape[0]
    S = np.zeros((N,K)) 
    ar_to_sof = np.zeros(num_samples)
    sampled_ind = np.zeros(num_samples, dtype = int)
    for line in range(N):        
        for samp in range(num_samples):
            sampled_ind[samp] = randint(0,K-1)
            ar_to_sof[samp] = np.dot(X[line],np.transpose(W[sampled_ind[samp]])) 
        ar_to_sof = softmax(ar_to_sof)
        S[line][sampled_ind] = ar_to_sof 

    return S

Finalmente, S conteria zeros e valores diferentes de zero nos índices definidos para cada linha pela matriz "samped_ind". Eu gostaria de implementar isso usando o Tensorflow. O problema é que ele contém indexação "avançada" e não consigo encontrar uma maneira de usar essa biblioteca para criar isso.

Estou tentando isso usando este código:

S = tf.Variable(tf.zeros((N,K)))
tfx = tf.placeholder(tf.float32,shape=(None,D))
wsampled = tf.placeholder(tf.float32, shape = (None,D))
ar_to_sof = tf.matmul(tfx,wsampled,transpose_b=True)
softy = tf.nn.softmax(ar_to_sof)
r = tf.random_uniform(shape=(), minval=0,maxval=K, dtype=tf.int32)
...
for line in range(N):
    sampled_ind = tf.constant(value=[sess.run(r),sess.run(r)],dtype= tf.int32)
    Wsampled = sess.run(tf.gather(W,sampled_ind))
    sess.run(softy,feed_dict={tfx:X[line:line+1], wsampled:Wsampled})

Tudo funciona até aqui, mas não consigo encontrar uma maneira de fazer a atualização que desejo na matriz S, no código python "S [linha] [sampled_ind] = ar_to_sof".

Como eu poderia fazer isso funcionar?

questionAnswers(1)

yourAnswerToTheQuestion