or run

npx @tessl/cli init
Log in

Version

Tile

Overview

Evals

Files

Files

docs

advanced-features.mddevices-distributed.mdindex.mdmathematical-functions.mdneural-networks.mdtensor-operations.mdtraining.md

neural-networks.mddocs/

0

# Neural Networks

1

2

Complete neural network building blocks including layers, activation functions, loss functions, and containers for building deep learning models. The torch.nn module provides high-level abstractions for neural network construction.

3

4

## Capabilities

5

6

### Base Classes

7

8

Core classes that form the foundation of all neural network components.

9

10

```python { .api }

11

class Module:

12

"""Base class for all neural network modules."""

13

def __init__(self): ...

14

def forward(self, *input):

15

"""Define forward computation."""

16

def parameters(self, recurse=True):

17

"""Return iterator over module parameters."""

18

def named_parameters(self, prefix='', recurse=True):

19

"""Return iterator over (name, parameter) pairs."""

20

def modules(self):

21

"""Return iterator over all modules."""

22

def named_modules(self, memo=None, prefix=''):

23

"""Return iterator over (name, module) pairs."""

24

def children(self):

25

"""Return iterator over immediate children modules."""

26

def named_children(self):

27

"""Return iterator over (name, child) pairs."""

28

def train(self, mode=True):

29

"""Set module in training mode."""

30

def eval(self):

31

"""Set module in evaluation mode."""

32

def zero_grad(self, set_to_none=False):

33

"""Set gradients to zero."""

34

def to(self, *args, **kwargs):

35

"""Move module to device/dtype."""

36

def cuda(self, device=None):

37

"""Move module to CUDA device."""

38

def cpu(self):

39

"""Move module to CPU."""

40

def state_dict(self, destination=None, prefix='', keep_vars=False):

41

"""Return dictionary of module state."""

42

def load_state_dict(self, state_dict, strict=True):

43

"""Load parameters and buffers."""

44

45

class Parameter(Tensor):

46

"""Trainable parameter tensor."""

47

def __init__(self, data=None, requires_grad=True): ...

48

49

class UninitializedParameter(Parameter):

50

"""Parameter that is not yet initialized."""

51

def __init__(self, requires_grad=True): ...

52

```

53

54

### Linear Layers

55

56

Dense layers that perform linear transformations.

57

58

```python { .api }

59

class Linear(Module):

60

"""Linear transformation: y = xA^T + b."""

61

def __init__(self, in_features: int, out_features: int, bias: bool = True, device=None, dtype=None): ...

62

def forward(self, input: Tensor) -> Tensor: ...

63

64

class Bilinear(Module):

65

"""Bilinear transformation: y = x1^T A x2 + b."""

66

def __init__(self, in1_features: int, in2_features: int, out_features: int, bias: bool = True): ...

67

def forward(self, input1: Tensor, input2: Tensor) -> Tensor: ...

68

69

class LazyLinear(Module):

70

"""Linear layer with lazy weight initialization."""

71

def __init__(self, out_features: int, bias: bool = True): ...

72

def forward(self, input: Tensor) -> Tensor: ...

73

74

class Identity(Module):

75

"""Identity transformation."""

76

def __init__(self, *args, **kwargs): ...

77

def forward(self, input: Tensor) -> Tensor: ...

78

```

79

80

### Convolution Layers

81

82

Convolutional layers for spatial feature extraction.

83

84

```python { .api }

85

class Conv1d(Module):

86

"""1D convolution layer."""

87

def __init__(self, in_channels: int, out_channels: int, kernel_size, stride=1,

88

padding=0, dilation=1, groups=1, bias=True, padding_mode='zeros'): ...

89

def forward(self, input: Tensor) -> Tensor: ...

90

91

class Conv2d(Module):

92

"""2D convolution layer."""

93

def __init__(self, in_channels: int, out_channels: int, kernel_size, stride=1,

94

padding=0, dilation=1, groups=1, bias=True, padding_mode='zeros'): ...

95

def forward(self, input: Tensor) -> Tensor: ...

96

97

class Conv3d(Module):

98

"""3D convolution layer."""

99

def __init__(self, in_channels: int, out_channels: int, kernel_size, stride=1,

100

padding=0, dilation=1, groups=1, bias=True, padding_mode='zeros'): ...

101

def forward(self, input: Tensor) -> Tensor: ...

102

103

class ConvTranspose1d(Module):

104

"""1D transposed convolution layer."""

105

def __init__(self, in_channels: int, out_channels: int, kernel_size, stride=1,

106

padding=0, output_padding=0, groups=1, bias=True, dilation=1, padding_mode='zeros'): ...

107

def forward(self, input: Tensor, output_size=None) -> Tensor: ...

108

109

class ConvTranspose2d(Module):

110

"""2D transposed convolution layer."""

111

def __init__(self, in_channels: int, out_channels: int, kernel_size, stride=1,

112

padding=0, output_padding=0, groups=1, bias=True, dilation=1, padding_mode='zeros'): ...

113

def forward(self, input: Tensor, output_size=None) -> Tensor: ...

114

115

class ConvTranspose3d(Module):

116

"""3D transposed convolution layer."""

117

def __init__(self, in_channels: int, out_channels: int, kernel_size, stride=1,

118

padding=0, output_padding=0, groups=1, bias=True, dilation=1, padding_mode='zeros'): ...

119

def forward(self, input: Tensor, output_size=None) -> Tensor: ...

120

```

121

122

### Activation Functions

123

124

Non-linear activation functions for introducing non-linearity.

125

126

```python { .api }

127

class ReLU(Module):

128

"""Rectified Linear Unit: max(0, x)."""

129

def __init__(self, inplace: bool = False): ...

130

def forward(self, input: Tensor) -> Tensor: ...

131

132

class ReLU6(Module):

133

"""ReLU clamped to maximum value of 6."""

134

def __init__(self, inplace: bool = False): ...

135

def forward(self, input: Tensor) -> Tensor: ...

136

137

class LeakyReLU(Module):

138

"""Leaky ReLU: max(negative_slope * x, x)."""

139

def __init__(self, negative_slope: float = 0.01, inplace: bool = False): ...

140

def forward(self, input: Tensor) -> Tensor: ...

141

142

class PReLU(Module):

143

"""Parametric ReLU with learnable negative slope."""

144

def __init__(self, num_parameters: int = 1, init: float = 0.25): ...

145

def forward(self, input: Tensor) -> Tensor: ...

146

147

class ELU(Module):

148

"""Exponential Linear Unit."""

149

def __init__(self, alpha: float = 1.0, inplace: bool = False): ...

150

def forward(self, input: Tensor) -> Tensor: ...

151

152

class SELU(Module):

153

"""Scaled Exponential Linear Unit."""

154

def __init__(self, inplace: bool = False): ...

155

def forward(self, input: Tensor) -> Tensor: ...

156

157

class GELU(Module):

158

"""Gaussian Error Linear Unit."""

159

def __init__(self, approximate: str = 'none'): ...

160

def forward(self, input: Tensor) -> Tensor: ...

161

162

class SiLU(Module):

163

"""Sigmoid Linear Unit (Swish): x * sigmoid(x)."""

164

def __init__(self, inplace: bool = False): ...

165

def forward(self, input: Tensor) -> Tensor: ...

166

167

class Mish(Module):

168

"""Mish activation: x * tanh(softplus(x))."""

169

def __init__(self, inplace: bool = False): ...

170

def forward(self, input: Tensor) -> Tensor: ...

171

172

class Sigmoid(Module):

173

"""Sigmoid activation: 1 / (1 + exp(-x))."""

174

def __init__(self): ...

175

def forward(self, input: Tensor) -> Tensor: ...

176

177

class Tanh(Module):

178

"""Hyperbolic tangent activation."""

179

def __init__(self): ...

180

def forward(self, input: Tensor) -> Tensor: ...

181

182

class Softmax(Module):

183

"""Softmax activation along specified dimension."""

184

def __init__(self, dim=None): ...

185

def forward(self, input: Tensor) -> Tensor: ...

186

187

class LogSoftmax(Module):

188

"""Log-Softmax activation."""

189

def __init__(self, dim=None): ...

190

def forward(self, input: Tensor) -> Tensor: ...

191

```

192

193

### Normalization Layers

194

195

Normalization techniques for training stability and performance.

196

197

```python { .api }

198

class BatchNorm1d(Module):

199

"""Batch normalization for 2D or 3D inputs."""

200

def __init__(self, num_features, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True): ...

201

def forward(self, input: Tensor) -> Tensor: ...

202

203

class BatchNorm2d(Module):

204

"""Batch normalization for 4D inputs."""

205

def __init__(self, num_features, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True): ...

206

def forward(self, input: Tensor) -> Tensor: ...

207

208

class BatchNorm3d(Module):

209

"""Batch normalization for 5D inputs."""

210

def __init__(self, num_features, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True): ...

211

def forward(self, input: Tensor) -> Tensor: ...

212

213

class LayerNorm(Module):

214

"""Layer normalization."""

215

def __init__(self, normalized_shape, eps=1e-05, elementwise_affine=True, bias=True): ...

216

def forward(self, input: Tensor) -> Tensor: ...

217

218

class GroupNorm(Module):

219

"""Group normalization."""

220

def __init__(self, num_groups: int, num_channels: int, eps=1e-05, affine=True): ...

221

def forward(self, input: Tensor) -> Tensor: ...

222

223

class InstanceNorm1d(Module):

224

"""Instance normalization for 3D inputs."""

225

def __init__(self, num_features, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False): ...

226

def forward(self, input: Tensor) -> Tensor: ...

227

228

class InstanceNorm2d(Module):

229

"""Instance normalization for 4D inputs."""

230

def __init__(self, num_features, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False): ...

231

def forward(self, input: Tensor) -> Tensor: ...

232

233

class InstanceNorm3d(Module):

234

"""Instance normalization for 5D inputs."""

235

def __init__(self, num_features, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False): ...

236

def forward(self, input: Tensor) -> Tensor: ...

237

```

238

239

### Pooling Layers

240

241

Pooling operations for spatial dimension reduction.

242

243

```python { .api }

244

class MaxPool1d(Module):

245

"""1D max pooling."""

246

def __init__(self, kernel_size, stride=None, padding=0, dilation=1, return_indices=False, ceil_mode=False): ...

247

def forward(self, input: Tensor) -> Tensor: ...

248

249

class MaxPool2d(Module):

250

"""2D max pooling."""

251

def __init__(self, kernel_size, stride=None, padding=0, dilation=1, return_indices=False, ceil_mode=False): ...

252

def forward(self, input: Tensor) -> Tensor: ...

253

254

class MaxPool3d(Module):

255

"""3D max pooling."""

256

def __init__(self, kernel_size, stride=None, padding=0, dilation=1, return_indices=False, ceil_mode=False): ...

257

def forward(self, input: Tensor) -> Tensor: ...

258

259

class AvgPool1d(Module):

260

"""1D average pooling."""

261

def __init__(self, kernel_size, stride=None, padding=0, ceil_mode=False, count_include_pad=True): ...

262

def forward(self, input: Tensor) -> Tensor: ...

263

264

class AvgPool2d(Module):

265

"""2D average pooling."""

266

def __init__(self, kernel_size, stride=None, padding=0, ceil_mode=False, count_include_pad=True, divisor_override=None): ...

267

def forward(self, input: Tensor) -> Tensor: ...

268

269

class AvgPool3d(Module):

270

"""3D average pooling."""

271

def __init__(self, kernel_size, stride=None, padding=0, ceil_mode=False, count_include_pad=True, divisor_override=None): ...

272

def forward(self, input: Tensor) -> Tensor: ...

273

274

class AdaptiveMaxPool1d(Module):

275

"""1D adaptive max pooling."""

276

def __init__(self, output_size, return_indices=False): ...

277

def forward(self, input: Tensor) -> Tensor: ...

278

279

class AdaptiveMaxPool2d(Module):

280

"""2D adaptive max pooling."""

281

def __init__(self, output_size, return_indices=False): ...

282

def forward(self, input: Tensor) -> Tensor: ...

283

284

class AdaptiveAvgPool1d(Module):

285

"""1D adaptive average pooling."""

286

def __init__(self, output_size): ...

287

def forward(self, input: Tensor) -> Tensor: ...

288

289

class AdaptiveAvgPool2d(Module):

290

"""2D adaptive average pooling."""

291

def __init__(self, output_size): ...

292

def forward(self, input: Tensor) -> Tensor: ...

293

```

294

295

### Loss Functions

296

297

Loss functions for training neural networks.

298

299

```python { .api }

300

class MSELoss(Module):

301

"""Mean Squared Error loss."""

302

def __init__(self, size_average=None, reduce=None, reduction='mean'): ...

303

def forward(self, input: Tensor, target: Tensor) -> Tensor: ...

304

305

class L1Loss(Module):

306

"""Mean Absolute Error loss."""

307

def __init__(self, size_average=None, reduce=None, reduction='mean'): ...

308

def forward(self, input: Tensor, target: Tensor) -> Tensor: ...

309

310

class CrossEntropyLoss(Module):

311

"""Cross entropy loss for classification."""

312

def __init__(self, weight=None, size_average=None, ignore_index=-100, reduce=None, reduction='mean', label_smoothing=0.0): ...

313

def forward(self, input: Tensor, target: Tensor) -> Tensor: ...

314

315

class NLLLoss(Module):

316

"""Negative log likelihood loss."""

317

def __init__(self, weight=None, size_average=None, ignore_index=-100, reduce=None, reduction='mean'): ...

318

def forward(self, input: Tensor, target: Tensor) -> Tensor: ...

319

320

class BCELoss(Module):

321

"""Binary cross entropy loss."""

322

def __init__(self, weight=None, size_average=None, reduce=None, reduction='mean'): ...

323

def forward(self, input: Tensor, target: Tensor) -> Tensor: ...

324

325

class BCEWithLogitsLoss(Module):

326

"""Binary cross entropy with sigmoid."""

327

def __init__(self, weight=None, size_average=None, reduce=None, reduction='mean', pos_weight=None): ...

328

def forward(self, input: Tensor, target: Tensor) -> Tensor: ...

329

330

class KLDivLoss(Module):

331

"""Kullback-Leibler divergence loss."""

332

def __init__(self, size_average=None, reduce=None, reduction='mean', log_target=False): ...

333

def forward(self, input: Tensor, target: Tensor) -> Tensor: ...

334

335

class SmoothL1Loss(Module):

336

"""Smooth L1 loss (Huber loss)."""

337

def __init__(self, size_average=None, reduce=None, reduction='mean', beta=1.0): ...

338

def forward(self, input: Tensor, target: Tensor) -> Tensor: ...

339

340

class HuberLoss(Module):

341

"""Huber loss."""

342

def __init__(self, reduction='mean', delta=1.0): ...

343

def forward(self, input: Tensor, target: Tensor) -> Tensor: ...

344

```

345

346

### Recurrent Neural Networks

347

348

RNN, LSTM, and GRU layers for sequential data processing.

349

350

```python { .api }

351

class RNN(Module):

352

"""Multi-layer RNN with tanh or ReLU non-linearity."""

353

def __init__(self, input_size, hidden_size, num_layers=1, nonlinearity='tanh', bias=True,

354

batch_first=False, dropout=0.0, bidirectional=False): ...

355

def forward(self, input, h_0=None) -> Tuple[Tensor, Tensor]: ...

356

357

class LSTM(Module):

358

"""Multi-layer Long Short-Term Memory network."""

359

def __init__(self, input_size, hidden_size, num_layers=1, bias=True, batch_first=False,

360

dropout=0.0, bidirectional=False, proj_size=0): ...

361

def forward(self, input, hx=None) -> Tuple[Tensor, Tuple[Tensor, Tensor]]: ...

362

363

class GRU(Module):

364

"""Multi-layer Gated Recurrent Unit network."""

365

def __init__(self, input_size, hidden_size, num_layers=1, bias=True, batch_first=False,

366

dropout=0.0, bidirectional=False): ...

367

def forward(self, input, h_0=None) -> Tuple[Tensor, Tensor]: ...

368

369

class RNNCell(Module):

370

"""RNN cell."""

371

def __init__(self, input_size, hidden_size, bias=True, nonlinearity='tanh'): ...

372

def forward(self, input, hidden) -> Tensor: ...

373

374

class LSTMCell(Module):

375

"""LSTM cell."""

376

def __init__(self, input_size, hidden_size, bias=True): ...

377

def forward(self, input, hx=None) -> Tuple[Tensor, Tensor]: ...

378

379

class GRUCell(Module):

380

"""GRU cell."""

381

def __init__(self, input_size, hidden_size, bias=True): ...

382

def forward(self, input, hidden) -> Tensor: ...

383

```

384

385

### Transformer Components

386

387

Transformer architecture components for attention-based models.

388

389

```python { .api }

390

class Transformer(Module):

391

"""Complete transformer model."""

392

def __init__(self, d_model=512, nhead=8, num_encoder_layers=6, num_decoder_layers=6,

393

dim_feedforward=2048, dropout=0.1, activation='relu', custom_encoder=None, custom_decoder=None): ...

394

def forward(self, src, tgt, src_mask=None, tgt_mask=None, memory_mask=None) -> Tensor: ...

395

396

class TransformerEncoder(Module):

397

"""Stack of transformer encoder layers."""

398

def __init__(self, encoder_layer, num_layers, norm=None): ...

399

def forward(self, src, mask=None, src_key_padding_mask=None) -> Tensor: ...

400

401

class TransformerEncoderLayer(Module):

402

"""Single transformer encoder layer."""

403

def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1, activation='relu', batch_first=False): ...

404

def forward(self, src, src_mask=None, src_key_padding_mask=None) -> Tensor: ...

405

406

class TransformerDecoder(Module):

407

"""Stack of transformer decoder layers."""

408

def __init__(self, decoder_layer, num_layers, norm=None): ...

409

def forward(self, tgt, memory, tgt_mask=None, memory_mask=None) -> Tensor: ...

410

411

class TransformerDecoderLayer(Module):

412

"""Single transformer decoder layer."""

413

def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1, activation='relu', batch_first=False): ...

414

def forward(self, tgt, memory, tgt_mask=None, memory_mask=None) -> Tensor: ...

415

416

class MultiheadAttention(Module):

417

"""Multi-head attention mechanism."""

418

def __init__(self, embed_dim, num_heads, dropout=0.0, bias=True, add_bias_kv=False, add_zero_attn=False, kdim=None, vdim=None, batch_first=False): ...

419

def forward(self, query, key, value, key_padding_mask=None, need_weights=True, attn_mask=None) -> Tuple[Tensor, Tensor]: ...

420

```

421

422

### Container Classes

423

424

Containers for organizing and combining multiple modules.

425

426

```python { .api }

427

class Sequential(Module):

428

"""Sequential container of modules."""

429

def __init__(self, *args): ...

430

def forward(self, input): ...

431

432

class ModuleList(Module):

433

"""List container for modules."""

434

def __init__(self, modules=None): ...

435

def append(self, module): ...

436

def extend(self, modules): ...

437

def insert(self, index, module): ...

438

439

class ModuleDict(Module):

440

"""Dictionary container for modules."""

441

def __init__(self, modules=None): ...

442

def __getitem__(self, key): ...

443

def __setitem__(self, key, module): ...

444

def keys(self): ...

445

def items(self): ...

446

def values(self): ...

447

448

class ParameterList(Module):

449

"""List container for parameters."""

450

def __init__(self, parameters=None): ...

451

def append(self, parameter): ...

452

def extend(self, parameters): ...

453

454

class ParameterDict(Module):

455

"""Dictionary container for parameters."""

456

def __init__(self, parameters=None): ...

457

def __getitem__(self, key): ...

458

def __setitem__(self, key, parameter): ...

459

```

460

461

### Dropout and Regularization

462

463

Regularization techniques to prevent overfitting.

464

465

```python { .api }

466

class Dropout(Module):

467

"""Randomly zeros elements with probability p."""

468

def __init__(self, p=0.5, inplace=False): ...

469

def forward(self, input: Tensor) -> Tensor: ...

470

471

class Dropout1d(Module):

472

"""1D channel-wise dropout."""

473

def __init__(self, p=0.5, inplace=False): ...

474

def forward(self, input: Tensor) -> Tensor: ...

475

476

class Dropout2d(Module):

477

"""2D channel-wise dropout."""

478

def __init__(self, p=0.5, inplace=False): ...

479

def forward(self, input: Tensor) -> Tensor: ...

480

481

class Dropout3d(Module):

482

"""3D channel-wise dropout."""

483

def __init__(self, p=0.5, inplace=False): ...

484

def forward(self, input: Tensor) -> Tensor: ...

485

486

class AlphaDropout(Module):

487

"""Alpha dropout for SELU networks."""

488

def __init__(self, p=0.5, inplace=False): ...

489

def forward(self, input: Tensor) -> Tensor: ...

490

```

491

492

### Embedding Layers

493

494

Embedding layers for discrete inputs like words or tokens.

495

496

```python { .api }

497

class Embedding(Module):

498

"""Lookup table for embeddings."""

499

def __init__(self, num_embeddings: int, embedding_dim: int, padding_idx=None, max_norm=None,

500

norm_type=2.0, scale_grad_by_freq=False, sparse=False): ...

501

def forward(self, input: Tensor) -> Tensor: ...

502

503

class EmbeddingBag(Module):

504

"""Embedding bag for variable length sequences."""

505

def __init__(self, num_embeddings: int, embedding_dim: int, max_norm=None, norm_type=2.0,

506

scale_grad_by_freq=False, mode='mean', sparse=False, include_last_offset=False): ...

507

def forward(self, input: Tensor, offsets=None, per_sample_weights=None) -> Tensor: ...

508

```

509

510

## Usage Examples

511

512

### Simple Neural Network

513

514

```python

515

import torch

516

import torch.nn as nn

517

import torch.optim as optim

518

519

class SimpleNet(nn.Module):

520

def __init__(self, input_size, hidden_size, output_size):

521

super(SimpleNet, self).__init__()

522

self.layers = nn.Sequential(

523

nn.Linear(input_size, hidden_size),

524

nn.ReLU(),

525

nn.Linear(hidden_size, hidden_size),

526

nn.ReLU(),

527

nn.Linear(hidden_size, output_size)

528

)

529

530

def forward(self, x):

531

return self.layers(x)

532

533

# Initialize model

534

model = SimpleNet(784, 128, 10)

535

criterion = nn.CrossEntropyLoss()

536

optimizer = optim.Adam(model.parameters(), lr=0.001)

537

538

# Example forward pass

539

x = torch.randn(32, 784) # Batch of 32 samples

540

y = torch.randint(0, 10, (32,)) # Labels

541

542

output = model(x)

543

loss = criterion(output, y)

544

545

# Backward pass

546

optimizer.zero_grad()

547

loss.backward()

548

optimizer.step()

549

550

print(f"Loss: {loss.item()}")

551

```

552

553

### Convolutional Neural Network

554

555

```python

556

import torch

557

import torch.nn as nn

558

559

class CNN(nn.Module):

560

def __init__(self, num_classes=10):

561

super(CNN, self).__init__()

562

self.features = nn.Sequential(

563

nn.Conv2d(3, 32, kernel_size=3, padding=1),

564

nn.ReLU(inplace=True),

565

nn.MaxPool2d(kernel_size=2),

566

nn.Conv2d(32, 64, kernel_size=3, padding=1),

567

nn.ReLU(inplace=True),

568

nn.MaxPool2d(kernel_size=2),

569

nn.Conv2d(64, 128, kernel_size=3, padding=1),

570

nn.ReLU(inplace=True),

571

nn.AdaptiveAvgPool2d((1, 1))

572

)

573

self.classifier = nn.Sequential(

574

nn.Dropout(0.5),

575

nn.Linear(128, num_classes)

576

)

577

578

def forward(self, x):

579

x = self.features(x)

580

x = torch.flatten(x, 1)

581

x = self.classifier(x)

582

return x

583

584

# Initialize model

585

model = CNN(num_classes=10)

586

587

# Example forward pass

588

x = torch.randn(8, 3, 32, 32) # Batch of images

589

output = model(x)

590

print(f"Output shape: {output.shape}")

591

```

592

593

### LSTM for Sequence Processing

594

595

```python

596

import torch

597

import torch.nn as nn

598

599

class LSTMModel(nn.Module):

600

def __init__(self, vocab_size, embed_size, hidden_size, num_layers, num_classes):

601

super(LSTMModel, self).__init__()

602

self.embedding = nn.Embedding(vocab_size, embed_size)

603

self.lstm = nn.LSTM(embed_size, hidden_size, num_layers, batch_first=True)

604

self.fc = nn.Linear(hidden_size, num_classes)

605

self.dropout = nn.Dropout(0.5)

606

607

def forward(self, x):

608

embedded = self.embedding(x)

609

lstm_out, (hidden, cell) = self.lstm(embedded)

610

# Use the last output

611

output = self.fc(self.dropout(lstm_out[:, -1, :]))

612

return output

613

614

# Initialize model

615

model = LSTMModel(vocab_size=10000, embed_size=128, hidden_size=256, num_layers=2, num_classes=5)

616

617

# Example forward pass

618

x = torch.randint(0, 10000, (16, 50)) # Batch of sequences

619

output = model(x)

620

print(f"Output shape: {output.shape}")

621

```