tensorflow的基本用法(十)——保存神经网络参数和加载神经网络参数

文章作者:Tyan
博客:noahsnail.com  |  CSDN  |  简书

本文主要是使用tensorfl保存神经网络参数和加载神经网络参数,不包括神经网络框架。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
#!/usr/bin/env python
# _*_ coding: utf-8 _*_

import tensorflow as tf
import numpy as np


# 保存神经网络参数
def save_para():
# 定义权重参数
W = tf.Variable([[1, 2, 3], [4, 5, 6]], dtype = tf.float32, name = 'weights')
# 定义偏置参数
b = tf.Variable([[1, 2, 3]], dtype = tf.float32, name = 'biases')
# 参数初始化
init = tf.global_variables_initializer()
# 定义保存参数的saver
saver = tf.train.Saver()

with tf.Session() as sess:
sess.run(init)
# 保存session中的数据
save_path = saver.save(sess, 'my_net/save_net.ckpt')
# 输出保存路径
print 'Save to path: ', save_path

# 恢复神经网络参数
def restore_para():
# 定义权重参数
W = tf.Variable(np.arange(6).reshape((2, 3)), dtype = tf.float32, name = 'weights')
# 定义偏置参数
b = tf.Variable(np.arange(3).reshape((1, 3)), dtype = tf.float32, name = 'biases')
# 定义提取参数的saver
saver = tf.train.Saver()

with tf.Session() as sess:
# 加载文件中的参数数据,会根据name加载数据并保存到变量W和b中
save_path = saver.restore(sess, 'my_net/save_net.ckpt')
# 输出保存路径
print 'Weights: ', sess.run(W)
print 'biases: ', sess.run(b)


# save_para()
restore_para()

执行结果如下:

1
2
3
4
5
6
7
8
# save
Save to path: my_net/save_net.ckpt


# restore
Weights: [[ 1. 2. 3.]
[ 4. 5. 6.]]
biases: [[ 1. 2. 3.]]

参考资料

  1. https://www.youtube.com/user/MorvanZhou
如果有收获,可以请我喝杯咖啡!