argmax的解读
今天我们来看argmax这个op,经常写python的朋友肯定会用到这么一个函数,他是用来返回矩阵最大值的索引的。下面代码通过ipython执行:
In [1]: import numpy as np
In [2]: a = np.arange(6).reshape(2,3) + 10
In [3]: a
Out[3]:
array([[10, 11, 12],
[13, 14, 15]])
In [4]: np.argmax(a)
Out[4]: 5
In [5]: np.argmax(a, axis = 0)
Out[5]: array([1, 1, 1])
In [6]: np.argmax(a, axis = 1)
Out[6]: array([2, 2])
现在我们可以来看ncnn的argmax.cpp和argmax.h的代码了:
先看构造函数:
ArgMax::ArgMax()
{
one_blob_only = true;
}
这里只有一个one_blob_only是true的,说明该函数是单输入,单输出的。
接下来看参数装载函数:
int ArgMax::load_param(const ParamDict& pd)
{
out_max_val = pd.get(0, 0);
topk = pd.get(1, 1);
return 0;
}
该函数内可以看见两个参数out_max_val和topk,暂时通过字面意思不太能分析出其中的含义,我们接下来看推理代码:
int ArgMax::forward(const Mat& bottom_blob, Mat& top_blob, const Option& opt) const
{
//底层blob的所有元素个数
int size = bottom_blob.total();
//由分配器给矩阵分配内存
if (out_max_val)
top_blob.create(topk, 2, 4u, opt.blob_allocator);
else
top_blob.create(topk, 1, 4u, opt.blob_allocator);
if (top_blob.empty())
return -100;
const float* ptr = bottom_blob;
// partial sort topk with index
// optional value
std::vector<std::pair<float, int> > vec;
vec.resize(size);
for (int i = 0; i < size; i++)
{
vec[i] = std::make_pair(ptr[i], i);
}
std::partial_sort(vec.begin(), vec.begin() + topk, vec.end(),
std::greater<std::pair<float, int> >());
//上面这部分代码构造了一个值和索引的向量,且进行部分降序排序,取前topk个值
float* outptr = top_blob;
if (out_max_val)
{
float* valptr = outptr + topk;
for (int i = 0; i < topk; i++)
{
outptr[i] = vec[i].first;
valptr[i] = vec[i].second;
}
}
else
{
for (int i = 0; i < topk; i++)
{
outptr[i] = vec[i].second;
}
}
return 0;
}
现在来看这段代码已经很清晰了,out_max_val这个参数是是否输出值,因为argmax本身是只输出索引的,这里为true可以输出值。topk表示输出为前topk个最大的值,这里topk这个名字一开始有点懵逼,现在看来非常贴切。
上面这段代码虽然少,但信息量不小,可以看出ncnn在进行运算过程中数据是由bottom向top传播的。
最下面那里感觉像是一个小错误,outptr应该是输出的索引值,valptr应该是输出的数据值,这里明天和nihui确认一下。
先把这个pr上去:
argmax
y = argmax(x, out_max_val, topk)
- one_blob_only
param id | name | type | default |
---|---|---|---|
0 | out_max_val | int | 0 |
1 | topk | int | 1 |