Skip to content

Commit

Permalink
adapted model files for PR#2 'adjust qllama call to output hidden sta…
Browse files Browse the repository at this point in the history
…tes'
  • Loading branch information
Haojin Yang committed Sep 11, 2024
1 parent ea02447 commit 3a46cf8
Show file tree
Hide file tree
Showing 7 changed files with 22 additions and 11 deletions.
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# GBA Model Toolkit for M LX
# GBA Model Toolkit for MLX

## Introduction
Welcome to the GreenBitAI (GBA) Model Toolkit for [MLX](https://github.com/ml-explore/mlx)! This comprehensive Python package not only facilitates the conversion of [GreenBitAI's Low-bit Language Models (LLMs)](https://huggingface.co/collections/GreenBitAI/greenbitai-mlx-llm-6614eb6ceb8da657c2b4ed58) to MLX framework compatible format but also supports generation, model loading, and other essential scripts tailored for GBA quantized models. Designed to enhance the integration and deployment of GBA models within the MLX ecosystem, this toolkit enables the efficient execution of GBA models on a variety of platforms, with special optimizations for Apple devices to enable local inference and natural language content generation.
Expand Down Expand Up @@ -54,7 +54,7 @@ A high-performance HTTP API for text generation with GreenBitAI's mlx models. Im
#### Quick Start
1. Run:
```shell
python -m gbx_lm.fastapi_server.py --model GreenBitAI/Llama-3-8B-instruct-layer-mix-bpw-4.0-mlx
python -m gbx_lm.fastapi_server --model GreenBitAI/Llama-3-8B-instruct-layer-mix-bpw-4.0-mlx
```
2. Use:
```shell
Expand Down
3 changes: 2 additions & 1 deletion gbx_lm/models/qgemma.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,9 +167,10 @@ def __call__(
self,
inputs: mx.array,
cache=None,
hidden_states=False
):
out = self.model(inputs, cache)
out = self.model.embed_tokens.as_linear(out)
out = (self.model.embed_tokens.as_linear(out), out) if hidden_states else self.model.embed_tokens.as_linear(out)
return out

@property
Expand Down
6 changes: 4 additions & 2 deletions gbx_lm/models/qllama.py
Original file line number Diff line number Diff line change
Expand Up @@ -288,12 +288,14 @@ def __call__(
self,
inputs: mx.array,
cache=None,
hidden_states=False
):
out = self.model(inputs, cache)
if self.args.tie_word_embeddings:
out = self.model.embed_tokens.as_linear(out)
out = (self.model.embed_tokens.as_linear(out), out) if hidden_states else self.model.embed_tokens.as_linear(
out)
else:
out = self.lm_head(out)
out = (self.lm_head(out), out) if hidden_states else self.lm_head(out)
return out

def sanitize(self, weights):
Expand Down
4 changes: 3 additions & 1 deletion gbx_lm/models/qmixtral.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,9 +190,11 @@ def __call__(
self,
inputs: mx.array,
cache=None,
hidden_states=False
):
out = self.model(inputs, cache)
return self.lm_head(out)
out = (self.lm_head(out), out) if hidden_states else self.lm_head(out)
return out

def sanitize(self, weights):
if "model.layers.0.block_sparse_moe.experts.0.w1.weight" not in weights:
Expand Down
4 changes: 3 additions & 1 deletion gbx_lm/models/qphi3.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,9 +197,11 @@ def __call__(
self,
inputs: mx.array,
cache=None,
hidden_states=False
):
out = self.model(inputs, cache)
return self.lm_head(out)
out = (self.lm_head(out), out) if hidden_states else self.lm_head(out)
return out

@property
def layers(self):
Expand Down
6 changes: 4 additions & 2 deletions gbx_lm/models/qqwen2.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,12 +179,14 @@ def __call__(
self,
inputs: mx.array,
cache=None,
hidden_states=False
):
out = self.model(inputs, cache)
if self.args.tie_word_embeddings:
out = self.model.embed_tokens.as_linear(out)
out = (self.model.embed_tokens.as_linear(out), out) if hidden_states else self.model.embed_tokens.as_linear(
out)
else:
out = self.lm_head(out)
out = (self.lm_head(out), out) if hidden_states else self.lm_head(out)
return out

def sanitize(self, weights):
Expand Down
6 changes: 4 additions & 2 deletions gbx_lm/models/qstarcoder2.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,12 +154,14 @@ def __call__(
self,
inputs: mx.array,
cache=None,
hidden_states=False
):
out = self.model(inputs, cache)
if self.args.tie_word_embeddings:
out = self.model.embed_tokens.as_linear(out)
out = (self.model.embed_tokens.as_linear(out), out) if hidden_states else self.model.embed_tokens.as_linear(
out)
else:
out = self.lm_head(out)
out = (self.lm_head(out), out) if hidden_states else self.lm_head(out)
return out

@property
Expand Down

0 comments on commit 3a46cf8

Please sign in to comment.