or run

npx @tessl/cli init
Log in

Version

Tile

Overview

Evals

Files

Files

docs

config-persistence.mdcore-framework.mddata-examples.mddistributed-runtime.mdindex.mdutilities.md

distributed-runtime.mddocs/

0

# Distributed Runtime

1

2

Service interfaces and message types for distributed TensorFlow execution across multiple devices and machines. These types enable TensorFlow to run computations on clusters and coordinate between master and worker nodes.

3

4

## Capabilities

5

6

### Master Service Messages

7

8

The master service coordinates distributed TensorFlow sessions and manages the overall execution flow across worker nodes.

9

10

#### CreateSessionRequest/Response

11

12

Creates a new TensorFlow session on the cluster.

13

14

```java { .api }

15

/**

16

* Request to create a new session

17

*/

18

class CreateSessionRequest {

19

/** Get the computation graph */

20

GraphDef getGraphDef();

21

22

/** Get session configuration */

23

ConfigProto getConfig();

24

25

/** Get target specification (e.g., "grpc://localhost:2222") */

26

String getTarget();

27

28

/** Create a new builder */

29

static Builder newBuilder();

30

31

static class Builder {

32

Builder setGraphDef(GraphDef graphDef);

33

Builder setConfig(ConfigProto config);

34

Builder setTarget(String target);

35

CreateSessionRequest build();

36

}

37

}

38

39

/**

40

* Response with created session handle

41

*/

42

class CreateSessionResponse {

43

/** Get unique session handle */

44

String getSessionHandle();

45

46

/** Get cluster information */

47

ClusterDef getClusterDef();

48

49

/** Get graph version */

50

int getGraphVersion();

51

}

52

```

53

54

#### RunStepRequest/Response

55

56

Executes a computation step in a distributed session.

57

58

```java { .api }

59

/**

60

* Request to run a computation step

61

*/

62

class RunStepRequest {

63

/** Get session handle */

64

String getSessionHandle();

65

66

/** Get input feed mappings (tensor name -> tensor value) */

67

Map<String, TensorProto> getFeedMap();

68

69

/** Get output fetch names */

70

List<String> getFetchList();

71

72

/** Get target operation names to run */

73

List<String> getTargetList();

74

75

/** Get run options */

76

RunOptions getOptions();

77

78

/** Get partial run handle for partial execution */

79

String getPartialRunHandle();

80

81

/** Check if this creates a partial run */

82

boolean getStoreErrorsInResponseBody();

83

84

/** Create a new builder */

85

static Builder newBuilder();

86

87

static class Builder {

88

Builder setSessionHandle(String handle);

89

Builder putFeed(String key, TensorProto value);

90

Builder addFetch(String fetch);

91

Builder addTarget(String target);

92

Builder setOptions(RunOptions options);

93

Builder setPartialRunHandle(String handle);

94

RunStepRequest build();

95

}

96

}

97

98

/**

99

* Response with computation results

100

*/

101

class RunStepResponse {

102

/** Get output tensor values */

103

List<TensorProto> getTensorList();

104

105

/** Get execution metadata */

106

RunMetadata getMetadata();

107

108

/** Get step execution statistics */

109

StepStats getStepStats();

110

111

/** Get cost graph information */

112

CostGraphDef getCostGraph();

113

114

/** Get status code */

115

int getStatusCode();

116

117

/** Get error message if failed */

118

String getStatusErrorMessage();

119

}

120

```

121

122

**Usage Examples:**

123

124

```java

125

import org.tensorflow.distruntime.*;

126

import org.tensorflow.framework.*;

127

128

// Create a session on distributed cluster

129

CreateSessionRequest sessionRequest = CreateSessionRequest.newBuilder()

130

.setGraphDef(myGraphDef)

131

.setConfig(ConfigProto.newBuilder()

132

.setAllowSoftPlacement(true)

133

.putDeviceCount("GPU", 2)

134

.putDeviceCount("CPU", 4)

135

.build())

136

.setTarget("grpc://chief:2222")

137

.build();

138

139

// Run a training step

140

RunStepRequest stepRequest = RunStepRequest.newBuilder()

141

.setSessionHandle(sessionHandle)

142

.putFeed("input:0", inputTensor)

143

.putFeed("labels:0", labelTensor)

144

.addFetch("loss:0")

145

.addFetch("accuracy:0")

146

.addTarget("train_op")

147

.setOptions(RunOptions.newBuilder()

148

.setTraceLevel(RunOptions.TraceLevel.FULL_TRACE)

149

.setTimeoutInMs(30000)

150

.build())

151

.build();

152

```

153

154

#### ExtendSessionRequest/Response

155

156

Extends an existing session with additional graph nodes.

157

158

```java { .api }

159

/**

160

* Request to extend session with new graph nodes

161

*/

162

class ExtendSessionRequest {

163

/** Get session handle */

164

String getSessionHandle();

165

166

/** Get additional graph definition */

167

GraphDef getGraphDef();

168

169

/** Get current graph version */

170

int getCurrentGraphVersion();

171

172

/** Create a new builder */

173

static Builder newBuilder();

174

}

175

176

/**

177

* Response with updated graph version

178

*/

179

class ExtendSessionResponse {

180

/** Get new graph version */

181

int getNewGraphVersion();

182

}

183

```

184

185

#### ListDevicesRequest/Response

186

187

Lists available devices in the cluster.

188

189

```java { .api }

190

/**

191

* Request to list available devices

192

*/

193

class ListDevicesRequest {

194

/** Get session handle (optional) */

195

String getSessionHandle();

196

}

197

198

/**

199

* Response with device information

200

*/

201

class ListDevicesResponse {

202

/** Get list of local devices */

203

List<DeviceAttributes> getLocalDeviceList();

204

205

/** Get list of remote devices */

206

List<DeviceAttributes> getRemoteDeviceList();

207

}

208

```

209

210

### Worker Service Messages

211

212

Worker services execute graph partitions on individual machines in a distributed setup.

213

214

#### RegisterGraphRequest/Response

215

216

Registers a graph partition on a worker node.

217

218

```java { .api }

219

/**

220

* Request to register graph partition on worker

221

*/

222

class RegisterGraphRequest {

223

/** Get session handle */

224

String getSessionHandle();

225

226

/** Check if this creates a new session */

227

boolean getCreateWorkerSessionOnly();

228

229

/** Get graph definition */

230

GraphDef getGraphDef();

231

232

/** Check if variables should be initialized */

233

boolean getHasControlDependencies();

234

235

/** Get graph options */

236

GraphOptions getGraphOptions();

237

238

/** Get debug options */

239

DebugOptions getDebugOptions();

240

241

/** Get collective graph key */

242

long getCollectiveGraphKey();

243

244

/** Create a new builder */

245

static Builder newBuilder();

246

}

247

248

/**

249

* Response after registering graph

250

*/

251

class RegisterGraphResponse {

252

/** Get graph handle for future operations */

253

String getGraphHandle();

254

}

255

```

256

257

#### RunGraphRequest/Response

258

259

Executes a registered graph partition.

260

261

```java { .api }

262

/**

263

* Request to run registered graph partition

264

*/

265

class RunGraphRequest {

266

/** Get session handle */

267

String getSessionHandle();

268

269

/** Get graph handle */

270

String getGraphHandle();

271

272

/** Get step ID for coordination */

273

long getStepId();

274

275

/** Get execution count */

276

long getExecCount();

277

278

/** Get input tensors */

279

List<NamedTensorProto> getSendList();

280

281

/** Get output tensor names */

282

List<String> getRecvKeyList();

283

284

/** Check if this is a partial run */

285

boolean getIsPartial();

286

287

/** Check if this is the last partial run */

288

boolean getIsLastPartialRun();

289

290

/** Create a new builder */

291

static Builder newBuilder();

292

}

293

294

/**

295

* Response with computation results

296

*/

297

class RunGraphResponse {

298

/** Get output tensors */

299

List<NamedTensorProto> getRecvList();

300

301

/** Get step execution statistics */

302

StepStats getStepStats();

303

304

/** Get cost graph */

305

CostGraphDef getCostGraph();

306

307

/** Get partition graphs executed */

308

List<GraphDef> getPartitionGraphList();

309

}

310

```

311

312

**Usage Examples:**

313

314

```java

315

import org.tensorflow.distruntime.*;

316

317

// Register a graph partition on worker

318

RegisterGraphRequest registerRequest = RegisterGraphRequest.newBuilder()

319

.setSessionHandle(sessionHandle)

320

.setGraphDef(partitionedGraph)

321

.setHasControlDependencies(true)

322

.setGraphOptions(GraphOptions.newBuilder()

323

.setEnableRecvScheduling(true)

324

.build())

325

.build();

326

327

// Execute the registered graph

328

RunGraphRequest runRequest = RunGraphRequest.newBuilder()

329

.setSessionHandle(sessionHandle)

330

.setGraphHandle(graphHandle)

331

.setStepId(currentStepId)

332

.addSend(NamedTensorProto.newBuilder()

333

.setName("input_partition:0")

334

.setTensor(inputTensor)

335

.build())

336

.addRecvKey("output_partition:0")

337

.setIsPartial(false)

338

.build();

339

```

340

341

### Eager Service Messages

342

343

Services for TensorFlow's eager execution mode, allowing operations to be executed immediately.

344

345

#### CreateContextRequest/Response

346

347

Creates an eager execution context.

348

349

```java { .api }

350

/**

351

* Request to create eager execution context

352

*/

353

class CreateContextRequest {

354

/** Get server definition */

355

ServerDef getServerDef();

356

357

/** Check if async execution is enabled */

358

boolean getAsync();

359

360

/** Get keep alive interval in seconds */

361

int getKeepAliveSecs();

362

363

/** Get version compatibility requirements */

364

VersionDef getVersionDef();

365

366

/** Get cluster device filters */

367

ClusterDeviceFilters getClusterDeviceFilters();

368

369

/** Create a new builder */

370

static Builder newBuilder();

371

}

372

373

/**

374

* Response with context information

375

*/

376

class CreateContextResponse {

377

/** Get context ID */

378

long getContextId();

379

380

/** Get context view ID */

381

long getContextViewId();

382

383

/** Get device attributes */

384

List<DeviceAttributes> getDeviceAttributesList();

385

}

386

```

387

388

#### EnqueueRequest/Response

389

390

Enqueues operations for eager execution.

391

392

```java { .api }

393

/**

394

* Request to enqueue eager operations

395

*/

396

class EnqueueRequest {

397

/** Get context ID */

398

long getContextId();

399

400

/** Get list of operations to execute */

401

List<Operation> getQueueList();

402

403

/** Operation definition for eager execution */

404

static class Operation {

405

/** Get operation ID */

406

long getId();

407

408

/** Get operation name */

409

String getName();

410

411

/** Get operation attributes */

412

Map<String, AttrValue> getAttrsMap();

413

414

/** Get input handles */

415

List<RemoteTensorHandle> getInputsList();

416

417

/** Get control input operation IDs */

418

List<Long> getControlOpIdsList();

419

420

/** Get device name */

421

String getDevice();

422

423

/** Check if operation is a function */

424

boolean getIsFunction();

425

}

426

}

427

428

/**

429

* Response with operation results

430

*/

431

class EnqueueResponse {

432

/** Get list of operation results */

433

List<QueueResponse> getQueueResponseList();

434

435

/** Response for individual operations */

436

static class QueueResponse {

437

/** Get output tensor handles */

438

List<TensorHandle> getTensorList();

439

440

/** Get output shapes */

441

List<TensorShapeProto> getShapeList();

442

}

443

}

444

```

445

446

### Common Distributed Types

447

448

#### RunOptions

449

450

Options for controlling step execution behavior.

451

452

```java { .api }

453

/**

454

* Options for controlling step execution behavior

455

*/

456

class RunOptions {

457

/** Get trace level for profiling */

458

TraceLevel getTraceLevel();

459

460

/** Get timeout in milliseconds */

461

long getTimeoutInMs();

462

463

/** Get inter-op thread pool setting */

464

int getInterOpThreadPool();

465

466

/** Check if output partition graphs is enabled */

467

boolean getOutputPartitionGraphs();

468

469

/** Get debug options */

470

DebugOptions getDebugOptions();

471

472

/** Check if report tensor allocations during execution */

473

boolean getReportTensorAllocationsUponOom();

474

475

/** Get experimental options */

476

Experimental getExperimental();

477

478

static Builder newBuilder();

479

480

static class Builder {

481

Builder setTraceLevel(TraceLevel level);

482

Builder setTimeoutInMs(long timeout);

483

Builder setInterOpThreadPool(int pool);

484

Builder setOutputPartitionGraphs(boolean output);

485

Builder setDebugOptions(DebugOptions options);

486

RunOptions build();

487

}

488

489

enum TraceLevel {

490

NO_TRACE,

491

SOFTWARE_TRACE,

492

HARDWARE_TRACE,

493

FULL_TRACE

494

}

495

496

static class Experimental {

497

int getCollectiveGraphKey();

498

boolean getUseRunHandler();

499

}

500

}

501

```

502

503

#### RunMetadata

504

505

Metadata returned from step execution.

506

507

```java { .api }

508

/**

509

* Metadata returned from step execution

510

*/

511

class RunMetadata {

512

/** Get step execution statistics */

513

StepStats getStepStats();

514

515

/** Get cost graph information */

516

CostGraphDef getCostGraph();

517

518

/** Get partition graphs that were executed */

519

List<GraphDef> getPartitionGraphsList();

520

521

/** Get function graphs that were executed */

522

List<GraphDef> getFunctionGraphsList();

523

524

static Builder newBuilder();

525

}

526

```

527

528

#### CostGraphDef

529

530

Cost model information for operations.

531

532

```java { .api }

533

/**

534

* Cost model information for operations

535

*/

536

class CostGraphDef {

537

/** Get cost information for each node */

538

List<Node> getNodeList();

539

540

/** Cost information for a single node */

541

static class Node {

542

/** Get node name */

543

String getName();

544

545

/** Get device name */

546

String getDevice();

547

548

/** Get node ID */

549

int getId();

550

551

/** Get input information */

552

List<InputInfo> getInputInfoList();

553

554

/** Get output information */

555

List<OutputInfo> getOutputInfoList();

556

557

/** Get temporary memory used */

558

long getTempMemorySize();

559

560

/** Get persistent memory used */

561

long getPersistentMemorySize();

562

563

/** Get compute cost */

564

long getComputeCost();

565

566

/** Get compute time */

567

long getComputeTime();

568

569

/** Get memory time */

570

long getMemoryTime();

571

572

/** Check if this is the final node */

573

boolean getIsFinal();

574

575

/** Get control input nodes */

576

List<Integer> getControlInputList();

577

578

/** Check if inaccurate */

579

boolean getInaccurate();

580

}

581

582

/** Input information for cost calculation */

583

static class InputInfo {

584

int getPrecedingNode();

585

int getPrecedingPort();

586

}

587

588

/** Output information for cost calculation */

589

static class OutputInfo {

590

long getSize();

591

long getAliasInputPort();

592

TensorShapeProto getShape();

593

DataType getDtype();

594

}

595

}

596

```

597

598

#### ClusterDef

599

600

Defines the cluster topology and job configurations.

601

602

```java { .api }

603

/**

604

* Cluster topology definition

605

*/

606

class ClusterDef {

607

/** Get job definitions */

608

Map<String, JobDef> getJobMap();

609

610

/** Job definition within cluster */

611

static class JobDef {

612

/** Get job name */

613

String getName();

614

615

/** Get task index to address mapping */

616

Map<Integer, String> getTasksMap();

617

}

618

}

619

```

620

621

#### ServerDef

622

623

Defines server configuration for distributed execution.

624

625

```java { .api }

626

/**

627

* Server configuration for distributed execution

628

*/

629

class ServerDef {

630

/** Get cluster definition */

631

ClusterDef getCluster();

632

633

/** Get job name for this server */

634

String getJobName();

635

636

/** Get task index for this server */

637

int getTaskIndex();

638

639

/** Get default session configuration */

640

ConfigProto getDefaultSessionConfig();

641

642

/** Get server protocol (e.g., "grpc") */

643

String getProtocol();

644

645

/** Get server port */

646

int getPort();

647

}

648

```

649

650

**Usage Examples:**

651

652

```java

653

import org.tensorflow.distruntime.*;

654

655

// Define a cluster with chief and workers

656

ClusterDef cluster = ClusterDef.newBuilder()

657

.putJob("chief", JobDef.newBuilder()

658

.setName("chief")

659

.putTasks(0, "chief:2222")

660

.build())

661

.putJob("worker", JobDef.newBuilder()

662

.setName("worker")

663

.putTasks(0, "worker0:2222")

664

.putTasks(1, "worker1:2222")

665

.putTasks(2, "worker2:2222")

666

.build())

667

.putJob("ps", JobDef.newBuilder()

668

.setName("ps")

669

.putTasks(0, "ps0:2222")

670

.putTasks(1, "ps1:2222")

671

.build())

672

.build();

673

674

// Configure server as worker

675

ServerDef serverDef = ServerDef.newBuilder()

676

.setCluster(cluster)

677

.setJobName("worker")

678

.setTaskIndex(0)

679

.setProtocol("grpc")

680

.setPort(2222)

681

.setDefaultSessionConfig(ConfigProto.newBuilder()

682

.setAllowSoftPlacement(true)

683

.build())

684

.build();

685

```