Python Tensorflow:如何保存/恢复模型?

声明:本页面是StackOverFlow热门问题的中英对照翻译,遵循CC BY-SA 4.0协议,如果您需要使用它,必须同样遵循CC BY-SA许可,注明原文地址和作者信息,同时你必须将它归于原作者(不是我):StackOverFlow 原文地址: http://stackoverflow.com/questions/33759623/
Warning: these are provided under cc-by-sa 4.0 license. You are free to use/share it, But you must attribute it to the original authors (not me): StackOverFlow

提示:将鼠标放在中文语句上可以显示对应的英文。显示中英文
时间:2020-08-19 13:56:12  来源:igfitidea点击:

Tensorflow: how to save/restore a model?

pythontensorflowmachine-learningmodel

提问by mathetes

After you train a model in Tensorflow:

在 Tensorflow 中训练模型后:

  1. How do you save the trained model?
  2. How do you later restore this saved model?
  1. 如何保存训练好的模型?
  2. 你以后如何恢复这个保存的模型?

采纳答案by ted

Docs

文档

From the docs:

从文档:

Save

节省

# Create some variables.
v1 = tf.get_variable("v1", shape=[3], initializer = tf.zeros_initializer)
v2 = tf.get_variable("v2", shape=[5], initializer = tf.zeros_initializer)

inc_v1 = v1.assign(v1+1)
dec_v2 = v2.assign(v2-1)

# Add an op to initialize the variables.
init_op = tf.global_variables_initializer()

# Add ops to save and restore all the variables.
saver = tf.train.Saver()

# Later, launch the model, initialize the variables, do some work, and save the
# variables to disk.
with tf.Session() as sess:
  sess.run(init_op)
  # Do some work with the model.
  inc_v1.op.run()
  dec_v2.op.run()
  # Save the variables to disk.
  save_path = saver.save(sess, "/tmp/model.ckpt")
  print("Model saved in path: %s" % save_path)

Restore

恢复

tf.reset_default_graph()

# Create some variables.
v1 = tf.get_variable("v1", shape=[3])
v2 = tf.get_variable("v2", shape=[5])

# Add ops to save and restore all the variables.
saver = tf.train.Saver()

# Later, launch the model, use the saver to restore variables from disk, and
# do some work with the model.
with tf.Session() as sess:
  # Restore variables from disk.
  saver.restore(sess, "/tmp/model.ckpt")
  print("Model restored.")
  # Check the values of the variables
  print("v1 : %s" % v1.eval())
  print("v2 : %s" % v2.eval())

Tensorflow 2

张量流 2

This is still beta so I'd advise against for now. If you still want to go down that road here is the tf.saved_modelusage guide

这仍然是测试版,所以我现在建议不要这样做。如果您仍然想沿着这条路走下去,这里是tf.saved_model使用指南

Tensorflow < 2

张量流 < 2

simple_save

simple_save

Many good answer, for completeness I'll add my 2 cents: simple_save. Also a standalone code example using the tf.data.DatasetAPI.

许多好的答案,为了完整起见,我将添加我的 2 美分:simple_save。也是一个使用tf.data.DatasetAPI的独立代码示例。

Python 3 ; Tensorflow 1.14

蟒蛇 3 ; TensorFlow 1.14

import tensorflow as tf
from tensorflow.saved_model import tag_constants

with tf.Graph().as_default():
    with tf.Session() as sess:
        ...

        # Saving
        inputs = {
            "batch_size_placeholder": batch_size_placeholder,
            "features_placeholder": features_placeholder,
            "labels_placeholder": labels_placeholder,
        }
        outputs = {"prediction": model_output}
        tf.saved_model.simple_save(
            sess, 'path/to/your/location/', inputs, outputs
        )

Restoring:

恢复:

graph = tf.Graph()
with restored_graph.as_default():
    with tf.Session() as sess:
        tf.saved_model.loader.load(
            sess,
            [tag_constants.SERVING],
            'path/to/your/location/',
        )
        batch_size_placeholder = graph.get_tensor_by_name('batch_size_placeholder:0')
        features_placeholder = graph.get_tensor_by_name('features_placeholder:0')
        labels_placeholder = graph.get_tensor_by_name('labels_placeholder:0')
        prediction = restored_graph.get_tensor_by_name('dense/BiasAdd:0')

        sess.run(prediction, feed_dict={
            batch_size_placeholder: some_value,
            features_placeholder: some_other_value,
            labels_placeholder: another_value
        })

Standalone example

独立示例

Original blog post

原始博客文章

The following code generates random data for the sake of the demonstration.

为了演示,以下代码生成随机数据。

  1. We start by creating the placeholders. They will hold the data at runtime. From them, we create the Datasetand then its Iterator. We get the iterator's generated tensor, called input_tensorwhich will serve as input to our model.
  2. The model itself is built from input_tensor: a GRU-based bidirectional RNN followed by a dense classifier. Because why not.
  3. The loss is a softmax_cross_entropy_with_logits, optimized with Adam. After 2 epochs (of 2 batches each), we save the "trained" model with tf.saved_model.simple_save. If you run the code as is, then the model will be saved in a folder called simple/in your current working directory.
  4. In a new graph, we then restore the saved model with tf.saved_model.loader.load. We grab the placeholders and logits with graph.get_tensor_by_nameand the Iteratorinitializing operation with graph.get_operation_by_name.
  5. Lastly we run an inference for both batches in the dataset, and check that the saved and restored model both yield the same values. They do!
  1. 我们首先创建占位符。它们将在运行时保存数据。从它们中,我们创建了Dataset,然后创建了它的Iterator. 我们得到迭代器生成的张量,调用input_tensor它将作为我们模型的输入。
  2. 模型本身构建于input_tensor:一个基于 GRU 的双向 RNN,后跟一个密集分类器。因为为什么不呢。
  3. 损失为softmax_cross_entropy_with_logits,优化为Adam。在 2 个 epochs(每个 2 个批次)之后,我们使用tf.saved_model.simple_save. 如果按原样运行代码,则模型将保存在simple/当前工作目录中名为的文件夹中。
  4. 在一个新的图表中,我们然后用 恢复保存的模型tf.saved_model.loader.load。我们使用 获取占位符和日志,graph.get_tensor_by_name并使用 获取Iterator初始化操作graph.get_operation_by_name
  5. 最后,我们对数据集中的两个批次进行推理,并检查保存和恢复的模型是否产生相同的值。他们是这样!

Code:

代码:

import os
import shutil
import numpy as np
import tensorflow as tf
from tensorflow.python.saved_model import tag_constants


def model(graph, input_tensor):
    """Create the model which consists of
    a bidirectional rnn (GRU(10)) followed by a dense classifier

    Args:
        graph (tf.Graph): Tensors' graph
        input_tensor (tf.Tensor): Tensor fed as input to the model

    Returns:
        tf.Tensor: the model's output layer Tensor
    """
    cell = tf.nn.rnn_cell.GRUCell(10)
    with graph.as_default():
        ((fw_outputs, bw_outputs), (fw_state, bw_state)) = tf.nn.bidirectional_dynamic_rnn(
            cell_fw=cell,
            cell_bw=cell,
            inputs=input_tensor,
            sequence_length=[10] * 32,
            dtype=tf.float32,
            swap_memory=True,
            scope=None)
        outputs = tf.concat((fw_outputs, bw_outputs), 2)
        mean = tf.reduce_mean(outputs, axis=1)
        dense = tf.layers.dense(mean, 5, activation=None)

        return dense


def get_opt_op(graph, logits, labels_tensor):
    """Create optimization operation from model's logits and labels

    Args:
        graph (tf.Graph): Tensors' graph
        logits (tf.Tensor): The model's output without activation
        labels_tensor (tf.Tensor): Target labels

    Returns:
        tf.Operation: the operation performing a stem of Adam optimizer
    """
    with graph.as_default():
        with tf.variable_scope('loss'):
            loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(
                    logits=logits, labels=labels_tensor, name='xent'),
                    name="mean-xent"
                    )
        with tf.variable_scope('optimizer'):
            opt_op = tf.train.AdamOptimizer(1e-2).minimize(loss)
        return opt_op


if __name__ == '__main__':
    # Set random seed for reproducibility
    # and create synthetic data
    np.random.seed(0)
    features = np.random.randn(64, 10, 30)
    labels = np.eye(5)[np.random.randint(0, 5, (64,))]

    graph1 = tf.Graph()
    with graph1.as_default():
        # Random seed for reproducibility
        tf.set_random_seed(0)
        # Placeholders
        batch_size_ph = tf.placeholder(tf.int64, name='batch_size_ph')
        features_data_ph = tf.placeholder(tf.float32, [None, None, 30], 'features_data_ph')
        labels_data_ph = tf.placeholder(tf.int32, [None, 5], 'labels_data_ph')
        # Dataset
        dataset = tf.data.Dataset.from_tensor_slices((features_data_ph, labels_data_ph))
        dataset = dataset.batch(batch_size_ph)
        iterator = tf.data.Iterator.from_structure(dataset.output_types, dataset.output_shapes)
        dataset_init_op = iterator.make_initializer(dataset, name='dataset_init')
        input_tensor, labels_tensor = iterator.get_next()

        # Model
        logits = model(graph1, input_tensor)
        # Optimization
        opt_op = get_opt_op(graph1, logits, labels_tensor)

        with tf.Session(graph=graph1) as sess:
            # Initialize variables
            tf.global_variables_initializer().run(session=sess)
            for epoch in range(3):
                batch = 0
                # Initialize dataset (could feed epochs in Dataset.repeat(epochs))
                sess.run(
                    dataset_init_op,
                    feed_dict={
                        features_data_ph: features,
                        labels_data_ph: labels,
                        batch_size_ph: 32
                    })
                values = []
                while True:
                    try:
                        if epoch < 2:
                            # Training
                            _, value = sess.run([opt_op, logits])
                            print('Epoch {}, batch {} | Sample value: {}'.format(epoch, batch, value[0]))
                            batch += 1
                        else:
                            # Final inference
                            values.append(sess.run(logits))
                            print('Epoch {}, batch {} | Final inference | Sample value: {}'.format(epoch, batch, values[-1][0]))
                            batch += 1
                    except tf.errors.OutOfRangeError:
                        break
            # Save model state
            print('\nSaving...')
            cwd = os.getcwd()
            path = os.path.join(cwd, 'simple')
            shutil.rmtree(path, ignore_errors=True)
            inputs_dict = {
                "batch_size_ph": batch_size_ph,
                "features_data_ph": features_data_ph,
                "labels_data_ph": labels_data_ph
            }
            outputs_dict = {
                "logits": logits
            }
            tf.saved_model.simple_save(
                sess, path, inputs_dict, outputs_dict
            )
            print('Ok')
    # Restoring
    graph2 = tf.Graph()
    with graph2.as_default():
        with tf.Session(graph=graph2) as sess:
            # Restore saved values
            print('\nRestoring...')
            tf.saved_model.loader.load(
                sess,
                [tag_constants.SERVING],
                path
            )
            print('Ok')
            # Get restored placeholders
            labels_data_ph = graph2.get_tensor_by_name('labels_data_ph:0')
            features_data_ph = graph2.get_tensor_by_name('features_data_ph:0')
            batch_size_ph = graph2.get_tensor_by_name('batch_size_ph:0')
            # Get restored model output
            restored_logits = graph2.get_tensor_by_name('dense/BiasAdd:0')
            # Get dataset initializing operation
            dataset_init_op = graph2.get_operation_by_name('dataset_init')

            # Initialize restored dataset
            sess.run(
                dataset_init_op,
                feed_dict={
                    features_data_ph: features,
                    labels_data_ph: labels,
                    batch_size_ph: 32
                }

            )
            # Compute inference for both batches in dataset
            restored_values = []
            for i in range(2):
                restored_values.append(sess.run(restored_logits))
                print('Restored values: ', restored_values[i][0])

    # Check if original inference and restored inference are equal
    valid = all((v == rv).all() for v, rv in zip(values, restored_values))
    print('\nInferences match: ', valid)

This will print:

这将打印:

$ python3 save_and_restore.py

Epoch 0, batch 0 | Sample value: [-0.13851789 -0.3087595   0.12804556  0.20013677 -0.08229901]
Epoch 0, batch 1 | Sample value: [-0.00555491 -0.04339041 -0.05111827 -0.2480045  -0.00107776]
Epoch 1, batch 0 | Sample value: [-0.19321944 -0.2104792  -0.00602257  0.07465433  0.11674127]
Epoch 1, batch 1 | Sample value: [-0.05275984  0.05981954 -0.15913513 -0.3244143   0.10673307]
Epoch 2, batch 0 | Final inference | Sample value: [-0.26331693 -0.13013336 -0.12553    -0.04276478  0.2933622 ]
Epoch 2, batch 1 | Final inference | Sample value: [-0.07730117  0.11119192 -0.20817074 -0.35660955  0.16990358]

Saving...
INFO:tensorflow:Assets added to graph.
INFO:tensorflow:No assets to write.
INFO:tensorflow:SavedModel written to: b'/some/path/simple/saved_model.pb'
Ok

Restoring...
INFO:tensorflow:Restoring parameters from b'/some/path/simple/variables/variables'
Ok
Restored values:  [-0.26331693 -0.13013336 -0.12553    -0.04276478  0.2933622 ]
Restored values:  [-0.07730117  0.11119192 -0.20817074 -0.35660955  0.16990358]

Inferences match:  True

回答by Ryan Sepassi

For TensorFlow version < 0.11.0RC1:

对于 TensorFlow 版本 < 0.11.0RC1:

The checkpoints that are saved contain values for the Variables in your model, not the model/graph itself, which means that the graph should be the same when you restore the checkpoint.

保存的检查点包含Variable模型中 s 的值,而不是模型/图形本身,这意味着恢复检查点时图形应该相同。

Here's an example for a linear regression where there's a training loop that saves variable checkpoints and an evaluation section that will restore variables saved in a prior run and compute predictions. Of course, you can also restore variables and continue training if you'd like.

这是一个线性回归示例,其中有一个训练循环可以保存变量检查点,还有一个评估部分可以恢复先前运行中保存的变量并计算预测。当然,如果您愿意,您也可以恢复变量并继续训练。

x = tf.placeholder(tf.float32)
y = tf.placeholder(tf.float32)

w = tf.Variable(tf.zeros([1, 1], dtype=tf.float32))
b = tf.Variable(tf.ones([1, 1], dtype=tf.float32))
y_hat = tf.add(b, tf.matmul(x, w))

...more setup for optimization and what not...

saver = tf.train.Saver()  # defaults to saving all variables - in this case w and b

with tf.Session() as sess:
    sess.run(tf.initialize_all_variables())
    if FLAGS.train:
        for i in xrange(FLAGS.training_steps):
            ...training loop...
            if (i + 1) % FLAGS.checkpoint_steps == 0:
                saver.save(sess, FLAGS.checkpoint_dir + 'model.ckpt',
                           global_step=i+1)
    else:
        # Here's where you're restoring the variables w and b.
        # Note that the graph is exactly as it was when the variables were
        # saved in a prior training run.
        ckpt = tf.train.get_checkpoint_state(FLAGS.checkpoint_dir)
        if ckpt and ckpt.model_checkpoint_path:
            saver.restore(sess, ckpt.model_checkpoint_path)
        else:
            ...no checkpoint found...

        # Now you can run the model to get predictions
        batch_x = ...load some data...
        predictions = sess.run(y_hat, feed_dict={x: batch_x})

Here are the docsfor Variables, which cover saving and restoring. And here are the docsfor the Saver.

下面是文档Variables,这包括保存和恢复。这里是文档Saver

回答by Yaroslav Bulatov

There are two parts to the model, the model definition, saved by Supervisoras graph.pbtxtin the model directory and the numerical values of tensors, saved into checkpoint files like model.ckpt-1003418.

模型分为两部分,模型定义,保存在模型目录中的Supervisorasgraph.pbtxt和张量的数值,保存到检查点文件中,如model.ckpt-1003418.

The model definition can be restored using tf.import_graph_def, and the weights are restored using Saver.

可以使用 恢复模型定义,使用 恢复tf.import_graph_def权重Saver

However, Saveruses special collection holding list of variables that's attached to the model Graph, and this collection is not initialized using import_graph_def, so you can't use the two together at the moment (it's on our roadmap to fix). For now, you have to use approach of Ryan Sepassi -- manually construct a graph with identical node names, and use Saverto load the weights into it.

但是,Saver使用特殊集合保存附加到模型 Graph 的变量列表,并且此集合未使用 import_graph_def 进行初始化,因此您目前不能同时使用这两者(这是我们要修复的路线图)。现在,您必须使用 Ryan Sepassi 的方法——手动构建一个具有相同节点名称的图,并用于Saver将权重加载到其中。

(Alternatively you could hack it by using by using import_graph_def, creating variables manually, and using tf.add_to_collection(tf.GraphKeys.VARIABLES, variable)for each variable, then using Saver)

(或者,您可以通过 using 来破解它import_graph_def,手动创建变量,并tf.add_to_collection(tf.GraphKeys.VARIABLES, variable)为每个变量使用,然后使用Saver

回答by nikitakit

As Yaroslav said, you can hack restoring from a graph_def and checkpoint by importing the graph, manually creating variables, and then using a Saver.

正如 Yaroslav 所说,您可以通过导入图形、手动创建变量,然后使用 Saver 来破解从 graph_def 和检查点恢复。

I implemented this for my personal use, so I though I'd share the code here.

我为我个人使用实现了这个,所以我想在这里分享代码。

Link: https://gist.github.com/nikitakit/6ef3b72be67b86cb7868

链接:https: //gist.github.com/nikitakit/6ef3b72be67b86cb7868

(This is, of course, a hack, and there is no guarantee that models saved this way will remain readable in future versions of TensorFlow.)

(当然,这是一种黑客行为,不能保证以这种方式保存的模型在 TensorFlow 的未来版本中仍然可读。)

回答by Yuan Tang

You can also check out examplesin TensorFlow/skflow, which offers saveand restoremethods that can help you easily manage your models. It has parameters that you can also control how frequently you want to back up your model.

您还可以检查出的例子TensorFlow / skflow其报价,save并且restore方法可以帮助您轻松管理您的模型。它具有参数,您还可以控制要备份模型的频率。

回答by Sergey Demyanov

If it is an internally saved model, you just specify a restorer for all variables as

如果它是内部保存的模型,则只需为所有变量指定一个恢复器即可

restorer = tf.train.Saver(tf.all_variables())

and use it to restore variables in a current session:

并使用它来恢复当前会话中的变量:

restorer.restore(self._sess, model_file)

For the external model you need to specify the mapping from the its variable names to your variable names. You can view the model variable names using the command

对于外部模型,您需要指定从其变量名称到变量名称的映射。您可以使用命令查看模型变量名称

python /path/to/tensorflow/tensorflow/python/tools/inspect_checkpoint.py --file_name=/path/to/pretrained_model/model.ckpt

The inspect_checkpoint.py script can be found in './tensorflow/python/tools' folder of the Tensorflow source.

inspect_checkpoint.py 脚本可以在 Tensorflow 源的“./tensorflow/python/tools”文件夹中找到。

To specify the mapping, you can use my Tensorflow-Worklab, which contains a set of classes and scripts to train and retrain different models. It includes an example of retraining ResNet models, located here

要指定映射,您可以使用我的Tensorflow-Worklab,其中包含一组用于训练和重新训练不同模型的类和脚本。它包括重新训练 ResNet 模型的示例,位于此处

回答by lei du

In (and after) TensorFlow version 0.11.0RC1, you can save and restore your model directly by calling tf.train.export_meta_graphand tf.train.import_meta_graphaccording to https://www.tensorflow.org/programmers_guide/meta_graph.

在(及之后)TensorFlow 0.11.0RC1 版本中,您可以通过调用tf.train.export_meta_graphtf.train.import_meta_graph根据https://www.tensorflow.org/programmers_guide/meta_graph直接保存和恢复您的模型。

Save the model

保存模型

w1 = tf.Variable(tf.truncated_normal(shape=[10]), name='w1')
w2 = tf.Variable(tf.truncated_normal(shape=[20]), name='w2')
tf.add_to_collection('vars', w1)
tf.add_to_collection('vars', w2)
saver = tf.train.Saver()
sess = tf.Session()
sess.run(tf.global_variables_initializer())
saver.save(sess, 'my-model')
# `save` method will call `export_meta_graph` implicitly.
# you will get saved graph files:my-model.meta

Restore the model

恢复模型

sess = tf.Session()
new_saver = tf.train.import_meta_graph('my-model.meta')
new_saver.restore(sess, tf.train.latest_checkpoint('./'))
all_vars = tf.get_collection('vars')
for v in all_vars:
    v_ = sess.run(v)
    print(v_)

回答by AI4U.ai

As described in issue 6255:

如问题6255 中所述

use '**./**model_name.ckpt'
saver.restore(sess,'./my_model_final.ckpt')

instead of

代替

saver.restore('my_model_final.ckpt')

回答by Himanshu Babal

You can also take this easier way.

您也可以采用这种更简单的方法。

Step 1: initialize all your variables

第 1 步:初始化所有变量

W1 = tf.Variable(tf.truncated_normal([6, 6, 1, K], stddev=0.1), name="W1")
B1 = tf.Variable(tf.constant(0.1, tf.float32, [K]), name="B1")

Similarly, W2, B2, W3, .....

Step 2: save the session inside model Saverand save it

第 2 步:在模型内保存会话Saver并保存

model_saver = tf.train.Saver()

# Train the model and save it in the end
model_saver.save(session, "saved_models/CNN_New.ckpt")

Step 3: restore the model

第三步:恢复模型

with tf.Session(graph=graph_cnn) as session:
    model_saver.restore(session, "saved_models/CNN_New.ckpt")
    print("Model restored.") 
    print('Initialized')

Step 4: check your variable

第 4 步:检查您的变量

W1 = session.run(W1)
print(W1)


While running in different python instance, use

在不同的 python 实例中运行时,使用

with tf.Session() as sess:
    # Restore latest checkpoint
    saver.restore(sess, tf.train.latest_checkpoint('saved_model/.'))

    # Initalize the variables
    sess.run(tf.global_variables_initializer())

    # Get default graph (supply your custom graph if you have one)
    graph = tf.get_default_graph()

    # It will give tensor object
    W1 = graph.get_tensor_by_name('W1:0')

    # To get the value (numpy array)
    W1_value = session.run(W1)

回答by MiniQuark

In most cases, saving and restoring from disk using a tf.train.Saveris your best option:

在大多数情况下,使用 a 保存和从磁盘恢复tf.train.Saver是您的最佳选择:

... # build your model
saver = tf.train.Saver()

with tf.Session() as sess:
    ... # train the model
    saver.save(sess, "/tmp/my_great_model")

with tf.Session() as sess:
    saver.restore(sess, "/tmp/my_great_model")
    ... # use the model

You can also save/restore the graph structure itself (see the MetaGraph documentationfor details). By default, the Saversaves the graph structure into a .metafile. You can call import_meta_graph()to restore it. It restores the graph structure and returns a Saverthat you can use to restore the model's state:

您还可以保存/恢复图形结构本身(有关详细信息,请参阅MetaGraph 文档)。默认情况下,Saver将图形结构保存到.meta文件中。你可以打电话import_meta_graph()恢复。它恢复图形结构并返回一个Saver可用于恢复模型状态的值:

saver = tf.train.import_meta_graph("/tmp/my_great_model.meta")

with tf.Session() as sess:
    saver.restore(sess, "/tmp/my_great_model")
    ... # use the model

However, there are cases where you need something much faster. For example, if you implement early stopping, you want to save checkpoints every time the model improves during training (as measured on the validation set), then if there is no progress for some time, you want to roll back to the best model. If you save the model to disk every time it improves, it will tremendously slow down training. The trick is to save the variable states to memory, then just restore them later:

但是,在某些情况下,您需要更快的速度。例如,如果您实施提前停止,您希望在训练期间模型每次改进时保存检查点(在验证集上测量),然后如果一段时间没有进展,您希望回滚到最佳模型。如果每次模型改进时都将模型保存到磁盘,则会大大减慢训练速度。诀窍是将变量状态保存到memory,然后稍后恢复它们:

... # build your model

# get a handle on the graph nodes we need to save/restore the model
graph = tf.get_default_graph()
gvars = graph.get_collection(tf.GraphKeys.GLOBAL_VARIABLES)
assign_ops = [graph.get_operation_by_name(v.op.name + "/Assign") for v in gvars]
init_values = [assign_op.inputs[1] for assign_op in assign_ops]

with tf.Session() as sess:
    ... # train the model

    # when needed, save the model state to memory
    gvars_state = sess.run(gvars)

    # when needed, restore the model state
    feed_dict = {init_value: val
                 for init_value, val in zip(init_values, gvars_state)}
    sess.run(assign_ops, feed_dict=feed_dict)

A quick explanation: when you create a variable X, TensorFlow automatically creates an assignment operation X/Assignto set the variable's initial value. Instead of creating placeholders and extra assignment ops (which would just make the graph messy), we just use these existing assignment ops. The first input of each assignment op is a reference to the variable it is supposed to initialize, and the second input (assign_op.inputs[1]) is the initial value. So in order to set any value we want (instead of the initial value), we need to use a feed_dictand replace the initial value. Yes, TensorFlow lets you feed a value for any op, not just for placeholders, so this works fine.

一个简单的解释:当你创建一个变量时X,TensorFlow 会自动创建一个赋值操作X/Assign来设置变量的初始值。而不是创建占位符和额外的赋值操作(这只会使图形变得混乱),我们只使用这些现有的赋值操作。每个赋值操作的第一个输入是对它应该初始化的变量的引用,第二个输入 ( assign_op.inputs[1]) 是初始值。所以为了设置我们想要的任何值(而不是初始值),我们需要使用 afeed_dict并替换初始值。是的,TensorFlow 允许您为任何操作提供值,而不仅仅是占位符,所以这很好用。

回答by Martin Pecka

Here's my simple solution for the two basic cases differing on whether you want to load the graph from file or build it during runtime.

这是我针对两种基本情况的简单解决方案,不同之处在于您是要从文件加载图形还是在运行时构建它。

This answer holds for Tensorflow 0.12+ (including 1.0).

这个答案适用于 Tensorflow 0.12+(包括 1.0)。

Rebuilding the graph in code

用代码重建图形

Saving

保存

graph = ... # build the graph
saver = tf.train.Saver()  # create the saver after the graph
with ... as sess:  # your session object
    saver.save(sess, 'my-model')

Loading

加载中

graph = ... # build the graph
saver = tf.train.Saver()  # create the saver after the graph
with ... as sess:  # your session object
    saver.restore(sess, tf.train.latest_checkpoint('./'))
    # now you can use the graph, continue training or whatever

Loading also the graph from a file

还从文件加载图形

When using this technique, make sure all your layers/variables have explicitly set unique names.Otherwise Tensorflow will make the names unique itself and they'll be thus different from the names stored in the file. It's not a problem in the previous technique, because the names are "mangled" the same way in both loading and saving.

使用此技术时,请确保所有图层/变量都明确设置了唯一名称。否则 Tensorflow 将使名称本身唯一,因此它们将与存储在文件中的名称不同。在以前的技术中这不是问题,因为名称在加载和保存时以相同的方式“修改”。

Saving

保存

graph = ... # build the graph

for op in [ ... ]:  # operators you want to use after restoring the model
    tf.add_to_collection('ops_to_restore', op)

saver = tf.train.Saver()  # create the saver after the graph
with ... as sess:  # your session object
    saver.save(sess, 'my-model')

Loading

加载中

with ... as sess:  # your session object
    saver = tf.train.import_meta_graph('my-model.meta')
    saver.restore(sess, tf.train.latest_checkpoint('./'))
    ops = tf.get_collection('ops_to_restore')  # here are your operators in the same order in which you saved them to the collection