{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# 图像分类数据集(Fashion-MNIST)\n",
"\n",
"在介绍softmax回归的实现前我们先引入一个多类图像分类数据集。它将在后面的章节中被多次使用,以方便我们观察比较算法之间在模型精度和计算效率上的区别。图像分类数据集中最常用的是手写数字识别数据集MNIST [1]。但大部分模型在MNIST上的分类精度都超过了95%。为了更直观地观察算法之间的差异,我们将使用一个图像内容更加复杂的Fashion-MNIST数据集 [2]。\n",
"\n",
"## 获取数据集\n",
"\n",
"首先导入本节需要的包或模块。"
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"%matplotlib inline\n",
"import d2lzh as d2l\n",
"from mxnet.gluon import data as gdata\n",
"import sys\n",
"import time"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"下面,我们通过Gluon的`data`包来下载这个数据集。第一次调用时会自动从网上获取数据。我们通过参数`train`来指定获取训练数据集或测试数据集(testing data set)。测试数据集也叫测试集(testing set),只用来评价模型的表现,并不用来训练模型。"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {
"attributes": {
"classes": [],
"id": "",
"n": "23"
}
},
"outputs": [],
"source": [
"mnist_train = gdata.vision.FashionMNIST(train=True)\n",
"mnist_test = gdata.vision.FashionMNIST(train=False)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"训练集中和测试集中的每个类别的图像数分别为6,000和1,000。因为有10个类别,所以训练集和测试集的样本数分别为60,000和10,000。"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"(60000, 10000)"
]
},
"execution_count": 3,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"len(mnist_train), len(mnist_test)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"我们可以通过方括号`[]`来访问任意一个样本,下面获取第一个样本的图像和标签。"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {
"attributes": {
"classes": [],
"id": "",
"n": "24"
}
},
"outputs": [],
"source": [
"feature, label = mnist_train[0]"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"变量`feature`对应高和宽均为28像素的图像。每个像素的数值为0到255之间8位无符号整数(uint8)。它使用三维的`NDArray`存储。其中的最后一维是通道数。因为数据集中是灰度图像,所以通道数为1。为了表述简洁,我们将高和宽分别为$h$和$w$像素的图像的形状记为$h \\times w$或`(h,w)`。"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"((28, 28, 1), numpy.uint8)"
]
},
"execution_count": 5,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"feature.shape, feature.dtype"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"图像的标签使用NumPy的标量表示。它的类型为32位整数(int32)。"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"(2, numpy.int32, dtype('int32'))"
]
},
"execution_count": 6,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"label, type(label), label.dtype"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Fashion-MNIST中一共包括了10个类别,分别为t-shirt(T恤)、trouser(裤子)、pullover(套衫)、dress(连衣裙)、coat(外套)、sandal(凉鞋)、shirt(衬衫)、sneaker(运动鞋)、bag(包)和ankle boot(短靴)。以下函数可以将数值标签转成相应的文本标签。"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {
"attributes": {
"classes": [],
"id": "",
"n": "25"
}
},
"outputs": [],
"source": [
"# 本函数已保存在d2lzh包中方便以后使用\n",
"def get_fashion_mnist_labels(labels):\n",
" text_labels = ['t-shirt', 'trouser', 'pullover', 'dress', 'coat',\n",
" 'sandal', 'shirt', 'sneaker', 'bag', 'ankle boot']\n",
" return [text_labels[int(i)] for i in labels]"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"下面定义一个可以在一行里画出多张图像和对应标签的函数。"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [],
"source": [
"# 本函数已保存在d2lzh包中方便以后使用\n",
"def show_fashion_mnist(images, labels):\n",
" d2l.use_svg_display()\n",
" # 这里的_表示我们忽略(不使用)的变量\n",
" _, figs = d2l.plt.subplots(1, len(images), figsize=(12, 12))\n",
" for f, img, lbl in zip(figs, images, labels):\n",
" f.imshow(img.reshape((28, 28)).asnumpy())\n",
" f.set_title(lbl)\n",
" f.axes.get_xaxis().set_visible(False)\n",
" f.axes.get_yaxis().set_visible(False)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"现在,我们看一下训练数据集中前9个样本的图像内容和文本标签。"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {
"attributes": {
"classes": [],
"id": "",
"n": "27"
}
},
"outputs": [
{
"data": {
"image/svg+xml": [
"\n",
"\n",
"\n",
"\n"
],
"text/plain": [
"