找回密码
 注nanjixiong2017册

QQ登录

只需一步,快速开始

发表帖子

(Caffe,LeNet)权值更新(七)

[复制链接]
在Solver::ApplyUpdate()函数中,根据反向传播阶段计算的loss关于网络权值的偏导,使用配置的学习策略,更新网络权值从而完成本轮学习。
1 模型优化1.1 损失函数
损失函数L(W)可由经验损失加正则化项得到,如下,其中X(i)为输入样本;fW为某样本的损失函数;N为mini-batch的样本数量;r(W)为以权值为λ的正则项。
L(W)≈1N∑NifW(X(i))+λr(W)
在caffe中,可以分为三个阶段:
  • 前向计算阶段,这个阶段计算fW
  • 反向传播阶段,这个阶段计算∇fW
  • 权值更新阶段,这个阶段通过∇fW,∇r(W)等计算ΔW从而更新W
1.2 随机梯度下降
在lenet中,solver的类型为SGD(Stochastic gradient descent)
SGD通过以下公式对权值进行更新:
Wt+1=Wt+Vt+1
Vt+1=μVt−α∇L(Wt)
其中,Wt+1为第t+1轮的权值;Vt+1为第t+1轮的更新(也可以写作ΔWt+1);μ为上一轮更新的权重;α为学习率;∇L(Wt)为loss对权值的求导
2 代码分析2.1 ApplyUpdate
  1. void SGDSolver<Dtype>::ApplyUpdate() {
  2.   // 获取该轮迭代的学习率(learning rate)
  3.   Dtype rate = GetLearningRate();

  4.   // 对每一层网络的权值进行更新
  5.   // 在lenet中,只有`conv1`,`conv2`,`ip1`,`ip2`四层有参数
  6.   // 每层分别有参数与偏置参数两项参数
  7.   // 因而`learnable_params_`的size为8.
  8.   for (int param_id = 0; param_id < this->net_->learnable_params().size();
  9.        ++param_id) {
  10.     // 归一化,iter_size为1不需要,因而lenet不需要。
  11.     // 此处的归一化内容很简单,仅仅是iter_size大于1时值再除以iter_size
  12.     Normalize(param_id);
  13.     // 正则化
  14.     Regularize(param_id);
  15.     // 计算更新值\delta w
  16.     ComputeUpdateValue(param_id, rate);
  17.   }
  18.   // 更新权值
  19.   this->net_->Update();
  20. }
复制代码

说明:

lenet中学习参数设置可从lenet_solver.prototxt中查到
  1. # The base learning rate, momentum and the weight decay of the network.

  2. base_lr: 0.01
  3. momentum: 0.9
  4. weight_decay: 0.0005

  5. # The learning rate policy

  6. lr_policy: "inv"
  7. gamma: 0.0001
  8. power: 0.75
复制代码
获取学习率函数ApplyUpdate代码此处不给出,查看注释(以及caffe.proto)可知有如下学习率获取策略。在Lenet中采用的是inv的策略,是一种没一轮迭代学习率都改变的策略。
  1. // The learning rate decay policy. The currently implemented learning rate
  2.   // policies are as follows:
  3.   //    - fixed: always return base_lr.
  4.   //    - step: return base_lr * gamma ^ (floor(iter / step))
  5.   //    - exp: return base_lr * gamma ^ iter
  6.   //    - inv: return base_lr * (1 + gamma * iter) ^ (- power)
  7.   //    - multistep: similar to step but it allows non uniform steps defined by
  8.   //      stepvalue
  9.   //    - poly: the effective learning rate follows a polynomial decay, to be
  10.   //      zero by the max_iter. return base_lr (1 - iter/max_iter) ^ (power)
  11.   //    - sigmoid: the effective learning rate follows a sigmod decay
  12.   //      return base_lr ( 1/(1 + exp(-gamma * (iter - stepsize))))
  13.   //
  14.   // where base_lr, max_iter, gamma, step, stepvalue and power are defined
  15.   // in the solver parameter protocol buffer, and iter is the current iteration.
复制代码
2.2 Regularize
该函数实际执行以下公式
∂loss∂wij=decay∗wij+∂loss∂wij
代码如下:
  1. void SGDSolver<Dtype>::Regularize(int param_id) {
  2.   const vector<Blob<Dtype>*>& net_params = this->net_->learnable_params();
  3.   const vector<float>& net_params_weight_decay =
  4.       this->net_->params_weight_decay();
  5.   Dtype weight_decay = this->param_.weight_decay();
  6.   string regularization_type = this->param_.regularization_type();
  7.   // local_decay = 0.0005 in lenet
  8.   Dtype local_decay = weight_decay * net_params_weight_decay[param_id];

  9.   ...
  10.       if (regularization_type == "L2") {
  11.         // axpy means ax_plus_y. i.e., y = a*x + y
  12.         caffe_axpy(net_params[param_id]->count(),
  13.             local_decay,
  14.             net_params[param_id]->cpu_data(),
  15.             net_params[param_id]->mutable_cpu_diff());
  16.       }
  17.   ...
  18. }
复制代码
2.3 ComputeUpdateValue
该函数实际执行以下公式
vij=lr_rate∗∂loss∂wij+momentum∗vij
∂loss∂wij=vij

代码如下:
  1. void SGDSolver<Dtype>::ComputeUpdateValue(int param_id, Dtype rate) {
  2.   const vector<Blob<Dtype>*>& net_params = this->net_->learnable_params();
  3.   const vector<float>& net_params_lr = this->net_->params_lr();
  4.   // momentum = 0.9 in lenet
  5.   Dtype momentum = this->param_.momentum();
  6.   // local_rate = lr_mult * global_rate
  7.   // lr_mult为该层学习率乘子,在lenet_train_test.prototxt中设置
  8.   Dtype local_rate = rate * net_params_lr[param_id];

  9.   // Compute the update to history, then copy it to the parameter diff.

  10.   ...
  11.     // axpby means ax_plus_by. i.e., y = ax + by
  12.     // 计算新的权值更新变化值 \delta w,结果保存在历史权值变化中
  13.     caffe_cpu_axpby(net_params[param_id]->count(), local_rate,
  14.               net_params[param_id]->cpu_diff(), momentum,
  15.               history_[param_id]->mutable_cpu_data());

  16.     // 从历史权值变化中把变化值 \delta w 保存到历史权值中diff中
  17.     caffe_copy(net_params[param_id]->count(),
  18.         history_[param_id]->cpu_data(),
  19.         net_params[param_id]->mutable_cpu_diff());
  20.    ...
  21. }
复制代码
2.4 net_->Update
以下公式:

实际执行
wij=wij+(−1)∗∂loss∂wij
  1. </blockquote></div><div class="blockcode"><blockquote>caffe_axpy<Dtype>(count_, Dtype(-1),
  2.         static_cast<const Dtype*>(diff_->cpu_data()),
  3.         static_cast<Dtype*>(data_->mutable_cpu_data()))
复制代码

出处:http://blog.csdn.net/mounty_fsc/



使用道具 举报 回复
您需要登录后才可以回帖 登录 | 注nanjixiong2017册

本版积分规则