Uri Goren Uri Goren - 1 year ago 254
Python Question

Accuracy score in pyTorch LSTM

I have been running this LSTM tutorial on the wikigold.conll NER data set

training_data
contains a list of tuples of sequences and tags, for example:

training_data = [
("They also have a song called \" wake up \"".split(), ["O", "O", "O", "O", "O", "O", "I-MISC", "I-MISC", "I-MISC", "I-MISC"]),
("Major General John C. Scheidt Jr.".split(), ["O", "O", "I-PER", "I-PER", "I-PER"])
]


And I wrote down this function

def predict(indices):
"""Gets a list of indices of training_data, and returns a list of predicted lists of tags"""
for index in indicies:
inputs = prepare_sequence(training_data[index][0], word_to_ix)
tag_scores = model(inputs)
values, target = torch.max(tag_scores, 1)
yield target


This way I can get the predicted labels for specific indices in the training data.

However, how do I evaluate the accuracy score across all training data.

Accuracy being, the amount of words correctly classified across all sentences divided by the word count.

This is what I came up with, which is extremely slow and ugly:



y_pred = list(predict([s for s, t in training_data]))
y_true = [t for s, t in training_data]
c=0
s=0
for i in range(len(training_data)):
n = len(y_true[i])
#super ugly and ineffiicient
s+=(sum(sum(list(y_true[i].view(-1, n) == y_pred[i].view(-1, n).data))))
c+=n

print ('Training accuracy:{a}'.format(a=float(s)/c))


How can this be done efficiently in pytorch ?



P.S:
I've been trying to use sklearn's accuracy_score unsuccessfully

Answer Source

I would use numpy in order to not iterate the list in pure python.

The results are the same, but it runs much faster

def accuracy_score(y_true, y_pred):
    y_pred = np.concatenate(tuple(y_pred))
    y_true = np.concatenate(tuple([[t for t in y] for y in y_true])).reshape(y_pred.shape)
    return (y_true == y_pred).sum() / float(len(y_true))

And this is how to use it:

#original code:
y_pred = list(predict([s for s, t in training_data]))
y_true = [t for s, t in training_data]
#numpy accuracy score
print(accuracy_score(y_true, y_pred))
Recommended from our users: Dynamic Network Monitoring from WhatsUp Gold from IPSwitch. Free Download