Skip to content

Commit

Permalink
Merge pull request IBM#19 from mikulatomas/predicate_exception
Browse files Browse the repository at this point in the history
Add new exceptions and fix old ones
  • Loading branch information
NaweedAghmad authored Mar 31, 2022
2 parents 3dd77fb + 92d410b commit dfcef59
Show file tree
Hide file tree
Showing 2 changed files with 51 additions and 28 deletions.
72 changes: 50 additions & 22 deletions lnn/_exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,15 +4,16 @@
# SPDX-License-Identifier: Apache-2.0
##

from typing import Set, Tuple, Dict
from typing import Set, Tuple, Dict, Union

import torch
import lnn

from .constants import Fact, World, Direction


class AssertWorld:
"""AssertWorld(bounds)
"""AssertWorld(world: World)
Raised when world not given as World object
"""
Expand All @@ -27,7 +28,7 @@ def __init__(self, world: World):


class AssertBoundsBroadcasting:
"""AssertBoundsBroadcasting(bounds)
"""AssertBoundsBroadcasting(bounds: Set)
Raised when FOL bounds given as set of more than 1 item
"""
Expand All @@ -39,9 +40,11 @@ def __init__(self, bounds: Set):


class AssertBoundsType:
"""Raised when bounds given in the incorrect type"""
"""AssertBoundsType(bounds: Union[Fact, tuple])
def __init__(self, bounds):
Raised when bounds given in the incorrect type"""

def __init__(self, bounds: Union[Fact, tuple]):
options = [Fact, tuple]
if type(bounds) not in options:
raise TypeError(
Expand All @@ -51,12 +54,12 @@ def __init__(self, bounds):


class AssertBoundsLen:
"""AssertBoundsLen(bounds)
"""AssertBoundsLen(bounds: tuple)
Raised when tuple of bounds given in the incorrect length
"""

def __init__(self, bounds):
def __init__(self, bounds: tuple):
if isinstance(bounds, tuple):
if len(bounds) != 2:
raise IndexError(
Expand All @@ -66,7 +69,7 @@ def __init__(self, bounds):


class AssertBoundsInputs:
"""AssertBounds(bounds)
"""AssertBounds(bounds: Tuple[torch.Tensor, ...])
Raised when incorrect bounds given
"""
Expand All @@ -93,6 +96,11 @@ def __init__(self, bounds):


class AssertPropositionalInheritance:
"""AssertPropositionalInheritance(node)
Raise when node does not have propositional attribute
"""

def __init__(self, node):
if not hasattr(node, "propositional"):
raise Exception(
Expand All @@ -104,18 +112,23 @@ def __init__(self, node):


class AssertFormulaInModel:
def __init__(self, model, formula):
"""AssertFormulaInModel(model: lnn.Model, formula: lnn.symbolic.logic._Formula)
Raised when formula is not in the model
"""

def __init__(self, model: "lnn.Model", formula: "lnn.symbolic.logic._Formula"):
if formula not in model:
raise Exception(f"{formula} is not a stored formula, can't set facts")


class AssertGroundingKeyType:
"""AssertGroundingKeyType(facts)
"""AssertGroundingKeyType(facts: Dict)
Raised when fact keys are not valid groundings
"""

def __init__(self, facts):
def __init__(self, facts: Dict):
if isinstance(facts, dict):
if all([type(f) not in [tuple, Fact] for f in facts.keys()]):
raise TypeError(
Expand All @@ -124,7 +137,7 @@ def __init__(self, facts):


class AssertFOLFacts:
"""AssertGroundedBounds(bounds)
"""AssertFOLFacts(facts: Dict)
Raised when FOL bounds expected as a dict of {groundings: facts}
Expand Down Expand Up @@ -174,20 +187,20 @@ def __init__(self, direction: Direction):


class AssertBias:
"""AssertDirectionType(direction: Direction)
"""AssertBias(bias: float)
Raised when direction not a clarified str
Raised when bias is not float type
"""

def __init__(self, bias):
def __init__(self, bias: float):
if not isinstance(bias, float):
raise TypeError(f"bias expected as a float, received {type(bias)}: {bias}")


class AssertWeights:
"""AssertDirectionType(direction: Direction)
"""AssertWeights(weights: Tuple, arity: int)
Raised when direction not a clarified str
Raised when weights are wrong type or length does not match arity
"""

def __init__(self, weights: Tuple, arity: int):
Expand All @@ -203,26 +216,41 @@ def __init__(self, weights: Tuple, arity: int):


class AssertAlphaNodeValue:
"""AssertAlphaInitValue(alpha: torch.Tensor)
"""AssertAlphaNodeValue(alpha: torch.Tensor)
Raised when alpha not in range
"""

def __init__(self, alpha):
def __init__(self, alpha: torch.Tensor):
if not (0.5 < alpha <= 1):
raise ValueError(f"alpha expected between (.5, 1], received {alpha}")


class AssertAlphaNeuronArityValue:
"""AssertAlphaInitValue(arity: int)
"""AssertAlphaNeuronArityValue(alpha: torch.Tensor, arity: int)
Raised when alpha not in range
Raised when alpha is not larger than constraint
"""

def __init__(self, alpha, arity):
def __init__(self, alpha: torch.Tensor, arity: int):
constraint = arity / (arity + 1)
if not (alpha >= constraint):
raise ValueError(
f"alpha expected greater than n/(n+1) ({constraint:<.3e}) "
f"for n={arity}, received {alpha:<3e}"
)


class AssertCalledPredicate:
"""AssertCalledPredicate(formula: Tuple[lnn.symbolic.logic._Formula])
Raised when predicate in any subformula is not properly called
"""

def __init__(self, formula: Tuple["lnn.symbolic.logic._Formula", ...]):
if formula:
for subformula in formula:
if isinstance(subformula, lnn.Predicate):
raise ValueError(
f"predicate {subformula} inside formula must be called"
)
7 changes: 1 addition & 6 deletions lnn/symbolic/logic.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,12 +150,7 @@ def __init__(
**kwds,
):
# check if all subformula has called predicates
if formula:
for subformula in formula:
if isinstance(subformula, Predicate):
raise ValueError(
f"predicate {subformula} inside formula must be called"
)
_exceptions.AssertCalledPredicate(formula)

# formula naming
self.name = (
Expand Down

0 comments on commit dfcef59

Please sign in to comment.