Skip to content

Commit

Permalink
Remove turbo (#96)
Browse files Browse the repository at this point in the history
* rm turbo usage + associated tests

* rip out rest of turbo + all tests passing except topology_test
  • Loading branch information
dblalock committed Feb 20, 2024
1 parent f05609c commit 014667c
Show file tree
Hide file tree
Showing 12 changed files with 87 additions and 377 deletions.
2 changes: 0 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,6 @@ NOTE: This assumes you have `numpy` and `torch` installed.

**Extras:** MegaBlocks has optional dependencies that enable additional features.

Installing `megablocks[quant]` enables configurable quantization of saved activations in the dMoE layer to save memory during training. The degree of quantization is controlled via the `quantize_inputs_num_bits`, `quantize_rematerialize_num_bits` and `quantize_scatter_num_bits` [arguments](https://github.com/stanford-futuredata/megablocks/blob/main/megablocks/layers/arguments.py).

Installing `megablocks[gg]` enables dMoE computation with grouped GEMM. This feature is enabled by setting the `mlp_impl` argument to `grouped`. This is currently our recommended path for Hopper-generation GPUs.

MegaBlocks can be installed with all dependencies via the `megablocks[all]` package.
Expand Down
15 changes: 0 additions & 15 deletions megablocks/layers/arguments.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import dataclasses
from functools import partial
import megablocks.turbo_util as turbo
import megablocks.grouped_gemm_util as grouped_gemm
import torch
import torch.nn.functional as F
Expand Down Expand Up @@ -45,9 +44,6 @@ class Arguments:
memory_optimized_mlp : bool = False
mlp_type : str = 'mlp'
mlp_impl : str = 'sparse'
quantize_inputs_num_bits: int = -1 # -1 = no quantization
quantize_rematerialize_num_bits: int = -1
quantize_scatter_num_bits: int = -1

# Initialization arguments.
fp16 : bool = True
Expand All @@ -60,17 +56,6 @@ class Arguments:
uniform_expert_assignment : bool = False

def __post_init__(self):
for attr in ('quantize_inputs_num_bits',
'quantize_rematerialize_num_bits',
'quantize_scatter_num_bits'):
nbits = self.__getattribute__(attr)
if nbits not in _ALLOWED_BITWIDTHS:
raise ValueError(f'{attr} must be one of ' +
f'{_ALLOWED_BITWIDTHS}; got {nbits}')

if nbits != -1:
turbo.assert_turbo_is_available()

if self.__getattribute__('mlp_impl') == 'grouped':
grouped_gemm.assert_grouped_gemm_is_available()

Expand Down
6 changes: 2 additions & 4 deletions megablocks/layers/dmoe.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,8 +158,7 @@ def sparse_forward_once(self, x, expert_weights, top_experts):
expert_weights,
bins,
padded_bins,
self.top_k,
self.args.quantize_scatter_num_bits)
self.top_k)
return x, tokens_per_expert

# For use in the base-class parallel_forward_once.
Expand Down Expand Up @@ -260,8 +259,7 @@ def grouped_permute_and_compute(
bin_ids,
expert_weights,
bins,
top_k,
self.args.quantize_scatter_num_bits)
top_k)

def forward_once(self, x, expert_weights, top_experts):
if self.args.mlp_impl == 'sparse':
Expand Down
50 changes: 4 additions & 46 deletions megablocks/layers/dmoe_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@

from absl.testing import parameterized
from megablocks import grouped_gemm_util as gg
from megablocks import turbo_util as turbo
from megablocks.layers.arguments import Arguments
from megablocks.layers import dmoe
from megablocks.layers import moe
Expand All @@ -17,8 +16,6 @@ def test_modules(
moe_num_experts=1,
moe_capacity_factor=1,
moe_top_k=1,
num_input_bits=-1,
num_remat_bits=-1,
mlp_impl='sparse'):
init_method = partial(torch.nn.init.normal_, mean=0.0, std=0.1)
args = Arguments(
Expand All @@ -29,8 +26,6 @@ def test_modules(
moe_top_k=moe_top_k,
init_method=init_method,
memory_optimized_mlp=True,
quantize_inputs_num_bits=num_input_bits,
quantize_rematerialize_num_bits=num_remat_bits,
mlp_type='mlp',
mlp_impl=mlp_impl,
fp16=False,
Expand All @@ -57,7 +52,7 @@ def test_modules(
return args, mlp, moe_mlp, dmoe_mlp

# min size: (1, 2, 128, 2, 1)
_FORWARD_TESTS_NO_QUANTIZE = (
_FORWARD_TESTS_DEFAULT = (
(16, 1024, 512, 1, 1),
(16, 1024, 512, 2, 1),
(16, 1024, 512, 4, 1),
Expand All @@ -74,38 +69,10 @@ def test_modules(
)

_FORWARD_TESTS_GROUPED_MLP = tuple([
p + (-1, -1, 'grouped') for p in _FORWARD_TESTS_NO_QUANTIZE
p + ('grouped',) for p in _FORWARD_TESTS_DEFAULT
]) if gg.grouped_gemm_is_available() else ()

# quantization tests; assorted small sizes, systematic bitwidths
_FORWARD_TESTS_QUANTIZE_HIDDEN = (
(1, 2, 128, 2, 2, -1, -1),
(1, 8, 128, 2, 2, -1, 4),
(2, 8, 128, 2, 1, -1, 8),
) if turbo.turbo_is_available() else ()

_FORWARD_TESTS_QUANTIZE_INPUT = (
(1, 2, 128, 2, 1, 4, -1),
(2, 8, 128, 4, 1, 8, -1),
) if turbo.turbo_is_available() else ()

_FORWARD_TESTS_QUANTIZE_BOTH = (
(2, 2, 128, 2, 2, 4, 4),
(1, 8, 128, 4, 2, 4, 8),
(1, 2, 128, 4, 2, 8, 4),
(2, 2, 128, 4, 2, 8, 8),
) if turbo.turbo_is_available() else ()

_FORWARD_TESTS = (_FORWARD_TESTS_NO_QUANTIZE +
_FORWARD_TESTS_QUANTIZE_HIDDEN +
_FORWARD_TESTS_QUANTIZE_INPUT +
_FORWARD_TESTS_QUANTIZE_BOTH +
_FORWARD_TESTS_GROUPED_MLP)

_FORWARD_TESTS_WITH_HIDDEN_QUANTIZE = (
_FORWARD_TESTS_NO_QUANTIZE +
_FORWARD_TESTS_QUANTIZE_HIDDEN +
_FORWARD_TESTS_GROUPED_MLP)
_FORWARD_TESTS = (_FORWARD_TESTS_DEFAULT + _FORWARD_TESTS_GROUPED_MLP)


_DENSE_TESTS = (
Expand All @@ -122,7 +89,6 @@ def tearDown():

@parameterized.parameters(*_FORWARD_TESTS)
def testdMoE_Forward(self, bs, sl, hs, num_experts, top_k,
num_input_bits=-1, num_remat_bits=-1,
mlp_impl='sparse'):
x = torch.randn(sl, bs, hs).to(torch.bfloat16).cuda()

Expand All @@ -131,8 +97,6 @@ def testdMoE_Forward(self, bs, sl, hs, num_experts, top_k,
ffn_hidden_size=hs * 2,
moe_num_experts=num_experts,
moe_top_k=top_k,
num_input_bits=num_input_bits,
num_remat_bits=num_remat_bits,
mlp_impl=mlp_impl)

out, _ = layer(x)
Expand All @@ -141,7 +105,6 @@ def testdMoE_Forward(self, bs, sl, hs, num_experts, top_k,
@parameterized.parameters(*_FORWARD_TESTS)
def testdMoE_ForwardBackward(
self, bs, sl, hs, num_experts, top_k,
num_input_bits=-1, num_remat_bits=-1,
mlp_impl='sparse'):
x = torch.randn(sl, bs, hs).to(torch.bfloat16).cuda()
x.requires_grad_(True)
Expand All @@ -151,8 +114,6 @@ def testdMoE_ForwardBackward(
ffn_hidden_size=hs * 2,
moe_num_experts=num_experts,
moe_top_k=top_k,
num_input_bits=num_input_bits,
num_remat_bits=num_remat_bits,
mlp_impl=mlp_impl)

out, _ = layer(x)
Expand All @@ -178,12 +139,9 @@ def testdMoE_ForwardVersusBaseline(self, bs, sl, hs):
self.assertSequenceEqual(expected_out.shape, x.shape)
self.assertTrue(testing.allclose(out, expected_out))

# we don't run the input quantization cases just to avoid redundancy,
# since input quantization doesn't affect any of these asserts
@parameterized.parameters(*_FORWARD_TESTS_WITH_HIDDEN_QUANTIZE)
@parameterized.parameters(*_FORWARD_TESTS)
def testdMoE_ForwardVersusMoE(
self, bs, sl, hs, num_experts, top_k,
num_input_bits=-1, num_remat_bits=-1,
mlp_impl='sparse'):
torch.manual_seed(42)

Expand Down
87 changes: 13 additions & 74 deletions megablocks/layers/glu.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
from megablocks.layers.mlp import SparseMLP, create_dmoe_expert_weights, resolve_dtensor
from megablocks.layers import mpu
from megablocks.layers.arguments import Arguments, DEFAULT_ACTIVATION_FN
from megablocks import turbo_util as turbo
from megablocks import grouped_gemm_util as gg
import stk
import torch
Expand Down Expand Up @@ -36,7 +35,7 @@ def forward(self, x, topo):

w1, v1, w2 = self.scale_grad(self.w1), self.scale_grad(self.v1), self.scale_grad(self.w2)
w1, v1, w2 = resolve_dtensor(w1), resolve_dtensor(v1), resolve_dtensor(w2)

# Compute the GLU.
x1 = stk.ops.sdd(x, w1.t(), topo)
x2 = stk.ops.sdd(x, v1.t(), topo)
Expand All @@ -51,7 +50,7 @@ class MemoryOptimizedGroupedGLU(torch.autograd.Function):

@staticmethod
@torch.cuda.amp.custom_fwd
def forward(ctx, x, w1, v1, w2, batch_sizes, num_input_bits, num_remat_bits, activation_fn):
def forward(ctx, x, w1, v1, w2, batch_sizes, activation_fn):
# x: [m, k], w1: [n, k], v1: [n, k], w2: [n, k]
if (not x.is_contiguous() or not w1.is_contiguous() or
not v1.is_contiguous() or not w2.is_contiguous()):
Expand All @@ -61,27 +60,8 @@ def forward(ctx, x, w1, v1, w2, batch_sizes, num_input_bits, num_remat_bits, act
sdd_out = gg.backend.gmm(x, w1, batch_sizes, trans_b=True)
v1_out = gg.backend.gmm(x, v1, batch_sizes, trans_b=True)

# Save input tensor, quantizing if needed
input_save_args = (x,)
if num_input_bits != -1:
x_q, x_scales = turbo.quantize_signed(x, num_bits=num_input_bits)
input_save_args = (x_q, x_scales)

# GeLU.
if num_remat_bits == -1:
activation_fn_out = activation_fn(sdd_out) * v1_out
input_save_args += (sdd_out, v1_out,)
else:
if activation_fn is not DEFAULT_ACTIVATION_FN:
raise NotImplementedError(f'`num_remat_bits` != -1 not implemented for custom {activation_fn=} ({num_remat_bits=}).')
# Fused GELU into sdd_out buffer while quantizing input
hidden_q_sdd, hidden_scales_sdd, _ = turbo.quantize_signed(
sdd_out, num_bits=num_remat_bits,
op=turbo.ElemwiseOps.GELU_FORWARD, x_forward=sdd_out)
activation_fn_out = sdd_out * v1_out
hidden_q_v1, hidden_scales_v1 = turbo.quantize_signed(
v1_out, num_bits=num_remat_bits)
input_save_args += (hidden_q_sdd, hidden_scales_sdd, hidden_q_v1, hidden_scales_v1)
activation_fn_out = activation_fn(sdd_out) * v1_out

# Layer 1: x @ w2.
dsd_out = gg.backend.gmm(activation_fn_out, w2, batch_sizes)
Expand All @@ -90,13 +70,11 @@ def forward(ctx, x, w1, v1, w2, batch_sizes, num_input_bits, num_remat_bits, act
# gradient computation. We'll re-compute the activation_fn forward
# pass in the backward pass to avoid materializing another
# intermediate.
ctx.num_input_bits = num_input_bits
ctx.num_remat_bits = num_remat_bits
ctx.x_shape = x.shape
ctx.sdd_out_shape = sdd_out.shape
ctx.dtype = x.dtype
ctx.activation_fn = activation_fn
ctx.save_for_backward(w1, v1, w2, batch_sizes, *input_save_args)
ctx.save_for_backward(w1, v1, w2, batch_sizes, x, sdd_out, v1_out)
return dsd_out

@staticmethod
Expand All @@ -107,44 +85,21 @@ def backward(ctx, ddsd_out):
not ctx.needs_input_grad[2]):
raise ValueError("Expected all MLP inputs to need grad.")

# Unpack saved tensors; ugly because quantizing changes tensor count
# Unpack saved tensors
dtype = ctx.dtype
saved_tensors = ctx.saved_tensors
w1, v1, w2 = saved_tensors[:3]
batch_sizes = saved_tensors[3]

# Either 1 or 2 tensors for MLP input after the always-present tensors
if ctx.num_input_bits == -1:
x = saved_tensors[4]
else:
x_q, x_scales = saved_tensors[4:6]

# Either 1 or 4 tensors at the end for saved GELU input / sdd output
if ctx.num_remat_bits == -1:
sdd_out, v1_out = saved_tensors[-2:]
else:
hidden_q_sdd, hidden_scales_sdd, hidden_q_v1, hidden_scales_v1 = saved_tensors[-4:]
x = saved_tensors[4]
sdd_out, v1_out = saved_tensors[5:7]

# Rematerialize activation_fn output.
activation_fn = ctx.activation_fn
activation_grad_fn = None
if ctx.num_remat_bits == -1:
with torch.set_grad_enabled(True):
with torch.set_grad_enabled(True):
sdd_out.requires_grad = True
v1_out.requires_grad = True
activation_fn_out = activation_fn(sdd_out) * v1_out
activation_grad_fn = activation_fn_out.backward
else:
if activation_fn is not DEFAULT_ACTIVATION_FN:
raise NotImplementedError(f'`num_remat_bits` != -1 not implemented for custom {activation_fn=} ({ctx.num_remat_bits=}).')
sdd_out = turbo.dequantize_signed(
hidden_q_sdd, hidden_scales_sdd, num_bits=ctx.num_remat_bits,
op=turbo.ElemwiseOps.GELU_FORWARD,
out_shape=ctx.sdd_out_shape, out_dtype=dtype)
v1_out = turbo.dequantize_signed(
hidden_q_v1, hidden_scales_v1, num_bits=ctx.num_remat_bits,
out_shape=ctx.sdd_out_shape, out_dtype=dtype)
activation_fn_out = sdd_out * v1_out

# Compute dw2 with recomputed activation_fn output.
dw2 = gg.backend.gmm(
Expand All @@ -160,24 +115,10 @@ def backward(ctx, ddsd_out):
# Compute dsdd_out.
#
# NOTE: This reuses the dactivation_fn_out allocation.
if ctx.num_remat_bits == -1:
assert activation_grad_fn is not None
activation_grad_fn(dactivation_fn_out)
dsdd_out = sdd_out.grad
dv1_out = v1_out.grad
else:
# confusingly, x_out is interpreted as the gradient to overwrite
# in-place when the elemwise op is a backwards op
dsdd_out = turbo.dequantize_signed(
hidden_q_sdd, hidden_scales_sdd, num_bits=ctx.num_remat_bits,
op=turbo.ElemwiseOps.GELU_BACKWARD, x_out=(dactivation_fn_out * v1_out).data)
dv1_out = (dactivation_fn_out * sdd_out).data

# rematerialize MLP input now that we need it
if ctx.num_input_bits != -1:
x = turbo.dequantize_signed(
x_q, x_scales, num_bits=ctx.num_input_bits,
out_dtype=dtype, out_shape=ctx.x_shape)
assert activation_grad_fn is not None
activation_grad_fn(dactivation_fn_out)
dsdd_out = sdd_out.grad
dv1_out = v1_out.grad

# Compute dw1.
dw1 = gg.backend.gmm(dsdd_out, x, batch_sizes, trans_a=True)
Expand All @@ -191,7 +132,7 @@ def backward(ctx, ddsd_out):
dx = ddsd_out
gg.backend.gmm(dsdd_out, w1, batch_sizes, c=dx)
dx += gg.backend.gmm(dv1_out, v1, batch_sizes)
return dx, dw1, dv1, dw2, None, None, None, None
return dx, dw1, dv1, dw2, None, None

memory_optimized_grouped_glu = MemoryOptimizedGroupedGLU.apply

Expand All @@ -211,8 +152,6 @@ def forward(self, x, tokens_per_expert):
if self.args.memory_optimized_mlp:
return memory_optimized_grouped_glu(
x, w1, v1, w2, batch_sizes,
self.args.quantize_inputs_num_bits,
self.args.quantize_rematerialize_num_bits,
self.args.activation_fn)

# Compute the MLP.
Expand Down

0 comments on commit 014667c

Please sign in to comment.