Skip to content

Commit

Permalink
Merge pull request IBM#22 from mikulatomas/graphviz_optional
Browse files Browse the repository at this point in the history
Put graphviz as optional requirement
  • Loading branch information
NaweedAghmad authored Apr 6, 2022
2 parents dfcef59 + c2f8d1c commit 780a1ff
Show file tree
Hide file tree
Showing 8 changed files with 28 additions and 13 deletions.
2 changes: 0 additions & 2 deletions .github/workflows/build.yml
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,6 @@ jobs:
uses: actions/setup-python@v2
with:
python-version: ${{ matrix.python-version }}
- name: Install graphviz
run: sudo apt-get update && sudo apt-get install -y graphviz graphviz-dev
- name: Upgrade pip
run: python -m pip install --upgrade pip setuptools wheel
- name: Install dependencies
Expand Down
8 changes: 7 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,16 @@ properties of both neural nets (learning) and symbolic logic (knowledge and reas
## Quickstart

To install the LNN:
1. Run:
```
pip install git+https://github.com/IBM/LNN.git
```

To install the LNN with graph plot support:
1. Install [GraphViz](https://www.graphviz.org/download/)
2. Run:
```
pip install git+https://github.com/IBM/LNN.git
pip install git+https://github.com/IBM/LNN.git#egg=lnn"[plot]"
```

## Documentation
Expand Down
3 changes: 2 additions & 1 deletion lnn/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
from typing import Union, TypeVar, Tuple

import torch
import torchviz
import numpy as np

from . import _exceptions
Expand Down Expand Up @@ -177,6 +176,8 @@ def dict_rekey(d, old_key, new_key) -> None:


def plot_autograd(model: Model, loss: torch.Tensor, **kwds) -> None:
import torchviz

params = model.named_parameters()
torchviz.make_dot(
loss,
Expand Down
7 changes: 6 additions & 1 deletion lnn/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@

import numpy as np
import networkx as nx
import matplotlib.pyplot as plt

from . import _utils
from .constants import Fact
Expand Down Expand Up @@ -78,6 +77,8 @@ def predicate_truth_table(*args: str, arity: int, model, states=None):


def plot_graph(self, **kwds) -> None:
import matplotlib.pyplot as plt

labels = {node: f"{node.__class__.__name__}\n{node}" for node in self.graph}

options = {
Expand All @@ -97,6 +98,8 @@ def plot_graph(self, **kwds) -> None:


def plot_loss(total_loss, losses) -> None:
import matplotlib.pyplot as plt

loss, cummulative_loss = total_loss
fig, axs = plt.subplots(1, 2)
fig.suptitle("Model Loss")
Expand All @@ -113,6 +116,8 @@ def plot_loss(total_loss, losses) -> None:


def plot_params(self: Model) -> None:
import matplotlib.pyplot as plt

legend = []
for node in self.nodes:
if hasattr(self[node], "parameter_history"):
Expand Down
8 changes: 2 additions & 6 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,8 +1,4 @@
networkx==2.5.1
torch==1.9.0
matplotlib==3.3.3
pygraphviz==1.6
torch==1.11.0
tqdm==4.50.2
numpy~=1.21.0
setuptools~=52.0.0
torchviz==0.0.2
numpy==1.21.0
3 changes: 3 additions & 0 deletions requirements_plot.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
pygraphviz==1.6
torchviz==0.0.2
matplotlib==3.3.3
1 change: 1 addition & 0 deletions requirements_test.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
pytest==7.1.0
9 changes: 7 additions & 2 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,10 @@
import pathlib


def parse_requirements(filename):
return pathlib.Path(filename).read_text().replace("==", ">=").split("\n")


setuptools.setup(
name="lnn",
version="1.0",
Expand All @@ -17,9 +21,10 @@
long_description_content_type="text/markdown",
url="https://github.com/IBM/LNN",
packages=setuptools.find_packages(),
install_requires=pathlib.Path("requirements.txt").read_text().replace("==", ">="),
install_requires=parse_requirements("requirements.txt"),
extras_require={
"test": ["pytest"],
"test": parse_requirements("requirements_test.txt"),
"plot": parse_requirements("requirements_plot.txt"),
},
classifiers=[
"Programming Language :: Python :: 3",
Expand Down

0 comments on commit 780a1ff

Please sign in to comment.