Skip to content

Commit

Permalink
Merge pull request #245 from PrefectHQ/ai-functions-library
Browse files Browse the repository at this point in the history
  • Loading branch information
jlowin authored Apr 22, 2023
2 parents bc31086 + e7bcd17 commit 3190e84
Show file tree
Hide file tree
Showing 8 changed files with 236 additions and 12 deletions.
65 changes: 65 additions & 0 deletions docs/guide/ai_functions/entities.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
# AI Functions for entities

AI functions are powerful tools for extracting structured data from unstructured text.

## Extract keywords

```python
from marvin.ai_functions.entities import extract_keywords


text = (
'The United States passed a law that requires all cars to have a "black'
' box" that records data about the car and its driver. The law is sponsored'
" by John Smith. It goes into effect in 2025."
)

entities_fns.extract_keywords(text)
# ["United States", "law", "cars", "black box", "records", "data", "driver", "John Smith", "2025"]
```

## Extract named entitites

This function extracts named entities, tagging them with spaCy-compatible types:

```python
from marvin.ai_functions.entities import extract_named_entities


text = (
'The United States passed a law that requires all cars to have a "black'
' box" that records data about the car and its driver. The law is sponsored'
" by John Smith. It goes into effect in 2025."
)

entities_fns.extract_named_entities(text)
# [
# NamedEntity(entity="United States", type="GPE"),
# NamedEntity(entity="John Smith", type="PERSON"),
# NamedEntity(entity="2025", type="DATE"),
# ]
```


## Extract any type of entity

A more flexible extraction function can retrieve multiple entity types in a single pass over the text. Here we pull countries and monetary values out of a sentence:

```python
from pydantic import BaseModel
class Country(BaseModel):
name: str

class Money(BaseModel):
amount: float
currency: str

text = "The United States EV tax credit is $7,500 for cars worth up to $50k."
entities_fns.extract_types(text, types=[Country, Money])

# [
# Country(name="United States"),
# Money(amount=7500, currency="USD"),
# Money(amount=50000, currency="USD"),
# ]
```
12 changes: 11 additions & 1 deletion docs/guide/ai_functions/strings.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,16 @@
# AI Functions for strings

## Actual title case
## Fix capitalization
Given a string that may not have correct capitalization, fix its capitalization but make no other changes.

```python
from marvin.ai_functions.strings import fix_capitalization

fix_capitalization("the european went over to canada, eh?")
# The European went over to Canada, eh?
```

## APA title case

Return a title case string that you would want to use in a title.

Expand Down
11 changes: 6 additions & 5 deletions mkdocs.yml
Original file line number Diff line number Diff line change
Expand Up @@ -11,24 +11,25 @@ nav:
- AI functions: getting_started/ai_functions_quickstart.md
- Bots: getting_started/bots_quickstart.md
- The Guide:
- Introduction:
- Introduction 👋:
- Overview: guide/introduction/overview.md
- Configuration: guide/introduction/configuration.md
- CLI: guide/introduction/cli.md
- Concepts:
- Concepts 💡:
- AI functions: guide/concepts/ai_functions.md
- Bots: guide/concepts/bots.md
- TUI: guide/concepts/tui.md
- Loaders: guide/concepts/loaders_and_documents.md
- Infrastructure: guide/concepts/infra.md
- Plugins: guide/concepts/plugins.md
- AI Functions:
- AI Functions 🪄:
- Entity extraction: guide/ai_functions/entities.md
- Data: guide/ai_functions/data.md
- Strings: guide/ai_functions/strings.md
- Use Cases:
- Use Cases 🏗️:
- Enforcing LLM output formats: guide/use_cases/enforcing_format.md
- Slackbot: guide/use_cases/slackbot.md
- Development:
- Development 🧑‍💻:
- Setting up: development/development.md
- FAQ: faq.md

Expand Down
5 changes: 5 additions & 0 deletions src/marvin/ai_functions/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from typing import Any, Callable, TypeVar

from marvin.bot import Bot
from marvin.bot.history import InMemoryHistory
from marvin.utilities.strings import jinja_env

AI_FN_INSTRUCTIONS = jinja_env.from_string(inspect.cleandoc("""
Expand Down Expand Up @@ -146,6 +147,10 @@ def ai_fn_wrapper(*args, **kwargs) -> Any:
if "plugins" not in wrapper_bot_kwargs:
wrapper_bot_kwargs["plugins"] = []

# ai functions do not persist by default
if "history" not in wrapper_bot_kwargs:
wrapper_bot_kwargs["history"] = InMemoryHistory()

# create the bot
bot = Bot(
instructions=instructions,
Expand Down
48 changes: 48 additions & 0 deletions src/marvin/ai_functions/entities.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
from typing import TypeVar, Union

from pydantic import Field

from marvin.ai_functions import ai_fn
from marvin.utilities.types import MarvinBaseModel

T = TypeVar("T")


@ai_fn
def extract_keywords(text: str) -> list[str]:
"""
Extract the most important keywords from the given `text`. Choose words that
best characterize its content. If there are no keywords, return an empty
list.
"""


class NamedEntity(MarvinBaseModel):
entity: str = Field(description="The entity name")
type: str = Field(description="The entity type (based on spaCy NER types)")


@ai_fn
def extract_named_entities(text: str) -> list[NamedEntity]:
"""
Extract named entities from the given `text` and return a list of
NamedEntity objects. Correct capitalization if needed.
"""


def extract_types(text: str, types: list[type[T]]) -> list[T]:
"""
Given text, extract entities of the given `types` in a single pass and
return a list of matched objects.
"""
if len(types) > 1:
types = Union[tuple(types)]

@ai_fn
def _extract(text: str) -> list[types]:
"""
Extract entities from the given `text` and return a list of any objects
that match any of the provided `types`. Correct capitalization if needed.
"""

return _extract(text)
34 changes: 32 additions & 2 deletions src/marvin/ai_functions/strings.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,38 @@
import asyncio
import functools
import inspect

from marvin.ai_functions import ai_fn


def _strip_result(fn):
"""
A decorator that automatically strips whitespace from the result of
calling the function
"""

@functools.wraps(fn)
def wrapper(*args, **kwargs):
result = fn(*args, **kwargs)
if inspect.iscoroutine(result):
result = asyncio.run(result)
return result.strip()

return wrapper


@_strip_result
@ai_fn
def fix_capitalization(text: str) -> str:
"""
Given `text`, which represents complete sentences, fix any capitalization
errors.
"""


@_strip_result
@ai_fn
def title_case(input: str) -> str:
def title_case(text: str) -> str:
"""
Given a string {input} change the case to make it APA style guide title case.
Given `text`, change the case to make it APA style guide title case.
"""
57 changes: 57 additions & 0 deletions tests/llm_tests/ai_functions/test_entities.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
from marvin.ai_functions import entities as entities_fns
from marvin.utilities.types import MarvinBaseModel


class TestKeywordExtraction:
def test_keyword_extraction(self):
text = (
'The United States passed a law that requires all cars to have a "black'
' box" that records data about the car and its driver. The law is sponsored'
" by John Smith. It goes into effect in 2025."
)
result = entities_fns.extract_keywords(text)
assert result == [
"United States",
"law",
"cars",
"black box",
"records",
"data",
"driver",
"John Smith",
"2025",
]


class TestNamedEntityExtraction:
def test_named_entity_extraction(self):
text = (
'The United States passed a law that requires all cars to have a "black'
' box" that records data about the car and its driver. The law is sponsored'
" by John Smith. It goes into effect in 2025."
)
result = entities_fns.extract_named_entities(text)
assert result == [
entities_fns.NamedEntity(entity="United States", type="GPE"),
entities_fns.NamedEntity(entity="John Smith", type="PERSON"),
entities_fns.NamedEntity(entity="2025", type="DATE"),
]


class TestExtractTypes:
class Country(MarvinBaseModel):
name: str

class Money(MarvinBaseModel):
amount: float
currency: str

def test_extract_types(self, gpt_4):
text = "The United States EV tax credit is $7,500 for cars worth up to $50k."
result = entities_fns.extract_types(text, types=[self.Country, self.Money])

assert result == [
self.Country(name="United States"),
self.Money(amount=7500, currency="USD"),
self.Money(amount=50000, currency="USD"),
]
16 changes: 12 additions & 4 deletions tests/llm_tests/ai_functions/test_strings.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,18 @@
from marvin.ai_functions import strings as string_fns


class TestFixCapitalization:
def test_fix_capitalization(self, gpt_4):
result = string_fns.fix_capitalization("the european went over to canada, eh?")
assert result == "The European went over to Canada, eh?"


class TestTitleCase:
def test_title_case(self):
result = string_fns.title_case("the european went over to canada, eh?")
assert result == "The European Went Over to Canada, Eh?"

def test_short_prepositions_not_capitalized(self):
result = string_fns.title_case(
input="let me go to the store",
)
result = string_fns.title_case("let me go to the store")

assert result == "Let Me Go to the store"
assert result == "Let Me Go to the Store"

0 comments on commit 3190e84

Please sign in to comment.