在TensorFlow的使用过程中,我们常常希望得到一个tensor的维度信息使用,具体的说,也就是现在有了一个tensor值,如何才能得到其shape信息,也就是维度值作为一个整数值使用呢?
对于一个tensor值,我们很容易利用tensor.get_shape() 或 tf.shape(tensor)来获取其shape。但是,这两种方法返回的shape信息都是Dimension 类型的,并非int32类型的。下面两种方法可以获得tensor shape的具体值。
- 方法一:利用as_list()方法
利用tensor.get_shape().as_list() 方法。对于一个2-D的tensor,获得其行列值可以这样做,
num_rows, num_cols = X.get_shape().as_list()
- 方法二:利用Dimension对象的value属性
利用tensor.get_shape()[0].value 方法。对于一个2-D的tensor,获得其行列值可以这样做,
num_rows, num_cols = map(lambda i: i.value, X.get_shape())