笔者,最近参加的贪心科技的机器学习训练营
学习本是一个反复的过程。
竟然要我写笔记交作业,还要写在知乎
我知乎没文章啊啊啊啊
我赶紧找下之前写的博文
从简单的一元回归分析入门机器学习
用多元线性回归分析问题
机器学习概念
线性回归实例
机器学习入门之线性回归
你所在的公司在电视上做产品广告, 收集到了电视广告投入x(以百万为单位)与产品销售量y(以亿为单位)的数据. 你作为公司的数据科学家, 希望通过分析这些数据, 了解电视广告投入x(以百万为单位)与产品销售量y的关系.
假设x与y的之间的关系是线性的, 也就是说 y = ax + b. 通过线性回归(Linear Regression), 我们就可以得知 a 和 b 的值. 于是我们在未来做规划的时候, 通过电视广告投入x, 就可以预测产品销售量y, 从而可以提前做好生产和物流, 仓储的规划. 为客户提供更好的服务.
import pandas as pdimport numpy as npimport matplotlib.pyplot as pltfrom sklearn.linear_model import LinearRegression
data = pd.read_csv("data/Advertising.csv")
data.head
data.columns
Index(['TV', 'sales'], dtype='object')
通过数据可视化分析数据
plt.figure(figsize=(16, 8))plt.scatter(data['TV'], data['sales'], c ='black')plt.xlabel("Money spent on TV ads")plt.ylabel("Sales")plt.show
训练线性回归模型
# 将pandas的Series变成numpy的ndarrayX = data['TV'].values.reshape(-1,1)y = data['sales'].values.reshape(-1,1)reg = LinearRegressionreg.fit(X, y)
LinearRegression(copy_X=True, fit_intercept=True, n_jobs=None, normalize=False)
print('a = {:.5}'.format(reg.coef_[0][0]))print('b = {:.5}'.format(reg.intercept_[0]))print("线性模型为: Y = {:.5}X + {:.5} ".format(reg.coef_[0][0], reg.intercept_[0]))
a = 0.047537b = 7.0326线性模型为: Y = 0.047537X + 7.0326
可视化训练好的线性回归模型
predictions = reg.predict(X)plt.figure(figsize=(16, 8))plt.scatter(data['TV'], data['sales'], c ='black')plt.plot(data['TV'], predictions,c ='blue', linewidth=2)plt.xlabel("Money spent on TV ads")plt.ylabel("Sales")plt.show
[外链图片转存失败(img-JgGwDudv-1562917224938)(output_11_0.png)]
假设公司希望在下一个季度投一亿元的电视广告, 那么预期的产品销量会是多少呢
predictions = reg.predict([[100])print('投入一亿元的电视广告, 预计的销售量为{:.5}亿'.format( predictions[0][0]) )
投入一亿元的电视广告, 预计的销售量为11.786亿
# 练习df = pd.read_csv('exercise/height.vs.temperature.csv')
df.head
from sklearn.linear_model import LinearRegressionx = df['height'].values.reshape(-1, 1)y = df['temperature'].values.reshape(-1, 1)model= LinearRegressionmodel.fit(x,y)
LinearRegression(copy_X=True, fit_intercept=True, n_jobs=None, normalize=False)
# 查看斜率和截距print(model.coef_)print(model.intercept_)
[[-0.00656953]][12.71850742]
# 查看数据plt.figure(figsize=(16, 8))plt.scatter(df['height'], df['temperature'], c ='black')plt.xlabel("heigth")plt.ylabel("temperature")plt.plot(df['height'],model.predict(df['height'].values.reshape(-1,1)))plt.show
最后当然打广告啦啦啦