A TensorFlow training task is usually composed of sample data reading and graph calculation. The reading of sample data is an IO bound operation, which occupies a large percentage of the entire E2E time, then slowing down the training task. It cannot efficiently use computing resources (CPU, GPU). DeepRec has already provided the stage function. Its idea comes from the StagingArea function of TensorFlow. We provide the API tf.staged in DeepRec. The user explicitly specifies which part of the graph needs the stage, and adds tf.make_prefetch_hook() when the Session is created. Under the asynchronous execution of TensorFlow runtime, the execution efficiency of the entire graph is improved.


tf.staged, prefetch the input features, and return the prefetched tensor.



default value


The tensor that needs to be executed asynchronously: tensor, list of tensor (each element in the list is a tensor) or dict of tensor (the keys in the dict are all strings, and the values are all tensors).



A dictionary that maps graph elements to tensors.



The maximum number of cached items asynchronous execution results.



Number of threads to execute items asynchronously.



Number of clients of prefetched sample.



Max milliseconds put op can take.

300000 ms


Exception types recognized as graceful exits.



Exception types that are recognized to be ignored and skipped.



Whether to run the Stage subgraph on an independent thread pool, you need to create an independent thread pool first.

False (If it is True, a separate thread pool must be created first)


If you enable the stage subgraph to run on the independent thread pool to specify the independent thread pool index, you need to create an independent thread pool first, and enable the use_stage_subgraph_thread_pool option.

0, The index range is [0, the number of independent thread pools created - 1]


In the GPU Multi-Stream scenario, the index of gpu stream used by stage subgraph.

0 (0 means that the stage subgraph shares the gpu stream used by the main graph, the index range is [0, total number of GPU streams -1])


Name of prefetching operations.

None (Automatic generated)

Adds tf.make_prefetch_hook()hook when create session.

with tf.train.MonitoredTrainingSession(hooks=hooks, config=sess_config) as sess:
  • Create a separate thread pool (optional)

sess_config = tf.ConfigProto()
sess_config.session_stage_subgraph_thread_pool.add() # Add a thread pool
sess_config.session_stage_subgraph_thread_pool[0].inter_op_threads_num = 8 # inter thread number in thread pool
sess_config.session_stage_subgraph_thread_pool[0].intra_op_threads_num = 8 # intra thread number in thread pool
sess_config.session_stage_subgraph_thread_pool[0].global_name = "StageThreadPool_1" # thread pool name
  • GPU Multi-Stream Stage (optional)

sess_config = tf.ConfigProto()
sess_config.graph_options.rewrite_options.use_multi_stream = (rewriter_config_pb2.RewriterConfig.ON) # enable GPU Multi-Stream
sess_config.graph_options.rewrite_options.multi_stream_opts.multi_stream_num = 2 # The number of gpu streams, stream 0 is used by the main graph
sess_config.graph_options.optimizer_options.stage_multi_stream = True # enable GPU Multi-Stream Stage

GPU Mulit-Stream Stage can achieve better performance by enabling GPU MPS, please refer to GPU-MultiStream.


  • Computations to be asynchronous should compete with subsequent main computations for resources as little as possible (gpu, cpu, thread pool, etc.)

  • A larger capacity will consume more memory or video memory, and may occupy CPU resources for subsequent model training. It is recommended to set it to follow-up calculation time/waiting for asynchronization time. It can be adjusted gradually upwards starting from 1.

  • num_threads is not as big as possible, it just needs to allow calculation and preprocessing to overlap, and a larger number will preempt CPU resources for model training. Calculation formula: num_threads >= preprocessing time / training time, can be adjusted upwards from 1.

  • tf.make_prefetch_hook() must be added, otherwise it will hang.


import tensorflow as tf

filename_queue = tf.train.string_input_producer(['1.txt'])
reader = tf.TextLineReader()
k, v = reader.read(filename_queue)

var = tf.get_variable("var", shape=[100, 3], initializer=tf.ones_initializer())
v = tf.train.batch([v], batch_size=2, capacity=20 * 3)
v0, v1 = tf.decode_csv(v, record_defaults=[[''], ['']], field_delim=',')
xx = tf.staged([v0, v1])

xx[0] = tf.nn.embedding_lookup(var, xx[0])
xx[1]=tf.concat([xx[1], ['xxx']], axis = 0)
target = tf.concat([tf.as_string(xx[0]), [xx[1], xx[1]]], 0)

# mark target node

with tf.train.MonitoredTrainingSession(hooks=[tf.make_prefetch_hook()]) as sess:
  for i in range(5):