Blog

Happy II: Adding new metrics

October 4th 2021

Last time I worked on Happy's speech recognition module was when I built the training scripts for the model. The next step is now adding metrics with which to evaluate the performance of the neural network.

Because we are using a CTC-based model, we need to create a decoder to translate the model output into the final prediction. 

CTC addresses the issue of having different length input and output sequences by combining same characters if they are next to each other and not separated by a blank. This works particularly well for the problem of speech recognition because there remains some form of alignment between the input and output data (unlike machine translation for example) even though they are different lengths.


def greedy_decoder(output, labels, blank_label=28, collapse_repeated=True):
    decoded = []
    targets = [[l.item() for l in label] for label in labels]

    for phrase in torch.argmax(output, dim=2):
        decoded.append([0])
        previous = ''

        for arg in phrase:
            if arg != blank_label:
                if not(collapse_repeated and previous == arg):
                    decoded[-1].append(arg.item())
                    previous = arg.item()

            else:
                previous = ''

    decoded, labels = [int_to_text(d) for d in decoded], [
        int_to_text(l) for l in targets]

    return decoded, labels

In this function we keep each character as long as it is not the same as the previous character or a blank. If it is a blank we also reset the last character.

Now, to actually determine how correct this sequence is, we cannot use simple 1-1 accuracy as a single missing character would cause the entire rest of the prediction to be misaligned, even if it is correct. Therefore we want to determine the word-error-rate and character-error-rate. As their names would imply, the former represents how often the model gets a word wrong while that latter is how often the model is wrong about an individual character.

Both calculations are the same but one is performed by comparing the characters and one with the words. The first function we need is Levenshtein distance. Intuitively, this is a calculation of how many edits will get us from one sequence to another. The wikipedia page on the topic is very informative.


def levenshtein_distance(a, b):
    distances = np.zeros((len(a)+1, len(b)+1))

    for i in range(len(a)+1):
        distances[i][0] = i

    for i in range(len(b)+1):
        distances[0][i] = i

    if i in range(1, len(a)+1):
        for j in range(1, len(b)+1):
            if a[i] == b[j]:
                distances[i][j] = distances[i-1][j-1]

            else:
                distances[i][j] = min([
                    distances[i][j-1],
                    distances[i-1][j],
                    distances[i-1][j-1],
                ]) + 1

    return distances[len(a)][len(b)]

Both calculations are the same but one is performed by comparing the characters and one with the words. The first function we need is Levenshtein distance. Intuitively, this is a calculation of how many edits will get us from one sequence to another. The wikipedia page on the topic is very informative. What we are doing for each step is calculating the cheapest operation that will get us to the next step if the two characters are not the same. The final value in the lower right-hand corner of the matrix is the distance. This is then divided by the length of the original sequence to determine the error rate.


def error_rate(output, labels):
    rates = []

    for a, b in zip(output, labels):
        rates.append(levenshtein_distance(a, b) / len(a))

    return rates

Next we determine the average for the batch, which will be used to evaluate the model.


def avg(array):
  return sum(array)/len(array)


def word_error_rate(output, labels):
    output, labels = greedy_decoder(output, labels)
    return avg(error_rate([o.split(' ') for o in output], [l.split(' ') for l in labels]))


def char_error_rate(output, labels):
    return avg(error_rate(*greedy_decoder(output, labels)))