Skip to content

Commit

Permalink
start adding cohere rerank support for cross-encoder
Browse files Browse the repository at this point in the history
  • Loading branch information
ystoneman committed Apr 16, 2024
1 parent befef57 commit 0042052
Show file tree
Hide file tree
Showing 7 changed files with 80 additions and 17 deletions.
5 changes: 5 additions & 0 deletions bin/config.ts
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,11 @@ export function getConfig(): SystemConfig {
name: "cross-encoder/ms-marco-MiniLM-L-12-v2",
default: true,
},
{
provider: "cohere",
name: "rerank-english-v3.0",
default: false,
},
],
},
};
Expand Down
4 changes: 4 additions & 0 deletions cli/magic-config.ts
Original file line number Diff line number Diff line change
Expand Up @@ -770,6 +770,10 @@ async function processCreateOptions(options: any): Promise<void> {
name: "cross-encoder/ms-marco-MiniLM-L-12-v2",
default: true,
};
config.rag.crossEncoderModels[1] = {
provider: "cohere",
name: "rerank-english-v3.0",
};
config.rag.embeddingsModels = embeddingModels;
config.rag.embeddingsModels.forEach((m: any) => {
if (m.name === models.defaultEmbedding) {
Expand Down
16 changes: 16 additions & 0 deletions docs/documentation/inference-script.md
Original file line number Diff line number Diff line change
Expand Up @@ -30,3 +30,19 @@ The API is JSON body based:
"passages": ["I love Paris", "I love London"]
}
```

## Cohere Rerank 3

To use the Cohere Rerank 3 model, get an API key from Cohere, and include the following in the JSON request body:

```json
{
"type": "cross-encoder",
"model": "rerank-english-v3.0",
"input": "What is the capital of the United States?",
"passages": [
"Carson City is the capital city of the American state of Nevada.",
"Washington, D.C. is the capital of the United States.",
...
]
}
37 changes: 21 additions & 16 deletions lib/rag-engines/sagemaker-rag-models/model/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
"intfloat/multilingual-e5-large",
"sentence-transformers/all-MiniLM-L6-v2",
]
cross_encoder_models = ["cross-encoder/ms-marco-MiniLM-L-12-v2"]
cross_encoder_models = ["cross-encoder/ms-marco-MiniLM-L-12-v2", "rerank-english-v3.0"]


def process_model_list(model_list):
Expand Down Expand Up @@ -130,21 +130,26 @@ def predict_fn(input_object, config):
passages = input_object["passages"]
data = [[current_input, passage] for passage in passages]

with torch.inference_mode():
features = current_tokenizer(
data, padding=True, truncation=True, return_tensors="pt"
)

features = features.to(device)

scores = current_model(**features).logits.cpu().numpy()
ret_value = list(
map(
lambda val: val[-1] if isinstance(val, list) else val,
scores.tolist(),
if current_model_id == "rerank-english-v3.0":
# Use Cohere Rerank 3 API
co = cohere.Client(os.environ["COHERE_API_KEY"])
results = co.rerank(query=current_input, documents=passages, top_n=len(passages), model='rerank-english-v3.0')
ret_value = [result.relevance_score for result in results]
else:
with torch.inference_mode():
features = current_tokenizer(
data, padding=True, truncation=True, return_tensors="pt"
)
)

return ret_value

features = features.to(device)

scores = current_model(**features).logits.cpu().numpy()
ret_value = list(
map(
lambda val: val[-1] if isinstance(val, list) else val,
scores.tolist(),
)
)
return ret_value

return []
10 changes: 10 additions & 0 deletions lib/shared/layers/python-sdk/python/genai_core/clients.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import boto3
import cohere
import openai
import genai_core.types
import genai_core.parameters
Expand Down Expand Up @@ -52,3 +53,12 @@ def get_bedrock_client(service_name="bedrock-runtime"):
bedrock_config_data["aws_session_token"] = credentials["SessionToken"]

return boto3.client(**bedrock_config_data)

def get_cohere_client():
api_key = genai_core.parameters.get_external_api_key("COHERE_API_KEY")
if not api_key:
return None

cohere_client = cohere.Client(api_key)

return cohere_client
23 changes: 23 additions & 0 deletions lib/shared/layers/python-sdk/python/genai_core/cross_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@ def rank_passages(

if model.provider == "sagemaker":
return _rank_passages_sagemaker(model, input, passages)
elif model.provider == "cohere":
return _rank_passages_cohere(model, input, passages)

raise genai_core.typesCommonError(f"Unknown provider")

Expand Down Expand Up @@ -66,3 +68,24 @@ def _rank_passages_sagemaker(
ret_value = json.loads(response["Body"].read().decode())

return ret_value

def _rank_passages_cohere(
model: genai_core.types.CrossEncoderModel, input: str, passages: List[str]
):
cohere_client = genai_core.clients.get_cohere_client()
if not cohere_client:
raise genai_core.types.CommonError("Cohere API key not set")

results = cohere_client.rerank(
query=input,
documents=passages,
model=model.name,
)

return [
genai_core.types.RankedPassage(
passage=passage,
score=result.relevance_score,
)
for passage, result in zip(passages, results)
]
2 changes: 1 addition & 1 deletion lib/shared/types.ts
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import * as sagemaker from "aws-cdk-lib/aws-sagemaker";

export type ModelProvider = "sagemaker" | "bedrock" | "openai";
export type ModelProvider = "sagemaker" | "bedrock" | "openai" | "cohere";

export enum SupportedSageMakerModels {
FalconLite = "FalconLite [ml.g5.12xlarge]",
Expand Down

0 comments on commit 0042052

Please sign in to comment.