Distributed TensorFlow

Updated at 2018-04-30 15:15

Distributed TensorFlow allows you to share parts of a TensorFlow graph between multiple processes, even on different machines.

Increase iteration speed by scaling up to hundreds of CPUs and GPUs.

Distributed TensorFlow system has three types of processes:

  • master worker / mw (chief?)
  • parameter servers / ps
  • workers / w

What they do:

  • master worker initializes the model, coordinates the training operations and handles of fault-tolerance of parameter servers and workers.
  • parameter servers manage the model state in a synced manner e.g. variables and the update operations.
  • workers do all the intensive parts of the process e.g. pre-processing, loss calculation, backpropagation.

You need more parameter servers to handle a large volume of worker I/O. One parameter server won't be able to take requests from 50 workers.

The same code is often sent to all nodes. Environmental variables are used to execute certain code blocks depending on the node.

How to prepare your code for Distributed TensorFlow:

  1. Define tf.train.ClusterSpec, which describes all the tasks and jobs in the cluster.
  2. Define one or more tf.train.Server in that cluster, which corresponds to a particular task in a named job. Each task will typically run on a different machine.
  3. Assign your model to the named jobs.
  4. Setup and launch tf.train.MonitoredTrainingSession


Cluster managers like Kubernetes or Bork usually handle creating the ClusterSpec.

Round-robin variables

# define which servers handles which variables
with tf.device("jobs:ps/task:0"):

# assign variables in a round-robin fashion (default)
with tf.device(tf.train.replica_device_setter(ps_tasks=3)):

# load balancing
greedy =
with tf.device(tf.train.replica_device_setter(ps_tasks=3, ps_strategy=greedy)):

# you
with ()
	partitioner = tf.fixed_size_partitioner(3)
	embedding = tf.get_variable("embedding", [10000000, 20], partitioner=partitioner)


tf.Session knows only about the devices in the local machine.

You have to create multiple tf.train.Server that communicate with each other.


Always use sharded saving. Otherwise the variables will be written to a single location which will most likely run out of memory with big models.

saver = tf.train.Saver(sharded=True)
if is_chief ad step % 1000 == 0:, "/home/hello/...")


with tf.train.MonitoredTrainingSession(, is_chief) as sess:
    while not sess.should_stop():

Easy chief recovery with restore from checkpoint, random initialization or recovery. run() automatically recovers from PS failures and can trigger hooks.

Fault tolerance

  • Worker fails => workers are stateless so you can just have cluster manager to bring a new online
  • Parameter server fails => parameter servers are stateful and master worker is responsible of noticing ps going down, halt workers, boot new up and restore parameters from the last checkpoint
  • Master worker fails => interrupt operations and wait until new master is found

High-level API

You can use estimators to use distributed TensorFlow without touching the low-level code.

# how to distribute your model
distribution = tf.contrib.distribute.MirroredStrategy(num_gpus=2)

run_config = tf.estimator.RunConfig(distribute=distribution)

classifier = tf.estimator.Estimator(

Input Data has three sets of tools:

  1. Extract data from various data sources.
  2. Transform data from a format to another.
  3. Load data to the devices e.g. GPUs or TPUs.


files =
dataset =, num_parallel_reads=32)


dataset = dataset.shuffle(10000)
dataset = dataset.repeat(NUM_EPOCHS)
dataset = x: tf.parse_single_exmaple(x, features))
dataset = dataset.batch(NUM_BATCH)


iterator = dataset.make_one_shot_iterator()
features = iterator.get_next()


Each call to tf.Session() creates a separate execution engine. Execution engine is the process that stores variables and runs operations. Execution engines are don't share knowledge by default.

import tensorflow as tf

variable = tf.Variable(initial\_value=0.0)

sess1 = tf.Session()\_variables\_initializer())

sess2 = tf.Session()\_variables\_initializer())

print("Initial value of var in session 1:",
print("Initial value of var in session 2:",\_add(1.0))

print("Value of var in session 1:",
print("Value of var in session 2:",

In Distributed TensorFlow

  • each process runs a special execution engine; a TensorFlow server
  • TensorFlow servers form a cluster
  • each server in the cluster is also known as a task