Turbot-DL入门教程篇-TensorFlow应用-线性回归计算
Turbot-DL入门教程篇-TensorFlow应用-线性回归计算
说明:
- 介绍如何使用tensorflow解决线性回归的问题
环境:
- Python 3.5.2
步骤:
- 创建数据集:
$ vim linear_regression_data.py
import numpy as np
import matplotlib.pyplot as plt
trX = np.linspace(-1, 1, 101)
trY = 2 * trX + \
np.ones(*trX.shape) * 4 + \
np.random.randn(*trX.shape) * 0.03
plt.figure(1)
plt.plot(trX, trY, 'o')
plt.xlabel('trX')
plt.ylabel('trY')
plt.show()
- 训练数据集线性回归模型
$ vim linear_regression.py
#!/usr/bin/env python
import tensorflow as tf
import numpy as np
trX = np.linspace(-1, 1, 101)
trY = 2 * trX + \
np.ones(*trX.shape) * 4 + \
np.random.randn(*trX.shape) * 0.03
X = tf.placeholder(tf.float32)
Y = tf.placeholder(tf.float32)
w = tf.Variable(0.0, name="weights")
b = tf.Variable(0.0, name="biases")
y_model = tf.multiply(X, w) + b
cost = tf.square(Y - y_model)
train_op = tf.train.GradientDescentOptimizer(0.01).minimize(cost)
sess = tf.Session()
init = tf.global_variables_initializer()
sess.run(init)
for i in range(100):
for (x, y) in zip(trX, trY):
sess.run(train_op, feed_dict={X: x, Y: y})
w_ = sess.run(w)
b_ = sess.run(b)
print("Result : trY = " + str(w_) + "*trX + " + str(b_))
- 运行
python3 linear_regression.py
获取最新文章: 扫一扫右上角的二维码加入“创客智造”公众号