Calcular distancia por pares en un lote sin replicar tensor en Tensorflow?

Quiero calcular la distancia cuadrada por pares de un lote de características en Tensorflow. Tengo una implementación simple usando operaciones + y * al mosaico del tensor original:

def pairwise_l2_norm2(x, y, scope=None):
    with tf.op_scope([x, y], scope, 'pairwise_l2_norm2'):
        size_x = tf.shape(x)[0]
        size_y = tf.shape(y)[0]
        xx = tf.expand_dims(x, -1)
        xx = tf.tile(xx, tf.pack([1, 1, size_y]))

        yy = tf.expand_dims(y, -1)
        yy = tf.tile(yy, tf.pack([1, 1, size_x]))
        yy = tf.transpose(yy, perm=[2, 1, 0])

        diff = tf.sub(xx, yy)
        square_diff = tf.square(diff)

        square_dist = tf.reduce_sum(square_diff, 1)

        return square_dist

Esta función toma como entrada dos matrices de tamaño (m, d) y (n, d) y calcula la distancia al cuadrado entre cada vector de fila. La salida es una matriz de tamaño (m, n) con el elemento 'd_ij = dist (x_i, y_j)'.

El problema es que tengo un lote grande y las características de alta intensidad 'm, n, d' que replican el tensor consumen mucha memoria. Estoy buscando otra forma de implementar esto sin aumentar el uso de memoria y solo almacenar el tensor de distancia final. Tipo de doble bucle del tensor original.

Respuestas a la pregunta(4)

Su respuesta a la pregunta