本文原作者:尹迪,经授权后发布。

1 原理

迭代再加权最小二乘(IRLS)用于解决特定的最优化问题,这个最优化问题的目标函数如下所示:

$$arg min_{\beta} \sum_{i=1}^{n}|y_{i} - f_{i}(\beta)|^{p}$$

这个目标函数可以通过迭代的方法求解。在每次迭代中,解决一个带权最小二乘问题,形式如下:

$$\beta ^{t+1} = argmin_{\beta} \sum_{i=1}^{n} w_{i}(\beta^{(t)}))|y_{i} - f_{i}(\beta)|^{2} = (X{T}WX){-1}XW^{(t)}y$$

在这个公式中,$W^{(t)}$是权重对角矩阵,它的所有元素都初始化为1。每次迭代中,通过下面的公式更新。

$$W_{i}^{(t)} = |y_{i} - X_{i}\beta{(t)}|$$

2 源码分析

spark ml中,迭代再加权最小二乘主要解决广义线性回归问题。下面看看实现代码。

2.1 更新权重

 // Update offsets and weights using reweightFunc
 val newInstances = instances.map { instance =>
    val (newOffset, newWeight) = reweightFunc(instance, oldModel)
    Instance(newOffset, newWeight, instance.features)
 }

这里使用reweightFunc方法更新权重。具体的实现在广义线性回归的实现中。

    /**
     * The reweight function used to update offsets and weights
     * at each iteration of [[IterativelyReweightedLeastSquares]].
     */
    val reweightFunc: (Instance, WeightedLeastSquaresModel) => (Double, Double) = {
      (instance: Instance, model: WeightedLeastSquaresModel) => {
        val eta = model.predict(instance.features)
        val mu = fitted(eta)
        val offset = eta + (instance.label - mu) * link.deriv(mu)
        val weight = instance.weight / (math.pow(this.link.deriv(mu), 2.0) * family.variance(mu))
        (offset, weight)
      }
    }
    
    def fitted(eta: Double): Double = family.project(link.unlink(eta))

这里的model.predict利用带权最小二乘模型预测样本的取值,然后调用fitted方法计算均值函数$\mu$。offset表示 更新后的标签值,weight表示更新后的权重。关于链接函数的相关计算可以参考广义线性回归的分析。

有一点需要说明的是,这段代码中标签和权重的更新并没有参照上面的原理或者说我理解有误。

2.2 训练新的模型

  // 使用更新过的样本训练新的模型 
  model = new WeightedLeastSquares(fitIntercept, regParam, elasticNetParam = 0.0,
        standardizeFeatures = false, standardizeLabel = false).fit(newInstances)

  // 检查是否收敛
  val oldCoefficients = oldModel.coefficients
  val coefficients = model.coefficients
  BLAS.axpy(-1.0, coefficients, oldCoefficients)
  val maxTolOfCoefficients = oldCoefficients.toArray.reduce { (x, y) =>
        math.max(math.abs(x), math.abs(y))
  }
  val maxTol = math.max(maxTolOfCoefficients, math.abs(oldModel.intercept - model.intercept))
  if (maxTol < tol) {
    converged = true
  }

训练完新的模型后,重复2.1步,直到参数收敛或者到达迭代的最大次数。

3 参考文献

【1】Iteratively reweighted least squares

文章来源于腾讯云开发者社区,点击查看原文