Skip to content

Module

visionlab.VisionTransformer

Bases: LightningModule

A custom PyTorch Lightning LightningModule for torchvision VisionTransformers

Parameters:

Name Type Description Default
optimizer str

"Adam". A valid torch.optim name.

'Adam'
lr float

1e-3

0.001
accuracy_task str

"multiclass". One of (binary, multiclass, multilabel).

'multiclass'
image_size int

32

32
num_classes int

100

100
dropout float

0.0

0.0
attention_dropout float

0.0

0.0
norm_layer Optional[Module]

None

None
conv_stem_configs Optional[List[ConvStemConfig]]

None

None
progress bool

False

False
weights bool

False

False
vit_type str

one of (b_16, b_32, l_16, l_32). Default is b_32.

'b_32'

configure_optimizers()

configures the torch.optim used in training loop

forward(x)

calls .forward of a given model flow

predict_step(batch)

returns predicted logits from the trained model

test_step(batch, *args)

runs a test step sequence

training_step(batch)

runs a training step sequence

validation_step(batch, *args)

runs a validation step sequence