## Gradient Checkpointing

A simple example (sentiment analysis task) using gradient checkpointing.

Gradient checkpointing allows you to train using less VRAM, but at the cost of recomputing activations that are not stored.

In [None]:
import t2t

In [None]:
trainer_arguments = t2t.TrainerArguments(
    # model
    model_name_or_path="t5-large",
    cache_dir="/workspace/cache",
    # data inputs
    train_file="../sample_data/trainlines.json",
    max_source_length=128,
    max_target_length=8,
    # taining outputs
    output_dir="/tmp/saved_model",
    overwrite_output_dir=True,
    # training settings
    num_train_epochs=1,
    per_device_train_batch_size=1,
    gradient_accumulation_steps=32,
    learning_rate=1e-5,
    gradient_checkpointing=True,
    prefix="predict sentiment: ",
    # validation settings
)
trainer = t2t.Trainer(arguments=trainer_arguments)

In [None]:
trainer.model_summary()

### Train Model

In [None]:
trainer.train(valid=False)

### Test Model

In [None]:
input_text = "predict sentiment: This is the worst movie I have ever seen!"
trainer.generate_single(input_text, max_length=8)

In [None]:
input_text = "predict sentiment: This is the best movie I have ever seen!"
trainer.generate_single(input_text, max_length=8)