-
Notifications
You must be signed in to change notification settings - Fork 350
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #245 from PrefectHQ/ai-functions-library
- Loading branch information
Showing
8 changed files
with
236 additions
and
12 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,65 @@ | ||
# AI Functions for entities | ||
|
||
AI functions are powerful tools for extracting structured data from unstructured text. | ||
|
||
## Extract keywords | ||
|
||
```python | ||
from marvin.ai_functions.entities import extract_keywords | ||
|
||
|
||
text = ( | ||
'The United States passed a law that requires all cars to have a "black' | ||
' box" that records data about the car and its driver. The law is sponsored' | ||
" by John Smith. It goes into effect in 2025." | ||
) | ||
|
||
entities_fns.extract_keywords(text) | ||
# ["United States", "law", "cars", "black box", "records", "data", "driver", "John Smith", "2025"] | ||
``` | ||
|
||
## Extract named entitites | ||
|
||
This function extracts named entities, tagging them with spaCy-compatible types: | ||
|
||
```python | ||
from marvin.ai_functions.entities import extract_named_entities | ||
|
||
|
||
text = ( | ||
'The United States passed a law that requires all cars to have a "black' | ||
' box" that records data about the car and its driver. The law is sponsored' | ||
" by John Smith. It goes into effect in 2025." | ||
) | ||
|
||
entities_fns.extract_named_entities(text) | ||
# [ | ||
# NamedEntity(entity="United States", type="GPE"), | ||
# NamedEntity(entity="John Smith", type="PERSON"), | ||
# NamedEntity(entity="2025", type="DATE"), | ||
# ] | ||
``` | ||
|
||
|
||
## Extract any type of entity | ||
|
||
A more flexible extraction function can retrieve multiple entity types in a single pass over the text. Here we pull countries and monetary values out of a sentence: | ||
|
||
```python | ||
from pydantic import BaseModel | ||
class Country(BaseModel): | ||
name: str | ||
|
||
class Money(BaseModel): | ||
amount: float | ||
currency: str | ||
|
||
text = "The United States EV tax credit is $7,500 for cars worth up to $50k." | ||
entities_fns.extract_types(text, types=[Country, Money]) | ||
|
||
# [ | ||
# Country(name="United States"), | ||
# Money(amount=7500, currency="USD"), | ||
# Money(amount=50000, currency="USD"), | ||
# ] | ||
``` |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,48 @@ | ||
from typing import TypeVar, Union | ||
|
||
from pydantic import Field | ||
|
||
from marvin.ai_functions import ai_fn | ||
from marvin.utilities.types import MarvinBaseModel | ||
|
||
T = TypeVar("T") | ||
|
||
|
||
@ai_fn | ||
def extract_keywords(text: str) -> list[str]: | ||
""" | ||
Extract the most important keywords from the given `text`. Choose words that | ||
best characterize its content. If there are no keywords, return an empty | ||
list. | ||
""" | ||
|
||
|
||
class NamedEntity(MarvinBaseModel): | ||
entity: str = Field(description="The entity name") | ||
type: str = Field(description="The entity type (based on spaCy NER types)") | ||
|
||
|
||
@ai_fn | ||
def extract_named_entities(text: str) -> list[NamedEntity]: | ||
""" | ||
Extract named entities from the given `text` and return a list of | ||
NamedEntity objects. Correct capitalization if needed. | ||
""" | ||
|
||
|
||
def extract_types(text: str, types: list[type[T]]) -> list[T]: | ||
""" | ||
Given text, extract entities of the given `types` in a single pass and | ||
return a list of matched objects. | ||
""" | ||
if len(types) > 1: | ||
types = Union[tuple(types)] | ||
|
||
@ai_fn | ||
def _extract(text: str) -> list[types]: | ||
""" | ||
Extract entities from the given `text` and return a list of any objects | ||
that match any of the provided `types`. Correct capitalization if needed. | ||
""" | ||
|
||
return _extract(text) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,8 +1,38 @@ | ||
import asyncio | ||
import functools | ||
import inspect | ||
|
||
from marvin.ai_functions import ai_fn | ||
|
||
|
||
def _strip_result(fn): | ||
""" | ||
A decorator that automatically strips whitespace from the result of | ||
calling the function | ||
""" | ||
|
||
@functools.wraps(fn) | ||
def wrapper(*args, **kwargs): | ||
result = fn(*args, **kwargs) | ||
if inspect.iscoroutine(result): | ||
result = asyncio.run(result) | ||
return result.strip() | ||
|
||
return wrapper | ||
|
||
|
||
@_strip_result | ||
@ai_fn | ||
def fix_capitalization(text: str) -> str: | ||
""" | ||
Given `text`, which represents complete sentences, fix any capitalization | ||
errors. | ||
""" | ||
|
||
|
||
@_strip_result | ||
@ai_fn | ||
def title_case(input: str) -> str: | ||
def title_case(text: str) -> str: | ||
""" | ||
Given a string {input} change the case to make it APA style guide title case. | ||
Given `text`, change the case to make it APA style guide title case. | ||
""" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,57 @@ | ||
from marvin.ai_functions import entities as entities_fns | ||
from marvin.utilities.types import MarvinBaseModel | ||
|
||
|
||
class TestKeywordExtraction: | ||
def test_keyword_extraction(self): | ||
text = ( | ||
'The United States passed a law that requires all cars to have a "black' | ||
' box" that records data about the car and its driver. The law is sponsored' | ||
" by John Smith. It goes into effect in 2025." | ||
) | ||
result = entities_fns.extract_keywords(text) | ||
assert result == [ | ||
"United States", | ||
"law", | ||
"cars", | ||
"black box", | ||
"records", | ||
"data", | ||
"driver", | ||
"John Smith", | ||
"2025", | ||
] | ||
|
||
|
||
class TestNamedEntityExtraction: | ||
def test_named_entity_extraction(self): | ||
text = ( | ||
'The United States passed a law that requires all cars to have a "black' | ||
' box" that records data about the car and its driver. The law is sponsored' | ||
" by John Smith. It goes into effect in 2025." | ||
) | ||
result = entities_fns.extract_named_entities(text) | ||
assert result == [ | ||
entities_fns.NamedEntity(entity="United States", type="GPE"), | ||
entities_fns.NamedEntity(entity="John Smith", type="PERSON"), | ||
entities_fns.NamedEntity(entity="2025", type="DATE"), | ||
] | ||
|
||
|
||
class TestExtractTypes: | ||
class Country(MarvinBaseModel): | ||
name: str | ||
|
||
class Money(MarvinBaseModel): | ||
amount: float | ||
currency: str | ||
|
||
def test_extract_types(self, gpt_4): | ||
text = "The United States EV tax credit is $7,500 for cars worth up to $50k." | ||
result = entities_fns.extract_types(text, types=[self.Country, self.Money]) | ||
|
||
assert result == [ | ||
self.Country(name="United States"), | ||
self.Money(amount=7500, currency="USD"), | ||
self.Money(amount=50000, currency="USD"), | ||
] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,10 +1,18 @@ | ||
from marvin.ai_functions import strings as string_fns | ||
|
||
|
||
class TestFixCapitalization: | ||
def test_fix_capitalization(self, gpt_4): | ||
result = string_fns.fix_capitalization("the european went over to canada, eh?") | ||
assert result == "The European went over to Canada, eh?" | ||
|
||
|
||
class TestTitleCase: | ||
def test_title_case(self): | ||
result = string_fns.title_case("the european went over to canada, eh?") | ||
assert result == "The European Went Over to Canada, Eh?" | ||
|
||
def test_short_prepositions_not_capitalized(self): | ||
result = string_fns.title_case( | ||
input="let me go to the store", | ||
) | ||
result = string_fns.title_case("let me go to the store") | ||
|
||
assert result == "Let Me Go to the store" | ||
assert result == "Let Me Go to the Store" |