使用matplotlib对分类数据集样本可视化

在训练分类器之前,我们常常希望能够对数据集有一个直观的感受,比如像这样:
imgA
今天做CS231n作业一的时候,发现作业里有一段代码用matplotlib实现了这一功能,代码如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
classes = ['plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck']
num_classes = len(classes)
samples_per_class = 7
for y, cls in enumerate(classes):
idxs = np.flatnonzero(y_train == y)
idxs = np.random.choice(idxs, samples_per_class, replace=False)
for i, idx in enumerate(idxs):
plt_idx = i * num_classes + y + 1
plt.subplot(samples_per_class, num_classes, plt_idx)
plt.imshow(X_train[idx].astype('uint8'))
plt.axis('off')
if i == 0:
plt.title(cls)
plt.savefig(yourpath)

这段程序中samples_per_class为每个类别图片采样的张数,随后使用了两次枚举,外层的枚举用来从每个类里随机抽取用于展示的7张图片,内层的枚举用来将挑选出来的图片作为subplot放在matplotlib提供的图框中,最后使用plt.savefig(yourpath)保存即可。