【9】例子--1--general--Plotting Cross-Validated Predictions(交叉验证预测)

看sklearn的英文文档,就算看完还是不清楚到底如何用,还不如直接看它的例子,通过例子来了解我的这个机器学习到底可以有什么用,不说了,就跟刷数学题一样,每天撸这么一个例子。

什么都不说,先放结果图

这是什么鬼??屌不屌?

代码:

from sklearn import datasets
from sklearn.model_selection import cross_val_predict
from sklearn import linear_model
import matplotlib.pyplot as plt

lr = linear_model.LinearRegression()
boston = datasets.load_boston()

y = boston.target

# cross_val_predict returns an array of the same size as `y` where each entry
# is a prediction obtained by cross validation:
predicted = cross_val_predict(lr, boston.data, y, cv=10)

fig, ax = plt.subplots()
ax.scatter(y, predicted)   #散点图
ax.plot([y.min(), y.max()], [y.min(), y.max()], 'k--', lw=4)  #中间的那条直线
ax.set_xlabel('Measured')
ax.set_ylabel('Predicted')
plt.show()

代码详解:

1.这里面首先要搞明白,输入和输出分别是什么?

因为波士顿数据集里面包含的是波士顿房价和周边各个因素(包括城镇犯罪率、一氧化氮浓度、住宅平均房 间数等)的关系 ,其中data指的是环境因素,target指的是价格,毕竟探讨的是环境和价格的关系嘛

print boston.data.shape
print boston.data



(506, 13)
[[ 6.32000000e-03 1.80000000e+01 2.31000000e+00 ..., 1.53000000e+01
3.96900000e+02 4.98000000e+00]
[ 2.73100000e-02 0.00000000e+00 7.07000000e+00 ..., 1.78000000e+01
3.96900000e+02 9.14000000e+00]
[ 2.72900000e-02 0.00000000e+00 7.07000000e+00 ..., 1.78000000e+01
3.92830000e+02 4.03000000e+00]
...,
[ 6.07600000e-02 0.00000000e+00 1.19300000e+01 ..., 2.10000000e+01
3.96900000e+02 5.64000000e+00]
[ 1.09590000e-01 0.00000000e+00 1.19300000e+01 ..., 2.10000000e+01
3.93450000e+02 6.48000000e+00]
[ 4.74100000e-02 0.00000000e+00 1.19300000e+01 ..., 2.10000000e+01
3.96900000e+02 7.88000000e+00]]

boston.data是一个包含了506个样本,每个样本有13个feature

print boston.target

[ 24. 21.6 34.7 33.4 36.2 28.7 22.9 27.1 16.5 18.9 15. 18.9
21.7 20.4 18.2 19.9 23.1 17.5 20.2 18.2 13.6 19.6 15.2 14.5
15.6 13.9 16.6 14.8 18.4 21. 12.7 14.5 13.2 13.1 13.5 18.9
20. 21. 24.7 30.8 34.9 26.6 25.3 24.7 21.2 19.3 20. 16.6
14.4 19.4 19.7 20.5 25. 23.4 18.9 35.4 24.7 31.6 23.3 19.6
18.7 16. 22.2 25. 33. 23.5 19.4 22. 17.4 20.9 24.2 21.7
22.8 23.4 24.1 21.4 20. 20.8 21.2 20.3 28. 23.9 24.8 22.9
23.9 26.6 22.5 22.2 23.6 28.7 22.6 22. 22.9 25. 20.6 28.4
21.4 38.7 43.8 33.2 27.5 26.5 18.6 19.3 20.1 19.5 19.5 20.4
19.8 19.4 21.7 22.8 18.8 18.7 18.5 18.3 21.2 19.2 20.4 19.3
22. 20.3 20.5 17.3 18.8 21.4 15.7 16.2 18. 14.3 19.2 19.6
23. 18.4 15.6 18.1 17.4 17.1 13.3 17.8 14. 14.4 13.4 15.6

2.函数是几个鬼?

LinearRegression 怎么算?

cv 代表什么?

cross_val_predict 得出预测值?怎么实现的?

参考资料:

http://scikit-learn.org/stable/auto_examples/plot_cv_predict.html#sphx-glr-auto-examples-plot-cv-predict-py

个人公众号,比较懒,很少更新,可以在上面提问题:

更多精彩,请移步公众号阅读:

Sam avatar
About Sam
专注生物信息 专注转化医学