pytorch/ignite

Possible improvements for Accuracy

Open

#1.089 geöffnet am 31. Mai 2020

Auf GitHub ansehen
 (10 Kommentare) (4 Reaktionen) (1 zugewiesene Person)Python (602 Forks)batch import
HacktoberfestPyDataGlobalenhancementhelp wanted

Repository-Metriken

Stars
 (4.313 Stars)
PR-Merge-Metriken
 (Durchschn. Merge 15T 11h) (17 gemergte PRs in 30 T)

Beschreibung

In full detail the feature request is described here, below is a quick recap.

There are two inconveniences I experience with the current interface of Accuracy.

1. Inconsistent input format for binary classification and multiclass problems

In the first case, Accuracy expects labels as input, whilst in the second case it expects probabilities/logits. I am a somewhat experienced Ignite user and I still get confused by this behavior.

2. No shortcuts for saying "I want to pass logits/probabilities as input"

In practice, I have never used Accuracy in the following manner for binary classification:

accuracy = Accuracy()

Instead, I always do one of the following:

accuracy = Accuracy(transform=lambda x: torch.round(torch.sigmoid(x)))
# either
accuracy = Accuracy(transform=lambda x: torch.round(x))

Suggested solution for both problems: let the user explicitly say in which form input will be passed:

import enum
class Accuracy(...):
    class Mode(enum.Enum):
        LABELS = enum.auto()
        PROBABILITIES = enum.auto()
        LOGITS = enum.auto()

    def __init__(self, mode=Mode.LABELS, ...):
        ...

The suggested interface can be also extended to support custom thresholds by adding the __call__ method to the Mode class.

Contributor Guide