数学推导
过程参考于:
李航《统计学习方法》多项式函数拟合问题V2 https://blog.csdn.net/xiaolewennofollow/article/details/46757657
《统计学习方法》中关于求拟合多项式系数的问题? - wanger的回答 - 知乎https://www.zhihu.com/question/23483726/answer/73307537
假定给定一个训练数据集:
其中,是输入x的观测值,是相应的输出y的观测值,多项式函数拟合的任务是假设给定数据由次多项式函数生成,选择最有可能产生这些数据的次多项式函数,即在M次多项式函数中选择一个对已知数据以及未知数据都有很好预测能力的函数。
设次多项式为
是个参数。其中为矩阵的第行:
用平方损失作为损失函数,系数1/2是为了方便计算,将模型与训练数据代入,有
对求偏导并令其为0
上式可以化简为:
再化简得:
这样就得到了w的解析解,当然我们也可以用最小二乘法来得到w的解,最小二乘法会更加通用.
python
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import axes3d
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import axes3d
size = 10 #取点的多少
n = 9 #多项式的阶数
x = np.linspace(0,1,size)
noise = np.random.normal(loc=0.0, scale=0.3, size=size)#加噪声
y = np.sin(x*2*np.pi)+noise
# 生成X矩阵
X = np.zeros((size,n+1))
for i in range(0,n+1):
X[:,i] = np.power(x,i)
# 计算梯度
w = np.dot(np.linalg.inv(np.dot(X.T,X)),np.dot(X.T,y))
print(w)
Y = np.sum(X*w,1)
plt.plot(x,Y)
plt.plot(np.linspace(0,1,100),np.sin(np.linspace(0,1,100)*2*np.pi),label='sin')
plt.scatter(x,y)
过拟合欠拟合分析
如上图所示,当时,分别出现了过拟合,正常拟合,欠拟合的情况,由此可见阶数n对拟合情况的影响。
接下来我们考察数据规模对模型的影响,在n=9的情况下,加大数据点数size到15与100
可以看出增大数据集的规模可以减小过拟合情况,也就是说数据集越大,模型受噪声的影响越小,越能体现数据真正的规律.
正则化
正则化的原理其实就是通过给误差函数加一个惩罚项,使得系数不会达到很大的值,系数变小后自然过拟合的可能性就降低了。
这个行为也符合奥卡姆剃刀原理:在所有可能选择的模型中,能够很好地解释已知数据并十分简单的才是最好的模型。
引入正则项后,我们就不太方便直接求得w的解析解,因此我们引入最小二乘法来计算w.
from scipy.optimize import leastsq
def fit_func(w,x):
f = np.poly1d(w)
return f(x)
def cost(w,x,y):
res = fit_func(w,x)-y
return res
def cost_with_reg(w,x,y):
res = fit_func(w,x)-y
res = np.append(res,np.sqrt(regularization) * w)
return res
regularization =1.52299797e-8#正则化系数
w = np.random.randn(n)
w = leastsq(cost_with_reg,w,args=(x,y))
plot_x = np.linspace(0,1,100)
print('拟合参数: ', w)
plt.plot(plot_x,fit_func(w[0],plot_x))
plt.plot(np.linspace(0,1,100),np.sin(np.linspace(0,1,100)*2*np.pi),label='sin')
plt.scatter(x,y,None,'r')
plt.text(0.7,0.8,'lnλ=-18',fontsize=20)
对比未加正则项的结果:
可以看出给正则项合适的lambda可以大大增加模型的泛化能力,并且对比参数表可知,加了正则项后的大小有了很好的限制,不会像无正则项时变得非常大.
附录:
参数表:
无正则项,n=9时:[ 1.11448699e+04, -4.42672236e+04, 7.16836458e+04, -6.09053586e+04,
2.90933720e+04, -7.73548740e+03, 1.03777920e+03, -5.22070047e+01,
7.96235723e-01]
有正则项,n=9时:[ 0.24473024, 0.2027851 , 0.14337231, 0.05886199, -0.0607331 ,
-0.22495246, -0.42590437, -0.5518487 , 0.43506335]
有正则项,n=9时:[ 174.12842562, -389.84399225, 133.95566243, 185.06182615,
-37.81153661, -108.20050901, 46.69397522, -4.59371681,
0.7814253 ]