序列化 — 读写模型

我们现在已经讲了很多,包括

  • 如何处理数据
  • 如何构建模型
  • 如何在数据上训练模型
  • 如何使用不同的损失函数来做分类和回归

但即使知道了所有这些,我们还没有完全准备好来构建一个真正的机器学习系统。这是因为我们还没有讲如何读和写模型。因为现实中,我们通常在一个地方训练好模型,然后部署到很多不同的地方。我们需要把内存中的训练好的模型存在硬盘上好下次使用。

读写NDArrays

作为开始,我们先看看如何读写NDArray。虽然我们可以使用Python的序列化包例如Pickle,不过我们更倾向直接saveload,通常这样更快,而且别的语言,例如R和Scala也能用到。

In [1]:
from mxnet import nd

x = nd.ones(3)
y = nd.zeros(4)
filename = "../data/test1.params"
nd.save(filename, [x, y])

读回来

In [2]:
a, b = nd.load(filename)
print(a, b)

[ 1.  1.  1.]
<NDArray 3 @cpu(0)>
[ 0.  0.  0.  0.]
<NDArray 4 @cpu(0)>

不仅可以读写单个NDArray,NDArray list,dict也是可以的:

In [3]:
mydict = {"x": x, "y": y}
filename = "../data/test2.params"
nd.save(filename, mydict)
In [4]:
c = nd.load(filename)
print(c)
{'x':
[ 1.  1.  1.]
<NDArray 3 @cpu(0)>, 'y':
[ 0.  0.  0.  0.]
<NDArray 4 @cpu(0)>}

读写Gluon模型的参数

跟NDArray类似,Gluon的模型(就是nn.Block)提供便利的save_paramsload_params函数来读写数据。我们同前一样创建一个简单的多层感知机

In [5]:
from mxnet.gluon import nn

def get_net():
    net = nn.Sequential()
    with net.name_scope():
        net.add(nn.Dense(10, activation="relu"))
        net.add(nn.Dense(2))
    return net

net = get_net()
net.initialize()
x = nd.random.uniform(shape=(2,10))
print(net(x))

[[ 0.00202203  0.00100273]
 [-0.00134863  0.00299659]]
<NDArray 2x2 @cpu(0)>

下面我们把模型参数存起来

In [6]:
filename = "../data/mlp.params"
net.save_params(filename)

之后我们构建一个一样的多层感知机,但不像前面那样随机初始化,我们直接读取前面的模型参数。这样给定同样的输入,新的模型应该会输出同样的结果。

In [7]:
import mxnet as mx
net2 = get_net()
net2.load_params(filename, mx.cpu())  # FIXME, gluon will support default ctx later
print(net2(x))

[[ 0.00202203  0.00100273]
 [-0.00134863  0.00299659]]
<NDArray 2x2 @cpu(0)>

总结

通过load_paramssave_params可以很方便的读写模型参数。

吐槽和讨论欢迎点这里