Welcome to OGeek Q&A Community for programmer and developer-Open, Learning and Share
Welcome To Ask or Share your Answers For Others

Categories

0 votes
576 views
in Technique[技术] by (71.8m points)

python - Accuracy metric of a subsection of categories in Keras

I've got a 3-class classification problem. Let's define them as classes 0,1 and 2. In my case, class 0 is not important - that is, whatever gets classified as class 0 is irrelevant. What's relevant, however, is accuracy, precision, recall, and error rate only for classes 1 and 2. I would like to define an accuracy metric that only looks at a subsection of the data that relates to 1 and 2 and gives me a measure of that as the model is training. I am not asking for code for accuracy or f1 or precision/recall - those I've found and can implement myself. What I'm asking is for code that can help select a subsection of the categories to perform these metrics on. Visually, with a confusion matrix: Given:

>  0   1   2
>0 10  3   4
>1 2   5   1
>2 8   5   9

I would like to only perform an accuracy measure in-training for the following subset only:

>  1   2
>1 5   1
>2 5   9

Possible idea: Concatenate a categorized, argmaxed y_pred and argmaxed y_true, drop all instances where 0 appears, re-unravel them back into a one_hot array, and do a simple binary accuracy on what remains?

Edit: I've tried to exclude the 0-class through this code, but it doesn't make sense. the 0-category gets effectively wrapped into the 1-category (that is, the true positives of both 0 and 1 end up being labeled as 1). Still looking for help - can anybody help out please?

#this solution does not work :(
def my_acc(y_true, y_pred):
#excluding the 0-category
y_true_cust = y_true[:,np.r_[1:3]]
y_pred_cust = y_pred[:,np.r_[1:3]]
#binary accuracy source code, slightly edited
y_pred_cat = Ker.round(y_pred_cust)
eql_cust = Ker.equal(y_true_cust, y_pred_cust)
return Ker.mean(eql_cust, axis = -1)

@ Ashwin Geet D'Sa

correct_guesses_3cat = 10 + 5 + 9
print(correct_guesses_3cat)
24

total_guesses_3cat = 10+3+4+2+5+1+8+5+9
print(total_guesses_3cat)
47

accuracy_3cat = 24/47
print(accuracy_3cat)
51.1 %

correct_guesses_2cat =5 + 9
print(correct_guesses_2cat)
14

total_guesses_2cat = 5+1+5+9
print(total_guesses_2cat)
20

accuracy_2cat = 14/20
print(accuracy_2cat)
70.0 %
See Question&Answers more detail:os

与恶龙缠斗过久,自身亦成为恶龙;凝视深渊过久,深渊将回以凝视…
Welcome To Ask or Share your Answers For Others

1 Reply

0 votes
by (71.8m points)
Waitting for answers

与恶龙缠斗过久,自身亦成为恶龙;凝视深渊过久,深渊将回以凝视…
OGeek|极客中国-欢迎来到极客的世界,一个免费开放的程序员编程交流平台!开放,进步,分享!让技术改变生活,让极客改变未来! Welcome to OGeek Q&A Community for programmer and developer-Open, Learning and Share
Click Here to Ask a Question

...