NLP and Transformers
CSE 891: Deep Learning
Vishnu Boddeti
Wednesday October 20, 2021
Layer Type | Complexity Per Layer | Sequential Ops | Max Path Length |
---|---|---|---|
Self-Attention | $\mathcal{O}(n^2 \cdot d)$ | $\mathcal{O}(1)$ | $\mathcal{O}(1)$ |
Recurrent | $\mathcal{O}(n \cdot d^2)$ | $\mathcal{O}(n)$ | $\mathcal{O}(n)$ |
Convolutional | $\mathcal{O}(k \cdot n \cdot d^2)$ | $\mathcal{O}(1)$ | $\mathcal{O}(\log_k(n))$ |
Self-Attention (restricted) | $\mathcal{O}(r \cdot n \cdot d)$ | $\mathcal{O}(1)$ | $\mathcal{O}(n/r)$ |
Model | Layers | Width | Heads | Params | Data | Training |
---|---|---|---|---|---|---|
Transformer-Base | 12 | 512 | 8 | 65M | 8x P100 (12 hrs) | |
Transformer-Large | 12 | 1024 | 16 | 213M | 8x P100 (3.5 days) | |
BERT-Base | 12 | 768 | 12 | 110M | 13GB | |
BERT-Large | 24 | 1024 | 16 | 340M | 13GB | |
XLNet-Large | 24 | 1024 | 16 | 340M | 126GB | 512x TPU-v3 (2.5 days) |
RoBERTa | 24 | 1024 | 16 | 355M | 160GB | 1024x V100 (1 day) |
GPT-2 | 48 | 1600 | ? | 1.5B | 40GB | |
Megatron-LM | 72 | 3072 | 32 | 8.3B | 174GB | 512x V100 (9 days) |
Turing-NLG | 78 | 4256 | 28 | 17B | ? | 256x V100 |
GPT-3 | 96 | 12288 | 96 | 175B | 694GB | ? |
class PreNormResidual(nn.Module):
def __init__(self, dim, fn):
super().__init__()
self.fn = fn
self.norm = nn.LayerNorm(dim)
def forward(self, x):
return self.fn(self.norm(x)) + x
def FeedForward(dim, expansion_factor = 4, dropout = 0., dense = nn.Linear):
return nn.Sequential(
dense(dim, dim * expansion_factor),
nn.GELU(),
nn.Dropout(dropout),
dense(dim * expansion_factor, dim),
nn.Dropout(dropout)
)
def MLPMixer(*, image_size, channels, patch_size, dim, depth, num_classes, expansion_factor = 4, dropout = 0.):
assert (image_size % patch_size) == 0, 'image must be divisible by patch size'
num_patches = (image_size // patch_size) ** 2
chan_first, chan_last = partial(nn.Conv1d, kernel_size = 1), nn.Linear
return nn.Sequential(
Rearrange('b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1 = patch_size, p2 = patch_size),
nn.Linear((patch_size ** 2) * channels, dim),
*[nn.Sequential(
PreNormResidual(dim, FeedForward(num_patches, expansion_factor, dropout, chan_first)),
PreNormResidual(dim, FeedForward(dim, expansion_factor, dropout, chan_last))
) for _ in range(depth)],
nn.LayerNorm(dim),
Reduce('b n c -> b c', 'mean'),
nn.Linear(dim, num_classes)
)
import einops
import flax.linen as nn
import jax.numpy as jnp
class MlpBlock(nn.Module):
mlp_dim: int
@nn.compact
def __call__(self, x):
y = nn.Dense(self.mlp_dim)(x)
y = nn.gelu(y)
return nn.Dense(x.shape[-1])(y)
class MixerBlock(nn.Module):
"""Mixer block layer."""
tokens_mlp_dim: int
channels_mlp_dim: int
@nn.compact
def __call__(self, x):
y = nn.LayerNorm()(x)
y = jnp.swapaxes(y, 1, 2)
y = MlpBlock(self.tokens_mlp_dim, name='token_mixing')(y)
y = jnp.swapaxes(y, 1, 2)
x = x + y
y = nn.LayerNorm()(x)
return x + MlpBlock(self.channels_mlp_dim, name='channel_mixing')(y)
class MlpMixer(nn.Module):
"""Mixer architecture."""
patches: Any
num_classes: int
num_blocks: int
hidden_dim: int
tokens_mlp_dim: int
channels_mlp_dim: int
@nn.compact
def __call__(self, inputs, *, train):
del train
x = nn.Conv(self.hidden_dim, self.patches.size, strides=self.patches.size, name='stem')(inputs)
x = einops.rearrange(x, 'n h w c -> n (h w) c')
for _ in range(self.num_blocks):
x = MixerBlock(self.tokens_mlp_dim, self.channels_mlp_dim)(x)
x = nn.LayerNorm(name='pre_head_layer_norm')(x)
x = jnp.mean(x, axis=1)
if self.num_classes:
x = nn.Dense(self.num_classes, kernel_init=nn.initializers.zeros,name='head')(x)
return x