Calcular a distância aos pares em um lote sem replicar o tensor no fluxo de tensão?

Desejo calcular a distância quadrada aos pares de um lote de recursos no Tensorflow. Eu tenho uma implementação simples usando operações + e * lado a lado com o tensor original:

def pairwise_l2_norm2(x, y, scope=None):
    with tf.op_scope([x, y], scope, 'pairwise_l2_norm2'):
        size_x = tf.shape(x)[0]
        size_y = tf.shape(y)[0]
        xx = tf.expand_dims(x, -1)
        xx = tf.tile(xx, tf.pack([1, 1, size_y]))

        yy = tf.expand_dims(y, -1)
        yy = tf.tile(yy, tf.pack([1, 1, size_x]))
        yy = tf.transpose(yy, perm=[2, 1, 0])

        diff = tf.sub(xx, yy)
        square_diff = tf.square(diff)

        square_dist = tf.reduce_sum(square_diff, 1)

        return square_dist

Essa função usa como entrada duas matrizes de tamanho (m, d) e (n, d) e calcula a distância ao quadrado entre cada vetor de linha. A saída é uma matriz de tamanho (m, n) com o elemento 'd_ij = dist (x_i, y_j)'.

O problema é que eu tenho um lote grande e recursos de baixa luminosidade 'm, n, d' replicando o tensor consomem muita memória. Estou procurando outra maneira de implementar isso sem aumentar o uso de memória e apenas armazenar o tensor de distância final. Tipo de loop duplo do tensor original.

questionAnswers(4)

yourAnswerToTheQuestion