Skip to content

zihanghliu/TempBalance-LM

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

7 Commits
 
 
 
 
 
 

Repository files navigation

TempBalance-LM

Language Modeling Experiments of paper: Temperature Balancing, Layer-wise Weight Analysis, and Neural Network Training [NeurIPS 2023 Spotlight]

Yefan Zhou, Tianyu Pang, Keqin Liu, Charles H. Martin, Michael W. Mahoney, Yaoqing Yang

Full paper

Install

bash install.sh
conda activate ww_train_lm
bash penn_tree.sh

Usage

from tempbalance import Tempbalance
import torch
model = ...
# initialize necessary hyperparameters
start_lr = ...
total_steps = ...
# initialize the scheduler
tb_scheduler = Tempbalance(net=model,
                start_lr=start_lr,
                total_steps=total_steps,
                lr_min_ratio=0.5,
                lr_max_ratio=1.5
                )
# initialize optimizer parameter group
tb_param_group = tb_scheduler.build_optimizer_param_group(untuned_lr=0.1)
optimizer = optim.SGD(
    tb_param_group,
    ...
)
# training loop
for epoch in range(1, ...):
    ...
    train()
    test()
    # get global decayed learning rate
    untuned_global_lr = some_torch_lr_scheduler(epoch)
    # temperature balancing
    tb_scheduler.step(optimizer, untuned_global_lr, current_step)
    ...

Experiments

# Baseline 
bash ./BTD-Transformer/scripts/tensorized/run_ptb.sh

# TempBalance
bash ./BTD-Transformer/scripts/tensorized/run_ptb_tb.sh

Acknowledgement

  1. We thank the open-sourced codebase The-compression-of-Transformer.

About

No description, website, or topics provided.

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published