Skip to content

Commit

Permalink
Merge pull request #26 from jonfairbanks/develop
Browse files Browse the repository at this point in the history
Allow embeddings to use cuda devices
  • Loading branch information
jonfairbanks authored Mar 7, 2024
2 parents e8aa361 + 24d9a79 commit cbbe3e2
Show file tree
Hide file tree
Showing 15 changed files with 231 additions and 213 deletions.
3 changes: 2 additions & 1 deletion Pipfile
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,12 @@ llama-index-embeddings-huggingface = "*"
pycryptodome = "*"
nbconvert = "*"
pyexiftool = "*"
numba = "*"
llama-index-readers-web = "*"
html2text = "*"
streamlit-tags = "*"
streamlit-extras = "*"
black = "*"
torch = "*"

[dev-packages]

Expand Down
295 changes: 139 additions & 156 deletions Pipfile.lock

Large diffs are not rendered by default.

7 changes: 6 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

![local-rag-demo](demo.gif)

[![OpenSSF Best Practices](https://www.bestpractices.dev/projects/8588/badge)](https://www.bestpractices.dev/projects/8588)
![GitHub commit activity](https://img.shields.io/github/commit-activity/t/jonfairbanks/local-rag)
![GitHub last commit](https://img.shields.io/github/last-commit/jonfairbanks/local-rag)
![GitHub License](https://img.shields.io/github/license/jonfairbanks/local-rag)
Expand All @@ -13,8 +14,12 @@ Ingest files for retrieval augmented generation (RAG) with open-source Large Lan
Features:

- Offline Embeddings & LLMs Support (No OpenAI!)
- Support for Multiple Sources
- Local Files
- GitHub Repos
- Websites
- Streaming Responses
- Conversation Memory
- Conversational Memory
- Chat Export

Learn More:
Expand Down
13 changes: 7 additions & 6 deletions components/page_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,20 +6,21 @@ def set_page_config():
page_title="Local RAG",
page_icon="📚",
layout="wide",
initial_sidebar_state=st.session_state['sidebar_state'],
initial_sidebar_state=st.session_state["sidebar_state"],
menu_items={
'Get Help': 'https://github.com/jonfairbanks/local-rag/discussions',
'Report a bug': "https://github.com/jonfairbanks/local-rag/issues",
}
"Get Help": "https://github.com/jonfairbanks/local-rag/discussions",
"Report a bug": "https://github.com/jonfairbanks/local-rag/issues",
},
)

# Remove the Streamlit `Deploy` button from the Header
st.markdown(
r"""
r"""
<style>
.stDeployButton {
visibility: hidden;
}
</style>
""", unsafe_allow_html=True
""",
unsafe_allow_html=True,
)
14 changes: 9 additions & 5 deletions components/page_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,8 @@ def set_initial_state():
# General #
###########

if 'sidebar_state' not in st.session_state:
st.session_state['sidebar_state'] = 'expanded'
if "sidebar_state" not in st.session_state:
st.session_state["sidebar_state"] = "expanded"

if "ollama_endpoint" not in st.session_state:
st.session_state["ollama_endpoint"] = "http://localhost:11434"
Expand All @@ -30,10 +30,14 @@ def set_initial_state():

if "selected_model" not in st.session_state:
try:
if("llama2:7b" in st.session_state["ollama_models"]):
st.session_state["selected_model"] = "llama2:7b" # Default to llama2:7b on initial load
if "llama2:7b" in st.session_state["ollama_models"]:
st.session_state["selected_model"] = (
"llama2:7b" # Default to llama2:7b on initial load
)
else:
st.session_state["selected_model"] = st.session_state["ollama_models"][0] # If llama2:7b is not present, select the first model available
st.session_state["selected_model"] = st.session_state["ollama_models"][
0
] # If llama2:7b is not present, select the first model available
except Exception:
st.session_state["selected_model"] = None
pass
Expand Down
1 change: 1 addition & 0 deletions components/tabs/local_files.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
"txt",
)


def local_files():
# Force users to confirm Settings before uploading files
if st.session_state["selected_model"] is not None:
Expand Down
24 changes: 12 additions & 12 deletions components/tabs/website.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from llama_index.readers.web import SimpleWebPageReader
from urllib.parse import urlparse


def ensure_https(url):
parsed = urlparse(url)

Expand All @@ -13,6 +14,7 @@ def ensure_https(url):

return url


def website():
# if st.session_state["selected_model"] is not None:
# st.text_input(
Expand Down Expand Up @@ -44,31 +46,29 @@ def website():

# st.write(css_example, unsafe_allow_html=True)



st.write("Enter a Website")
col1, col2 = st.columns([1,.2])
col1, col2 = st.columns([1, 0.2])
with col1:
new_website = st.text_input("Enter a Website", label_visibility="collapsed")
with col2:
add_button = st.button(u"➕")
add_button = st.button("➕")

# If the add button is clicked, append the new website to our list
if add_button and new_website != '':
st.session_state['websites'].append(ensure_https(new_website))
st.session_state['websites'] = sorted(set(st.session_state['websites']))
if st.session_state['websites'] != []:
if add_button and new_website != "":
st.session_state["websites"].append(ensure_https(new_website))
st.session_state["websites"] = sorted(set(st.session_state["websites"]))

if st.session_state["websites"] != []:
st.markdown(f"<p>Website(s)</p>", unsafe_allow_html=True)
for site in st.session_state['websites']:
for site in st.session_state["websites"]:
st.caption(f"- {site}")
st.write("")

process_button = st.button("Process", key="process_website")

if process_button:
documents = SimpleWebPageReader(html_to_text=True).load_data(
st.session_state['websites']
st.session_state["websites"]
)

if len(documents) > 0:
Expand Down
8 changes: 8 additions & 0 deletions docs/setup.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,3 +13,11 @@ Before you get started with Local RAG, ensure you have:

### Docker
- `docker compose up -d`

#### Note:

If you are running Ollama as a service, you may need to add an additional configuration to your docker-compose.yml file:
```
extra_hosts:
- 'host.docker.internal:host-gateway'
```
5 changes: 2 additions & 3 deletions docs/todo.md
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ Although not final, items are generally sorted from highest to lowest priority.
- [x] Docker Support
- [x] Windows Support
- [ ] Extract Metadata and Load into Index
- [ ] Parallelize Document Embeddings
- [x] Faster Document Embeddings (Cuda, Batch Size, ...)
- [ ] Swap to OpenAI compatible endpoints
- [ ] Allow Usage of Ollama hosted embeddings
- [ ] Enable support for additional LLM backends
Expand All @@ -33,7 +33,6 @@ Although not final, items are generally sorted from highest to lowest priority.
- [x] View and Manage Imported Files
- [x] About Tab in Sidebar w/ Resources
- [x] Enable Caching
- [ ] Swap Repo & Website input to [Streamlit-Tags](https://gagan3012-streamlit-tags-examplesapp-7aiy65.streamlit.app)
- [ ] Allow Users to Set LLM Settings
- [ ] System Prompt (needs more work)
- [x] Chat Mode
Expand All @@ -55,7 +54,7 @@ Although not final, items are generally sorted from highest to lowest priority.
- [x] Refactor README
- [x] Implement Log Library
- [x] Improve Logging
- [ ] Re-write Docstrings
- [x] Re-write Docstrings
- [ ] Tests

### Known Issues & Bugs
Expand Down
1 change: 1 addition & 0 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ def generate_welcome_message(msg):
time.sleep(0.025) # This is blocking :(
yield char


### Setup Initial State
set_initial_state()

Expand Down
4 changes: 2 additions & 2 deletions utils/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ def clone_github_repo(repo: str):
# Extract File Metadata
#
###################################


def get_file_metadata(file_path):
"""
Expand All @@ -131,4 +131,4 @@ def get_file_metadata(file_path):
for d in et.get_metadata(file_path):
return json.dumps(d, indent=2)
except Exception:
pass
pass
26 changes: 16 additions & 10 deletions utils/llama_index.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,10 @@

import utils.logs as logs

from numba import cuda
from llama_index.embeddings.huggingface import ( HuggingFaceEmbedding )
from torch import cuda
from llama_index.embeddings.huggingface import HuggingFaceEmbedding

# This is not used but required by llama-index and must be imported FIRST
# This is not used but required by llama-index and must be set FIRST
os.environ["OPENAI_API_KEY"] = "sk-abc123"

from llama_index.core import (
Expand All @@ -24,6 +24,7 @@
#
###################################


@st.cache_resource(show_spinner=False)
def setup_embedding_model(
model: str,
Expand All @@ -43,11 +44,12 @@ def setup_embedding_model(
Notes:
The `device` parameter can be set to 'cpu' or 'cuda' to specify the device to use for the embedding computations. If 'cuda' is used and CUDA is available, the embedding model will be run on the GPU. Otherwise, it will be run on the CPU.
"""
device = 'cpu' if not cuda.is_available() else 'cuda'
device = "cpu" if not cuda.is_available() else "cuda"
logs.log.info(f"Using {device} to generate embeddings")
embed_model = HuggingFaceEmbedding(
model_name=model,
# embed_batch_size=25, // TODO: Turning this on creates chaos, but has the potential to improve performance
device=device
device=device,
)
logs.log.info(f"Embedding model created successfully")
return embed_model
Expand All @@ -61,6 +63,7 @@ def setup_embedding_model(

# TODO: Migrate to LlamaIndex.Settings: https://docs.llamaindex.ai/en/stable/module_guides/supporting_modules/service_context_migration.html


def create_service_context(
llm, # TODO: Determine type
system_prompt: str = None,
Expand Down Expand Up @@ -95,13 +98,13 @@ def create_service_context(
service_context = ServiceContext.from_defaults(
llm=llm,
system_prompt=system_prompt,
embed_model=formatted_embed_model,
embed_model=embedding_model,
chunk_size=int(chunk_size),
# chunk_overlap=int(chunk_overlap),
)
logs.log.info(f"Service Context created successfully")
st.session_state["service_context"] = service_context

# Note: this may be redundant since service_context is returned
set_global_service_context(service_context)

Expand Down Expand Up @@ -143,7 +146,9 @@ def load_documents(data_dir: str):
logs.log.error(f"Error creating data index: {err}")
finally:
for file in os.scandir(data_dir):
if file.is_file() and not file.name.startswith(".gitkeep"): # TODO: Confirm syntax here
if file.is_file() and not file.name.startswith(
".gitkeep"
): # TODO: Confirm syntax here
os.remove(file.path)
logs.log.info(f"Document loading complete; removing local file(s)")

Expand All @@ -154,6 +159,7 @@ def load_documents(data_dir: str):
#
###################################


@st.cache_data(show_spinner=False)
def create_index(_documents, _service_context):
"""
Expand All @@ -172,7 +178,7 @@ def create_index(_documents, _service_context):
Notes:
The `documents` parameter should be a list of strings representing the content of the documents to be indexed. The `service_context` parameter should be an instance of `ServiceContext`, providing information about the Llama model and other configuration settings for the index.
"""

try:
index = VectorStoreIndex.from_documents(
documents=_documents, service_context=_service_context, show_progress=True
Expand All @@ -193,7 +199,7 @@ def create_index(_documents, _service_context):
###################################


@st.cache_data(show_spinner=False)
@st.cache_resource(show_spinner=False)
def create_query_engine(_documents, _service_context):
"""
Creates a query engine from the provided documents and service context.
Expand Down
4 changes: 3 additions & 1 deletion utils/logs.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,9 @@
from typing import Union


def setup_logger(log_file: str = "local-rag.log", level: Union[int, str] = logging.INFO):
def setup_logger(
log_file: str = "local-rag.log", level: Union[int, str] = logging.INFO
):
"""
Sets up a logger for this module.
Expand Down
22 changes: 12 additions & 10 deletions utils/ollama.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,8 +61,8 @@ def get_models():
- Exception: If there is an error retrieving the list of models.
Notes:
This function retrieves a list of available language models from the Ollama server using the `ollama` library. It takes no parameters and returns a list of available language model names.
This function retrieves a list of available language models from the Ollama server using the `ollama` library. It takes no parameters and returns a list of available language model names.
The function raises an exception if there is an error retrieving the list of models.
Side Effects:
Expand All @@ -80,8 +80,10 @@ def get_models():
if len(models) > 0:
logs.log.info("Ollama models loaded successfully")
else:
logs.log.warn("Ollama did not return any models. Make sure to download some!")

logs.log.warn(
"Ollama did not return any models. Make sure to download some!"
)

return models
except Exception as err:
logs.log.error(f"Failed to retrieve Ollama model list: {err}")
Expand Down Expand Up @@ -167,12 +169,12 @@ def context_chat(prompt: str, query_engine: RetrieverQueryEngine):
- Exception: If there is an error retrieving answers from the Llama-Index model.
Notes:
This function initiates a chat with context using the Llama-Index language model and index.
It takes two parameters, `prompt` and `query_engine`, which should be the starting prompt for the conversation and the Llama-Index query engine to use for retrieving answers, respectively.
The function returns an iterable yielding successive chunks of conversation from the Llama-Index index with context.
This function initiates a chat with context using the Llama-Index language model and index.
It takes two parameters, `prompt` and `query_engine`, which should be the starting prompt for the conversation and the Llama-Index query engine to use for retrieving answers, respectively.
The function returns an iterable yielding successive chunks of conversation from the Llama-Index index with context.
If there is an error retrieving answers from the Llama-Index instance, the function raises an exception.
Side Effects:
Expand Down
Loading

0 comments on commit cbbe3e2

Please sign in to comment.