From c2f8d1c4be711743cf9992480769ed11589fef81 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tom=C3=A1=C5=A1=20Mikula?= Date: Wed, 6 Apr 2022 16:28:19 +0200 Subject: [PATCH] Put graphviz as optional requirement Enable LNN installation without plotting ability, plotting modules can be installed via extra requirements. --- .github/workflows/build.yml | 2 -- README.md | 8 +++++++- lnn/_utils.py | 3 ++- lnn/utils.py | 7 ++++++- requirements.txt | 8 ++------ requirements_plot.txt | 3 +++ requirements_test.txt | 1 + setup.py | 9 +++++++-- 8 files changed, 28 insertions(+), 13 deletions(-) create mode 100644 requirements_plot.txt create mode 100644 requirements_test.txt diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index a0ad7b6..8a6e194 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -22,8 +22,6 @@ jobs: uses: actions/setup-python@v2 with: python-version: ${{ matrix.python-version }} - - name: Install graphviz - run: sudo apt-get update && sudo apt-get install -y graphviz graphviz-dev - name: Upgrade pip run: python -m pip install --upgrade pip setuptools wheel - name: Install dependencies diff --git a/README.md b/README.md index 3b204d6..da371ea 100644 --- a/README.md +++ b/README.md @@ -21,10 +21,16 @@ properties of both neural nets (learning) and symbolic logic (knowledge and reas ## Quickstart To install the LNN: +1. Run: + ``` + pip install git+https://github.com/IBM/LNN.git + ``` + +To install the LNN with graph plot support: 1. Install [GraphViz](https://www.graphviz.org/download/) 2. Run: ``` - pip install git+https://github.com/IBM/LNN.git + pip install git+https://github.com/IBM/LNN.git#egg=lnn"[plot]" ``` ## Documentation diff --git a/lnn/_utils.py b/lnn/_utils.py index b48a58a..6b59760 100644 --- a/lnn/_utils.py +++ b/lnn/_utils.py @@ -8,7 +8,6 @@ from typing import Union, TypeVar, Tuple import torch -import torchviz import numpy as np from . import _exceptions @@ -177,6 +176,8 @@ def dict_rekey(d, old_key, new_key) -> None: def plot_autograd(model: Model, loss: torch.Tensor, **kwds) -> None: + import torchviz + params = model.named_parameters() torchviz.make_dot( loss, diff --git a/lnn/utils.py b/lnn/utils.py index 37bf69b..a84d956 100644 --- a/lnn/utils.py +++ b/lnn/utils.py @@ -10,7 +10,6 @@ import numpy as np import networkx as nx -import matplotlib.pyplot as plt from . import _utils from .constants import Fact @@ -78,6 +77,8 @@ def predicate_truth_table(*args: str, arity: int, model, states=None): def plot_graph(self, **kwds) -> None: + import matplotlib.pyplot as plt + labels = {node: f"{node.__class__.__name__}\n{node}" for node in self.graph} options = { @@ -97,6 +98,8 @@ def plot_graph(self, **kwds) -> None: def plot_loss(total_loss, losses) -> None: + import matplotlib.pyplot as plt + loss, cummulative_loss = total_loss fig, axs = plt.subplots(1, 2) fig.suptitle("Model Loss") @@ -113,6 +116,8 @@ def plot_loss(total_loss, losses) -> None: def plot_params(self: Model) -> None: + import matplotlib.pyplot as plt + legend = [] for node in self.nodes: if hasattr(self[node], "parameter_history"): diff --git a/requirements.txt b/requirements.txt index 25c8d24..7910737 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,8 +1,4 @@ networkx==2.5.1 -torch==1.9.0 -matplotlib==3.3.3 -pygraphviz==1.6 +torch==1.11.0 tqdm==4.50.2 -numpy~=1.21.0 -setuptools~=52.0.0 -torchviz==0.0.2 \ No newline at end of file +numpy==1.21.0 \ No newline at end of file diff --git a/requirements_plot.txt b/requirements_plot.txt new file mode 100644 index 0000000..f11fb41 --- /dev/null +++ b/requirements_plot.txt @@ -0,0 +1,3 @@ +pygraphviz==1.6 +torchviz==0.0.2 +matplotlib==3.3.3 \ No newline at end of file diff --git a/requirements_test.txt b/requirements_test.txt new file mode 100644 index 0000000..90f614f --- /dev/null +++ b/requirements_test.txt @@ -0,0 +1 @@ +pytest==7.1.0 \ No newline at end of file diff --git a/setup.py b/setup.py index b83795b..7f824bc 100644 --- a/setup.py +++ b/setup.py @@ -8,6 +8,10 @@ import pathlib +def parse_requirements(filename): + return pathlib.Path(filename).read_text().replace("==", ">=").split("\n") + + setuptools.setup( name="lnn", version="1.0", @@ -17,9 +21,10 @@ long_description_content_type="text/markdown", url="https://github.com/IBM/LNN", packages=setuptools.find_packages(), - install_requires=pathlib.Path("requirements.txt").read_text().replace("==", ">="), + install_requires=parse_requirements("requirements.txt"), extras_require={ - "test": ["pytest"], + "test": parse_requirements("requirements_test.txt"), + "plot": parse_requirements("requirements_plot.txt"), }, classifiers=[ "Programming Language :: Python :: 3",