大数据AI

贪心机器学习训练营

来源:贪心学院,https://www.zhihu.com/people/tan-xin-xue-yuan/activities

knn回归作业

二手车估计案例

贪心机器学习训练营

import pandas as pdimport matplotlibimport matplotlib.pyplot as pltimport numpy as npimport seaborn as sns

#读取数据df = pd.read_csv('data.csv')df # data frame

数据预处理

Color 是文本,onehot文本转数字标准化

onehot编码用get_dummies

#清洗数据# 把颜色独热编码df_colors = df['Color'].str.get_dummies().add_prefix('Color: ')# 把类型独热编码df_type = df['Type'].apply(str).str.get_dummies().add_prefix('Type: ')# 添加独热编码数据列df = pd.concat([df, df_colors, df_type], axis=1)# 去除独热编码对应的原始列df = df.drop(['Brand', 'Type', 'Color'], axis=1)df

# 数据转换matrix = df.corr()f, ax = plt.subplots(figsize=(8, 6))sns.heatmap(matrix, square=True)plt.title('Car Price Variables')

sns.pairplot(df[['Construction Year', 'Days Until MOT', 'Odometer', 'Ask Price']], size=3)plt.show()

numpy.ravelj就是reshape(-1, order=order)

贪心机器学习训练营

from sklearn.neighbors import KNeighborsRegressorfrom sklearn.model_selection import train_test_splitfrom sklearn import preprocessing # 预处理from sklearn.preprocessing import StandardScaler # 标准化import numpy as npX = df[['Construction Year', 'Days Until MOT', 'Odometer']]y = df['Ask Price'].values.reshape(-1, 1) # Series 转 ndarrayX_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3, random_state=41)X_normalizer = StandardScaler() # N(0,1)# 先用x_train fitX_train = X_normalizer.fit_transform(X_train)X_test = X_normalizer.transform(X_test)y_normalizer = StandardScaler()y_train = y_normalizer.fit_transform(y_train)y_test = y_normalizer.transform(y_test)knn = KNeighborsRegressor(n_neighbors=2)# knn.fit(X_train, y_train.ravel())knn.fit(X_train, y_train)#Now we can predict prices:y_pred = knn.predict(X_test)y_pred_inv = y_normalizer.inverse_transform(y_pred)y_test_inv = y_normalizer.inverse_transform(y_test)# Build a plotplt.scatter(y_pred_inv, y_test_inv)plt.xlabel('Prediction')plt.ylabel('Real value')# Now add the perfect prediction linediagonal = np.linspace(500, 1500, 100)plt.plot(diagonal, diagonal, '-r')plt.xlabel('Predicted ask price')plt.ylabel('Ask price')plt.show()print(y_pred_inv)

Similar Posts

发表评论

邮箱地址不会被公开。 必填项已用*标注