0
# Distributed Runtime
1
2
Service interfaces and message types for distributed TensorFlow execution across multiple devices and machines. These types enable TensorFlow to run computations on clusters and coordinate between master and worker nodes.
3
4
## Capabilities
5
6
### Master Service Messages
7
8
The master service coordinates distributed TensorFlow sessions and manages the overall execution flow across worker nodes.
9
10
#### CreateSessionRequest/Response
11
12
Creates a new TensorFlow session on the cluster.
13
14
```java { .api }
15
/**
16
* Request to create a new session
17
*/
18
class CreateSessionRequest {
19
/** Get the computation graph */
20
GraphDef getGraphDef();
21
22
/** Get session configuration */
23
ConfigProto getConfig();
24
25
/** Get target specification (e.g., "grpc://localhost:2222") */
26
String getTarget();
27
28
/** Create a new builder */
29
static Builder newBuilder();
30
31
static class Builder {
32
Builder setGraphDef(GraphDef graphDef);
33
Builder setConfig(ConfigProto config);
34
Builder setTarget(String target);
35
CreateSessionRequest build();
36
}
37
}
38
39
/**
40
* Response with created session handle
41
*/
42
class CreateSessionResponse {
43
/** Get unique session handle */
44
String getSessionHandle();
45
46
/** Get cluster information */
47
ClusterDef getClusterDef();
48
49
/** Get graph version */
50
int getGraphVersion();
51
}
52
```
53
54
#### RunStepRequest/Response
55
56
Executes a computation step in a distributed session.
57
58
```java { .api }
59
/**
60
* Request to run a computation step
61
*/
62
class RunStepRequest {
63
/** Get session handle */
64
String getSessionHandle();
65
66
/** Get input feed mappings (tensor name -> tensor value) */
67
Map<String, TensorProto> getFeedMap();
68
69
/** Get output fetch names */
70
List<String> getFetchList();
71
72
/** Get target operation names to run */
73
List<String> getTargetList();
74
75
/** Get run options */
76
RunOptions getOptions();
77
78
/** Get partial run handle for partial execution */
79
String getPartialRunHandle();
80
81
/** Check if this creates a partial run */
82
boolean getStoreErrorsInResponseBody();
83
84
/** Create a new builder */
85
static Builder newBuilder();
86
87
static class Builder {
88
Builder setSessionHandle(String handle);
89
Builder putFeed(String key, TensorProto value);
90
Builder addFetch(String fetch);
91
Builder addTarget(String target);
92
Builder setOptions(RunOptions options);
93
Builder setPartialRunHandle(String handle);
94
RunStepRequest build();
95
}
96
}
97
98
/**
99
* Response with computation results
100
*/
101
class RunStepResponse {
102
/** Get output tensor values */
103
List<TensorProto> getTensorList();
104
105
/** Get execution metadata */
106
RunMetadata getMetadata();
107
108
/** Get step execution statistics */
109
StepStats getStepStats();
110
111
/** Get cost graph information */
112
CostGraphDef getCostGraph();
113
114
/** Get status code */
115
int getStatusCode();
116
117
/** Get error message if failed */
118
String getStatusErrorMessage();
119
}
120
```
121
122
**Usage Examples:**
123
124
```java
125
import org.tensorflow.distruntime.*;
126
import org.tensorflow.framework.*;
127
128
// Create a session on distributed cluster
129
CreateSessionRequest sessionRequest = CreateSessionRequest.newBuilder()
130
.setGraphDef(myGraphDef)
131
.setConfig(ConfigProto.newBuilder()
132
.setAllowSoftPlacement(true)
133
.putDeviceCount("GPU", 2)
134
.putDeviceCount("CPU", 4)
135
.build())
136
.setTarget("grpc://chief:2222")
137
.build();
138
139
// Run a training step
140
RunStepRequest stepRequest = RunStepRequest.newBuilder()
141
.setSessionHandle(sessionHandle)
142
.putFeed("input:0", inputTensor)
143
.putFeed("labels:0", labelTensor)
144
.addFetch("loss:0")
145
.addFetch("accuracy:0")
146
.addTarget("train_op")
147
.setOptions(RunOptions.newBuilder()
148
.setTraceLevel(RunOptions.TraceLevel.FULL_TRACE)
149
.setTimeoutInMs(30000)
150
.build())
151
.build();
152
```
153
154
#### ExtendSessionRequest/Response
155
156
Extends an existing session with additional graph nodes.
157
158
```java { .api }
159
/**
160
* Request to extend session with new graph nodes
161
*/
162
class ExtendSessionRequest {
163
/** Get session handle */
164
String getSessionHandle();
165
166
/** Get additional graph definition */
167
GraphDef getGraphDef();
168
169
/** Get current graph version */
170
int getCurrentGraphVersion();
171
172
/** Create a new builder */
173
static Builder newBuilder();
174
}
175
176
/**
177
* Response with updated graph version
178
*/
179
class ExtendSessionResponse {
180
/** Get new graph version */
181
int getNewGraphVersion();
182
}
183
```
184
185
#### ListDevicesRequest/Response
186
187
Lists available devices in the cluster.
188
189
```java { .api }
190
/**
191
* Request to list available devices
192
*/
193
class ListDevicesRequest {
194
/** Get session handle (optional) */
195
String getSessionHandle();
196
}
197
198
/**
199
* Response with device information
200
*/
201
class ListDevicesResponse {
202
/** Get list of local devices */
203
List<DeviceAttributes> getLocalDeviceList();
204
205
/** Get list of remote devices */
206
List<DeviceAttributes> getRemoteDeviceList();
207
}
208
```
209
210
### Worker Service Messages
211
212
Worker services execute graph partitions on individual machines in a distributed setup.
213
214
#### RegisterGraphRequest/Response
215
216
Registers a graph partition on a worker node.
217
218
```java { .api }
219
/**
220
* Request to register graph partition on worker
221
*/
222
class RegisterGraphRequest {
223
/** Get session handle */
224
String getSessionHandle();
225
226
/** Check if this creates a new session */
227
boolean getCreateWorkerSessionOnly();
228
229
/** Get graph definition */
230
GraphDef getGraphDef();
231
232
/** Check if variables should be initialized */
233
boolean getHasControlDependencies();
234
235
/** Get graph options */
236
GraphOptions getGraphOptions();
237
238
/** Get debug options */
239
DebugOptions getDebugOptions();
240
241
/** Get collective graph key */
242
long getCollectiveGraphKey();
243
244
/** Create a new builder */
245
static Builder newBuilder();
246
}
247
248
/**
249
* Response after registering graph
250
*/
251
class RegisterGraphResponse {
252
/** Get graph handle for future operations */
253
String getGraphHandle();
254
}
255
```
256
257
#### RunGraphRequest/Response
258
259
Executes a registered graph partition.
260
261
```java { .api }
262
/**
263
* Request to run registered graph partition
264
*/
265
class RunGraphRequest {
266
/** Get session handle */
267
String getSessionHandle();
268
269
/** Get graph handle */
270
String getGraphHandle();
271
272
/** Get step ID for coordination */
273
long getStepId();
274
275
/** Get execution count */
276
long getExecCount();
277
278
/** Get input tensors */
279
List<NamedTensorProto> getSendList();
280
281
/** Get output tensor names */
282
List<String> getRecvKeyList();
283
284
/** Check if this is a partial run */
285
boolean getIsPartial();
286
287
/** Check if this is the last partial run */
288
boolean getIsLastPartialRun();
289
290
/** Create a new builder */
291
static Builder newBuilder();
292
}
293
294
/**
295
* Response with computation results
296
*/
297
class RunGraphResponse {
298
/** Get output tensors */
299
List<NamedTensorProto> getRecvList();
300
301
/** Get step execution statistics */
302
StepStats getStepStats();
303
304
/** Get cost graph */
305
CostGraphDef getCostGraph();
306
307
/** Get partition graphs executed */
308
List<GraphDef> getPartitionGraphList();
309
}
310
```
311
312
**Usage Examples:**
313
314
```java
315
import org.tensorflow.distruntime.*;
316
317
// Register a graph partition on worker
318
RegisterGraphRequest registerRequest = RegisterGraphRequest.newBuilder()
319
.setSessionHandle(sessionHandle)
320
.setGraphDef(partitionedGraph)
321
.setHasControlDependencies(true)
322
.setGraphOptions(GraphOptions.newBuilder()
323
.setEnableRecvScheduling(true)
324
.build())
325
.build();
326
327
// Execute the registered graph
328
RunGraphRequest runRequest = RunGraphRequest.newBuilder()
329
.setSessionHandle(sessionHandle)
330
.setGraphHandle(graphHandle)
331
.setStepId(currentStepId)
332
.addSend(NamedTensorProto.newBuilder()
333
.setName("input_partition:0")
334
.setTensor(inputTensor)
335
.build())
336
.addRecvKey("output_partition:0")
337
.setIsPartial(false)
338
.build();
339
```
340
341
### Eager Service Messages
342
343
Services for TensorFlow's eager execution mode, allowing operations to be executed immediately.
344
345
#### CreateContextRequest/Response
346
347
Creates an eager execution context.
348
349
```java { .api }
350
/**
351
* Request to create eager execution context
352
*/
353
class CreateContextRequest {
354
/** Get server definition */
355
ServerDef getServerDef();
356
357
/** Check if async execution is enabled */
358
boolean getAsync();
359
360
/** Get keep alive interval in seconds */
361
int getKeepAliveSecs();
362
363
/** Get version compatibility requirements */
364
VersionDef getVersionDef();
365
366
/** Get cluster device filters */
367
ClusterDeviceFilters getClusterDeviceFilters();
368
369
/** Create a new builder */
370
static Builder newBuilder();
371
}
372
373
/**
374
* Response with context information
375
*/
376
class CreateContextResponse {
377
/** Get context ID */
378
long getContextId();
379
380
/** Get context view ID */
381
long getContextViewId();
382
383
/** Get device attributes */
384
List<DeviceAttributes> getDeviceAttributesList();
385
}
386
```
387
388
#### EnqueueRequest/Response
389
390
Enqueues operations for eager execution.
391
392
```java { .api }
393
/**
394
* Request to enqueue eager operations
395
*/
396
class EnqueueRequest {
397
/** Get context ID */
398
long getContextId();
399
400
/** Get list of operations to execute */
401
List<Operation> getQueueList();
402
403
/** Operation definition for eager execution */
404
static class Operation {
405
/** Get operation ID */
406
long getId();
407
408
/** Get operation name */
409
String getName();
410
411
/** Get operation attributes */
412
Map<String, AttrValue> getAttrsMap();
413
414
/** Get input handles */
415
List<RemoteTensorHandle> getInputsList();
416
417
/** Get control input operation IDs */
418
List<Long> getControlOpIdsList();
419
420
/** Get device name */
421
String getDevice();
422
423
/** Check if operation is a function */
424
boolean getIsFunction();
425
}
426
}
427
428
/**
429
* Response with operation results
430
*/
431
class EnqueueResponse {
432
/** Get list of operation results */
433
List<QueueResponse> getQueueResponseList();
434
435
/** Response for individual operations */
436
static class QueueResponse {
437
/** Get output tensor handles */
438
List<TensorHandle> getTensorList();
439
440
/** Get output shapes */
441
List<TensorShapeProto> getShapeList();
442
}
443
}
444
```
445
446
### Common Distributed Types
447
448
#### RunOptions
449
450
Options for controlling step execution behavior.
451
452
```java { .api }
453
/**
454
* Options for controlling step execution behavior
455
*/
456
class RunOptions {
457
/** Get trace level for profiling */
458
TraceLevel getTraceLevel();
459
460
/** Get timeout in milliseconds */
461
long getTimeoutInMs();
462
463
/** Get inter-op thread pool setting */
464
int getInterOpThreadPool();
465
466
/** Check if output partition graphs is enabled */
467
boolean getOutputPartitionGraphs();
468
469
/** Get debug options */
470
DebugOptions getDebugOptions();
471
472
/** Check if report tensor allocations during execution */
473
boolean getReportTensorAllocationsUponOom();
474
475
/** Get experimental options */
476
Experimental getExperimental();
477
478
static Builder newBuilder();
479
480
static class Builder {
481
Builder setTraceLevel(TraceLevel level);
482
Builder setTimeoutInMs(long timeout);
483
Builder setInterOpThreadPool(int pool);
484
Builder setOutputPartitionGraphs(boolean output);
485
Builder setDebugOptions(DebugOptions options);
486
RunOptions build();
487
}
488
489
enum TraceLevel {
490
NO_TRACE,
491
SOFTWARE_TRACE,
492
HARDWARE_TRACE,
493
FULL_TRACE
494
}
495
496
static class Experimental {
497
int getCollectiveGraphKey();
498
boolean getUseRunHandler();
499
}
500
}
501
```
502
503
#### RunMetadata
504
505
Metadata returned from step execution.
506
507
```java { .api }
508
/**
509
* Metadata returned from step execution
510
*/
511
class RunMetadata {
512
/** Get step execution statistics */
513
StepStats getStepStats();
514
515
/** Get cost graph information */
516
CostGraphDef getCostGraph();
517
518
/** Get partition graphs that were executed */
519
List<GraphDef> getPartitionGraphsList();
520
521
/** Get function graphs that were executed */
522
List<GraphDef> getFunctionGraphsList();
523
524
static Builder newBuilder();
525
}
526
```
527
528
#### CostGraphDef
529
530
Cost model information for operations.
531
532
```java { .api }
533
/**
534
* Cost model information for operations
535
*/
536
class CostGraphDef {
537
/** Get cost information for each node */
538
List<Node> getNodeList();
539
540
/** Cost information for a single node */
541
static class Node {
542
/** Get node name */
543
String getName();
544
545
/** Get device name */
546
String getDevice();
547
548
/** Get node ID */
549
int getId();
550
551
/** Get input information */
552
List<InputInfo> getInputInfoList();
553
554
/** Get output information */
555
List<OutputInfo> getOutputInfoList();
556
557
/** Get temporary memory used */
558
long getTempMemorySize();
559
560
/** Get persistent memory used */
561
long getPersistentMemorySize();
562
563
/** Get compute cost */
564
long getComputeCost();
565
566
/** Get compute time */
567
long getComputeTime();
568
569
/** Get memory time */
570
long getMemoryTime();
571
572
/** Check if this is the final node */
573
boolean getIsFinal();
574
575
/** Get control input nodes */
576
List<Integer> getControlInputList();
577
578
/** Check if inaccurate */
579
boolean getInaccurate();
580
}
581
582
/** Input information for cost calculation */
583
static class InputInfo {
584
int getPrecedingNode();
585
int getPrecedingPort();
586
}
587
588
/** Output information for cost calculation */
589
static class OutputInfo {
590
long getSize();
591
long getAliasInputPort();
592
TensorShapeProto getShape();
593
DataType getDtype();
594
}
595
}
596
```
597
598
#### ClusterDef
599
600
Defines the cluster topology and job configurations.
601
602
```java { .api }
603
/**
604
* Cluster topology definition
605
*/
606
class ClusterDef {
607
/** Get job definitions */
608
Map<String, JobDef> getJobMap();
609
610
/** Job definition within cluster */
611
static class JobDef {
612
/** Get job name */
613
String getName();
614
615
/** Get task index to address mapping */
616
Map<Integer, String> getTasksMap();
617
}
618
}
619
```
620
621
#### ServerDef
622
623
Defines server configuration for distributed execution.
624
625
```java { .api }
626
/**
627
* Server configuration for distributed execution
628
*/
629
class ServerDef {
630
/** Get cluster definition */
631
ClusterDef getCluster();
632
633
/** Get job name for this server */
634
String getJobName();
635
636
/** Get task index for this server */
637
int getTaskIndex();
638
639
/** Get default session configuration */
640
ConfigProto getDefaultSessionConfig();
641
642
/** Get server protocol (e.g., "grpc") */
643
String getProtocol();
644
645
/** Get server port */
646
int getPort();
647
}
648
```
649
650
**Usage Examples:**
651
652
```java
653
import org.tensorflow.distruntime.*;
654
655
// Define a cluster with chief and workers
656
ClusterDef cluster = ClusterDef.newBuilder()
657
.putJob("chief", JobDef.newBuilder()
658
.setName("chief")
659
.putTasks(0, "chief:2222")
660
.build())
661
.putJob("worker", JobDef.newBuilder()
662
.setName("worker")
663
.putTasks(0, "worker0:2222")
664
.putTasks(1, "worker1:2222")
665
.putTasks(2, "worker2:2222")
666
.build())
667
.putJob("ps", JobDef.newBuilder()
668
.setName("ps")
669
.putTasks(0, "ps0:2222")
670
.putTasks(1, "ps1:2222")
671
.build())
672
.build();
673
674
// Configure server as worker
675
ServerDef serverDef = ServerDef.newBuilder()
676
.setCluster(cluster)
677
.setJobName("worker")
678
.setTaskIndex(0)
679
.setProtocol("grpc")
680
.setPort(2222)
681
.setDefaultSessionConfig(ConfigProto.newBuilder()
682
.setAllowSoftPlacement(true)
683
.build())
684
.build();
685
```