Close Menu
    Main Menu
    • Home
    • News
    • Tech
    • Robotics
    • ML & Research
    • AI
    • Digital Transformation
    • AI Ethics & Regulation
    • Thought Leadership in AI

    Subscribe to Updates

    Get the latest creative news from FooBar about art, design and business.

    What's Hot

    Unpatched FortiGate Safety Flaw Permits Attackers to Bypass 2FA Controls

    December 26, 2025

    OpenAI admits immediate injection is right here to remain as enterprises lag on defenses

    December 26, 2025

    Practice a Mannequin Quicker with torch.compile and Gradient Accumulation

    December 26, 2025
    Facebook X (Twitter) Instagram
    UK Tech InsiderUK Tech Insider
    Facebook X (Twitter) Instagram
    UK Tech InsiderUK Tech Insider
    Home»Machine Learning & Research»Practice a Mannequin Quicker with torch.compile and Gradient Accumulation
    Machine Learning & Research

    Practice a Mannequin Quicker with torch.compile and Gradient Accumulation

    Oliver ChambersBy Oliver ChambersDecember 26, 2025No Comments6 Mins Read
    Facebook Twitter Pinterest Telegram LinkedIn Tumblr Email Reddit
    Practice a Mannequin Quicker with torch.compile and Gradient Accumulation
    Share
    Facebook Twitter LinkedIn Pinterest Email Copy Link


    Coaching a language mannequin with a deep transformer structure is time-consuming. Nevertheless, there are methods you should use to speed up coaching. On this article, you’ll study:

    • Utilizing torch.compile() to hurry up the mannequin
    • Utilizing gradient accumulation to coach a mannequin with a bigger efficient batch measurement

    Let’s get began!

    Practice a Mannequin Quicker with torch.compile and Gradient Accumulation
    Photograph by François Genon. Some rights reserved.

    Overview

    This text is split into two elements; they’re:

    • Utilizing torch.compile()
    • Gradient Accumulation

    Utilizing torch.compile

    Whenever you write your mannequin code and run it with PyTorch, the code is executed in keen mode. This implies the code is executed line by line, and the outcomes are saved in reminiscence. That is native to Python since it’s an interpreted language. that is the case as a result of once you make a mistake in your code, you’ll not see the error till you run that line of code.

    Operating a mannequin in keen mode is gradual. Beginning with PyTorch 2.0, you should use torch.compile() to compile a mannequin for improved efficiency. This generates a brand new mannequin object that’s optimized. It’s not the identical mannequin object you created utilizing nn.Module, nevertheless it shares the identical tensors with the unique mannequin. You should use this compiled mannequin for ahead move, backward move, and optimizer updates as common.

    Constructing a mannequin and compiling it as a computation graph is how TensorFlow 1.0 was purported to work. This makes debugging more durable, for the reason that mannequin you execute can’t match line by line with the code you wrote. Due to this fact, you shouldn’t compile your mannequin till you might have run a trial and confirmed that it’s error-free.

    Not all fashions will be compiled. Nevertheless, in case your mannequin helps compilation, you instantly profit from the speedup. To compile a mannequin, all you should do is substitute the mannequin object proper earlier than you’re prepared to make use of it:

    ...

    mannequin = LlamaForPretraining(model_config).to(machine)

    mannequin.load_state_dict(checkpoint)

    mannequin = torch.compile(mannequin)

    ...

    Don’t load the mannequin weights after compilation. It is because the compiled mannequin is an object that shares the identical weights as the unique mannequin. Throughout compilation, the computation graph is constructed referencing the burden tensors of the unique mannequin. For those who load the weights after compilation, the mannequin could not work as anticipated.

    Equally, to save lots of the compiled mannequin, you must consult with the unique mannequin’s state dict, as follows:

    torch.save(getattr(mannequin, “_orig_mod”, mannequin).state_dict(), “mannequin.pth”)

    The unique mannequin will be accessed from the compiled mannequin utilizing mannequin._orig_mod. Within the code above, we use getattr(mannequin, "_orig_mod", mannequin) to get the unique mannequin if it exists, or use mannequin itself if it doesn’t. This line of code works for each compiled and authentic fashions.

    Gradient Accumulation

    Whenever you prepare a mannequin, you probably spend two to 3 instances extra time on the backward move than the ahead move. It is because the backward move is extra computationally intensive and makes use of extra reminiscence.

    One straightforward trick to hurry up coaching is to carry out fewer backward passes. This may be achieved by growing the batch measurement: with the identical variety of information samples, a bigger batch measurement means fewer batches to course of.

    Nevertheless, a bigger batch measurement requires extra reminiscence. In a memory-constrained atmosphere, you’ll be able to mimic a bigger batch measurement by working a number of ahead passes and accumulating the gradients. That is referred to as gradient accumulation.

    It’s simpler to clarify this concept with code:

    1

    2

    3

    4

    5

    6

    7

    8

    9

    10

    11

    12

    13

    14

    15

    16

    17

    18

    19

    20

    21

    22

    ..

    accumulate_steps = 4

    for epoch in vary(num_epochs):

        optimizer.zero_grad()

        for i, batch in enumerate(dataloader):

            # get batched information

            input_ids, target_ids = batch

            # create consideration masks: causal masks + padding masks

            attn_mask = create_causal_mask(input_ids.form[1], machine) +

                        create_padding_mask(input_ids, PAD_TOKEN_ID, machine)

            # extract output from mannequin

            logits = mannequin(input_ids, attn_mask)

            # compute loss: cross-entropy between logits and goal, ignoring padding tokens

            loss = loss_fn(logits.view(–1, logits.measurement(–1)), target_ids.view(–1))

            loss = loss / accumulate_steps

            # Run backward, however replace solely as soon as each `accumulate_steps` steps

            loss.backward()

            if (i + 1) % accumulate_steps == 0:

                torch.nn.utils.clip_grad_norm_(mannequin.parameters(), 1.0)

                optimizer.step()

                optimizer.zero_grad()

                scheduler.step()

    The coaching loop above is an excerpt from the earlier article for coaching a Llama mannequin in your native GPU.

    Usually, once you run a ahead move, you calculate the loss. You then name loss.backward() to backpropagate the loss gradient by the mannequin parameters. In PyTorch, the backward() methodology is cumulative, which means gradients are added up. Due to this fact, you should name optimizer.zero_grad() explicitly to clear the gradients earlier than working the backward move.

    Within the code above, you intentionally don’t name optimizer.zero_grad() in each iteration. As a substitute, you run backpropagation for the loss divided by accumulate_steps. This fashion, the gradients are scaled down however gathered over accumulate_steps iterations. As soon as each accumulate_steps iterations, you run the optimizer to regulate the mannequin parameters.

    This method yields outcomes corresponding to utilizing a bigger batch measurement. Nevertheless, because you run fewer optimizer updates, the training charge schedule must be adjusted accordingly. This implies you should initialize the scheduler with a unique variety of steps:

    ...

    num_training_steps = (len(dataloader) // accumulate_steps) * num_epochs

    cosine_scheduler = lr_scheduler.CosineAnnealingLR(

        optimizer,

        T_max=num_training_steps – num_warmup_steps,

        eta_min=0

    )

    Additional Studying

    Beneath are some supplies that you could be discover attention-grabbing:

    Abstract

    On this article, you realized that utilizing torch.compile() will help you pace up the mannequin by compiling the computation graph. You additionally realized that gradient accumulation is a method for coaching with a bigger efficient batch measurement by accumulating gradients from a number of mini-batches. Because you run fewer optimizer updates this fashion, you save time on backward passes and parameter updates.

    Share. Facebook Twitter Pinterest LinkedIn Tumblr Email
    Oliver Chambers
    • Website

    Related Posts

    Rating Distillation of Circulate Matching Fashions

    December 25, 2025

    Programmatically creating an IDP answer with Amazon Bedrock Information Automation

    December 25, 2025

    Prime 7 Open Supply OCR Fashions

    December 25, 2025
    Top Posts

    Evaluating the Finest AI Video Mills for Social Media

    April 18, 2025

    Utilizing AI To Repair The Innovation Drawback: The Three Step Resolution

    April 18, 2025

    Midjourney V7: Quicker, smarter, extra reasonable

    April 18, 2025

    Meta resumes AI coaching utilizing EU person knowledge

    April 18, 2025
    Don't Miss

    Unpatched FortiGate Safety Flaw Permits Attackers to Bypass 2FA Controls

    By Declan MurphyDecember 26, 2025

    A important authentication bypass vulnerability in FortiGate gadgets permits risk actors to bypass two-factor authentication…

    OpenAI admits immediate injection is right here to remain as enterprises lag on defenses

    December 26, 2025

    Practice a Mannequin Quicker with torch.compile and Gradient Accumulation

    December 26, 2025

    World’s smallest autonomous robots might save lives

    December 25, 2025
    Stay In Touch
    • Facebook
    • Twitter
    • Pinterest
    • Instagram
    • YouTube
    • Vimeo

    Subscribe to Updates

    Get the latest creative news from SmartMag about art & design.

    UK Tech Insider
    Facebook X (Twitter) Instagram
    • About Us
    • Contact Us
    • Privacy Policy
    • Terms Of Service
    • Our Authors
    © 2025 UK Tech Insider. All rights reserved by UK Tech Insider.

    Type above and press Enter to search. Press Esc to cancel.