Computes Multi-label confusion matrix.
metric_multilabel_confusion_matrix( num_classes, name = "Multilabel_confusion_matrix", dtype = tf$int32 )
num_classes | Number of unique classes in the dataset. |
---|---|
name | (Optional) String name of the metric instance. |
dtype | (Optional) Data type of the metric result. Defaults to `tf$int32`. |
MultiLabelConfusionMatrix: float
Class-wise confusion matrix is computed for the evaluation of classification. If multi-class input is provided, it will be treated as multilabel data. Consider classification problem with two classes (i.e num_classes=2). Resultant matrix `M` will be in the shape of (num_classes, 2, 2). Every class `i` has a dedicated 2*2 matrix that contains: - true negatives for class i in M(0,0) - false positives for class i in M(0,1) - false negatives for class i in M(1,0) - true positives for class i in M(1,1) ```python # multilabel confusion matrix y_true = tf$constant(list(as.integer(c(1, 0, 1)), as.integer(c(0, 1, 0))), dtype=tf$int32) y_pred = tf$constant(list(as.integer(c(1, 0, 0)), as.integer(c(0, 1, 1))), dtype=tf$int32) output = metric_multilabel_confusion_matrix(num_classes=3) output$update_state(y_true, y_pred) paste('Confusion matrix:', output$result()) # Confusion matrix: [[[1 0] [0 1]] [[1 0] [0 1]] [[0 1] [1 0]]] # if multiclass input is provided y_true = tf$constant(list(as.integer(c(1, 0, 0)), as.integer(c(0, 1, 0))), dtype=tf$int32) y_pred = tf$constant(list(as.integer(c(1, 0, 0)), as.integer(c(0, 0, 1))), dtype=tf$int32) output = metric_multilabel_confusion_matrix(num_classes=3) output$update_state(y_true, y_pred) paste('Confusion matrix:', output$result()) # Confusion matrix: [[[1 0] [0 1]] [[1 0] [1 0]] [[1 1] [0 0]]] ```