使用Python程序识别图像中的数字

最近跑了很多的视频分类的实验,实验日志我记录在google sheet中,每个实验在本地都保存有对应的log和混淆矩阵。
在运行实验的时候,我为了便于观察混淆矩阵,将它保存成了png格式的图像,现在为了计算一个新的指标,我需要用到混淆矩阵中的数据。
最直接的办法就是手动输入,由于数据量较大,而且小数点后位数也很多,这种做法显然有违人道主义精神。
为了解决这个问题,我找到了一种方法,能够识别出图像中的数字。
我的混淆矩阵全都是用统一的格式保存的,因此本文不涉及检测的问题,直接定位就行了。

识别数字

说到识别数字,相信很多人和我的第一反应一样,想到了Mnist数据集。但Mnist是手写数字,而且是单个数字的识别,还需要自己训练模型,或者去调试别人训练好的模型。
我不想搞这么麻烦,于是找到了pytesseract包,pytesseract是基于Google tesseract
ocr
的一个OCR工具,识别率还不错。

安装pytesseract

以Ubuntu为例

1
2
sudo apt-get install tesseract-ocr libtesseract-dev
pip install pytesseract

对混淆矩阵进行识别

代码如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
def getCFMatrix(imgpath, class_num=4, start_position=(40, 128), block_height=95, block_width=93):
'''
Convert confusion matrix to numpy matrix via ocr.
:param imgpath: Confusion matrix file path.
:return: Numpy matrix.
'''
img = cv2.imread(imgpath)
ErrorRows = []
CFMatrix = np.zeros([class_num, class_num])
for i in range(class_num):
for j in range(class_num):
cropped_img = img[start_position[0] + block_height * i:block_height * (i + 1) + start_position[0],
start_position[1] + block_width * j:start_position[1] + block_width * (j + 1)]
text = pytesseract.image_to_string(cropped_img)
CFMatrix[i, j] = float(text)
RowSum = CFMatrix.sum(axis=1)
for i in range(RowSum.shape[0]):
if np.abs(RowSum[i] - 1.0) > 0.0005:
ErrorRows.append(i)
print('-'*20)
print("Error occurs in row {0}".format(i+1))
print(imgpath)
return CFMatrix, ErrorRows

在这段代码中,我手动划分出了混淆矩阵每个数字所在的区域。在查看识别结果的时候,我发现tesseract常常会存在识别错误的情况,比如将5和6,3和8弄混。好在混淆矩阵天生的性质是Ground Truth对应的行/列元素之和应该为1。根据这一性质我们可以检查是否存在识别错误的情况。

将结果保存至Excel

识别完成之后还需要手动校正一些错误,根据上一节,错误位置已经被标记出来,因此校正工作就十分轻松了。
我通过xlsxwriter保存识别后的矩阵以及原始的混淆矩阵图像。
详细的代码在这个Gist页面