数据预处理中存在tf.Variable需要手工修改

迁移原因

数据预处理中存在tf.Variable时,tf.Variable在Host侧执行,而变量初始化图默认在Device侧执行。变量和变量初始化图不在同一设备执行,会导致训练异常,需要手工迁移。

    batch_size = tf.Variable(
         tf.placeholder(tf.int64, [], 'batch_size'),
         trainable= False, collections=[]
    )
    train_dataset = train_dataset.batch(batch_size, drop_remainder=True)

迁移示例

将tf.Variable修改成常量,即可解决问题:

    batch_size = 64
    train_dataset = train_dataset.batch(batch_size, drop_remainder=True)