tensorflow波士顿房价线性回归预测入门实践

本文关于使用keras搭建简易的线性回归模型用于波士顿房价数据集的预测。

1.关于数据集
这次任务的数据集我直接使用了tensorflow的dataset内置的波士顿房价的数据集。这个数据集包含有特征值和标签两部分。决定房价的特征值一共有13个,房价采用的是房价中的中位数。

2.关于线性回归
经典的线性回归模型主要是用来预测一些存在着线性关系的数据集。回归可以理解为,存在着一个点集,要用一条曲线去拟合这个点集的分布过程。如果拟合的这一条曲线是直线,那这个回归称为线性回归,同时线性回归也是回归模型中最简单的一种。

3.关于模型
对于这个模型,我才用了最简单的一种假设,即认为该函数为形如y = w1x1+w2x2+w3x3+…+w13x13+b。拟合模型的过程就是求解w与b的过程。
我们在训练模型的过程中,必须有一个指导。起到这个作用的就是损失函数,损失函数会给模型的参数指定一个优化的方向,让函数尽可能的去逼近真是真实值。
在这个模型中,我采用了MSE和MAE两种损失函数进行训练,在这个训练过程中可以发现mae的性能更好一些。模型只用到一个神经元即可,一个神经元会产生13个权重和一个偏置量,如下图。但是由于这个模型实在过于简单,不能取得相对较好的效果,因此多加了一个神经元,如红色字体,我的思路是把13个特征值做处理,转为一个特征值来处理,把多元线性回归假设为一元线性回归来处理。
tensorflow波士顿房价线性回归预测入门实践_第1张图片
4.训练与测试结果
然后自然而然的就是把数据送入模型进行训练测试。
用了两种损失函数,相对而言,均方差的loss在相同的训练强度下loss并不能很快的收敛,在模型的预测结果上也不如mae取得的效果好。
这是mae的loss曲线与真实值与预测值得散点图,如下:
tensorflow波士顿房价线性回归预测入门实践_第2张图片
tensorflow波士顿房价线性回归预测入门实践_第3张图片
由图可知,预测数据在房价值偏大的部分出现了不小的偏离

代码部分如下:

from tensorflow.keras import layers
from tensorflow.keras import Sequential
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from tensorflow.keras.datasets import boston_housing
from tensorflow.keras.layers import BatchNormalization

(x,y),(x_,y_) = boston_housing.load_data()

model = Sequential()
model.add(BatchNormalization(input_shape=(13,)))
model.add(layers.Dense(1,activation=None))
model.add(layers.Dense(1,activation=None))
model.compile('adam','mae')
#model.compile('adam','mse')

his = model.fit(x,y,batch_size=32,validation_data=(x_,y_),epochs=50)#

loss = his.history['loss']
val_loss = his.history['val_loss']

plt.figure(figsize=(20,10))
plt.plot(loss,label='train')
plt.plot(val_loss,label='val')
plt.title("linear")
plt.legend()
plt.savefig('./mae_loss.jpg')
plt.show()

Truth = np.array([l for l in y_]).astype('float32').reshape(1,102)
Result = np.array(model.predict(x_.reshape(102,13))).reshape(1,102)

def Draw(X,Y):
    plt.title('Compare')
    x = np.arange(1,50)
    y = x
    plt.plot(x,y)
    plt.scatter(X,Y,color='red')
    plt.savefig('./mae_acc.jpg')
    plt.show()
    
Draw(Truth,Result)

理解与实践多有不足,望批评指正。

你可能感兴趣的:(练习,tensorflow,机器学习,神经网络)