-
-
Save Chillee/22cd93e11b887db1f596ab754d60a899 to your computer and use it in GitHub Desktop.
chunked_lce.py
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| class ChunkedCE(torch.autograd.Function): | |
| @staticmethod | |
| def forward(ctx, _input, weight, target, bias=None, compiled=True): | |
| CHUNK_SIZE=1024 | |
| def compute_loss(input_chunk, weight, bias, target): | |
| logits = torch.addmm(bias, input_chunk, weight.t()) | |
| logits = logits.float() | |
| loss = ce(logits, target) | |
| return loss | |
| grad_weight = torch.zeros_like(weight) | |
| grad_inputs = [] | |
| grad_bias = torch.zeros_like(bias) | |
| loss_acc = torch.zeros((), device=_input.device) | |
| chunks = _input.shape[0] // CHUNK_SIZE | |
| def accumulate_chunk(input_chunk, target_chunk): | |
| (chunk_grad_input, chunk_grad_weight, chunk_grad_bias), chunk_loss = torch.func.grad_and_value(compute_loss, argnums=(0,1,2))(input_chunk, weight, bias, target_chunk) | |
| grad_weight.add_(chunk_grad_weight) | |
| grad_bias.add_(chunk_grad_bias) | |
| loss_acc.add_(chunk_loss) | |
| return chunk_grad_input | |
| if compiled: | |
| accumulate_chunk = torch.compile(accumulate_chunk) | |
| input_chunks = torch.chunk(_input, chunks=chunks, dim=0) | |
| target_chunks = torch.chunk(target, chunks=chunks, dim=0) | |
| for input_chunk, target_chunk in zip(input_chunks, target_chunks): | |
| grad_inputs.append(accumulate_chunk(input_chunk, target_chunk)) | |
| ctx.save_for_backward( | |
| torch.cat(grad_inputs, dim=0)/chunks, | |
| grad_weight/chunks, | |
| grad_bias/chunks, | |
| ) | |
| return loss_acc / chunks | |
| @staticmethod | |
| def backward(ctx, grad_output): | |
| (grad_input, grad_weight, grad_bias) = ctx.saved_tensors | |
| return (grad_input, grad_weight, None, grad_bias, None) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| import torch | |
| import triton | |
| import triton.language as tl | |
| from typing import List | |
| import torch.nn as nn | |
| torch.set_default_device('cuda') | |
| @triton.jit | |
| def liger_cross_entropy_kernel( | |
| X_ptr, | |
| X_stride, | |
| Y_ptr, | |
| Y_stride, | |
| loss_ptr, | |
| loss_stride, | |
| n_cols, | |
| n_non_ignore, | |
| ignore_index, | |
| BLOCK_SIZE: tl.constexpr, | |
| ): | |
| """ | |
| This kernel computes both cross entropy loss and the gradient of the input. | |
| We only consider hard label + mean reduction for now. Please refer to https://pytorch.org/docs/stable/generated/torch.nn.CrossEntropyLoss.html for the math. | |
| Parameters: | |
| X_ptr: Pointer to input tensor. | |
| X_stride (int): The stride of the input tensor. | |
| Y_ptr: Pointer to target tensor. | |
| Y_stride (int): The stride of the target tensor. | |
| loss_ptr: Pointer to tensor to store the loss. | |
| loss_stride (int): The stride of the loss tensor. | |
| n_cols (int): The number of columns in the input tensor. | |
| n_non_ignore (int): The number of non-ignored elements in the batch. | |
| ignore_index (int): The index to ignore in the target. | |
| BLOCK_SIZE (int): The block size for Triton operations. | |
| """ | |
| # https://github.com/triton-lang/triton/issues/1058 | |
| # If B*T*V is too large, program_id * stride will overflow out of int32, so we convert to int64 | |
| program_id = tl.program_id(0).to(tl.int64) | |
| # 1. Load Y_ptr first because if the target is ignore_index, we can return right away | |
| Y_ptr += program_id * Y_stride | |
| y = tl.load(Y_ptr) | |
| # 2. locate the start index | |
| X_ptr += program_id * X_stride | |
| if y == ignore_index: | |
| # set all X_ptr as 0 | |
| for i in range(0, n_cols, BLOCK_SIZE): | |
| X_offsets = i + tl.arange(0, BLOCK_SIZE) | |
| tl.store(X_ptr + X_offsets, 0.0, mask=X_offsets < n_cols) | |
| return | |
| loss_ptr += program_id * loss_stride | |
| # Online softmax: 2 loads + 1 store (compared with 3 loads + 1 store for the safe softmax) | |
| # Refer to Algorithm 3 in the paper: https://arxiv.org/pdf/1805.02867 | |
| # 3. [Online softmax] first pass: find max + sum | |
| m = float("-inf") # m is the max value. use the notation from the paper | |
| d = 0.0 # d is the sum. use the notation from the paper | |
| ori_X_y = tl.load( | |
| X_ptr + y | |
| ) # we need to store the original value of X_y for the loss calculation | |
| for i in range(0, n_cols, BLOCK_SIZE): | |
| X_offsets = i + tl.arange(0, BLOCK_SIZE) | |
| X_block = tl.load( | |
| X_ptr + X_offsets, mask=X_offsets < n_cols, other=float("-inf") | |
| ) | |
| block_max = tl.max(X_block) | |
| m_new = tl.maximum(m, block_max) | |
| d = d * tl.exp(m - m_new) + tl.sum(tl.exp(X_block - m_new)) | |
| m = m_new | |
| # 4. [Online softmax] second pass: calculate the gradients | |
| # dx_y = (softmax(x_y) - 1) / N | |
| # dx_i = softmax(x_i) / N, i != y | |
| # N is the number of non ignored elements in the batch | |
| for i in range(0, n_cols, BLOCK_SIZE): | |
| X_offsets = i + tl.arange(0, BLOCK_SIZE) | |
| X_block = tl.load( | |
| X_ptr + X_offsets, mask=X_offsets < n_cols, other=float("-inf") | |
| ) | |
| X_block = (tl.exp(X_block - m) / d) / (n_non_ignore) | |
| tl.store(X_ptr + X_offsets, X_block, mask=X_offsets < n_cols) | |
| # We need tl.debug_barrier() to ensure the new result of X_ptr is written as mentioned in | |
| # https://github.com/triton-lang/triton/blob/ba42a5c68fd0505f8c42f4202d53be0f8d9a5fe0/python/triton/ops/cross_entropy.py#L34 | |
| tl.debug_barrier() | |
| # 5. Calculate the loss | |
| # loss = log (softmax(X_y)) = log ((e ^ (X_y - max(X)) / sum(e ^ (X - max(X)))) | |
| # = (X_y - max(X)) - log(sum(e ^ (X - max(X)))) | |
| # sum(e ^ (X - max(X))) must >= 1 because the max term is e ^ 0 = 1 | |
| # So we can safely calculate log (softmax(X_y)) without overflow | |
| loss = -(ori_X_y - m - tl.log(d)) | |
| # 6. Specially handle the i==y case where `dx_y = (softmax(x_y) - 1) / N` | |
| X_y = tl.load(X_ptr + y) | |
| X_y += -1 / (n_non_ignore) | |
| tl.store(loss_ptr, loss) | |
| tl.store(X_ptr + y, X_y) | |
| # The hard limit of TRITON_MAX_TENSOR_NUMEL is 1048576 https://github.com/triton-lang/triton/blob/ba42a5c68fd0505f8c42f4202d53be0f8d9a5fe0/python/triton/language/core.py#L19 | |
| # However, setting limit as 65536 as in LayerNorm tutorial is faster because of less register spilling | |
| # The optimal maximum block size depends on your hardware, your kernel, and your dtype | |
| MAX_FUSED_SIZE = 65536 // 2 # the best size we found by manually tuning | |
| @triton.jit | |
| def element_mul_kernel( | |
| X_ptr, | |
| X_stride, | |
| grad_output_ptr, | |
| n_cols, | |
| BLOCK_SIZE: tl.constexpr, | |
| ): | |
| """ | |
| This function multiplies each element of the tensor pointed by X_ptr with the value pointed by grad_output_ptr. | |
| The multiplication is performed in-place on the tensor pointed by X_ptr. | |
| Parameters: | |
| X_ptr: Pointer to the input tensor. | |
| X_stride (int): The stride of the input tensor. | |
| grad_output_ptr: Pointer to the gradient output value. | |
| n_cols (int): The number of columns in the input tensor. | |
| BLOCK_SIZE (int): The block size for Triton operations. | |
| """ | |
| # Get the program ID and convert it to int64 to avoid overflow | |
| program_id = tl.program_id(0).to(tl.int64) | |
| # Locate the start index | |
| X_ptr += program_id * X_stride | |
| # Load the gradient output value | |
| grad_output = tl.load(grad_output_ptr) | |
| # Perform the element-wise multiplication | |
| for i in range(0, n_cols, BLOCK_SIZE): | |
| X_offsets = i + tl.arange(0, BLOCK_SIZE) | |
| X_block = tl.load(X_ptr + X_offsets, mask=X_offsets < n_cols) | |
| tl.store(X_ptr + X_offsets, X_block * grad_output, mask=X_offsets < n_cols) | |
| # The hard limit of TRITON_MAX_TENSOR_NUMEL is 1048576 https://github.com/triton-lang/triton/blob/ba42a5c68fd0505f8c42f4202d53be0f8d9a5fe0/python/triton/language/core.py#L19 | |
| # However, setting limit as 65536 as in LayerNorm tutorial is faster because of less register spilling | |
| # The optimal maximum block size depends on your hardware, your kernel, and your dtype | |
| MAX_FUSED_SIZE = 65536 // 2 | |
| def fused_linear_cross_entropy_forward( | |
| _input, weight, target, bias=None, ignore_index=-100 | |
| ): | |
| dtype = ( | |
| torch.get_autocast_gpu_dtype() if torch.is_autocast_enabled() else _input.dtype | |
| ) | |
| device = _input.device | |
| # inputs have shape: BT x H | |
| # materialized activations will have shape: BT x V | |
| # the increase in memory = BT x V | |
| # reduction can be achieved by partitioning the number of tokens BT into smaller chunks. | |
| # for ex: if we were to achieve the same memory consumption as BT x H, then the chunk size should be: | |
| # inc_factor = (V+H-1)//H, chunk_size = (BT + inc_factor - 1)//inc_factor | |
| # for ex: BT = 4096*4, V = 32000, H = 4096 ==> inc_factor = 8, chunk_size = 2048 | |
| BT, H = _input.shape | |
| V = weight.shape[0] | |
| BLOCK_SIZE = min(MAX_FUSED_SIZE, triton.next_power_of_2(V)) | |
| inc_factor = triton.cdiv(V, H) # (V + H - 1) // H | |
| chunk_size = triton.next_power_of_2( | |
| triton.cdiv(BT, inc_factor) | |
| ) # (BT + inc_factor - 1) // inc_factor | |
| num_chunks = triton.cdiv(BT, chunk_size) # (BT + chunk_size - 1) // chunk_size | |
| grad_weight = torch.zeros_like(weight, device=device) | |
| grad_input = torch.zeros_like(_input, device=device) | |
| grad_bias = torch.zeros_like(bias, device=device) if bias is not None else None | |
| # we use fp32 for loss accumulator | |
| loss_1d = torch.zeros(BT, dtype=torch.float32, device=device) | |
| total_n_non_ignore = target.numel() | |
| for chunk_id in range(num_chunks): | |
| start_idx = chunk_id * chunk_size | |
| end_idx = min((chunk_id + 1) * chunk_size, BT) | |
| _input_chunk = _input[start_idx:end_idx] # chunk_size x H | |
| # when doing matmul, use the original precision | |
| logits_chunk = _input_chunk @ weight.t() # chunk_size x V | |
| if bias is not None: | |
| logits_chunk = logits_chunk + bias | |
| target_chunk = target[start_idx:end_idx] # chunk_size, | |
| n_rows = logits_chunk.shape[0] | |
| # unreduced loss | |
| loss_1d_slice = loss_1d[start_idx:end_idx] # chunk_size, | |
| n_non_ignore = (target_chunk != ignore_index).sum().item() | |
| # when doing CE, use the upcasted precision | |
| logits_chunk = logits_chunk.float() | |
| # ensure _input and target are contiguous | |
| logits_chunk = logits_chunk.contiguous() | |
| target_chunk = target_chunk.contiguous() | |
| # Here we calculate the gradient of logits_chunk in place so we can save memory. | |
| liger_cross_entropy_kernel[(n_rows,)]( | |
| X_ptr=logits_chunk, | |
| X_stride=logits_chunk.stride(-2), | |
| Y_ptr=target_chunk, | |
| Y_stride=target_chunk.stride(-1), # always 1 | |
| loss_ptr=loss_1d_slice, | |
| loss_stride=loss_1d_slice.stride(-1), # always 1 | |
| n_cols=V, | |
| n_non_ignore=n_non_ignore, | |
| ignore_index=ignore_index, | |
| BLOCK_SIZE=BLOCK_SIZE, | |
| num_warps=32, | |
| ) | |
| # gradient of logits_chunk is computed in-place by the above triton kernel. | |
| # Following HuggingFace model source code, we do the forward and backward | |
| # w.r.t. logits in fp32 for numerical stability especially as the num classes (vocab size) os huge. | |
| # (reference: https://github.com/huggingface/transformers/blob/v4.42.4/src/transformers/models/llama/modeling_llama.py#L1194) | |
| # Propagating to lm_head's backward, we'll switch back to the original dtype. | |
| logits_chunk = logits_chunk.to(dtype) | |
| # gradient of logits_chunk is computed in-place by the above triton kernel and is of shape: chunk_size x V | |
| # thus grad_input[start_idx: end_idx] should be of shape: chunk_size x H | |
| # additionally, since we are chunking the inputs, observe that the loss and gradients are calculated only | |
| # on `n_non_ignore` tokens. However, the gradient of the input should be calculated for all tokens. | |
| # Thus, we need an additional scaling factor of (n_non_ignore/total_n_non_ignore) to scale the gradients. | |
| grad_logits_chunk = logits_chunk * ( | |
| n_non_ignore / total_n_non_ignore | |
| ) # chunk_size x V | |
| grad_input[start_idx:end_idx] = grad_logits_chunk @ weight | |
| torch.addmm( | |
| input=grad_weight, | |
| mat1=logits_chunk.t(), | |
| mat2=_input_chunk, | |
| out=grad_weight, | |
| alpha=n_non_ignore / total_n_non_ignore, | |
| beta=1.0, | |
| ) | |
| if bias is not None: | |
| torch.add( | |
| input=grad_bias, | |
| other=logits_chunk.sum(dim=0), | |
| out=grad_bias, | |
| alpha=n_non_ignore / total_n_non_ignore, | |
| ) | |
| loss = torch.sum(loss_1d) / total_n_non_ignore | |
| return loss, grad_input, grad_weight, grad_bias | |
| def fused_linear_cross_entropy_backward( | |
| grad_output, grad_input, grad_weight, grad_bias | |
| ): | |
| # If cross entropy is the last layer, grad_output is 1.0. Skip the mul to save time | |
| if torch.ne(grad_output, torch.tensor(1.0, device=grad_output.device)): | |
| # We use a Triton kernel instead of a PyTorch operation because modifying inputs in-place | |
| # for gradient storage and backward multiple times causes anomalies with PyTorch but not with Triton. | |
| BT, H = grad_input.shape | |
| n_rows = BT | |
| BLOCK_SIZE = min(MAX_FUSED_SIZE, triton.next_power_of_2(H)) | |
| element_mul_kernel[(n_rows,)]( | |
| grad_input, | |
| grad_input.stride(-2), | |
| grad_output, | |
| H, | |
| BLOCK_SIZE=BLOCK_SIZE, | |
| num_warps=32, | |
| ) | |
| # handle grad_weight | |
| V, H = grad_weight.shape | |
| n_rows = V | |
| element_mul_kernel[(n_rows,)]( | |
| grad_weight, | |
| grad_weight.stride(-2), | |
| grad_output, | |
| H, | |
| BLOCK_SIZE=BLOCK_SIZE, | |
| num_warps=32, | |
| ) | |
| if grad_bias is not None: | |
| V = grad_bias.shape[0] | |
| n_rows = V | |
| element_mul_kernel[(n_rows,)]( | |
| grad_bias, | |
| grad_bias.stride(-1), | |
| grad_output, | |
| 1, | |
| BLOCK_SIZE=BLOCK_SIZE, | |
| num_warps=32, | |
| ) | |
| return grad_input, grad_weight, grad_bias | |
| class LigerFusedLinearCrossEntropyFunction(torch.autograd.Function): | |
| @staticmethod | |
| def forward(ctx, _input, weight, target, bias=None, ignore_index=-100): | |
| """ | |
| Fusing the last linear layer with cross-entropy loss | |
| Reference: https://github.com/mgmalek/efficient_cross_entropy | |
| Handle the forward and backward pass of the final linear layer via cross-entropy loss by avoiding | |
| the materialization of the large logits tensor. Since Cross Entropy Loss is the last layer, we can | |
| compute the gradient at the forward pass. By doing so, we don't have to store the _input and target | |
| for the backward pass. | |
| _input: (B*T, H) where B is batch size, T is sequence length, H is hidden dimension. | |
| target: (B*T) where each value is in [0, V-1] | |
| weight: (V, H) where V is the number of classes | |
| bias: (V) where V is the number of classes | |
| ignore_index: the index to ignore in the target | |
| """ | |
| loss, grad_input, grad_weight, grad_bias = fused_linear_cross_entropy_forward( | |
| _input, weight, target, bias, ignore_index | |
| ) | |
| # downcast to dtype and store for backward | |
| ctx.save_for_backward( | |
| grad_input.detach(), | |
| grad_weight.detach(), | |
| grad_bias.detach() if bias is not None else None, | |
| ) | |
| return loss | |
| @staticmethod | |
| def backward(ctx, grad_output): | |
| (grad_input, grad_weight, grad_bias) = ctx.saved_tensors | |
| grad_input, grad_weight, grad_bias = fused_linear_cross_entropy_backward( | |
| grad_output, grad_input, grad_weight, grad_bias | |
| ) | |
| return (grad_input, grad_weight, None, grad_bias, None) | |
| class ChunkedCE(torch.autograd.Function): | |
| @staticmethod | |
| def forward(ctx, _input, weight, target, bias=None, compiled=True): | |
| CHUNK_SIZE=1024 | |
| def compute_loss(input_chunk, weight, bias, target): | |
| logits = torch.addmm(bias, input_chunk, weight.t()) | |
| logits = logits.float() | |
| loss = ce(logits, target) | |
| return loss | |
| grad_weight = torch.zeros_like(weight) | |
| grad_inputs = [] | |
| grad_bias = torch.zeros_like(bias) | |
| loss_acc = torch.zeros((), device=_input.device) | |
| chunks = _input.shape[0] // CHUNK_SIZE | |
| def accumulate_chunk(input_chunk, target_chunk): | |
| (chunk_grad_input, chunk_grad_weight, chunk_grad_bias), chunk_loss = torch.func.grad_and_value(compute_loss, argnums=(0,1,2))(input_chunk, weight, bias, target_chunk) | |
| grad_weight.add_(chunk_grad_weight) | |
| grad_bias.add_(chunk_grad_bias) | |
| loss_acc.add_(chunk_loss) | |
| return chunk_grad_input | |
| if compiled: | |
| accumulate_chunk = torch.compile(accumulate_chunk) | |
| input_chunks = torch.chunk(_input, chunks=chunks, dim=0) | |
| target_chunks = torch.chunk(target, chunks=chunks, dim=0) | |
| for input_chunk, target_chunk in zip(input_chunks, target_chunks): | |
| grad_inputs.append(accumulate_chunk(input_chunk, target_chunk)) | |
| ctx.save_for_backward( | |
| torch.cat(grad_inputs, dim=0)/chunks, | |
| grad_weight/chunks, | |
| grad_bias/chunks, | |
| ) | |
| return loss_acc / chunks | |
| @staticmethod | |
| def backward(ctx, grad_output): | |
| (grad_input, grad_weight, grad_bias) = ctx.saved_tensors | |
| return (grad_input, grad_weight, None, grad_bias, None) | |
| B, T, D, V = 32, 1024, 768, 128256 | |
| model = nn.Linear(D, V).to(torch.bfloat16) | |
| ce = nn.CrossEntropyLoss() | |
| x = torch.randn(B, T, D, requires_grad=True, dtype=torch.bfloat16) | |
| label = torch.randint(0, V, (B, T)).to(torch.int64) | |
| def f(m, x, label): | |
| out = ce(m(x).view(-1, V), label.view(-1)) | |
| out.backward() | |
| return out | |
| def chunked_f(m, x, label, compiled=False): | |
| out = ChunkedCE.apply(x.view(-1, D), m.weight, label.view(-1), m.bias, compiled) | |
| out.backward() | |
| return out | |
| def ligerf(m, x, label): | |
| out = LigerFusedLinearCrossEntropyFunction.apply(x.view(-1, D), m.weight,label.view(-1), model.bias) | |
| out.backward() | |
| return out | |
| def bench(f, name=None, iters=100, warmup=5, display=True, profile=False, profile_mem=False): | |
| import time | |
| from triton.testing import do_bench | |
| for _ in range(warmup): | |
| f() | |
| if profile_mem: | |
| torch.cuda.memory._record_memory_history() | |
| f() | |
| torch.cuda.memory._dump_snapshot(f"{name if name is not None else 'memory'}.pickle") | |
| if profile: | |
| with torch.profiler.profile() as prof: | |
| f() | |
| prof.export_chrome_trace(f"{name if name is not None else 'trace'}.json") | |
| torch.cuda.reset_peak_memory_stats() | |
| ms_per_iter = do_bench(lambda: f()) | |
| if name is None: | |
| res = ms_per_iter | |
| else: | |
| res= f"{name}: {ms_per_iter:.3f}ms" | |
| if display: | |
| print(res) | |
| print("Peak mem: ", torch.cuda.max_memory_allocated()/1e9) | |
| print() | |
| return res | |
| opt_f = torch.compile(f) | |
| bench(lambda: ligerf(model, x, label), name='liger lce') | |
| bench(lambda: f(model, x, label), name='eager (non-chunked)') | |
| bench(lambda: chunked_f(model, x, label, compiled=False), name='eager (chunked)') | |
| bench(lambda: opt_f(model, x, label), name='compile (non-chunked)') | |
| bench(lambda: chunked_f(model, x, label, compiled=True), name='compile (chunked)') |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
There was a bug in L190 to compute total number of non-ignored index. This has been fixed in upstream by: