Skip to content

Commit

Permalink
Put graphviz as optional requirement
Browse files Browse the repository at this point in the history
Enable LNN installation without plotting ability, plotting modules can be installed via extra requirements.
  • Loading branch information
mikulatomas committed Apr 6, 2022
1 parent dfcef59 commit c2f8d1c
Show file tree
Hide file tree
Showing 8 changed files with 28 additions and 13 deletions.
2 changes: 0 additions & 2 deletions .github/workflows/build.yml
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,6 @@ jobs:
uses: actions/setup-python@v2
with:
python-version: ${{ matrix.python-version }}
- name: Install graphviz
run: sudo apt-get update && sudo apt-get install -y graphviz graphviz-dev
- name: Upgrade pip
run: python -m pip install --upgrade pip setuptools wheel
- name: Install dependencies
Expand Down
8 changes: 7 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,16 @@ properties of both neural nets (learning) and symbolic logic (knowledge and reas
## Quickstart

To install the LNN:
1. Run:
```
pip install git+https://github.com/IBM/LNN.git
```

To install the LNN with graph plot support:
1. Install [GraphViz](https://www.graphviz.org/download/)
2. Run:
```
pip install git+https://github.com/IBM/LNN.git
pip install git+https://github.com/IBM/LNN.git#egg=lnn"[plot]"
```

## Documentation
Expand Down
3 changes: 2 additions & 1 deletion lnn/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
from typing import Union, TypeVar, Tuple

import torch
import torchviz
import numpy as np

from . import _exceptions
Expand Down Expand Up @@ -177,6 +176,8 @@ def dict_rekey(d, old_key, new_key) -> None:


def plot_autograd(model: Model, loss: torch.Tensor, **kwds) -> None:
import torchviz

params = model.named_parameters()
torchviz.make_dot(
loss,
Expand Down
7 changes: 6 additions & 1 deletion lnn/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@

import numpy as np
import networkx as nx
import matplotlib.pyplot as plt

from . import _utils
from .constants import Fact
Expand Down Expand Up @@ -78,6 +77,8 @@ def predicate_truth_table(*args: str, arity: int, model, states=None):


def plot_graph(self, **kwds) -> None:
import matplotlib.pyplot as plt

labels = {node: f"{node.__class__.__name__}\n{node}" for node in self.graph}

options = {
Expand All @@ -97,6 +98,8 @@ def plot_graph(self, **kwds) -> None:


def plot_loss(total_loss, losses) -> None:
import matplotlib.pyplot as plt

loss, cummulative_loss = total_loss
fig, axs = plt.subplots(1, 2)
fig.suptitle("Model Loss")
Expand All @@ -113,6 +116,8 @@ def plot_loss(total_loss, losses) -> None:


def plot_params(self: Model) -> None:
import matplotlib.pyplot as plt

legend = []
for node in self.nodes:
if hasattr(self[node], "parameter_history"):
Expand Down
8 changes: 2 additions & 6 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,8 +1,4 @@
networkx==2.5.1
torch==1.9.0
matplotlib==3.3.3
pygraphviz==1.6
torch==1.11.0
tqdm==4.50.2
numpy~=1.21.0
setuptools~=52.0.0
torchviz==0.0.2
numpy==1.21.0
3 changes: 3 additions & 0 deletions requirements_plot.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
pygraphviz==1.6
torchviz==0.0.2
matplotlib==3.3.3
1 change: 1 addition & 0 deletions requirements_test.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
pytest==7.1.0
9 changes: 7 additions & 2 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,10 @@
import pathlib


def parse_requirements(filename):
return pathlib.Path(filename).read_text().replace("==", ">=").split("\n")


setuptools.setup(
name="lnn",
version="1.0",
Expand All @@ -17,9 +21,10 @@
long_description_content_type="text/markdown",
url="https://github.com/IBM/LNN",
packages=setuptools.find_packages(),
install_requires=pathlib.Path("requirements.txt").read_text().replace("==", ">="),
install_requires=parse_requirements("requirements.txt"),
extras_require={
"test": ["pytest"],
"test": parse_requirements("requirements_test.txt"),
"plot": parse_requirements("requirements_plot.txt"),
},
classifiers=[
"Programming Language :: Python :: 3",
Expand Down

0 comments on commit c2f8d1c

Please sign in to comment.