0
# Distribution Strategies
1
2
Multi-device and multi-worker training strategies for scaling machine learning workloads across GPUs and TPUs. These strategies enable efficient distributed training and deployment.
3
4
## Capabilities
5
6
### Strategy Classes
7
8
Core distribution strategy classes for different distributed training scenarios.
9
10
```python { .api }
11
class Strategy:
12
"""
13
Base class for distribution strategies.
14
15
Methods:
16
- scope(): Returns a context manager selecting this Strategy as current
17
- run(fn, args=(), kwargs=None, options=None): Invokes fn on each replica, with the given arguments
18
- reduce(reduce_op, value, axis): Reduce value across replicas and return result on current device
19
- gather(value, axis): Gather value across replicas along axis to current device
20
"""
21
22
class MirroredStrategy(Strategy):
23
"""
24
Synchronous training across multiple replicas on one machine.
25
26
This strategy is typically used for training on one machine with multiple GPUs.
27
Variables and updates will be mirrored across all replicas.
28
29
Parameters:
30
- devices: Optional list of device strings or device objects. If not specified, all visible GPUs are used
31
- cross_device_ops: Optional, a ReduceOp specifying how to combine values
32
"""
33
34
class MultiWorkerMirroredStrategy(Strategy):
35
"""
36
Synchronous training across multiple workers, each with potentially multiple replicas.
37
38
This strategy implements synchronous distributed training across multiple workers,
39
each of which may have multiple GPUs. Similar to MirroredStrategy, it replicates
40
all variables and computations to each local replica.
41
42
Parameters:
43
- cluster_resolver: Optional cluster resolver
44
- communication_options: Optional, communication options for CollectiveOps
45
"""
46
47
class TPUStrategy(Strategy):
48
"""
49
Synchronous training on TPUs and TPU Pods.
50
51
This strategy is for running on TPUs, including TPU pods which can scale
52
to hundreds or thousands of cores.
53
54
Parameters:
55
- tpu_cluster_resolver: A TPUClusterResolver, which provides information about the TPU cluster
56
- experimental_device_assignment: Optional, a DeviceAssignment to run replicas on
57
- experimental_spmd_xla_partitioning: Optional boolean for using SPMD-style sharding
58
"""
59
60
class OneDeviceStrategy(Strategy):
61
"""
62
A distribution strategy for running on a single device.
63
64
Using this strategy will place any variables created in its scope on the specified device.
65
Input distributed through this strategy will be prefetched to the specified device.
66
67
Parameters:
68
- device: Device string identifier for the device on which the variables should be placed
69
"""
70
71
class CentralStorageStrategy(Strategy):
72
"""
73
A one-machine strategy that puts all variables on a single device.
74
75
Variables are assigned to local CPU and operations are replicated across
76
all local GPUs. If there is only one GPU, operations will run on that GPU.
77
78
Parameters:
79
- compute_devices: Optional list of device strings for placing operations
80
- parameter_device: Optional device string for placing variables
81
"""
82
83
class ParameterServerStrategy(Strategy):
84
"""
85
An asynchronous multi-worker parameter server strategy.
86
87
Parameter server training is a common data-parallel method to scale up a
88
machine learning model on multiple machines.
89
90
Parameters:
91
- cluster_resolver: A ClusterResolver object specifying cluster configuration
92
- variable_partitioner: Optional callable for partitioning variables across parameter servers
93
"""
94
```
95
96
### Strategy Context and Execution
97
98
Methods for running code within distribution strategy contexts.
99
100
```python { .api }
101
def scope(self):
102
"""
103
Context manager to make the strategy current and distribute variables created in scope.
104
105
Returns:
106
A context manager
107
"""
108
109
def run(self, fn, args=(), kwargs=None, options=None):
110
"""
111
Invokes fn on each replica, with the given arguments.
112
113
Parameters:
114
- fn: The function to run on each replica
115
- args: Optional positional arguments to fn
116
- kwargs: Optional keyword arguments to fn
117
- options: Optional RunOptions specifying the options to run fn
118
119
Returns:
120
Merged return value of fn across replicas
121
"""
122
123
def reduce(self, reduce_op, value, axis=None):
124
"""
125
Reduce value across replicas and return result on current device.
126
127
Parameters:
128
- reduce_op: A ReduceOp value specifying how values should be combined
129
- value: A "per replica" value, e.g. returned by run
130
- axis: Specifies the dimension to reduce along within each replica's tensor
131
132
Returns:
133
A Tensor
134
"""
135
136
def gather(self, value, axis):
137
"""
138
Gather value across replicas along axis to current device.
139
140
Parameters:
141
- value: A "per replica" value, e.g. returned by Strategy.run
142
- axis: 0-D int32 Tensor. Dimension along which to gather
143
144
Returns:
145
A Tensor that's the concatenation of value across replicas along axis dimension
146
"""
147
```
148
149
### Distribution Utilities
150
151
Utility functions for working with distributed training.
152
153
```python { .api }
154
def get_strategy():
155
"""
156
Returns the current tf.distribute.Strategy object.
157
158
Returns:
159
A Strategy object. Inside a with strategy.scope() block, returns strategy,
160
otherwise returns the default (single-replica) strategy
161
"""
162
163
def has_strategy():
164
"""
165
Return if there is a current non-default tf.distribute.Strategy.
166
167
Returns:
168
True if inside a with strategy.scope() block for a non-default strategy
169
"""
170
171
def in_cross_replica_context():
172
"""
173
Returns True if in a cross-replica context.
174
175
Returns:
176
True if in a cross-replica context, False if in a replica context
177
"""
178
179
def get_replica_context():
180
"""
181
Returns the current tf.distribute.ReplicaContext or None.
182
183
Returns:
184
The current ReplicaContext object when in a replica context, else None
185
"""
186
187
def experimental_set_strategy(strategy):
188
"""
189
Set a tf.distribute.Strategy as current without with strategy.scope().
190
191
Parameters:
192
- strategy: A tf.distribute.Strategy object or None
193
"""
194
```
195
196
### Reduce Operations
197
198
Operations for combining values across replicas.
199
200
```python { .api }
201
class ReduceOp:
202
"""Indicates how a set of values should be reduced."""
203
204
SUM = "SUM" # Sum across replicas
205
MEAN = "MEAN" # Mean across replicas
206
MIN = "MIN" # Minimum across replicas
207
MAX = "MAX" # Maximum across replicas
208
209
class CrossDeviceOps:
210
"""Base class for cross-device reduction and broadcasting algorithms."""
211
212
def reduce(self, reduce_op, per_replica_value, destinations):
213
"""
214
Reduce per_replica_value to destinations.
215
216
Parameters:
217
- reduce_op: Indicates how per_replica_value will be reduced
218
- per_replica_value: A PerReplica object or a tensor with device placement
219
- destinations: The return value will be copied to these destinations
220
221
Returns:
222
A tensor or PerReplica object
223
"""
224
225
def broadcast(self, tensor, destinations):
226
"""
227
Broadcast tensor to destinations.
228
229
Parameters:
230
- tensor: The tensor to broadcast
231
- destinations: The broadcast destinations
232
233
Returns:
234
A tensor or PerReplica object
235
"""
236
```
237
238
## Usage Examples
239
240
```python
241
import tensorflow as tf
242
import numpy as np
243
244
# Single GPU strategy
245
strategy = tf.distribute.OneDeviceStrategy("/gpu:0")
246
247
# Multi-GPU strategy (automatic GPU detection)
248
strategy = tf.distribute.MirroredStrategy()
249
250
# Explicit device specification
251
strategy = tf.distribute.MirroredStrategy(devices=["/gpu:0", "/gpu:1"])
252
253
# Multi-worker strategy (requires cluster setup)
254
strategy = tf.distribute.MultiWorkerMirroredStrategy()
255
256
# Create and compile model within strategy scope
257
with strategy.scope():
258
model = tf.keras.Sequential([
259
tf.keras.layers.Dense(128, activation='relu', input_shape=(10,)),
260
tf.keras.layers.Dense(64, activation='relu'),
261
tf.keras.layers.Dense(1, activation='sigmoid')
262
])
263
264
model.compile(optimizer='adam',
265
loss='binary_crossentropy',
266
metrics=['accuracy'])
267
268
# Prepare distributed dataset
269
def make_dataset():
270
x = np.random.random((1000, 10))
271
y = np.random.randint(2, size=(1000, 1))
272
dataset = tf.data.Dataset.from_tensor_slices((x, y))
273
return dataset.batch(32)
274
275
# Distribute dataset across replicas
276
dataset = make_dataset()
277
dist_dataset = strategy.experimental_distribute_dataset(dataset)
278
279
# Custom training loop with strategy
280
with strategy.scope():
281
# Define loss and metrics
282
loss_object = tf.keras.losses.BinaryCrossentropy(
283
from_logits=False,
284
reduction=tf.keras.losses.Reduction.NONE
285
)
286
287
def compute_loss(labels, predictions):
288
per_example_loss = loss_object(labels, predictions)
289
return tf.nn.compute_average_loss(per_example_loss, global_batch_size=32)
290
291
train_accuracy = tf.keras.metrics.BinaryAccuracy()
292
293
optimizer = tf.keras.optimizers.Adam()
294
295
# Training step function
296
def train_step(inputs):
297
features, labels = inputs
298
299
with tf.GradientTape() as tape:
300
predictions = model(features, training=True)
301
loss = compute_loss(labels, predictions)
302
303
gradients = tape.gradient(loss, model.trainable_variables)
304
optimizer.apply_gradients(zip(gradients, model.trainable_variables))
305
306
train_accuracy.update_state(labels, predictions)
307
return loss
308
309
# Distributed training step
310
@tf.function
311
def distributed_train_step(dataset_inputs):
312
per_replica_losses = strategy.run(train_step, args=(dataset_inputs,))
313
return strategy.reduce(tf.distribute.ReduceOp.SUM, per_replica_losses, axis=None)
314
315
# Training loop
316
for epoch in range(5):
317
total_loss = 0.0
318
num_batches = 0
319
320
for x in dist_dataset:
321
loss = distributed_train_step(x)
322
total_loss += loss.numpy()
323
num_batches += 1
324
325
train_loss = total_loss / num_batches
326
print(f"Epoch {epoch + 1}, Loss: {train_loss:.4f}, "
327
f"Accuracy: {train_accuracy.result():.4f}")
328
329
train_accuracy.reset_states()
330
331
# Using built-in Keras fit with strategy
332
with strategy.scope():
333
model_fit = tf.keras.Sequential([
334
tf.keras.layers.Dense(128, activation='relu', input_shape=(10,)),
335
tf.keras.layers.Dense(1, activation='sigmoid')
336
])
337
338
model_fit.compile(optimizer='adam',
339
loss='binary_crossentropy',
340
metrics=['accuracy'])
341
342
# Keras fit automatically handles distribution
343
model_fit.fit(dataset, epochs=5)
344
345
# Multi-worker setup example (requires environment configuration)
346
# Set TF_CONFIG environment variable before running:
347
# os.environ['TF_CONFIG'] = json.dumps({
348
# 'cluster': {
349
# 'worker': ["host1:port", "host2:port", "host3:port"],
350
# 'ps': ["host4:port", "host5:port"]
351
# },
352
# 'task': {'type': 'worker', 'index': 1}
353
# })
354
355
# Strategy utilities
356
current_strategy = tf.distribute.get_strategy()
357
print(f"Current strategy: {type(current_strategy).__name__}")
358
print(f"Number of replicas: {current_strategy.num_replicas_in_sync}")
359
360
# Check execution context
361
if tf.distribute.in_cross_replica_context():
362
print("In cross-replica context")
363
else:
364
print("In replica context")
365
366
# Custom reduction example
367
with strategy.scope():
368
@tf.function
369
def replica_fn():
370
return tf.constant([1.0, 2.0, 3.0])
371
372
# Run function on all replicas
373
per_replica_result = strategy.run(replica_fn)
374
375
# Reduce across replicas
376
reduced_sum = strategy.reduce(tf.distribute.ReduceOp.SUM, per_replica_result)
377
reduced_mean = strategy.reduce(tf.distribute.ReduceOp.MEAN, per_replica_result)
378
379
print(f"Sum: {reduced_sum}")
380
print(f"Mean: {reduced_mean}")
381
```