"""Calculate accuracy. Authors * Jianyuan Zhong 2020 """ import torch from speechbrain.dataio.dataio import length_to_mask def Accuracy(log_probabilities, targets, length=None): """Calculates the accuracy for predicted log probabilities and targets in a batch. Arguments --------- log_probabilities : torch.Tensor Predicted log probabilities (batch_size, time, feature). targets : torch.Tensor Target (batch_size, time). length : torch.Tensor Length of target (batch_size,). Returns ------- numerator : float The number of correct samples denominator : float The total number of samples Example ------- >>> probs = torch.tensor([[0.9, 0.1], [0.1, 0.9], [0.8, 0.2]]).unsqueeze(0) >>> acc = Accuracy(torch.log(probs), torch.tensor([1, 1, 0]).unsqueeze(0), torch.tensor([2/3])) >>> print(acc) (1.0, 2.0) """ if length is not None: mask = length_to_mask( length * targets.shape[1], max_len=targets.shape[1], ).bool() if len(targets.shape) == 3: mask = mask.unsqueeze(2).repeat(1, 1, targets.shape[2]) padded_pred = log_probabilities.argmax(-1) if length is not None: numerator = torch.sum( padded_pred.masked_select(mask) == targets.masked_select(mask) ) denominator = torch.sum(mask) else: numerator = torch.sum(padded_pred == targets) denominator = targets.shape[1] return float(numerator), float(denominator) class AccuracyStats: """Module for calculate the overall one-step-forward prediction accuracy. Example ------- >>> probs = torch.tensor([[0.9, 0.1], [0.1, 0.9], [0.8, 0.2]]).unsqueeze(0) >>> stats = AccuracyStats() >>> stats.append(torch.log(probs), torch.tensor([1, 1, 0]).unsqueeze(0), torch.tensor([2/3])) >>> acc = stats.summarize() >>> print(acc) 0.5 """ def __init__(self): self.correct = 0 self.total = 0 def append(self, log_probabilities, targets, length=None): """This function is for updating the stats according to the prediction and target in the current batch. Arguments --------- log_probabilities : torch.Tensor Predicted log probabilities (batch_size, time, feature). targets : torch.Tensor Target (batch_size, time). length : torch.Tensor Length of target (batch_size,). """ numerator, denominator = Accuracy(log_probabilities, targets, length) self.correct += numerator self.total += denominator def summarize(self): """Computes the accuracy metric.""" return self.correct / self.total
Memory