tf.gather
tf.gather可以实现根据索引号收集数据的目的。考虑班级成绩册的例子,假设共有4个班级,每个班级35个学生,8门科目,保存成绩册的张量shape为[4,35,8]
x=tf.random.uniform([4,35,8],maxval=100,dtype=tf.int32) # 成绩册张量
现在需要收集第1~2个班级的成绩册,可以给定需要收集班级的索引号:[0,1],并指定班级的维度axis=0,通过tf.gather函数收集数据,带码如下:
tf.gather(x,[0,1],axis=0) # 在班级维度收集1~2班级成绩册
实际上,对于上述需求,通过切片x[:2]可以更加方便的实现。但是对于不规则的索引方式,比如,需要抽查所有班级的第1、4、9、12、13、27号学生的成绩数据,则切片方式实现起来非常麻烦,而tf.gather则是针对于此需求设计的,使用起来更加方便,实现如下:
tf.gather(x, [0,3,8,11,12,26],axis=1)
如果需要收集所有同学的第3和第5门科目的成绩,则可以指定科目维度axis=2,实现如下:
tf.gather(x,[2,4],axis=2)
可以看到,tf.gather非常适合索引号没有规则的场合,其中索引号可以乱序排序,此时收集的数据也是对应顺序,例如:
a = tf.range(8)
a = tf.reshape(a, [4,2])
print(a)
print(tf.gather(a, [3,1,0,2], axis=0))
我们将问题变得稍微复杂一点。如果希望抽查第[2,3]班级的第[3,4,6,27]号同学的科目成绩,则可以通过组合多个tf.gather实现。首先抽出第[2,3]班级,实现如下:
student = tf.gather(x, [1,2], axis=0)
再从这2个班级的同学中提取对应学生成绩,代码如下:
tf.gather(student, [2,3,5,26],axis=1)
此时得到这2个班级4个学生的成绩张量,shape为[2,4,8]
tf.gather_nd
通过tf.gather_nd函数,可以通过指定每次采样点的多维坐标实现采样多个点的目的。抽查第2个班级的第2个同学的所有科目,第3个班级的第3个同学的所有科目,第4个班级的第4个同学的所有科目。那么这3个采样点的索引坐标可以记为:[1,1][2,2][3,3],我们将这个采样方案合并一个List参数,即[[1,1][2,2][3,3]],通过tf.gather_nd函数即可,实现如下:
tf.gather_nd(x, [[1,1],[2,2],[3,3]])
可以看到,结果与串行采样方式的完全一样,实现更加简洁,计算效率大大提升。