0
# Distributed Computing
1
2
XGBoost provides comprehensive distributed computing capabilities for training and prediction across multiple workers and computing environments. This includes integration with Dask, Spark, and collective communication protocols.
3
4
## Capabilities
5
6
### Collective Communication
7
8
Low-level distributed communication primitives for custom distributed setups.
9
10
```python { .api }
11
import xgboost.collective as collective
12
13
def init(config=None):
14
"""
15
Initialize collective communication.
16
17
Parameters:
18
- config: Configuration dictionary or Config object
19
20
Returns:
21
None
22
"""
23
24
def finalize():
25
"""Finalize collective communication."""
26
27
def get_rank():
28
"""
29
Get rank of current process.
30
31
Returns:
32
int: Process rank
33
"""
34
35
def get_world_size():
36
"""
37
Get total number of processes.
38
39
Returns:
40
int: World size
41
"""
42
43
def is_distributed():
44
"""
45
Check if running in distributed mode.
46
47
Returns:
48
bool: True if distributed
49
"""
50
51
def broadcast(data, root):
52
"""
53
Broadcast data from root to all processes.
54
55
Parameters:
56
- data: Data to broadcast
57
- root: Root process rank
58
59
Returns:
60
Broadcasted data
61
"""
62
63
def allreduce(data, op):
64
"""
65
All-reduce operation across processes.
66
67
Parameters:
68
- data: Data to reduce
69
- op: Reduction operation (Op.SUM, Op.MAX, Op.MIN, Op.BITWISE_OR)
70
71
Returns:
72
Reduced data
73
"""
74
75
def communicator_print(message):
76
"""
77
Distributed-aware printing function.
78
79
Parameters:
80
- message: Message to print
81
82
Returns:
83
None
84
"""
85
86
def get_processor_name():
87
"""
88
Get processor name for current process.
89
90
Returns:
91
str: Processor name
92
"""
93
94
def signal_error(error_message):
95
"""
96
Signal error across all processes.
97
98
Parameters:
99
- error_message: Error message to signal
100
101
Returns:
102
None
103
"""
104
105
class Op:
106
"""Enumeration for collective operation types."""
107
MAX = "max"
108
MIN = "min"
109
SUM = "sum"
110
BITWISE_OR = "bitwise_or"
111
112
class Config:
113
"""Configuration for collective communication."""
114
def __init__(
115
self,
116
rank=0,
117
world_size=1,
118
tracker_uri=None,
119
user_name=None,
120
timeout=300,
121
retry=3
122
):
123
"""
124
Initialize collective config.
125
126
Parameters:
127
- rank: Process rank
128
- world_size: Total number of processes
129
- tracker_uri: Tracker URI for coordination
130
- user_name: User name for identification
131
- timeout: Communication timeout in seconds
132
- retry: Number of retries for failed operations
133
"""
134
135
class CommunicatorContext:
136
"""Context manager for collective operations."""
137
def __init__(self, **kwargs):
138
"""
139
Initialize communicator context.
140
141
Parameters:
142
**kwargs: Configuration parameters
143
"""
144
145
def __enter__(self):
146
"""Enter context and initialize communication."""
147
return self
148
149
def __exit__(self, exc_type, exc_val, exc_tb):
150
"""Exit context and finalize communication."""
151
```
152
153
### Dask Integration
154
155
Distributed training and prediction using Dask for Python-native scaling.
156
157
```python { .api }
158
from xgboost.dask import (
159
DaskDMatrix,
160
DaskQuantileDMatrix,
161
DaskXGBRegressor,
162
DaskXGBClassifier,
163
DaskXGBRanker,
164
train,
165
predict
166
)
167
168
class DaskDMatrix:
169
def __init__(
170
self,
171
client,
172
data,
173
label=None,
174
weight=None,
175
base_margin=None,
176
missing=None,
177
silent=False,
178
feature_names=None,
179
feature_types=None,
180
group=None,
181
qid=None,
182
label_lower_bound=None,
183
label_upper_bound=None,
184
feature_weights=None,
185
enable_categorical=False
186
):
187
"""
188
Dask-compatible DMatrix.
189
190
Parameters:
191
- client: Dask client
192
- data: Dask array or DataFrame
193
- (other parameters same as DMatrix)
194
"""
195
196
class DaskXGBRegressor:
197
def __init__(
198
self,
199
n_estimators=100,
200
max_depth=None,
201
learning_rate=None,
202
verbosity=None,
203
objective=None,
204
booster=None,
205
tree_method=None,
206
n_jobs=None,
207
**kwargs
208
):
209
"""
210
Dask XGBoost regressor.
211
212
Parameters same as XGBRegressor.
213
"""
214
215
def fit(
216
self,
217
X,
218
y,
219
sample_weight=None,
220
base_margin=None,
221
eval_set=None,
222
sample_weight_eval_set=None,
223
base_margin_eval_set=None,
224
eval_metric=None,
225
early_stopping_rounds=None,
226
verbose=True,
227
xgb_model=None,
228
feature_weights=None,
229
callbacks=None
230
):
231
"""Fit Dask XGBoost model."""
232
233
def predict(
234
self,
235
X,
236
output_margin=False,
237
validate_features=True,
238
base_margin=None,
239
iteration_range=None
240
):
241
"""Predict using Dask XGBoost model."""
242
243
class DaskXGBClassifier:
244
"""Dask XGBoost classifier with same interface as DaskXGBRegressor."""
245
246
def predict_proba(self, X, **kwargs):
247
"""Predict class probabilities."""
248
249
def train(
250
client,
251
params,
252
dtrain,
253
num_boost_round=10,
254
evals=None,
255
obj=None,
256
maximize=None,
257
early_stopping_rounds=None,
258
evals_result=None,
259
verbose_eval=True,
260
xgb_model=None,
261
callbacks=None,
262
custom_metric=None
263
):
264
"""
265
Distributed training with Dask.
266
267
Parameters:
268
- client: Dask client
269
- (other parameters same as xgb.train)
270
271
Returns:
272
Trained Booster
273
"""
274
275
def predict(client, model, data, **kwargs):
276
"""
277
Distributed prediction with Dask.
278
279
Parameters:
280
- client: Dask client
281
- model: Trained model (Booster or sklearn estimator)
282
- data: Input data (DaskDMatrix or Dask array/DataFrame)
283
284
Returns:
285
Predictions as Dask array
286
"""
287
```
288
289
### Spark Integration
290
291
Integration with Apache Spark for JVM-based distributed computing.
292
293
```python { .api }
294
from xgboost.spark import (
295
SparkXGBRegressor,
296
SparkXGBClassifier,
297
SparkXGBRanker,
298
SparkXGBRegressorModel,
299
SparkXGBClassifierModel,
300
SparkXGBRankerModel
301
)
302
303
class SparkXGBRegressor:
304
def __init__(
305
self,
306
max_depth=6,
307
learning_rate=0.3,
308
n_estimators=100,
309
verbosity=1,
310
objective='reg:squarederror',
311
booster='gbtree',
312
tree_method='auto',
313
n_jobs=1,
314
gamma=0,
315
min_child_weight=1,
316
max_delta_step=0,
317
subsample=1,
318
colsample_bytree=1,
319
colsample_bylevel=1,
320
colsample_bynode=1,
321
reg_alpha=0,
322
reg_lambda=1,
323
scale_pos_weight=1,
324
base_score=0.5,
325
random_state=0,
326
missing=float('nan'),
327
num_workers=1,
328
use_gpu=False,
329
**kwargs
330
):
331
"""
332
Spark XGBoost regressor.
333
334
Parameters:
335
- num_workers: Number of Spark workers
336
- use_gpu: Whether to use GPU training
337
- (other parameters same as XGBRegressor)
338
"""
339
340
def fit(self, dataset):
341
"""
342
Fit model on Spark DataFrame.
343
344
Parameters:
345
- dataset: Spark DataFrame with features and labels
346
347
Returns:
348
SparkXGBRegressorModel
349
"""
350
351
class SparkXGBClassifier:
352
"""Spark XGBoost classifier with same interface as SparkXGBRegressor."""
353
354
def fit(self, dataset):
355
"""Fit classifier on Spark DataFrame."""
356
357
class SparkXGBRegressorModel:
358
def transform(self, dataset):
359
"""
360
Transform Spark DataFrame with predictions.
361
362
Parameters:
363
- dataset: Input Spark DataFrame
364
365
Returns:
366
Spark DataFrame with predictions
367
"""
368
369
class SparkXGBClassifierModel:
370
"""Spark XGBoost classifier model with same interface as regressor model."""
371
372
def transform(self, dataset):
373
"""Transform with class predictions and probabilities."""
374
```
375
376
### Distributed Training Utilities
377
378
```python { .api }
379
class RabitTracker:
380
def __init__(
381
self,
382
hostIP=None,
383
nslave=None,
384
port=None,
385
port_end=None,
386
timeout=None
387
):
388
"""
389
Rabit tracker for distributed training coordination.
390
391
Parameters:
392
- hostIP: Host IP address
393
- nslave: Number of slave workers
394
- port: Starting port number
395
- port_end: Ending port number
396
- timeout: Connection timeout
397
"""
398
399
def start(self, nslave):
400
"""Start tracker with specified number of slaves."""
401
402
def stop(self):
403
"""Stop tracker."""
404
405
def get_worker_envs(self):
406
"""Get environment variables for workers."""
407
408
@property
409
def slave_env(self):
410
"""Environment variables for slave processes."""
411
```
412
413
## Usage Examples
414
415
### Dask Distributed Training
416
417
```python
418
import dask.array as da
419
from dask.distributed import Client
420
from xgboost.dask import DaskXGBRegressor
421
422
# Start Dask client
423
client = Client('scheduler-address:8786')
424
425
# Create distributed data
426
X = da.random.random((10000, 10), chunks=(1000, 10))
427
y = da.random.random(10000, chunks=1000)
428
429
# Train distributed model
430
model = DaskXGBRegressor(n_estimators=100, max_depth=3)
431
model.fit(X, y)
432
433
# Predict
434
predictions = model.predict(X)
435
```
436
437
### Spark Training
438
439
```python
440
from pyspark.sql import SparkSession
441
from xgboost.spark import SparkXGBRegressor
442
443
# Initialize Spark
444
spark = SparkSession.builder.appName("XGBoost").getOrCreate()
445
446
# Load data as Spark DataFrame
447
df = spark.read.format("libsvm").load("data.txt")
448
449
# Train model
450
regressor = SparkXGBRegressor(
451
num_workers=4,
452
max_depth=3,
453
n_estimators=100
454
)
455
456
model = regressor.fit(df)
457
458
# Make predictions
459
predictions = model.transform(df)
460
```
461
462
### Collective Communication
463
464
```python
465
import xgboost.collective as collective
466
467
# Initialize collective communication
468
collective.init()
469
470
# Get process information
471
rank = collective.get_rank()
472
world_size = collective.get_world_size()
473
474
# Broadcast data from rank 0
475
if rank == 0:
476
data = [1, 2, 3, 4, 5]
477
else:
478
data = None
479
480
data = collective.broadcast(data, root=0)
481
482
# All-reduce sum across processes
483
local_sum = sum(range(rank * 10, (rank + 1) * 10))
484
global_sum = collective.allreduce(local_sum, collective.Op.SUM)
485
486
# Finalize
487
collective.finalize()
488
```