or run

npx @tessl/cli init
Log in

Version

Tile

Overview

Evals

Files

Files

docs

core.mddata.mddistribute.mdimage.mdindex.mdkeras.mdmath.mdnn.mdsaved-model.md

distribute.mddocs/

0

# Distribution Strategies

1

2

Multi-device and multi-worker training strategies for scaling machine learning workloads across GPUs and TPUs. These strategies enable efficient distributed training and deployment.

3

4

## Capabilities

5

6

### Strategy Classes

7

8

Core distribution strategy classes for different distributed training scenarios.

9

10

```python { .api }

11

class Strategy:

12

"""

13

Base class for distribution strategies.

14

15

Methods:

16

- scope(): Returns a context manager selecting this Strategy as current

17

- run(fn, args=(), kwargs=None, options=None): Invokes fn on each replica, with the given arguments

18

- reduce(reduce_op, value, axis): Reduce value across replicas and return result on current device

19

- gather(value, axis): Gather value across replicas along axis to current device

20

"""

21

22

class MirroredStrategy(Strategy):

23

"""

24

Synchronous training across multiple replicas on one machine.

25

26

This strategy is typically used for training on one machine with multiple GPUs.

27

Variables and updates will be mirrored across all replicas.

28

29

Parameters:

30

- devices: Optional list of device strings or device objects. If not specified, all visible GPUs are used

31

- cross_device_ops: Optional, a ReduceOp specifying how to combine values

32

"""

33

34

class MultiWorkerMirroredStrategy(Strategy):

35

"""

36

Synchronous training across multiple workers, each with potentially multiple replicas.

37

38

This strategy implements synchronous distributed training across multiple workers,

39

each of which may have multiple GPUs. Similar to MirroredStrategy, it replicates

40

all variables and computations to each local replica.

41

42

Parameters:

43

- cluster_resolver: Optional cluster resolver

44

- communication_options: Optional, communication options for CollectiveOps

45

"""

46

47

class TPUStrategy(Strategy):

48

"""

49

Synchronous training on TPUs and TPU Pods.

50

51

This strategy is for running on TPUs, including TPU pods which can scale

52

to hundreds or thousands of cores.

53

54

Parameters:

55

- tpu_cluster_resolver: A TPUClusterResolver, which provides information about the TPU cluster

56

- experimental_device_assignment: Optional, a DeviceAssignment to run replicas on

57

- experimental_spmd_xla_partitioning: Optional boolean for using SPMD-style sharding

58

"""

59

60

class OneDeviceStrategy(Strategy):

61

"""

62

A distribution strategy for running on a single device.

63

64

Using this strategy will place any variables created in its scope on the specified device.

65

Input distributed through this strategy will be prefetched to the specified device.

66

67

Parameters:

68

- device: Device string identifier for the device on which the variables should be placed

69

"""

70

71

class CentralStorageStrategy(Strategy):

72

"""

73

A one-machine strategy that puts all variables on a single device.

74

75

Variables are assigned to local CPU and operations are replicated across

76

all local GPUs. If there is only one GPU, operations will run on that GPU.

77

78

Parameters:

79

- compute_devices: Optional list of device strings for placing operations

80

- parameter_device: Optional device string for placing variables

81

"""

82

83

class ParameterServerStrategy(Strategy):

84

"""

85

An asynchronous multi-worker parameter server strategy.

86

87

Parameter server training is a common data-parallel method to scale up a

88

machine learning model on multiple machines.

89

90

Parameters:

91

- cluster_resolver: A ClusterResolver object specifying cluster configuration

92

- variable_partitioner: Optional callable for partitioning variables across parameter servers

93

"""

94

```

95

96

### Strategy Context and Execution

97

98

Methods for running code within distribution strategy contexts.

99

100

```python { .api }

101

def scope(self):

102

"""

103

Context manager to make the strategy current and distribute variables created in scope.

104

105

Returns:

106

A context manager

107

"""

108

109

def run(self, fn, args=(), kwargs=None, options=None):

110

"""

111

Invokes fn on each replica, with the given arguments.

112

113

Parameters:

114

- fn: The function to run on each replica

115

- args: Optional positional arguments to fn

116

- kwargs: Optional keyword arguments to fn

117

- options: Optional RunOptions specifying the options to run fn

118

119

Returns:

120

Merged return value of fn across replicas

121

"""

122

123

def reduce(self, reduce_op, value, axis=None):

124

"""

125

Reduce value across replicas and return result on current device.

126

127

Parameters:

128

- reduce_op: A ReduceOp value specifying how values should be combined

129

- value: A "per replica" value, e.g. returned by run

130

- axis: Specifies the dimension to reduce along within each replica's tensor

131

132

Returns:

133

A Tensor

134

"""

135

136

def gather(self, value, axis):

137

"""

138

Gather value across replicas along axis to current device.

139

140

Parameters:

141

- value: A "per replica" value, e.g. returned by Strategy.run

142

- axis: 0-D int32 Tensor. Dimension along which to gather

143

144

Returns:

145

A Tensor that's the concatenation of value across replicas along axis dimension

146

"""

147

```

148

149

### Distribution Utilities

150

151

Utility functions for working with distributed training.

152

153

```python { .api }

154

def get_strategy():

155

"""

156

Returns the current tf.distribute.Strategy object.

157

158

Returns:

159

A Strategy object. Inside a with strategy.scope() block, returns strategy,

160

otherwise returns the default (single-replica) strategy

161

"""

162

163

def has_strategy():

164

"""

165

Return if there is a current non-default tf.distribute.Strategy.

166

167

Returns:

168

True if inside a with strategy.scope() block for a non-default strategy

169

"""

170

171

def in_cross_replica_context():

172

"""

173

Returns True if in a cross-replica context.

174

175

Returns:

176

True if in a cross-replica context, False if in a replica context

177

"""

178

179

def get_replica_context():

180

"""

181

Returns the current tf.distribute.ReplicaContext or None.

182

183

Returns:

184

The current ReplicaContext object when in a replica context, else None

185

"""

186

187

def experimental_set_strategy(strategy):

188

"""

189

Set a tf.distribute.Strategy as current without with strategy.scope().

190

191

Parameters:

192

- strategy: A tf.distribute.Strategy object or None

193

"""

194

```

195

196

### Reduce Operations

197

198

Operations for combining values across replicas.

199

200

```python { .api }

201

class ReduceOp:

202

"""Indicates how a set of values should be reduced."""

203

204

SUM = "SUM" # Sum across replicas

205

MEAN = "MEAN" # Mean across replicas

206

MIN = "MIN" # Minimum across replicas

207

MAX = "MAX" # Maximum across replicas

208

209

class CrossDeviceOps:

210

"""Base class for cross-device reduction and broadcasting algorithms."""

211

212

def reduce(self, reduce_op, per_replica_value, destinations):

213

"""

214

Reduce per_replica_value to destinations.

215

216

Parameters:

217

- reduce_op: Indicates how per_replica_value will be reduced

218

- per_replica_value: A PerReplica object or a tensor with device placement

219

- destinations: The return value will be copied to these destinations

220

221

Returns:

222

A tensor or PerReplica object

223

"""

224

225

def broadcast(self, tensor, destinations):

226

"""

227

Broadcast tensor to destinations.

228

229

Parameters:

230

- tensor: The tensor to broadcast

231

- destinations: The broadcast destinations

232

233

Returns:

234

A tensor or PerReplica object

235

"""

236

```

237

238

## Usage Examples

239

240

```python

241

import tensorflow as tf

242

import numpy as np

243

244

# Single GPU strategy

245

strategy = tf.distribute.OneDeviceStrategy("/gpu:0")

246

247

# Multi-GPU strategy (automatic GPU detection)

248

strategy = tf.distribute.MirroredStrategy()

249

250

# Explicit device specification

251

strategy = tf.distribute.MirroredStrategy(devices=["/gpu:0", "/gpu:1"])

252

253

# Multi-worker strategy (requires cluster setup)

254

strategy = tf.distribute.MultiWorkerMirroredStrategy()

255

256

# Create and compile model within strategy scope

257

with strategy.scope():

258

model = tf.keras.Sequential([

259

tf.keras.layers.Dense(128, activation='relu', input_shape=(10,)),

260

tf.keras.layers.Dense(64, activation='relu'),

261

tf.keras.layers.Dense(1, activation='sigmoid')

262

])

263

264

model.compile(optimizer='adam',

265

loss='binary_crossentropy',

266

metrics=['accuracy'])

267

268

# Prepare distributed dataset

269

def make_dataset():

270

x = np.random.random((1000, 10))

271

y = np.random.randint(2, size=(1000, 1))

272

dataset = tf.data.Dataset.from_tensor_slices((x, y))

273

return dataset.batch(32)

274

275

# Distribute dataset across replicas

276

dataset = make_dataset()

277

dist_dataset = strategy.experimental_distribute_dataset(dataset)

278

279

# Custom training loop with strategy

280

with strategy.scope():

281

# Define loss and metrics

282

loss_object = tf.keras.losses.BinaryCrossentropy(

283

from_logits=False,

284

reduction=tf.keras.losses.Reduction.NONE

285

)

286

287

def compute_loss(labels, predictions):

288

per_example_loss = loss_object(labels, predictions)

289

return tf.nn.compute_average_loss(per_example_loss, global_batch_size=32)

290

291

train_accuracy = tf.keras.metrics.BinaryAccuracy()

292

293

optimizer = tf.keras.optimizers.Adam()

294

295

# Training step function

296

def train_step(inputs):

297

features, labels = inputs

298

299

with tf.GradientTape() as tape:

300

predictions = model(features, training=True)

301

loss = compute_loss(labels, predictions)

302

303

gradients = tape.gradient(loss, model.trainable_variables)

304

optimizer.apply_gradients(zip(gradients, model.trainable_variables))

305

306

train_accuracy.update_state(labels, predictions)

307

return loss

308

309

# Distributed training step

310

@tf.function

311

def distributed_train_step(dataset_inputs):

312

per_replica_losses = strategy.run(train_step, args=(dataset_inputs,))

313

return strategy.reduce(tf.distribute.ReduceOp.SUM, per_replica_losses, axis=None)

314

315

# Training loop

316

for epoch in range(5):

317

total_loss = 0.0

318

num_batches = 0

319

320

for x in dist_dataset:

321

loss = distributed_train_step(x)

322

total_loss += loss.numpy()

323

num_batches += 1

324

325

train_loss = total_loss / num_batches

326

print(f"Epoch {epoch + 1}, Loss: {train_loss:.4f}, "

327

f"Accuracy: {train_accuracy.result():.4f}")

328

329

train_accuracy.reset_states()

330

331

# Using built-in Keras fit with strategy

332

with strategy.scope():

333

model_fit = tf.keras.Sequential([

334

tf.keras.layers.Dense(128, activation='relu', input_shape=(10,)),

335

tf.keras.layers.Dense(1, activation='sigmoid')

336

])

337

338

model_fit.compile(optimizer='adam',

339

loss='binary_crossentropy',

340

metrics=['accuracy'])

341

342

# Keras fit automatically handles distribution

343

model_fit.fit(dataset, epochs=5)

344

345

# Multi-worker setup example (requires environment configuration)

346

# Set TF_CONFIG environment variable before running:

347

# os.environ['TF_CONFIG'] = json.dumps({

348

# 'cluster': {

349

# 'worker': ["host1:port", "host2:port", "host3:port"],

350

# 'ps': ["host4:port", "host5:port"]

351

# },

352

# 'task': {'type': 'worker', 'index': 1}

353

# })

354

355

# Strategy utilities

356

current_strategy = tf.distribute.get_strategy()

357

print(f"Current strategy: {type(current_strategy).__name__}")

358

print(f"Number of replicas: {current_strategy.num_replicas_in_sync}")

359

360

# Check execution context

361

if tf.distribute.in_cross_replica_context():

362

print("In cross-replica context")

363

else:

364

print("In replica context")

365

366

# Custom reduction example

367

with strategy.scope():

368

@tf.function

369

def replica_fn():

370

return tf.constant([1.0, 2.0, 3.0])

371

372

# Run function on all replicas

373

per_replica_result = strategy.run(replica_fn)

374

375

# Reduce across replicas

376

reduced_sum = strategy.reduce(tf.distribute.ReduceOp.SUM, per_replica_result)

377

reduced_mean = strategy.reduce(tf.distribute.ReduceOp.MEAN, per_replica_result)

378

379

print(f"Sum: {reduced_sum}")

380

print(f"Mean: {reduced_mean}")

381

```