Skip to content

Commit

Permalink
Refactor elevenlabs tool
Browse files Browse the repository at this point in the history
  • Loading branch information
baskaryan committed Sep 13, 2023
1 parent 97122fb commit 79a567d
Showing 1 changed file with 24 additions and 15 deletions.
39 changes: 24 additions & 15 deletions libs/langchain/langchain/tools/eleven_labs/text2speech.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,28 @@
import tempfile
from typing import TYPE_CHECKING, Dict, Optional, Union
from enum import Enum
from typing import Any, Dict, Optional, Union

from langchain.callbacks.manager import CallbackManagerForToolRun
from langchain.pydantic_v1 import root_validator
from langchain.tools.base import BaseTool
from langchain.tools.eleven_labs.models import ElevenLabsModel
from langchain.utils import get_from_dict_or_env

if TYPE_CHECKING:

def _import_elevenlabs() -> Any:
try:
import elevenlabs

except ImportError:
except ImportError as e:
raise ImportError(
"elevenlabs is not installed. " "Run `pip install elevenlabs` to install."
)
"Cannot import elevenlabs, please install `pip install elevenlabs`."
) from e
return elevenlabs


class ElevenLabsModel(str, Enum):
"""Models available for Eleven Labs Text2Speech."""

MULTI_LINGUAL = "eleven_multilingual_v1"
MONO_LINGUAL = "eleven_monolingual_v1"


class ElevenLabsText2SpeechTool(BaseTool):
Expand All @@ -41,24 +49,24 @@ def validate_environment(cls, values: Dict) -> Dict:

return values

def _text2speech(self, text: str) -> str:
speech = elevenlabs.generate(text=text, model=self.model)
with tempfile.NamedTemporaryFile(mode="bx", suffix=".wav", delete=False) as f:
f.write(speech)
return f.name

def _run(
self, query: str, run_manager: Optional[CallbackManagerForToolRun] = None
) -> str:
"""Use the tool."""
elevenlabs = _import_elevenlabs()
try:
speech_file = self._text2speech(query)
return speech_file
speech = elevenlabs.generate(text=query, model=self.model)
with tempfile.NamedTemporaryFile(
mode="bx", suffix=".wav", delete=False
) as f:
f.write(speech)
return f.name
except Exception as e:
raise RuntimeError(f"Error while running ElevenLabsText2SpeechTool: {e}")

def play(self, speech_file: str) -> None:
"""Play the text as speech."""
elevenlabs = _import_elevenlabs()
with open(speech_file, mode="rb") as f:
speech = f.read()

Expand All @@ -67,5 +75,6 @@ def play(self, speech_file: str) -> None:
def stream_speech(self, query: str) -> None:
"""Stream the text as speech as it is generated.
Play the text in your speakers."""
elevenlabs = _import_elevenlabs()
speech_stream = elevenlabs.generate(text=query, model=self.model, stream=True)
elevenlabs.stream(speech_stream)

0 comments on commit 79a567d

Please sign in to comment.