Embedding Variable 导出格式说明

功能介绍

EmbeddingVariable是DeepRec中稀疏参数的表达,用户训练好模型,需要从ckpt 获取稀疏参数的信息,包含:稀疏特征,特征频次,以及特征版本。

EmbeddingVariable在存储为ckpt中,以tensor形式存储,假设图中EV的名称为a/b/c/embedding,则在ckpt 中的tensor为四个,即

  1. "a/b/c/embedding-keys" Tensor 存储EV的key,shape=[N],dtype=int64;

  2. "a/b/c/embedding-values" Tensor 存储EV的value,shape=[N, embedding_dim], dtype=float;

  3. "a/b/c/embedding-freqs" Tenosr 存储EV的频次,shape=[N],dtype=int64;

  4. "a/b/c/embedding-versions" Tensor 存储EV的最后一次更新的step数,shape=[N],dtype=int64

可以通过tensorflow的sdk,读取ckpt 的值从而获取到EV对应的信息 以上一组4个tensor的第一维shape相同,并且按序对应。设置了partition的EV,part 的tensor数目取决于设置的partitioner,该特征的稀疏参数为所有part的全部tensor集合。

代码示例

生成CKPT

import tensorflow as tf

a = tf.get_embedding_variable('a', embedding_dim=4)
b = tf.get_embedding_variable('b', embedding_dim=8, partitioner=tf.fixed_size_partitioner(4))

emb_a = tf.nn.embedding_lookup(a, tf.constant([0,1,2,3,4], dtype=tf.int64))
emb_b = tf.nn.embedding_lookup(b, tf.constant([5,6,7,8,9], dtype=tf.int64))

emb=tf.concat([emb_a, emb_b], axis=1)
loss=tf.reduce_sum(emb)

optimizer=tf.train.AdagradOptimizer(0.1)
train_op=optimizer.minimize(loss)

saver=tf.train.Saver()
with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    print(sess.run([emb, train_op]))
  saver.save(sess, "./ckpt_test/ckpt/model")

读取CKPT

import tensorflow as tf
from tensorflow.core.protobuf.meta_graph_pb2 import MetaGraphDef
from tensorflow.python.framework import meta_graph

ckpt_path='ckpt_test/ckpt'
path=ckpt_path+'/model.meta'
meta_graph = meta_graph.read_meta_graph_file(path)
ev_node_list=[]
for node in meta_graph.graph_def.node:
  if node.op == 'KvVarHandleOp':
    ev_node_list.append(node.name)

print("ev node list", ev_node_list)
# filter ev-slot
non_slot_ev_list=[]
for node in ev_node_list:
  if "Adagrad" not in node:
    non_slot_ev_list.append(node)
print("ev (exculde slot) node list", non_slot_ev_list)

for name in non_slot_ev_list:
  print(name+'-keys', tf.train.load_variable(ckpt_path, name+'-keys'))
  print(name+'-values', tf.train.load_variable(ckpt_path, name+'-values'))
  print(name+'-freqs', tf.train.load_variable(ckpt_path, name+'-freqs'))
  print(name+'-versions', tf.train.load_variable(ckpt_path, name+'-versions'))

运行结果

root@i22b15440:/home/admin/chen.ding/gitlab/code/DeepRec# python ckpt_test/gen_ckpt.py 
2022-04-18 08:56:37.462789: I tensorflow/core/platform/profile_utils/cpu_utils.cc:94] CPU Frequency: 2499445000 Hz
2022-04-18 08:56:37.468578: I tensorflow/compiler/xla/service/service.cc:168] XLA service 0x7fd5bb36fc60 initialized for platform Host (this does not guarantee that XLA will be used). Devices:
2022-04-18 08:56:37.469203: I tensorflow/compiler/xla/service/service.cc:176]   StreamExecutor device (0): Host, Default Version
[array([[ 0.19091757, -1.2494173 , -1.1098509 , -0.88375354,  1.5314401 ,
         0.05683817, -0.3698739 ,  1.7533029 , -0.95090204,  0.597882  ,
        -0.39382792,  1.2318059 ],
       [ 0.19091757, -1.2494173 , -1.1098509 , -0.88375354, -0.83169323,
         0.15894873, -0.66453475,  0.84301287,  1.125458  ,  0.12537971,
         0.7338474 , -0.02672509],
       [ 0.19091757, -1.2494173 , -1.1098509 , -0.88375354, -0.29247162,
        -1.1461716 ,  1.1172409 ,  1.9220417 , -1.2039331 , -1.248681  ,
        -1.9431682 , -0.27165115],
       [ 0.19091757, -1.2494173 , -1.1098509 , -0.88375354, -0.78701115,
        -0.28825614,  1.1483766 , -0.18648145,  0.7928211 , -0.4237969 ,
        -0.00831279,  0.68605185],
       [ 0.19091757, -1.2494173 , -1.1098509 , -0.88375354,  0.5981304 ,
         0.9115529 ,  1.2290154 ,  1.329322  ,  0.81167835,  0.43949136,
         0.4266789 ,  0.5895692 ]], dtype=float32), None]

root@i22b15440:/home/admin/chen.ding/gitlab/code/DeepRec# ls ckpt_test/ckpt/
checkpoint  model.data-00000-of-00001  model.index  model.meta

root@i22b15440:/home/admin/chen.ding/gitlab/code/DeepRec# python ckpt_test/read_ckpt.py 
ev node list ['a', 'b/part_0', 'b/part_1', 'b/part_2', 'b/part_3', 'a/Adagrad', 'b/part_0/Adagrad', 'b/part_1/Adagrad', 'b/part_2/Adagrad', 'b/part_3/Adagrad']
ev (exculde slot) node list ['a', 'b/part_0', 'b/part_1', 'b/part_2', 'b/part_3']
a-keys [0 1 2 3 4]
a-values [[ 0.19091757 -1.2494173  -1.1098509  -0.88375354]
 [ 0.19091757 -1.2494173  -1.1098509  -0.88375354]
 [ 0.19091757 -1.2494173  -1.1098509  -0.88375354]
 [ 0.19091757 -1.2494173  -1.1098509  -0.88375354]
 [ 0.19091757 -1.2494173  -1.1098509  -0.88375354]]
a-freqs []
a-versions []
b/part_0-keys [8]
b/part_0-values [[-0.8823574  -0.3836024   1.0530304  -0.28182772  0.69747484 -0.51914316
  -0.10365905  0.5907056 ]]
b/part_0-freqs []
b/part_0-versions []
b/part_1-keys [5 9]
b/part_1-values [[ 1.4360939  -0.03850809 -0.46522018  1.6579567  -1.0462483   0.5025357
  -0.4891742   1.1364597 ]
 [ 0.50278413  0.81620663  1.1336691   1.2339758   0.7163321   0.3441451
   0.33133262  0.49422294]]
b/part_1-freqs []
b/part_1-versions []
b/part_2-keys [6]
b/part_2-values [[-0.83169323  0.15894873 -0.66453475  0.84301287  1.125458    0.12537971
   0.7338474  -0.02672509]]
b/part_2-freqs []
b/part_2-versions []
b/part_3-keys [7]
b/part_3-values [[-0.3878179  -1.2415178   1.0218947   1.8266954  -1.2992793  -1.3440272
  -2.0385144  -0.36699742]]
b/part_3-freqs []
b/part_3-versions []