caffe--源码阅读笔记1

本系列文章是阅读caffe源码的笔记,主要理解caffe提供的功能和实现方式,对阅读代码重要或者比较难以理解部分进行摘要,给自己备忘也给朋友们阅读代码作为指引。

1. 整体功能

caffe作为工具提供四大功能,具体如下:
Train: Train / Finetune a model,如果训练一个模型,本文会讲解单GPU的训练结果,下一篇会进行多GPU训练过程。
Time: benchmark the execution time of a model, 这个主要是对升级网络每一层进行前向推理和反向传播的时间进行记录,其模型可以从命令行参数传入,其原理相对简单,主要是采用Timer类采用Start和Stop接口进行记录,在本文中不再详细描述。
Test: score a model,这个主要是对一个模型进行前向推理模型,对每一次迭代的每个推理结果积分进行记录,其模型可以从命令行参数传入。
Device Query: show diagnostic information for a GPU device.主要提供工具查询当前GPU相关信息,主要调用CUDA的相关接口,实现也比较简单,在此不再讲述。

1.1 功能注册

RegisterBrewFunction(device_query)
RegisterBrewFunction(train);
RegisterBrewFunction(test);
RegisterBrewFunction(time);

注册实现

#define RegisterBrewFunction(func) \
namespace { \
class __Registerer_##func { \
 public: /* NOLINT */ \
  __Registerer_##func() { \
    g_brew_map[#func] = &func; \
  } \
}; \
__Registerer_##func g_registerer_##func; \
}

注册后,g_brew_map被赋值,分别是4个函数指针,具体如下:

g_brew_map[device_query]  --> int device_query()
g_brew_map[train]  --> int train()
g_brew_map[test]  --> int test()
g_brew_map[time]  --> int time()

1.2 功能调用

如果调用train,则使用GetBrewFunction("train")().
具体实现如下:

static BrewFunction GetBrewFunction(const caffe::string& name) {
  if (g_brew_map.count(name)) {
    return g_brew_map[name];
  } else {
    LOG(ERROR) << "Available caffe actions:";
    for (BrewMap::iterator it = g_brew_map.begin();
         it != g_brew_map.end(); ++it) {
      LOG(ERROR) << "\t" << it->first;
    }
    LOG(FATAL) << "Unknown action: " << name;
    return NULL;  // not reachable, just to suppress old compiler warnings.
  }
}

2. train的使用和实现

2.1. 使用方式:

caffe编译完成后,可以作为可执行文件呈现,如果要训练,则类似调用如下的命令行:

caffe  --solver =**.prototxt  --weights=**.caffemodel --gpu=[0,1...]

2.2 实现方式

2.2.1. 参数FLAGS_*相关

这里是指FLAGS_solver, FLAGS_snapshot, FLAGS_weights等变量

2.2.1.1 定义

在此以FLAGS_solver为例

DEFINE_string(solver, "",
    "The solver definition protocol buffer text file.");

而其具体先挺有意思,如下:

#define DEFINE_string(name, val, txt)                                       \
  namespace fLS {                                                           \
    using ::fLS::clstring;                                                  \
    static union { void* align; char s[sizeof(clstring)]; } s_##name[2];    \
    clstring* const FLAGS_no##name = ::fLS::                                \
                                   dont_pass0toDEFINE_string(s_##name[0].s, \
                                                             val);          \
    static GFLAGS_NAMESPACE::FlagRegisterer o_##name(                       \
        #name, "string", MAYBE_STRIPPED_HELP(txt), __FILE__,                \
        s_##name[0].s, new (s_##name[1].s) clstring(*FLAGS_no##name));      \
    extern GFLAGS_DLL_DEFINE_FLAG clstring& FLAGS_##name;                   \
    using fLS::FLAGS_##name;                                                \
    clstring& FLAGS_##name = *FLAGS_no##name;                               \
  }                                                                         \
  using fLS::FLAGS_##name
2.2.1.2 初始化

此变量的初始化是从命令的入参进行,具体调用接口是 caffe::GlobalInit(&argc, &argv),其实现方式如下:

void GlobalInit(int* pargc, char*** pargv) {
  // Google flags.
  ::gflags::ParseCommandLineFlags(pargc, pargv, true);
  // Google logging.
  ::google::InitGoogleLogging(*(pargv)[0]);
  // Provide a backtrace on segfault.
  ::google::InstallFailureSignalHandler();
}

至于ParseCommandLineFlags如下实现,大家可自行参考gflags库的实现,对于一般应用者而言,了解其功能即可。

2.2.2 SolverParameter

从命令行解析后,遇到一个大类就是Solver,其中SolverParameter是其对应参数

2.2.2.1 定义

solver_param的定义在caffe.proto中定义,其是采用protobuf框架进行定义和自动生成代码。其具体定义如下,各个变量的可根据命名和注释进行理解

message SolverParameter {
  //////////////////////////////////////////////////////////////////////////////
  // Proto filename for the train net, possibly combined with one or more
  // test nets.
  optional string net = 24;
  // Inline train net param, possibly combined with one or more test nets.
  optional NetParameter net_param = 25;

  optional string train_net = 1; // Proto filename for the train net.
  repeated string test_net = 2; // Proto filenames for the test nets.
  optional NetParameter train_net_param = 21; // Inline train net params.
  repeated NetParameter test_net_param = 22; // Inline test net params.

  optional NetState train_state = 26;
  repeated NetState test_state = 27;

  // The number of iterations for each test net.
  repeated int32 test_iter = 3;

  // The number of iterations between two testing phases.
  optional int32 test_interval = 4 [default = 0];
  optional bool test_compute_loss = 19 [default = false];
  // If true, run an initial test pass before the first iteration,
  // ensuring memory availability and printing the starting value of the loss.
  optional bool test_initialization = 32 [default = true];
  optional float base_lr = 5; // The base learning rate
  // the number of iterations between displaying info. If display = 0, no info
  // will be displayed.
  optional int32 display = 6;
  // Display the loss averaged over the last average_loss iterations
  optional int32 average_loss = 33 [default = 1];
  optional int32 max_iter = 7; // the maximum number of iterations
  // accumulate gradients over `iter_size` x `batch_size` instances
  optional int32 iter_size = 36 [default = 1];

  optional string lr_policy = 8;
  optional float gamma = 9; // The parameter to compute the learning rate.
  optional float power = 10; // The parameter to compute the learning rate.
  optional float momentum = 11; // The momentum value.
  optional float weight_decay = 12; // The weight decay.
  // regularization types supported: L1 and L2
  // controlled by weight_decay
  optional string regularization_type = 29 [default = "L2"];
  // the stepsize for learning rate policy "step"
  optional int32 stepsize = 13;
  // the stepsize for learning rate policy "multistep"
  repeated int32 stepvalue = 34;

  optional float clip_gradients = 35 [default = -1];

  optional int32 snapshot = 14 [default = 0]; // The snapshot interval

  optional string snapshot_prefix = 15;
 
  optional bool snapshot_diff = 16 [default = false];
  enum SnapshotFormat {
    HDF5 = 0;
    BINARYPROTO = 1;
  }
  optional SnapshotFormat snapshot_format = 37 [default = BINARYPROTO];
  // the mode solver will use: 0 for CPU and 1 for GPU. Use GPU in default.
  enum SolverMode {
    CPU = 0;
    GPU = 1;
  }
  optional SolverMode solver_mode = 17 [default = GPU];
  // the device_id will that be used in GPU mode. Use device_id = 0 in default.
  optional int32 device_id = 18 [default = 0];
  // If non-negative, the seed with which the Solver will initialize the Caffe
  // random number generator -- useful for reproducible results. Otherwise,
  // (and by default) initialize using a seed derived from the system clock.
  optional int64 random_seed = 20 [default = -1];
  // type of the solver
  optional string type = 40 [default = "SGD"];
  // numerical stability for RMSProp, AdaGrad and AdaDelta and Adam
  optional float delta = 31 [default = 1e-8];
  // parameters for the Adam solver
  optional float momentum2 = 39 [default = 0.999];
  optional float rms_decay = 38 [default = 0.99];
  optional bool debug_info = 23 [default = false];
  // If false, don't save a snapshot after training finishes.
  optional bool snapshot_after_train = 28 [default = true];
  // DEPRECATED: old solver enum types, use string instead
  enum SolverType {
    SGD = 0;
    NESTEROV = 1;
    ADAGRAD = 2;
    RMSPROP = 3;
    ADADELTA = 4;
    ADAM = 5;
  }
  // DEPRECATED: use type instead of solver_type
  optional SolverType solver_type = 30 [default = SGD];
  // Overlap compute and communication for data parallel training
  optional bool layer_wise_reduce = 41 [default = true];
  repeated string weights = 42;
}
2.2.2.2 初始化
  1. caffe::ReadSolverParamsFromTextFileOrDie(FLAGS_solver, &solver_param);
// Read parameters from a file into a SolverParameter proto message.
void ReadSolverParamsFromTextFileOrDie(const string& param_file,
                                       SolverParameter* param) {
  CHECK(ReadProtoFromTextFile(param_file, param))
      << "Failed to parse SolverParameter file: " << param_file;
  UpgradeSolverAsNeeded(param_file, param);
  UpgradeSnapshotPrefixProperty(param_file, param);
}
  1. solver_param.mutable_train_state()->set_level(FLAGS_level);
  2. stage:
vector<string> stages = get_stages_from_flags();
solver_param.mutable_train_state()->add_stage(stages[i]);

3.gpu/cpu:

solver_param.set_device_id(0);
Caffe::set_mode(Caffe::GPU);
  1. caffe::SignalHandler signal_handler()
  2. solver_param.add_weights(FLAGS_weights);
  3. shared_ptr<caffe::Solver<float> > solver(caffe::SolverRegistry<float>::CreateSolver(solver_param));
    这个见后续详细描述
  4. solver->Restore(FLAGS_snapshot.c_str());
template <typename Dtype>
void Solver<Dtype>::Restore(const char* state_file) {
  string state_filename(state_file);
  if (state_filename.size() >= 3 &&
      state_filename.compare(state_filename.size() - 3, 3, ".h5") == 0) {
    RestoreSolverStateFromHDF5(state_filename);
  } else {
    RestoreSolverStateFromBinaryProto(state_filename);
  }
}
  1. run:这个见后续详细描述
    nccl.Run(gpus, FLAGS_snapshot.size() > 0 ? FLAGS_snapshot.c_str() : NULL);
    或者
    solver->Solve();

2.2.3 Solver

Solver是caffe一个主要的类,其具体子类有AdaGrad,SGD,Adam,Nesterov,RMSProp,AdaDelta,下面以SGD为例子进行讲解:

2.2.3.1 SGDSolver定义

在sgd_solvers.hpp 和 sgd_solver.cpp中进行定义,其父类Solver定义Solver基本功能,具体见下面:

template <typename Dtype>
class Solver {
 public:
  explicit Solver(const SolverParameter& param);
  explicit Solver(const string& param_file);
  void Init(const SolverParameter& param);
  void InitTrainNet();
  void InitTestNets();

  void SetActionFunction(ActionCallback func);
  SolverAction::Enum GetRequestedAction();

  virtual void Solve(const char* resume_file = NULL);
  inline void Solve(const string& resume_file) { Solve(resume_file.c_str()); }
  void Step(int iters);
 
  void Restore(const char* resume_file);

  void Snapshot();
  virtual ~Solver() {}

  // Invoked at specific points during an iteration
  class Callback {
   protected:
    virtual void on_start() = 0;
    virtual void on_gradients_ready() = 0;

    template <typename T>
    friend class Solver;
  };
  void add_callback(Callback* value) {
    callbacks_.push_back(value);
  }
  virtual void ApplyUpdate() = 0;
 protected:
  string SnapshotFilename(const string& extension);
  string SnapshotToBinaryProto();
  string SnapshotToHDF5();
  // The test routine
  void TestAll();
  void Test(const int test_net_id = 0);
  virtual void SnapshotSolverState(const string& model_filename) = 0;
  virtual void RestoreSolverStateFromHDF5(const string& state_file) = 0;
  virtual void RestoreSolverStateFromBinaryProto(const string& state_file) = 0;
  void DisplayOutputBlobs(const int net_id);
  void UpdateSmoothedLoss(Dtype loss, int start_iter, int average_loss);
  DISABLE_COPY_AND_ASSIGN(Solver);
};

2.2.3.2 SGDSolver注册和使用

话不多说,注册部分代码是:

INSTANTIATE_CLASS(SGDSolver);
REGISTER_SOLVER_CLASS(SGD);
#define REGISTER_SOLVER_CREATOR(type, creator)                                 \
  static SolverRegisterer<float> g_creator_f_##type(#type, creator<float>);    \
  static SolverRegisterer<double> g_creator_d_##type(#type, creator<double>)   \

#define REGISTER_SOLVER_CLASS(type)                                            \
  template <typename Dtype>                                                    \
  Solver<Dtype>* Creator_##type##Solver(                                       \
      const SolverParameter& param)                                            \
  {                                                                            \
    return new type##Solver<Dtype>(param);                                     \
  }                                                                            \
  REGISTER_SOLVER_CREATOR(type, Creator_##type##Solver)
}

template <typename Dtype>
class SolverRegisterer {
 public:
  SolverRegisterer(const string& type,
      Solver<Dtype>* (*creator)(const SolverParameter&)) {
    // LOG(INFO) << "Registering solver type: " << type;
    SolverRegistry<Dtype>::AddCreator(type, creator);
  }
};

至此,registry[SGD]=SGDSolver

注册完后,进行取用,见上面初始化部分第6点,取其代码是:

// Adds a creator.
  static void AddCreator(const string& type, Creator creator) {
    CreatorRegistry& registry = Registry();
    CHECK_EQ(registry.count(type), 0)
        << "Solver type " << type << " already registered.";
    registry[type] = creator;
  }
2.2.3.3 SGDSolver使用

这个是初始化第7点,其主要流程是:
void Solver<Dtype>::Solve(const char* resume_file):

  1. Restore(resume_file);
  2. Step(param_.max_iter() - iter_);
    这个函数主要是根据配置文件规定的迭代次数,进行轮次,此函数是核心函数,从流程来看,主要是前向推理和反向传播,然后更新loss。关于Net的前向推理和反向传播见下一章节(Net)
template <typename Dtype>
void Solver<Dtype>::Step(int iters) {
 
  while (iter_ < stop_iter) {
    // zero-init the params
    net_->ClearParamDiffs();
    if (param_.test_interval() && iter_ % param_.test_interval() == 0
        && (iter_ > 0 || param_.test_initialization())) {    
        TestAll();
    }

    for (int i = 0; i < callbacks_.size(); ++i) {
      callbacks_[i]->on_start();
    }
  
    for (int i = 0; i < param_.iter_size(); ++i) {
      loss += net_->ForwardBackward();
    }
  
    UpdateSmoothedLoss(loss, start_iter, average_loss);

    ApplyUpdate();

    SolverAction::Enum request = GetRequestedAction();

    // Save a snapshot if needed.
    if ((param_.snapshot()&& iter_ % param_.snapshot() == 0 && Caffe::root_solver()) ||
         (request == SolverAction::SNAPSHOT)) {
      Snapshot();
    }
  }
}
  1. Snapshot();
    此步主要是在整个训练过程中进行过程保存,具体保存间隔有输入参数执行。保存在snapshot目录下保存。
template <typename Dtype>
void Solver<Dtype>::Snapshot() {
  CHECK(Caffe::root_solver());
  string model_filename;
  switch (param_.snapshot_format()) {
  case caffe::SolverParameter_SnapshotFormat_BINARYPROTO:
    model_filename = SnapshotToBinaryProto();
    break;
  case caffe::SolverParameter_SnapshotFormat_HDF5:
    model_filename = SnapshotToHDF5();
    break;
  default:
    LOG(FATAL) << "Unsupported snapshot format.";
  }
  SnapshotSolverState(model_filename);
}
  1. 到了参数固定的迭代次数,进行前向推理并且更新loss。见下面2步,具体见后续代码分解
    net_->Forward(&loss);
    UpdateSmoothedLoss(loss, start_iter, average_loss);
  2. 到了参数规定的测试间隔,进行测试,主要是训练是给工程师判断是过拟合、欠拟合、收敛与否。主要功能是进行调用测试Net进行前向推理更新loss并保存,调用接口是TestAll(),具体代码如下:
template <typename Dtype>
void Solver<Dtype>::Test(const int test_net_id) {
  for (int i = 0; i < param_.test_iter(test_net_id); ++i) {
    Dtype iter_loss;
    const vector<Blob<Dtype>*>& result =
        test_net->Forward(&iter_loss);
 
      for (int j = 0; j < result.size(); ++j) {
        const Dtype* result_vec = result[j]->cpu_data();
        for (int k = 0; k < result[j]->count(); ++k) {
          test_score.push_back(result_vec[k]);
          test_score_output_id.push_back(j);
        }
      }
  }  
  for (int i = 0; i < test_score.size(); ++i) {
   ......
    if (loss_weight) {
      loss_msg_stream << " (* " << loss_weight
                      << " = " << loss_weight * mean_score << " loss)";
    }
    LOG(INFO) << "    Test net output #" << i << ": " << output_name << " = "
              << mean_score << loss_msg_stream.str();
  }
}

2.2.4 Net

Net是caffe另外一个核心类,从上面代码中,我们可以看到Solver的调用的是Net相关的操作。

2.2.4.1 Net定义

template <typename Dtype>
class Net {
 public:
  explicit Net(const NetParameter& param);
  explicit Net(const string& param_file, Phase phase,
      const int level = 0, const vector<string>* stages = NULL);
  virtual ~Net() {}
  /// @brief Initialize a network with a NetParameter.
  void Init(const NetParameter& param);
 
  const vector<Blob<Dtype>*>& Forward(Dtype* loss = NULL);

  Dtype ForwardFromTo(int start, int end);
  Dtype ForwardFrom(int start);
  Dtype ForwardTo(int end);
  const vector<Blob<Dtype>*>& Forward(const vector<Blob<Dtype>* > & bottom,
      Dtype* loss = NULL);

  void Backward();
  void BackwardFromTo(int start, int end);
  void BackwardFrom(int start);
  void BackwardTo(int end);
  void Reshape();
  Dtype ForwardBackward() ;
 void ShareWeights();
  void ShareTrainedLayersWith(const Net* other);
  void CopyTrainedLayersFrom(const NetParameter& param);
  void CopyTrainedLayersFrom(const string& trained_filename);
  void CopyTrainedLayersFromBinaryProto(const string& trained_filename);
  void CopyTrainedLayersFromHDF5(const string& trained_filename);
  /// @brief Writes the net to a proto.
  void ToProto(NetParameter* param, bool write_diff = false) const;
  /// @brief Writes the net to an HDF5 file.
 
  static void FilterNet(const NetParameter& param, NetParameter* param_filtered);

  static bool StateMeetsRule(const NetState& state, const NetStateRule& rule,
      const string& layer_name);
  class Callback {
   protected:
    virtual void run(int layer) = 0;

    template <typename T>
    friend class Net;
  };
 protected:  
  void AppendTop(const NetParameter& param, const int layer_id,
                 const int top_id, set<string>* available_blobs,
                 map<string, int>* blob_name_to_idx);
  /// @brief Append a new bottom blob to the net.
  int AppendBottom(const NetParameter& param, const int layer_id,
                   const int bottom_id, set<string>* available_blobs,
                   map<string, int>* blob_name_to_idx);
  /// @brief Append a new parameter blob to the net.
  void AppendParam(const NetParameter& param, const int layer_id,
                   const int param_id);

DISABLE_COPY_AND_ASSIGN(Net);
};

2.2.4.2 Net主要操作

2.2.4.2.1 ForwardBackward
  Dtype ForwardBackward() {
    Dtype loss;
    Forward(&loss);
    Backward();
    return loss;
  }
2.2.4.2.2 Forward
template <typename Dtype>
const vector<Blob<Dtype>*>& Net<Dtype>::Forward(Dtype* loss) {
  if (loss != NULL) {
    *loss = ForwardFromTo(0, layers_.size() - 1);
  } else {
    ForwardFromTo(0, layers_.size() - 1);
  }
  return net_output_blobs_;
}
template <typename Dtype>
Dtype Net<Dtype>::ForwardFromTo(int start, int end) {
  CHECK_GE(start, 0);
  CHECK_LT(end, layers_.size());
  Dtype loss = 0;
  for (int i = start; i <= end; ++i) {
    for (int c = 0; c < before_forward_.size(); ++c) {
      before_forward_[c]->run(i);
    }
    Dtype layer_loss = layers_[i]->Forward(bottom_vecs_[i], top_vecs_[i]);
    loss += layer_loss;
    if (debug_info_) { ForwardDebugInfo(i); }
    for (int c = 0; c < after_forward_.size(); ++c) {
      after_forward_[c]->run(i);
    }
  }
  return loss;
}
2.2.4.2.3 Backward
template <typename Dtype>
void Net<Dtype>::Backward() {
  BackwardFromTo(layers_.size() - 1, 0);
}
template <typename Dtype>
void Net<Dtype>::BackwardFromTo(int start, int end) {
  CHECK_GE(end, 0);
  CHECK_LT(start, layers_.size());
  for (int i = start; i >= end; --i) {
    for (int c = 0; c < before_backward_.size(); ++c) {
      before_backward_[c]->run(i);
    }
    if (layer_need_backward_[i]) {
      layers_[i]->Backward(
          top_vecs_[i], bottom_need_backward_[i], bottom_vecs_[i]);
      if (debug_info_) { BackwardDebugInfo(i); }
    }
    for (int c = 0; c < after_backward_.size(); ++c) {
      after_backward_[c]->run(i);
    }
  }
}

2.2.5 layer

2.2.5.1 定义

从上述代码看,Net操作由layer来落地,其实也很好理解,net是有layer的组成,所以net的数据运算类都是循环调用layer的操作来实现的。具体定义如下,本文以layer作为实例进行讲解

template <typename Dtype>
class Layer {
 public:
  explicit Layer(const LayerParameter& param)
    : layer_param_(param) {
      // Set phase and copy blobs (if there are any).
      phase_ = param.phase();
      if (layer_param_.blobs_size() > 0) {
        blobs_.resize(layer_param_.blobs_size());
        for (int i = 0; i < layer_param_.blobs_size(); ++i) {
          blobs_[i].reset(new Blob<Dtype>());
          blobs_[i]->FromProto(layer_param_.blobs(i));
        }
      }
    }
  virtual ~Layer() {}
  void SetUp(const vector<Blob<Dtype>*>& bottom,
      const vector<Blob<Dtype>*>& top) {
    CheckBlobCounts(bottom, top);
    LayerSetUp(bottom, top);
    Reshape(bottom, top);
    SetLossWeights(top);
  }
  virtual void LayerSetUp(const vector<Blob<Dtype>*>& bottom,
      const vector<Blob<Dtype>*>& top) {}
  virtual void Reshape(const vector<Blob<Dtype>*>& bottom,
      const vector<Blob<Dtype>*>& top) = 0;
  inline Dtype Forward(const vector<Blob<Dtype>*>& bottom,
      const vector<Blob<Dtype>*>& top);
  inline void Backward(const vector<Blob<Dtype>*>& top,
      const vector<bool>& propagate_down,
      const vector<Blob<Dtype>*>& bottom);
  virtual void ToProto(LayerParameter* param, bool write_diff = false);
 protected:
  virtual void Forward_gpu(const vector<Blob<Dtype>*>& bottom,
      const vector<Blob<Dtype>*>& top);
  virtual void Backward_cpu(const vector<Blob<Dtype>*>& top,
      const vector<bool>& propagate_down,
      const vector<Blob<Dtype>*>& bottom) = 0;
  virtual void Backward_gpu(const vector<Blob<Dtype>*>& top,
      const vector<bool>& propagate_down,
      const vector<Blob<Dtype>*>& bottom);
  }
  DISABLE_COPY_AND_ASSIGN(Layer);
};  // class Layer

2.2.5.2 主要操作

2.2.5.2.1 Forward
template <typename Dtype>
inline Dtype Layer<Dtype>::Forward(const vector<Blob<Dtype>*>& bottom,
    const vector<Blob<Dtype>*>& top) {
  Dtype loss = 0;
  Reshape(bottom, top);
  switch (Caffe::mode()) {
  case Caffe::CPU:
    Forward_cpu(bottom, top);
    for (int top_id = 0; top_id < top.size(); ++top_id) {
      if (!this->loss(top_id)) { continue; }
      const int count = top[top_id]->count();
      const Dtype* data = top[top_id]->cpu_data();
      const Dtype* loss_weights = top[top_id]->cpu_diff();
      loss += caffe_cpu_dot(count, data, loss_weights);
    }
    break;
  case Caffe::GPU:
    Forward_gpu(bottom, top);
#ifndef CPU_ONLY
    for (int top_id = 0; top_id < top.size(); ++top_id) {
      if (!this->loss(top_id)) { continue; }
      const int count = top[top_id]->count();
      const Dtype* data = top[top_id]->gpu_data();
      const Dtype* loss_weights = top[top_id]->gpu_diff();
      Dtype blob_loss = 0;
      caffe_gpu_dot(count, data, loss_weights, &blob_loss);
      loss += blob_loss;
    }
#endif
    break;
  default:
    LOG(FATAL) << "Unknown caffe mode.";
  }
  return loss;
}

以LSTMUnitLayer为例子:

template <typename Dtype>
void LSTMUnitLayer<Dtype>::Forward_gpu(const vector<Blob<Dtype>*>& bottom,
    const vector<Blob<Dtype>*>& top) {
  ...
  LSTMActsForward<Dtype><<<CAFFE_GET_BLOCKS(X_count), CAFFE_CUDA_NUM_THREADS>>>(
      X_count, hidden_dim_, X, X_acts);
  CUDA_POST_KERNEL_CHECK;
  // NOLINT_NEXT_LINE(whitespace/operators)
  LSTMUnitForward<Dtype><<<CAFFE_GET_BLOCKS(count), CAFFE_CUDA_NUM_THREADS>>>(
      count, hidden_dim_, C_prev, X_acts, cont, C, H);
  CUDA_POST_KERNEL_CHECK;
}
template <typename Dtype>
__global__ void LSTMUnitBackward(const int nthreads, const int dim,
    const Dtype* C_prev, const Dtype* X, const Dtype* C, const Dtype* H,
    const Dtype* cont, const Dtype* C_diff, const Dtype* H_diff,
    Dtype* C_prev_diff, Dtype* X_diff) {
  CUDA_KERNEL_LOOP(index, nthreads) {
    const int n = index / dim;
    const int d = index % dim;
    const Dtype* X_offset = X + 4 * dim * n;
    const Dtype i = X_offset[d];
    const Dtype f = X_offset[1 * dim + d];
    const Dtype o = X_offset[2 * dim + d];
    const Dtype g = X_offset[3 * dim + d];
    const Dtype c_prev = C_prev[index];
    const Dtype c = C[index];
    const Dtype tanh_c = tanh(c);
    Dtype* c_prev_diff = C_prev_diff + index;
    Dtype* X_diff_offset = X_diff + 4 * dim * n;
    Dtype* i_diff = X_diff_offset + d;
    Dtype* f_diff = X_diff_offset + 1 * dim + d;
    Dtype* o_diff = X_diff_offset + 2 * dim + d;
    Dtype* g_diff = X_diff_offset + 3 * dim + d;
    const Dtype c_term_diff =
        C_diff[index] + H_diff[index] * o * (1 - tanh_c * tanh_c);
    const Dtype cont_n = cont[n];
    *c_prev_diff = cont_n * c_term_diff * f;
    *i_diff = c_term_diff * g;
    *f_diff = cont_n * c_term_diff * c_prev;
    *o_diff = H_diff[index] * tanh_c;
    *g_diff = c_term_diff * i;
  }
}
2.2.5.2.2 Backward
template <typename Dtype>
inline void Layer<Dtype>::Backward(const vector<Blob<Dtype>*>& top,
    const vector<bool>& propagate_down,
    const vector<Blob<Dtype>*>& bottom) {
  switch (Caffe::mode()) {
  case Caffe::CPU:
    Backward_cpu(top, propagate_down, bottom);
    break;
  case Caffe::GPU:
    Backward_gpu(top, propagate_down, bottom);
    break;
  default:
    LOG(FATAL) << "Unknown caffe mode.";
  }
}

以LSTMUnitLayer为例子:

template <typename Dtype>
__global__ void LSTMActsBackward(const int nthreads, const int dim,
    const Dtype* X_acts, const Dtype* X_acts_diff, Dtype* X_diff) {
  CUDA_KERNEL_LOOP(index, nthreads) {
    const int x_dim = 4 * dim;
    const int d = index % x_dim;
    const Dtype X_act = X_acts[index];
    if (d < 3 * dim) {
      X_diff[index] = X_acts_diff[index] * X_act * (Dtype(1) - X_act);
    } else {
      X_diff[index] = X_acts_diff[index] * (Dtype(1) - X_act * X_act);
    }
  }
}

template <typename Dtype>
void LSTMUnitLayer<Dtype>::Backward_gpu(const vector<Blob<Dtype>*>& top,
    const vector<bool>& propagate_down,
    const vector<Blob<Dtype>*>& bottom) {
...
  LSTMUnitBackward<Dtype>  // NOLINT_NEXT_LINE(whitespace/operators)
      <<<CAFFE_GET_BLOCKS(count), CAFFE_CUDA_NUM_THREADS>>>(count, hidden_dim_,
      C_prev, X_acts, C, H, cont, C_diff, H_diff, C_prev_diff, X_acts_diff);
  CUDA_POST_KERNEL_CHECK;
  const int X_count = bottom[1]->count();
  Dtype* X_diff = bottom[1]->mutable_gpu_diff();
  LSTMActsBackward<Dtype>  // NOLINT_NEXT_LINE(whitespace/operators)
      <<<CAFFE_GET_BLOCKS(X_count), CAFFE_CUDA_NUM_THREADS>>>(
      X_count, hidden_dim_, X_acts, X_acts_diff, X_diff);
  CUDA_POST_KERNEL_CHECK;
}

至于____global__ void LSTMUnitBackward() 请查阅相关CUDA文档

3. Blob

我第一次看到Blob,误会成Bob,是一个人名,一直疑惑。直到看代码才真正明确其含义。其是对layer的数据进行保存和对应的操作。具体看他接口即可。

template <typename Dtype>
class Blob {
 public:
  Blob() : data_(), diff_(), count_(0), capacity_(0) {}
  explicit Blob(const int num, const int channels, const int height,
      const int width);
  explicit Blob(const vector<int>& shape);  
  void Reshape(const int num, const int channels, const int height,
      const int width);
  void Reshape(const vector<int>& shape);
  void Reshape(const BlobShape& shape);
  void ReshapeLike(const Blob& other);
  inline int num_axes() const { return shape_.size(); }
  inline int count() const { return count_; }
  const Dtype* cpu_data() const;
  void set_cpu_data(Dtype* data);
  const int* gpu_shape() const;
  const Dtype* gpu_data() const;
  void set_gpu_data(Dtype* data);
  const Dtype* cpu_diff() const;
  const Dtype* gpu_diff() const;
  Dtype* mutable_cpu_data();
  Dtype* mutable_gpu_data();
  Dtype* mutable_cpu_diff();
  Dtype* mutable_gpu_diff();
  void Update();
  void FromProto(const BlobProto& proto, bool reshape = true);
  void ToProto(BlobProto* proto, bool write_diff = false) const;
  /// @brief Compute the sum of absolute values (L1 norm) of the data.
  Dtype asum_data() const;
  /// @brief Compute the sum of absolute values (L1 norm) of the diff.
  Dtype asum_diff() const;
  /// @brief Compute the sum of squares (L2 norm squared) of the data.
  Dtype sumsq_data() const;
  /// @brief Compute the sum of squares (L2 norm squared) of the diff.
  Dtype sumsq_diff() const;
  void scale_data(Dtype scale_factor);
  void scale_diff(Dtype scale_factor);
  void ShareData(const Blob& other);
  void ShareDiff(const Blob& other);
  bool ShapeEquals(const BlobProto& other);
 protected:
  shared_ptr<SyncedMemory> data_;
  shared_ptr<SyncedMemory> diff_;
  shared_ptr<SyncedMemory> shape_data_;
  vector<int> shape_;
  int count_;
  int capacity_;
  DISABLE_COPY_AND_ASSIGN(Blob);
};  // class Blob

至此基本上把caffe的代码整体流程和框架阅读完毕,接下来会撰写多GPU并行训练方面的代码以及作为算法工程师采用caffe框架进行算法开发的工作流程介绍。欢迎各位提出代码遗漏或者错误的意见,便于补充相关内容。

最后编辑于
©著作权归作者所有,转载或内容合作请联系作者
  • 序言:七十年代末,一起剥皮案震惊了整个滨河市,随后出现的几起案子,更是在滨河造成了极大的恐慌,老刑警刘岩,带你破解...
    沈念sama阅读 219,589评论 6 508
  • 序言:滨河连续发生了三起死亡事件,死亡现场离奇诡异,居然都是意外死亡,警方通过查阅死者的电脑和手机,发现死者居然都...
    沈念sama阅读 93,615评论 3 396
  • 文/潘晓璐 我一进店门,熙熙楼的掌柜王于贵愁眉苦脸地迎上来,“玉大人,你说我怎么就摊上这事。” “怎么了?”我有些...
    开封第一讲书人阅读 165,933评论 0 356
  • 文/不坏的土叔 我叫张陵,是天一观的道长。 经常有香客问我,道长,这世上最难降的妖魔是什么? 我笑而不...
    开封第一讲书人阅读 58,976评论 1 295
  • 正文 为了忘掉前任,我火速办了婚礼,结果婚礼上,老公的妹妹穿的比我还像新娘。我一直安慰自己,他们只是感情好,可当我...
    茶点故事阅读 67,999评论 6 393
  • 文/花漫 我一把揭开白布。 她就那样静静地躺着,像睡着了一般。 火红的嫁衣衬着肌肤如雪。 梳的纹丝不乱的头发上,一...
    开封第一讲书人阅读 51,775评论 1 307
  • 那天,我揣着相机与录音,去河边找鬼。 笑死,一个胖子当着我的面吹牛,可吹牛的内容都是我干的。 我是一名探鬼主播,决...
    沈念sama阅读 40,474评论 3 420
  • 文/苍兰香墨 我猛地睁开眼,长吁一口气:“原来是场噩梦啊……” “哼!你这毒妇竟也来了?” 一声冷哼从身侧响起,我...
    开封第一讲书人阅读 39,359评论 0 276
  • 序言:老挝万荣一对情侣失踪,失踪者是张志新(化名)和其女友刘颖,没想到半个月后,有当地人在树林里发现了一具尸体,经...
    沈念sama阅读 45,854评论 1 317
  • 正文 独居荒郊野岭守林人离奇死亡,尸身上长有42处带血的脓包…… 初始之章·张勋 以下内容为张勋视角 年9月15日...
    茶点故事阅读 38,007评论 3 338
  • 正文 我和宋清朗相恋三年,在试婚纱的时候发现自己被绿了。 大学时的朋友给我发了我未婚夫和他白月光在一起吃饭的照片。...
    茶点故事阅读 40,146评论 1 351
  • 序言:一个原本活蹦乱跳的男人离奇死亡,死状恐怖,灵堂内的尸体忽然破棺而出,到底是诈尸还是另有隐情,我是刑警宁泽,带...
    沈念sama阅读 35,826评论 5 346
  • 正文 年R本政府宣布,位于F岛的核电站,受9级特大地震影响,放射性物质发生泄漏。R本人自食恶果不足惜,却给世界环境...
    茶点故事阅读 41,484评论 3 331
  • 文/蒙蒙 一、第九天 我趴在偏房一处隐蔽的房顶上张望。 院中可真热闹,春花似锦、人声如沸。这庄子的主人今日做“春日...
    开封第一讲书人阅读 32,029评论 0 22
  • 文/苍兰香墨 我抬头看了看天上的太阳。三九已至,却和暖如春,着一层夹袄步出监牢的瞬间,已是汗流浃背。 一阵脚步声响...
    开封第一讲书人阅读 33,153评论 1 272
  • 我被黑心中介骗来泰国打工, 没想到刚下飞机就差点儿被人妖公主榨干…… 1. 我叫王不留,地道东北人。 一个月前我还...
    沈念sama阅读 48,420评论 3 373
  • 正文 我出身青楼,却偏偏与公主长得像,于是被迫代替她去往敌国和亲。 传闻我的和亲对象是个残疾皇子,可洞房花烛夜当晚...
    茶点故事阅读 45,107评论 2 356

推荐阅读更多精彩内容