[D] Keras – Calculating run time precision (and recall) in a multi-label problem?
I’ve come across a bit of a problem, and my attempts at coding a solution seem to have been unsuccessful. I could do this quite easily in numpy, but since the calculations have to be done on the Keras.Backend tensor objects, I can’t figure it out.
Sidenote – how do you guys debug code in these backend functions? You can’t print() or step through…
For my models, I’m mostly interested in precision and recall, because those are what directly impact a real world application of the model. I typically calculate them as such:
from keras import backend as K def pos_precision_acc(y_true, y_pred): interesting_class_id = 1; class_id_true = K.argmax(y_true, axis=-1) class_id_preds = K.argmax(y_pred, axis=-1) # Replace class_id_preds with class_id_true for recall here accuracy_mask = K.cast(K.equal(class_id_preds, interesting_class_id), 'int32') class_acc_tensor = K.cast(K.equal(class_id_true, class_id_preds), 'int32') * accuracy_mask class_acc = K.sum(class_acc_tensor) / K.maximum(K.sum(accuracy_mask), 1) return class_acc
So, if I have a one-hot encoded binary problem with labels as [1 0] for negative and [0 1] for positive, the above code would return the precision for the positive class.
However, in a multi-label problem, my classes are multi-hot-encoded – for example a GT label [1 1 0 0] would indicate that the first two are positive, and the last two are negative. The pos_precision_acc() func above would fail at this, because the argmax call limits the output to a single result per instance.
I tried modifying the above function in several ways, but I’m not having any luck. argmax() is out for sure, because it only returns the idx of the first max value, whereas I expect to have multiple. I also tried to .equal after .flatten() on the tensors, but that didnt help either.
Anyone familiar with Keras.backend tensor operations willing it help out? It should actually be pretty easy…I’m just not having much luck.
Note – I could just calculate the precision and recall on the final model (in fact, I do), but I want to use these outputs as a monitored measure in a parameter search (for performance as well as early stopping), so having it work as a typical custom accuracy metric would be ideal.