Función de pérdida Keras con parámetro dinámico adicional

Estoy trabajando en la implementación de la repetición de experiencia priorizada para una red de q profundo, y parte de la especificación es multiplicar los gradientes por lo que se conoce como pesos de muestreo de importancia (IS). La modificación del gradiente se discute en la sección 3.4 del siguiente documento:https: //arxiv.org/pdf/1511.05952.pd Estoy luchando con la creación de una función de pérdida personalizada que tenga una variedad de pesos IS además dey_true yy_pred.

Aquí hay una versión simplificada de mi modelo:

import numpy as np
import tensorflow as tf

# Input is RAM, each byte in the range of [0, 255].
in_obs = tf.keras.layers.Input(shape=(4,))

# Normalize the observation to the range of [0, 1].
norm = tf.keras.layers.Lambda(lambda x: x / 255.0)(in_obs)

# Hidden layers.
dense1 = tf.keras.layers.Dense(128, activation="relu")(norm)
dense2 = tf.keras.layers.Dense(128, activation="relu")(dense1)
dense3 = tf.keras.layers.Dense(128, activation="relu")(dense2)
dense4 = tf.keras.layers.Dense(128, activation="relu")(dense3)

# Output prediction, which is an action to take.
out_pred = tf.keras.layers.Dense(2, activation="linear")(dense4)

opt     = tf.keras.optimizers.Adam(lr=5e-5)
network = tf.keras.models.Model(inputs=in_obs, outputs=out_pred)
network.compile(optimizer=opt, loss=huber_loss_mean_weighted)

Aquí está mi función de pérdida personalizada, que es solo una implementación de Huber Loss multiplicada por los pesos de IS:

'''
 ' Huber loss: https://en.wikipedia.org/wiki/Huber_loss
'''
def huber_loss(y_true, y_pred):
  error = y_true - y_pred
  cond  = tf.keras.backend.abs(error) < 1.0

  squared_loss = 0.5 * tf.keras.backend.square(error)
  linear_loss  = tf.keras.backend.abs(error) - 0.5

  return tf.where(cond, squared_loss, linear_loss)

'''
 ' Importance Sampling weighted huber loss.
'''
def huber_loss_mean_weighted(y_true, y_pred, is_weights):
  error = huber_loss(y_true, y_pred)

  return tf.keras.backend.mean(error * is_weights)

Lo importante es queis_weights es dinámico, es decir, es diferente cada vez quefit() se llama. Como tal, no puedo simplemente cerrar sobreis_weights como se describe aquí:Haga una función de pérdida personalizada en keras

Encontré este código en línea, que parece utilizar unaLambda capa para calcular la pérdida:https: //github.com/keras-team/keras/blob/master/examples/image_ocr.py#L47 Parece prometedor, pero me cuesta entenderlo / adaptarlo a mi problema particular. Cualquier ayuda es apreciada.

Respuestas a la pregunta(1)

Su respuesta a la pregunta