大规模深度学习,快速处理和解析TFRecord
已经是必备要求了,记录一下如何快速预览和解析TFRecord
导入相关包
import tensorflow as tf
from tensorflow.python.ops.parsing_ops import FixedLenFeature
tf.enable_eager_execution()
tf.logging.set_verbosity(tf.logging.INFO)
加载TFRecord(这里我加载的压缩格式)
filenames = 'data/20210830/part-r-00000.gz'
raw_dataset = tf.data.TFRecordDataset(filenames=filenames,compression_type='GZIP')
快速解析并预览内容
for raw_record in raw_dataset.take(1):
example = tf.train.Example()
example.ParseFromString(raw_record.numpy())
print(example)
输出如下:
features {
feature {
key: "album_fea"
value {
float_list {
value: 906.0
value: 1957.0
}
}
}
feature {
key: "albumid"
value {
bytes_list {
value: "41595773"
}
}
}
feature {
key: "is_click"
value {
float_list {
value: 0.0
}
}
}
}
上面只是单纯的解析为字符串,如果需要解析为可以Tensorflow可以操作的数据格式,则需要根据输出的数据格式定义feature_description
feature_description = {
'albumid': FixedLenFeature(shape=[1], dtype=tf.string),
'album_fea': FixedLenFeature(shape=[2], dtype=tf.float32),
'is_click': FixedLenFeature( shape=[1], dtype=tf.float32)
}
for serialized_example in raw_dataset.take(1):
features = tf.io.parse_single_example(serialized_example, feature_description)
print(features)
输出如下
{'album_fea': <tf.Tensor: id=748, shape=(2,), dtype=float32, numpy=array([ 906., 1957.], dtype=float32)>,
'albumid': <tf.Tensor: id=749, shape=(1,), dtype=string, numpy=array([b'41595773'], dtype=object)>,
'is_click': <tf.Tensor: id=750, shape=(1,), dtype=float32, numpy=array([0.], dtype=float32)>
}