Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

adding-support-for-mamba2 #1009

Open
wants to merge 49 commits into
base: main
Choose a base branch
from
Open
Changes from 1 commit
Commits
Show all changes
49 commits
Select commit Hold shift + click to select a range
49b9fc1
Create mamba2.py
Goekdeniz-Guelmez Oct 2, 2024
409ddc4
updating ACKNOWLEDGMENTS.md file
Goekdeniz-Guelmez Oct 2, 2024
264ba43
update trainer/lora.py and adding DepthWiseConv1d because mlx 0.18.0 …
Goekdeniz-Guelmez Oct 2, 2024
52d6ca0
Merge branch 'ml-explore:main' into adding-support-for-mamba2
Goekdeniz-Guelmez Oct 4, 2024
4e1236c
fixing loading the model
Goekdeniz-Guelmez Oct 11, 2024
9c075a7
Merge branch 'adding-support-for-mamba2' of https://github.com/Goekde…
Goekdeniz-Guelmez Oct 11, 2024
6f88dd5
quick clean up and fix
Goekdeniz-Guelmez Oct 11, 2024
00ba27f
adding debug statements
Goekdeniz-Guelmez Oct 11, 2024
3f1c1dd
Merge branch 'ml-explore:main' into adding-support-for-mamba2
Goekdeniz-Guelmez Oct 14, 2024
855fcc4
Merge branch 'ml-explore:main' into adding-support-for-mamba2
Goekdeniz-Guelmez Oct 16, 2024
8073cb4
adding debug statements (somehiw generating only goes through the fis…
Goekdeniz-Guelmez Oct 16, 2024
181d6ab
Merge branch 'adding-support-for-mamba2' of https://github.com/Goekde…
Goekdeniz-Guelmez Oct 16, 2024
cd036cc
fix generation works too (almost)
Goekdeniz-Guelmez Oct 16, 2024
4ab5139
quick save
Goekdeniz-Guelmez Oct 20, 2024
ab4cf1d
generation works but outputs gibberish
Goekdeniz-Guelmez Oct 20, 2024
c1634ce
still generating gibberish
Goekdeniz-Guelmez Oct 20, 2024
0ef73f3
Merge branch 'ml-explore:main' into adding-support-for-mamba2
Goekdeniz-Guelmez Oct 21, 2024
b9c57cd
generation works! trying training now
Goekdeniz-Guelmez Oct 22, 2024
5326d93
Merge branch 'adding-support-for-mamba2' of https://github.com/Goekde…
Goekdeniz-Guelmez Oct 22, 2024
758597e
adding multi token input and correct cache handling in ssm step
Goekdeniz-Guelmez Oct 22, 2024
55485b9
update
Goekdeniz-Guelmez Oct 22, 2024
e43a2ab
not working, incorrect handling with cache probably
Goekdeniz-Guelmez Oct 22, 2024
9ab581d
notes
Goekdeniz-Guelmez Oct 22, 2024
a677638
inference works but is hella slow
Goekdeniz-Guelmez Oct 22, 2024
7c8849e
update
Goekdeniz-Guelmez Oct 24, 2024
3b70708
Merge branch 'ml-explore:main' into adding-support-for-mamba2
Goekdeniz-Guelmez Oct 25, 2024
ffc7ab0
Merge branch 'ml-explore:main' into adding-support-for-mamba2
Goekdeniz-Guelmez Oct 30, 2024
58b448d
updates
Goekdeniz-Guelmez Oct 30, 2024
906f972
save push
Goekdeniz-Guelmez Nov 6, 2024
800b602
save checkpoint
Goekdeniz-Guelmez Nov 10, 2024
3a499f9
fixed inference slowness but it cant handle multible Token inputs and…
Goekdeniz-Guelmez Nov 10, 2024
49d3f18
Merge branch 'ml-explore:main' into adding-support-for-mamba2
Goekdeniz-Guelmez Nov 10, 2024
2f95b36
removed the custom Mamba2Cache adn updated the existing MambaCache bu…
Goekdeniz-Guelmez Nov 10, 2024
1a66883
imopemented multi Token inputs, but still generating Gibberish
Goekdeniz-Guelmez Nov 10, 2024
1d85106
nits
Goekdeniz-Guelmez Nov 10, 2024
e4eae97
Merge branch 'ml-explore:main' into adding-support-for-mamba2
Goekdeniz-Guelmez Nov 21, 2024
e22b2db
Fixed streaming generation and got rid of generating gibberish, but i…
Goekdeniz-Guelmez Nov 21, 2024
117ffd3
removing some files
Goekdeniz-Guelmez Nov 21, 2024
57b1717
inference fixed
Goekdeniz-Guelmez Nov 21, 2024
a6ddc27
removing last checkpoint file
Goekdeniz-Guelmez Nov 21, 2024
38e5801
loading codestral works but no tinference
Goekdeniz-Guelmez Nov 24, 2024
ddad210
Merge branch 'main' into adding-support-for-mamba2
Goekdeniz-Guelmez Dec 10, 2024
9f8a6a3
inference on codestral works but is giberish
Goekdeniz-Guelmez Dec 10, 2024
b10afe3
nits
Goekdeniz-Guelmez Dec 10, 2024
80e88b4
nits
Goekdeniz-Guelmez Dec 10, 2024
184d3d3
clean up
Goekdeniz-Guelmez Dec 10, 2024
c1d9ec3
Merge branch 'ml-explore:main' into adding-support-for-mamba2
Goekdeniz-Guelmez Dec 10, 2024
a883e39
optimizing the code for faster inference but still generates giberish
Goekdeniz-Guelmez Dec 12, 2024
dff4e52
adding the modelnames in the LORA.md file and removing unused functio…
Goekdeniz-Guelmez Dec 12, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
inference works but is hella slow
Goekdeniz-Guelmez committed Oct 22, 2024
commit a677638c4bf2395784c083b2546ed779c4bfa5f4
2 changes: 1 addition & 1 deletion llms/mlx_lm/models/cache.py
Original file line number Diff line number Diff line change
@@ -342,7 +342,7 @@ def state(self, v):

class Mamba2Cache(_BaseCache):
conv_states: Optional[mx.array] = None
ssm_states: Optional[mx.array] = None
ssm_state: Optional[mx.array] = None

def __getitem__(self, idx: int) -> Optional[mx.array]:
if idx == 0:
192 changes: 59 additions & 133 deletions llms/mlx_lm/models/mamba2.py
Original file line number Diff line number Diff line change
@@ -103,89 +103,55 @@ def __init__(self, in_channels, out_channels, kernel_size, bias=True, groups=Non
assert in_channels == out_channels, "In and out channels must be same for depthwise convolution"
assert self.groups == in_channels, "Groups must be equal to in_channels for depthwise convolution"

# Weight shape: (channels, 1, kernel_size) to match pretrained weights
self.weight = mx.random.normal((in_channels, 1, kernel_size))
self.bias = mx.zeros((out_channels,)) if bias else None

def __call__(self, x: mx.array, cache=None, cache_idx: int = 0) -> mx.array:
def __call__(self, x: mx.array, cache=None) -> mx.array:
B, L, C = x.shape
K = self.kernel_size

# Validate input dimensions
assert C == self.in_channels, f"Input channels {C} doesn't match expected {self.in_channels}"

# Handle padding and caching
if cache is not None:
conv_states = cache[cache_idx]
if cache is not None and 'conv_states' in cache:
conv_states = cache['conv_states']
if conv_states is not None:
# Validate cache shape
assert conv_states.shape[0] == B, "Cache batch size mismatch"
assert conv_states.shape[2] == C, "Cache channel count mismatch"
x = mx.concatenate([conv_states, x], axis=1)
L = x.shape[1]
else:
# Add left padding of size (kernel_size - 1)
pad_left = K - 1
x = mx.pad(x, [(0, 0), (pad_left, 0), (0, 0)])
L = x.shape[1]

# Pre-allocate output array if possible
outputs = []


# Process each channel independently
outputs = []
for c in range(C):
# Extract and prepare channel data
x_c = x[:, :, c] # Shape: [B, L]
x_c = mx.expand_dims(x_c, axis=1) # Shape: [B, 1, L]
x_c = x[:, :, c]
x_c = mx.expand_dims(x_c, axis=1)

# Prepare filter weights
w_c = self.weight[c] # Get channel weights
# Ensure filter is 3D: [depth(1), in_channels(1), kernel_size]
w_c = self.weight[c]
if w_c.ndim == 2:
w_c = mx.expand_dims(w_c, axis=0)
elif w_c.ndim == 1:
w_c = mx.expand_dims(mx.expand_dims(w_c, axis=0), axis=0)

# Handle inference mode (single token)
if L < K:
pad_size = K - L
x_c = mx.pad(x_c, [(0, 0), (0, 0), (pad_size, 0)])
# Apply convolution
y_c = mx.conv_general(
x_c,
w_c,
stride=1,
padding=0
)

if self.bias is not None:
y_c = y_c + self.bias[c]

# Apply 1D convolution
try:
y_c = mx.conv_general(
x_c,
w_c,
stride=1,
padding=0 # Padding already handled
)

if self.bias is not None:
y_c = y_c + self.bias[c]

# Remove singleton dimension and add to outputs
outputs.append(mx.squeeze(y_c, axis=1))

except Exception as e:
raise RuntimeError(f"Convolution failed for channel {c}. Shapes: input={x_c.shape}, weight={w_c.shape}") from e

# Stack channel outputs along last dimension
y = mx.stack(outputs, axis=-1) # Shape: [B, L', C]
outputs.append(mx.squeeze(y_c, axis=1))

y = mx.stack(outputs, axis=-1)

# Update cache if needed
# Update cache
if cache is not None:
# Store last (kernel_size - 1) tokens or entire input if shorter
new_cache = x[:, -(K-1):, :] if L >= K else x
cache[cache_idx] = new_cache

if new_cache.shape != cache[cache_idx].shape:
cache[cache_idx] = new_cache
print(f"Cache updated at index {cache_idx}")
else:
print(f"Skipping cache update at index {cache_idx}, shapes are identical.")

cache['conv_states'] = x[:, -K+1:, :] if x.shape[1] >= K else x

return y


class Mamba2Block(nn.Module):
def __init__(self, args: ModelArgs):
@@ -217,81 +183,49 @@ def __init__(self, args: ModelArgs):
self.out_proj.weight = self.out_proj.weight * layer_scale

def __call__(self, x: mx.array, cache=None):
# if cache is not None and self.args.use_cache:
if cache is not None:
return self.step(x, cache)

# Calculate sizes
# Regular forward pass code remains the same...
d_model = self.args.intermediate_size
d_state = self.args.state_size
n_heads = self.args.num_heads

# Compute A
A = -mx.exp(self.A_log)

# Project input
zxbcdt = self.in_proj(x)

# Correct splits for z, xBC, dt
splits = [
d_model, # z
d_model + 2 * d_state, # xBC (delta, B, C concatenated)
n_heads # dt
]

# Split using cumulative indices
splits = [d_model, d_model + 2 * d_state, n_heads]
z = zxbcdt[:, :, :splits[0]]
xBC = zxbcdt[:, :, splits[0]:splits[0] + splits[1]]
dt = zxbcdt[:, :, -splits[2]:]

# Process dt
dt = mx.clip(
nn.softplus(dt + self.dt_bias),
self.args.time_step_min,
self.args.time_step_max
)
dt = mx.maximum(dt, self.args.time_step_floor)

# Process convolution
xBC = silu(self.conv1d(xBC))

# Split convolved xBC into x, B, C
x = xBC[:, :, :d_model]
B = xBC[:, :, d_model:d_model + d_state]
C = xBC[:, :, -d_state:]

# Reshape for SSM computation
b, l, hp = x.shape
h = self.args.num_heads
p = hp // h
x = mx.reshape(x, (b, l, h, p))

# Compute SSM
y, ssm_state = ssd(
x * mx.expand_dims(dt, -1),
A * dt,
B,
C,
self.args.chunk_size
)

# Add skip connection
y, ssm_state = ssd(x * mx.expand_dims(dt, -1), A * dt, B, C, self.args.chunk_size)
y = y + x * mx.expand_dims(self.D, -1)

# Reshape back
y = mx.reshape(y, (b, l, h * p))

# Apply norm and projection
y = self.norm(y + z)
y = self.out_proj(y)

# Update cache if needed
if cache is not None and self.args.use_cache:
cache[1] = ssm_state

# Cast if needed
if self.args.residual_in_fp32:
y.astype(mx.float32)
y = y.astype(mx.float32)

return y

@@ -300,36 +234,35 @@ def step(self, u: mx.array, cache):
seq_len = u.shape[1]
outputs = []

# Initialize SSM state if needed
if cache[1] is None:
cache[1] = mx.zeros((
# Initialize cache if needed
if cache.conv_states is None:
conv_dim = self.args.intermediate_size + 2 * self.args.state_size
cache.conv_states = mx.zeros((
batch_size,
self.args.conv_kernel - 1,
conv_dim
))

if cache.ssm_state is None:
cache.ssm_state = mx.zeros((
batch_size,
self.args.num_heads,
self.args.head_dim,
self.args.state_size
))

for pos in range(seq_len):
# Getting stuck here in last position, also cache from pos 0 is the same.
# Get single token
u_t = u[:, pos:pos+1, :]

# Project input
zxbcdt = self.in_proj(u_t)

# Calculate sizes
d_model = self.args.intermediate_size
d_state = self.args.state_size
n_heads = self.args.num_heads
d_head = self.args.head_dim

# Split projected input
# conv_dim = d_model + 2 * d_state (this should match self.conv1d.in_channels)
z = zxbcdt[:, :, :d_model]
xBC = zxbcdt[:, :, d_model:d_model + 2*d_state + d_model] # Include the full conv dimension
xBC = zxbcdt[:, :, d_model:d_model + 2*d_state + d_model]
dt = zxbcdt[:, :, -(n_heads):]

# Process dt
dt = mx.reshape(dt, (batch_size, n_heads))
dt = mx.clip(
nn.softplus(dt + self.dt_bias),
@@ -338,49 +271,43 @@ def step(self, u: mx.array, cache):
)
dt = mx.maximum(dt, self.args.time_step_floor)

# Process convolution with correct dimensions
xBC = self.conv1d(xBC, cache=cache, cache_idx=0)
# Create a temporary cache dictionary for the convolution
conv_cache = {'conv_states': cache.conv_states}
xBC = self.conv1d(xBC, cache=conv_cache)
cache.conv_states = conv_cache['conv_states']

xBC = silu(xBC)

# Split convolved xBC into x, B, C with correct dimensions
x = xBC[:, :, :d_model]
B = xBC[:, :, d_model:d_model + d_state]
C = xBC[:, :, -d_state:]

# Reshape tensors for SSM computation
x = mx.reshape(x, (batch_size, 1, n_heads, d_head))
x = mx.squeeze(x, axis=1) # (batch, heads, dim)

x = mx.reshape(x, (batch_size, 1, n_heads, self.args.head_dim))
x = mx.squeeze(x, axis=1)

B = mx.reshape(B, (batch_size, 1, d_state))
B = mx.broadcast_to(B, (batch_size, n_heads, d_state))
B = mx.expand_dims(B, axis=2) # (batch, heads, 1, state)
B = mx.expand_dims(B, axis=2)

C = mx.reshape(C, (batch_size, 1, d_state))
C = mx.broadcast_to(C, (batch_size, n_heads, d_state))
C = mx.expand_dims(C, axis=3) # (batch, heads, state, 1)
C = mx.expand_dims(C, axis=3)

# Compute SSM updates
A = -mx.exp(self.A_log)
dA = mx.exp(dt * mx.expand_dims(A, 0))
dA = mx.expand_dims(mx.expand_dims(dA, -1), -1) # (batch, heads, 1, 1)
dA = mx.expand_dims(mx.expand_dims(dA, -1), -1)

# Update state with proper shapes
x = mx.expand_dims(x, axis=3) # (batch, heads, dim, 1)
dBx = mx.matmul(x, B) # (batch, heads, dim, state)
x = mx.expand_dims(x, axis=3)
dBx = mx.matmul(x, B)

ssm_state = cache[1]
ssm_state = ssm_state * dA + dBx
cache[1] = ssm_state
cache.ssm_state = cache.ssm_state * dA + dBx

# Compute output
y = mx.matmul(ssm_state, C) # (batch, heads, dim, 1)
y = mx.squeeze(y, axis=-1) # (batch, heads, dim)
y = mx.matmul(cache.ssm_state, C)
y = mx.squeeze(y, axis=-1)

# Add skip connection
y = y + x[:, :, :, 0] * mx.expand_dims(self.D, -1)

# Reshape and process output
y = mx.reshape(y, (batch_size, 1, n_heads * d_head))
y = mx.reshape(y, (batch_size, 1, n_heads * self.args.head_dim))
y = self.norm(y + z)
y = self.out_proj(y)

@@ -440,7 +367,6 @@ def __call__(self, inputs: mx.array, cache=None):
else:
logits = self.lm_head(x)

print('ouput')
return logits

def make_cache(self, batch_size=1):