侧边栏壁纸
博主头像
无知无畏 博主等级

行动起来,活在当下

  • 累计撰写 13 篇文章
  • 累计创建 6 个标签
  • 累计收到 0 条评论

目 录CONTENT

文章目录

tensorflow 1.12 如何实现增量更新 embedding

快乐玩耍
2024-11-03 / 0 评论 / 0 点赞 / 1 阅读 / 0 字

Tensorflow 1.12 是静态图,并且平时使用 tensorflow 原生 API 的时候,习惯将 ID-Index- Embedding 的过程建立到图里。当需要对保存到图里的这个映射关系进行更新,也就是实现增量更新时,我们通常会感到苦恼。
下面是增量更新 ID-Index-Embedding 的过程:

# TensorFlow 1.x 示例

import tensorflow as tf

# 假设你有meta文件和数据文件路径
meta_file = 'path/to/model.ckpt.meta'
checkpoint_path = 'path/to/model.ckpt'

# 创建一个新的会话
with tf.Session() as sess:
    # 导入元图以恢复图结构和变量
    saver = tf.train.import_meta_graph(meta_file)
    saver.restore(sess, checkpoint_path)

    # 获取图中已经存在的MutableHashTable
    hash_table = sess.graph.get_tensor_by_name('your_hash_table_name:0')  # 替换'your_hash_table_name'为实际名称

    # 假设有新的键和值列表
    new_keys = [...]
    new_values = [...]

    # 更新哈希表内容(通过会话运行insert操作)
    for key, value in zip(new_keys, new_values):
        insert_op = hash_table.insert(keys=tf.constant(key), values=tf.constant(value))
        sess.run(insert_op)

# 定义一个saver来保存所有的变量(包括哈希表中的内容)
saver = tf.train.Saver()

# 选择一个新的保存路径
new_checkpoint_path = 'path/to/new_model.ckpt'

# 保存模型
saver.save(sess, new_checkpoint_path)

通常来说,增量更新是通过加载已有模型,然后在已有模型的 HashTable 中插入新的键值对,从而实现增量。

2024.4.4 更新:实际上上面的这个方式不起作用,因为tensorflow在保存成savedModel的时候,会把MutableHashTable当成Operation,而不是当成一个tensor,并且不可变。

0

评论区