sparse_lookup
功能描述
Rec SDK模型训练框架,稀疏特征表查询接口。
当前仅支持一表一查和一表多查。若存在一表多查的情况下,查询次数最大值为128。
暂不支持tf.SparseTensor数据类型,若是tf.SparseTensor需转成tf.Tensor。示例代码如下:
# 示例代码 sparse_ids = tf.SparseTensor(indices=[[0, 0], [1, 2]], values=[1, 2], dense_shape=[3, 4]) dense_ids = tf.sparse.to_dense(sparse_ids, default_value=0) embedding = sparse_lookup(sparse_hashtable, dense_ids)
函数原型
def sparse_lookup(hashtable, ids, send_count, is_train=True, name=None, modify_graph=False, batch=None, access_and_evict_config=None, is_grad=True, serving_default_value, **kwargs)
参数说明
参数名  | 
类型  | 
可选/必选  | 
说明  | 
|---|---|---|---|
hashtable  | 
SparseEmbedding  | 
必选  | 
待查询的稀疏表。  | 
ids  | 
FeatureSpec/tf.Tensor  | 
必选  | 
查询的关键字(key),对应参数类型在不同功能模式下存在区别,具体参见如下。 
  | 
send_count  | 
int  | 
可选参数;开启静态shape时为必选参数。  | 
作为All2All通信技术,取值范围:[1, 2147483647]。 开启动态shape时无需传该参数,或传“None”即可。默认值为“None”。  | 
is_train  | 
bool  | 
必选  | 
是否为训练模式。默认值为“True”。 取值范围: 
  | 
name  | 
str  | 
可选  | 
为该次查询操作创建对应的名称,字符串长度为[1,255]。默认值为“None”。  | 
modify_graph  | 
bool  | 
可选  | 
自动改图功能开关,该功能将在创建Session实例前对模型原图进行修改优化,默认值为“False”。 取值范围: 
  | 
batch  | 
dict  | 
可选  | 
数据集的迭代器。 当同时使用FeatureSpec类型、动态Shape功能时,“batch”参数必须传入。默认值为“None”。  | 
access_and_evict_config  | 
dict  | 
可选  | 
自动改图模式下开启特征准入与淘汰时使用。该dict由两个key-value对组成,“key”分别为“access_threshold”和“eviction_threshold”,“value”为对应的阈值。默认值为“None”。  | 
is_grad  | 
bool  | 
可选  | 
此次查询是否需要梯度更新,默认值为“True”。 取值范围: 
  | 
serving_default_value  | 
tf.Tensor  | 
可选  | 
训练时未准入特征/预测时的新特征的默认emb值。如果不指定,默认为“0”。  | 
**kwargs参数说明
参数名  | 
类型  | 
可选/必选  | 
说明  | 
|---|---|---|---|
feature_spec_name_ids_dict  | 
dict  | 
可选  | 
字典结构,key为FeatureSpec名称,value为公开接口sparse_lookp()的参数ids,无默认值。  | 
multi_lookup  | 
bool  | 
可选  | 
是否存在一表多查的情况,无默认值。 取值范围: 
  | 
lookup_ids  | 
FeatureSpec/tf.Tensor  | 
可选  | 
查询的关键字(key),对应参数类型在不同功能模式下存在区别,具体参见如下。无默认值。 
  | 
 - **kwargs参数中的“feature_spec_name_ids_dict”、“multi_lookup”和“lookup_ids”作为内部使用参数,不建议用户通过kwargs传递这三个参数。
 - 如果通过kwargs传递其他未说明参数,则Rec SDK内部不会使用到该参数。
 
返回值说明
- 成功:返回查询到的Tensor类结果。
 - 失败:抛出异常。
 
使用示例
from mx_rec.core.embedding import sparse_lookup
from mx_rec.core.asc.feature_spec import FeatureSpec
feature_spec = FeatureSpec("sparse_feature", table_name="sparse_embeddings_table",
                                batch_size=1)
embedding = sparse_lookup(sparse_hashtable,
                          feature_spec,
                          send_count=6000,
                          is_train=True,
                          name="sparse_embeddings")
参考资源
接口调用流程及示例,参见模型迁移与训练。