获取数据预处理函数。
1 | def get_asc_insert_func(tgt_key_specs=None, args_index_list=None, table_names=None, **kwargs) |
参数名 |
类型 |
可选/必选 |
说明 |
---|---|---|---|
tgt_key_specs |
|
必选,可使用两种方式传入,具体请参见说明。 |
特征对象或特征对象列表或者特征对象元组,默认值为“None”。 |
args_index_list |
list[int] |
参数索引列表,默认值为“None”。取值范围:[1, 2^31-1] |
|
table_names |
list[str] |
表名称列表,默认值为“None”。取值范围:[1, 2^31-1] |
接口参数可选择以下其中一种方式传入。
参数名 |
类型 |
可选/必选 |
说明 |
---|---|---|---|
is_training |
bool |
可选 |
是否为训练模式,默认值为“True”。 取值范围:
|
dump_graph |
bool |
可选 |
是否保存模型图,默认值为“False”。 取值范围:
|
1 2 3 4 5 | import tensorflow as tf from mx_rec.core.asc.helper import get_asc_insert_func dataset = tf.data.TFRecordDataset(data_path) # data_path为数据集路径 dataset = dataset.map(get_asc_insert_func(tgt_key_specs=feature_spec_list, is_training=True)) # feature_spec_list中元素为FeatureSpec对象 |
接口调用流程及示例,参见模型迁移与训练。