Keras Custom Metric para precisão de classe única

Estou criando uma métrica personalizada para medir a precisão de uma classe no meu conjunto de dados de várias classes durante o treinamento. Estou tendo problemas para selecionar a turma.

Os destinos são um ponto quente (por exemplo: o rótulo da classe 0 é [1 0 0 0 0]:

from keras import backend as K

def single_class_accuracy(y_true, y_pred):
    idx = bool(y_true[:, 0])              # boolean mask for class 0 
    class_preds = y_pred[idx]
    class_true = y_true[idx]
    class_acc = K.mean(K.equal(K.argmax(class_true, axis=-1), K.argmax(class_preds, axis=-1)))  # multi-class accuracy  
    return class_acc

O problema é que temos que usar as funções Keras para indexar tensores. Como você cria uma máscara booleana para um tensor? Obrigado.

questionAnswers(1)

yourAnswerToTheQuestion