diff --git a/modAL/multilabel.py b/modAL/multilabel.py index c908674..485f72c 100644 --- a/modAL/multilabel.py +++ b/modAL/multilabel.py @@ -223,6 +223,7 @@ def max_score(classifier: OneVsRestClassifier, X_pool: modALinput, classwise_confidence = classifier.predict_proba(X_pool) classwise_predictions = classifier.predict(X_pool) + classwise_predictions = classwise_predictions[:,None,:] classwise_scores = classwise_confidence*(classwise_predictions - 1/2) classwise_max = np.max(classwise_scores, axis=1)