2022年 11月 5日

python之拟合

一、多项式拟合

多项式拟合的话,用的的是numpy这个库的polyfit这个函数。那么多项式拟合,最简单的当然是,一次多项式拟合了,就是线性回归。直接看代码吧

  1. import numpy as np
  2. def linear_regression(x,y):
  3. #y=bx+a,线性回归
  4. num=len(x)
  5. b=(np.sum(x*y)-num*np.mean(x)*np.mean(y))/(np.sum(x*x)-num*np.mean(x)**2)
  6. a=np.mean(y)-b*np.mean(x)
  7. return np.array([b,a])
  8. def f(x):
  9. return 2*x+1
  10. x=np.linspace(-5,5)
  11. y=f(x)+np.random.randn(len(x))#加入噪音
  12. y_fit=np.polyfit(x,y,1)#一次多项式拟合,也就是线性回归
  13. print(linear_regression(x,y))
  14. print(y_fit)

手写线性回归我还是会的,然后我们来看下输出:

  1. [1.9937839 1.24167225]
  2. [1.9937839 1.24167225]

由于有random每次显示的结果都不一样,但很明显的是上下两个print是意料之中的一样,emmmmm,一次多项式拟合的源代码应该就是像我写的那样。好了,那么一次以上呢?咳咳,我数学不算太好,还是老老实实用库函数吧,顺便画下图,见识它的威力。

  1. import numpy as np
  2. from matplotlib import pyplot as plt
  3. def f(x):
  4. return x**2+1
  5. def f_fit(x,y_fit):
  6. a,b,c=y_fit.tolist()
  7. return a*x**2+b*x+c
  8. x=np.linspace(-5,5)
  9. y=f(x)+np.random.randn(len(x))#加入噪音
  10. y_fit=np.polyfit(x,y,2)#二次多项式拟合
  11. y_show=np.poly1d(y_fit)#函数优美的形式
  12. print(y_show)#打印
  13. y1=f_fit(x,y_fit)
  14. plt.plot(x,f(x),'r',label='original')
  15. plt.scatter(x,y,c='g',label='before_fitting')#散点图
  16. plt.plot(x,y1,'b--',label='fitting')
  17. plt.title('polyfitting')
  18. plt.xlabel('x')
  19. plt.ylabel('y')
  20. plt.legend()#显示标签
  21. plt.show()

输出:

  1. 2
  2. 1.001 x - 0.04002 x + 0.8952

拟合效果看起来还是不错的。

二、各种函数的拟合

一般来说,多项式的拟合就能拟合很多函数了,比如指数函数,取对数就能化为多项式函数,甚至是一次多项式函数。可是,那些三角函数之类的复杂函数不能化为多项式去拟合,怎么办呢?要用到scipy.optimize的curve_fit函数了。

直接贴代码:

  1. import numpy as np
  2. from matplotlib import pyplot as plt
  3. from scipy.optimize import curve_fit
  4. def f(x):
  5. return 2*np.sin(x)+3
  6. def f_fit(x,a,b):
  7. return a*np.sin(x)+b
  8. def f_show(x,p_fit):
  9. a,b=p_fit.tolist()
  10. return a*np.sin(x)+b
  11. x=np.linspace(-2*np.pi,2*np.pi)
  12. y=f(x)+0.5*np.random.randn(len(x))#加入了噪音
  13. p_fit,pcov=curve_fit(f_fit,x,y)#曲线拟合
  14. print(p_fit)#最优参数
  15. print(pcov)#最优参数的协方差估计矩阵
  16. y1=f_show(x,p_fit)
  17. plt.plot(x,f(x),'r',label='original')
  18. plt.scatter(x,y,c='g',label='before_fitting')#散点图
  19. plt.plot(x,y1,'b--',label='fitting')
  20. plt.xlabel('x')
  21. plt.ylabel('y')
  22. plt.legend()
  23. plt.show()

输出:

  1. [1.91267059 3.04489528]
  2. [[ 9.06910892e-03 -1.83703696e-11]
  3. [-1.83703696e-11 4.44386331e-03]]

使用方法基础的就是这样了。然后更多详细的参数的使用就是要看官网了。

1、https://docs.scipy.org/doc/numpy/reference/generated/numpy.polyfit.html

2、https://docs.scipy.org/doc/scipy-0.18.1/reference/generated/scipy.optimize.curve_fit.html