Inicializar marcador de posición de keras como entrada a una capa personalizada
Quiero manipular las activaciones de la capa anterior con una capa de keras personalizada. La capa de abajo simplemente multiplica un número con las activaciones de la capa anterior.
class myLayer(Layer):
def __init__(self, **kwargs):
super(myLayer, self).__init__(**kwargs)
def build(self, input_shape):
self.output_dim = input_shape[0][1]
super(myLayer, self).build(input_shape)
def call(self, inputs, **kwargs):
if not isinstance(inputs, list):
raise ValueError('This layer should be called on a list of inputs.')
mainInput = inputs[0]
nInput = inputs[1]
changed = tf.multiply(mainInput,nInput)
forTest = changed
forTrain = inputs[0]
return K.in_train_phase(forTrain, forTest)
def compute_output_shape(self, input_shape):
print(input_shape)
return (input_shape[0][0], self.output_dim)
Estoy creando el modelo como
inputTensor = Input((5,))
out = Dense(units, input_shape=(5,),activation='relu')(inputTensor)
n = K.placeholder(shape=(1,))
auxInput = Input(tensor=n)
out = myLayer()([out, auxInput])
out = Dense(units, activation='relu')(out)
out = Dense(3, activation='softmax')(out)
model = Model(inputs=[inputTensor, auxInput], outputs=out)
model.compile(optimizer='adam', loss='categorical_crossentropy', metrics='acc'])
Me sale este error cuando intento usar
model.fit(X_train, Y_train, epochs=epochs, verbose=1)
Error
InvalidArgumentError: You must feed a value for placeholder tensor 'Placeholder_3' with dtype float and shape [1]
Y cuando trato de dar el valor al marcador de posición como
model.fit([X_train, np.array([3])], Y_train, epochs=epochs, verbose=1)
Yo obtengo:
ValueError: Error when checking model input: the list of Numpy arrays that you are passing to your model is not the size the model expected. Expected to see 1 arrays but instead got the following list of 2 arrays:
¿Cómo debo inicializar este marcador de posición? Mi objetivo es usar model.evaluate para probar el efecto de diferentes valores de n el modelo durante la inferencia. Gracias.