Module
visionlab.VisionTransformer
Bases: LightningModule
A custom PyTorch Lightning LightningModule for torchvision VisionTransformers
Parameters:
Name | Type | Description | Default |
---|---|---|---|
optimizer
|
str
|
"Adam". A valid torch.optim name. |
'Adam'
|
lr
|
float
|
1e-3 |
0.001
|
accuracy_task
|
str
|
"multiclass". One of (binary, multiclass, multilabel). |
'multiclass'
|
image_size
|
int
|
32 |
32
|
num_classes
|
int
|
100 |
100
|
dropout
|
float
|
0.0 |
0.0
|
attention_dropout
|
float
|
0.0 |
0.0
|
norm_layer
|
Optional[Module]
|
None |
None
|
conv_stem_configs
|
Optional[List[ConvStemConfig]]
|
None |
None
|
progress
|
bool
|
False |
False
|
weights
|
bool
|
False |
False
|
vit_type
|
str
|
one of (b_16, b_32, l_16, l_32). Default is b_32. |
'b_32'
|
configure_optimizers()
configures the torch.optim
used in training loop
forward(x)
calls .forward of a given model flow
predict_step(batch)
returns predicted logits from the trained model
test_step(batch, *args)
runs a test step sequence
training_step(batch)
runs a training step sequence
validation_step(batch, *args)
runs a validation step sequence