+
+##
+# Copyright 2021 IBM Corp. All Rights Reserved.
+#
+# SPDX-License-Identifier: Apache-2.0
+##
+
+# flake8: noqa: E501
+
+import logging
+import importlib
+import itertools
+from typing import Optional, Union, Tuple, Iterator, Set, List, Dict, TypeVar
+
+from . import _gm, _trace
+from ._bindings import get_bindings
+from .. import _utils, _exceptions, utils
+from ..constants import Fact, World, Direction, Join, NeuralActivation, Bound
+
+import copy
+import torch
+import numpy as np
+
+_utils.logger_setup()
+
+
+##
+# Internal Classes
+#
+# All classes below are protected classes (denoted with a prefixed underscore)
+# they follow the convention that they will be kept as internal classes rather than
+# exposing them to the end-user as an API.
+# The formulas are ordered according to the inheritance structure in
+# https://ibm.github.io/LNN/lnn/LNN.html
+##
+
+
+
+
+
+class _LeafFormula(Formula):
+ r"""Specifies activation functionality as nodes instead of neurons.
+
+ Assumes that all leaf formulae are propositions or predicates, therefore
+ uses the _NodeActivation accordingly
+
+ """
+
+ def __init__(self, *args, **kwds):
+ super().__init__(*args, **kwds)
+ kwds.setdefault("propositional", self.propositional)
+ self.neuron = _NodeActivation()(**kwds.get("activation", {}), **kwds)
+
+
+class _ConnectiveFormula(Formula):
+ def __init__(self, *formula: Formula, **kwds):
+ super().__init__(*formula, **kwds)
+
+
+class _ConnectiveNeuron(_ConnectiveFormula):
+ def __init__(self, *formula, **kwds):
+ super().__init__(*formula, **kwds)
+ kwds.setdefault("arity", self.arity)
+ kwds.setdefault("propositional", self.propositional)
+ if kwds.get("activation", {}).get("negative_weights"):
+ self.negation_absorption()
+ self.neuron = _NeuralActivation(kwds.get("activation", {}).get("type"))(
+ **kwds.get("activation", {}), **kwds
+ )
+ self.func = self.neuron.activation(
+ self.__class__.__name__, direction=Direction.UPWARD
+ )
+ self.func_inv = self.neuron.activation(
+ self.__class__.__name__, direction=Direction.DOWNWARD
+ )
+
+ def upward(
+ self, groundings: Set[Union[str, Tuple[str, ...]]] = None, **kwds
+ ) -> float:
+ r"""Upward inference from the operands to the operator.
+
+ Parameters
+ ----------
+ groundings : str or tuple of str, optional
+ restrict upward inference to a specific grounding or row in the truth table
+ lifted : bool, optional
+ flag that determines if lifting should be done on this node.
+
+ Returns
+ -------
+ tightened_bounds : float
+ The amount of bounds tightening or new information that is leaned by the inference step.
+
+ """
+ # Create (potentially) new groundings from functions
+ if not self.propositional:
+ self._ground_functions()
+
+ if kwds.get("lifted"):
+ result = self.neuron.aggregate_world(
+ tuple(
+ self.func(
+ torch.stack(
+ [torch.tensor(op.world) for op in self.operands], dim=-1
+ )[None]
+ ).tolist()[0]
+ )
+ )
+ if result:
+ logging.info(
+ "↑ WORLD FREE-VARIABLE UPDATED "
+ f"TIGHTENED:{result} "
+ f"FOR:'{self.name}' "
+ f"FORMULA:{self.formula_number} "
+ )
+ else:
+ upward_bounds = _gm.upward_bounds(self, self.operands, groundings)
+ if upward_bounds is None: # contradiction arresting
+ return 0.0
+ input_bounds, groundings = upward_bounds
+ grounding_rows = (
+ None
+ if self.propositional
+ else (
+ self.grounding_table.values()
+ if groundings is None
+ else [self.grounding_table.get(g) for g in groundings]
+ )
+ )
+ result = self.neuron.aggregate_bounds(
+ grounding_rows, self.func(input_bounds)
+ )
+ if result:
+ logging.info(
+ "↑ BOUNDS UPDATED "
+ f"TIGHTENED:{result} "
+ f"FOR:'{self.name}' "
+ f"FORMULA:{self.formula_number} "
+ )
+ if self.is_contradiction():
+ logging.info(
+ "↑ CONTRADICTION "
+ f"FOR:'{self.name}' "
+ f"FORMULA:{self.formula_number} "
+ )
+ return result
+
+ def downward(
+ self,
+ index: int = None,
+ groundings: Set[Union[str, Tuple[str, ...]]] = None,
+ **kwds,
+ ) -> float:
+ r"""Downward inference from the operator to the operands.
+
+ Parameters
+ ----------
+ index : int, optional
+ restricts downward inference to an operand at the specified index. If unspecified, all operands are updated.
+ groundings : str or tuple of str, optional
+ restrict upward inference to a specific grounding or row in the truth table
+ lifted : bool, optional
+ flag that determines if lifting should be done on this node.
+
+ Returns
+ -------
+ tightened_bounds : float
+ The amount of bounds tightening or new information that is leaned by the inference step.
+
+ """
+ # Create (potentially) new groundings from functions
+ if not self.propositional:
+ self._ground_functions()
+
+ if kwds.get("lifted"):
+ result = 0.0
+ new_worlds = self.func_inv(
+ torch.tensor(self.world)[None],
+ torch.stack([torch.tensor(op.world) for op in self.operands], dim=-1)[
+ None
+ ],
+ )
+ for op_idx, op in enumerate(self.operands):
+ op_aggregate = op.neuron.aggregate_world(
+ tuple(new_worlds[..., op_idx].tolist()[0])
+ )
+ result += op_aggregate
+ if op_aggregate:
+ logging.info(
+ "↓ WORLD FREE-VARIABLE UPDATED "
+ f"TIGHTENED:{op_aggregate} "
+ f"FOR:'{op.name}' "
+ f"FROM:'{self.name}' "
+ f"FORMULA:{op.formula_number} "
+ f"PARENT:{self.formula_number} "
+ )
+ else:
+ downward_bounds = _gm.downward_bounds(self, self.operands, groundings)
+ if downward_bounds is None: # contradiction arresting
+ return 0.0
+ out_bounds, input_bounds, groundings = downward_bounds
+ new_bounds = self.func_inv(out_bounds, input_bounds)
+ op_indices = (
+ enumerate(self.operands)
+ if index is None
+ else ([(index, self.operands[index])])
+ )
+ result = 0.0
+ for op_index, op in op_indices:
+ if op.propositional:
+ op_grounding_rows = None
+ else:
+ if groundings is None:
+ op_grounding_rows = op.grounding_table.values()
+ else:
+ op_grounding_rows = [None] * len(groundings)
+ for g_i, g in enumerate(groundings):
+ op_g = [
+ str(g.partial_grounding[slot])
+ for slot in self.operand_map[op_index]
+ ]
+ op_g = _Grounding(tuple(op_g) if len(op_g) > 1 else op_g[0])
+ op_grounding_rows[g_i] = op.grounding_table.get(op_g)
+ op_aggregate = op.neuron.aggregate_bounds(
+ op_grounding_rows, new_bounds[..., op_index]
+ )
+ if op_aggregate:
+ logging.info(
+ "↓ BOUNDS UPDATED "
+ f"TIGHTENED:{op_aggregate} "
+ f"FOR:'{op.name}' "
+ f"FROM:'{self.name}' "
+ f"FORMULA:{op.formula_number} "
+ f"PARENT:{self.formula_number} "
+ )
+ if op.is_contradiction():
+ logging.info(
+ "↓ CONTRADICTION "
+ f"FOR:'{op.name}' "
+ f"FROM:'{self.name}' "
+ f"FORMULA:{op.formula_number} "
+ f"PARENT:{self.formula_number} "
+ )
+ result = result + op_aggregate
+ return result
+
+ def _logical_loss(
+ self, coeff: float = None, slacks: Union[bool, float] = None
+ ) -> torch.Tensor:
+ r"""Logical loss to create a loss on logical constraint violation.
+
+ Assumes a soft logic computation and calculates the loss on constraints
+ as defined in [equations 86-89](https://arxiv.org/pdf/2006.13155.pdf)
+ when slacks are given, the constraints are allowed to be violated
+ however this affects the neuron interpretability and should only be
+ used if the model is not strictly required to obey a classical
+ definition of logic
+
+ """
+ a = self.neuron.alpha
+ b = self.neuron.bias
+ w = self.neuron.weights
+ T, F = a, 1 - a
+ coeff = 1 if coeff is None else coeff
+ if isinstance(self, And):
+ TRUE = b - (w * (1 - T)).sum()
+ FALSE = b - (w * (1 - F))
+ true_hinge = torch.where(TRUE < T, T - TRUE, TRUE * 0)
+ false_hinge = torch.where(FALSE > F, FALSE - F, FALSE * 0)
+ if slacks:
+ if slacks is True:
+ slacks_false = false_hinge * (false_hinge > 0)
+ slacks_true = true_hinge * (true_hinge > 0)
+ false_hinge -= slacks_false
+ true_hinge -= slacks_true
+ self.neuron.slacks = (
+ slacks_true.detach().clone(),
+ slacks_false.detach().clone(),
+ )
+ else:
+ false_hinge -= slacks
+ self.neuron.feasibility = (
+ true_hinge.detach().clone(),
+ false_hinge.detach().clone(),
+ )
+
+ elif isinstance(self, Or):
+ TRUE = 1 - b + (w * T)
+ FALSE = 1 - b + (w * F).sum()
+ true_hinge = torch.where(TRUE < T, T - TRUE, TRUE * 0).sum()
+ false_hinge = torch.where(FALSE > F, FALSE - F, FALSE * 0)
+ elif isinstance(self, Implies):
+ TRUE = 1 - b + (w * T) # T = 1-F for x and T for y
+ FALSE = 1 - b + (w[0] * (1 - T)) + (w[1] * F)
+ true_hinge = torch.where(TRUE < T, T - TRUE, TRUE * 0).sum()
+ false_hinge = torch.where(FALSE > F, FALSE - F, FALSE * 0)
+ result = (true_hinge.square() + false_hinge.square()).sum()
+ return coeff * result
+
+ def neural_equivalence(self, other):
+ if (
+ isinstance(self.neuron, other.neuron)
+ and self.neuron.bias == other.neuron.bias
+ and len(self.neuron.weights) == len(other.neuron.weights)
+ and all(
+ self.neuron.weights[idx] == other.neuron.weights[idx]
+ for idx in range(len(self.neuron.weights))
+ )
+ ):
+ return True
+ return False
+
+
+class _NAryNeuron(_ConnectiveNeuron):
+ r"""N-ary connective neuron."""
+
+ def __init__(self, *formula, **kwds):
+ super().__init__(*formula, arity=len(formula), **kwds)
+
+
+class _BinaryNeuron(_ConnectiveNeuron):
+ r"""Restrict neurons to 2 inputs."""
+
+ def __init__(self, *formula, **kwds):
+ if len(formula) != 2:
+ raise Exception(
+ "Binary neurons expect 2 formulae as inputs, received "
+ f"{len(formula)}"
+ )
+ super().__init__(*formula, arity=2, **kwds)
+
+
+class _UnaryOperator(_ConnectiveFormula):
+ r"""Restrict operators to 1 input."""
+
+ def __init__(self, *formula: Formula, **kwds):
+ if len(formula) != 1:
+ raise Exception(
+ "Unary operator expect 1 formula as input, received " f"{len(formula)}"
+ )
+ super().__init__(*formula, arity=1, **kwds)
+
+
+class _Quantifier(_UnaryOperator):
+ r"""Symbolic container for quantifiers.
+
+ Parameters
+ ------------
+ kwds : dict
+ fully_grounded : bool
+ specifies if a full upward inference can be done on a
+ quantifier due to all the groundings being present inside it.
+ This applies to the lower bound of a `ForAll` and upper bound
+ of an `Exists`
+
+ Attributes
+ ----------
+ fully_grounded : bool
+ unique_var_slots : tuple of int
+ returns the slot index of each unique variable
+
+ """
+
+ def __init__(self, quantified_variable: "Variable", formula: Formula, **kwds):
+ self.quantified_variable = quantified_variable
+ super().__init__(formula, **kwds)
+ self.fully_grounded = kwds.get("fully_grounded", False)
+ self._grounding_set = set()
+ self._set_activation(**kwds)
+
+ @property
+ def expanded_unique_vars(self):
+ result = list(self.unique_vars)
+ for idx, v in enumerate(self.variables):
+ result.append(v)
+ return tuple(result)
+
+ @staticmethod
+ def _unique_variables_overlap(
+ source: Tuple["Variable", ...], destination: Tuple["Variable", ...]
+ ) -> Tuple["Variable", ...]:
+ """combines all predicate variables into a unique tuple
+ the tuple is sorted by the order of appearance of variables in
+ the operands
+ """
+ result = list()
+ for dst_var in destination:
+ if dst_var not in source:
+ result.append(dst_var)
+ return tuple(result)
+
+ def upward(self, **kwds) -> float:
+ r"""Upward inference from the operands to the operator.
+
+ Parameters
+ ----------
+ lifted : bool, optional
+ flag that determines if lifting should be done on this node.
+
+ Returns
+ -------
+ tightened_bounds : float
+ The amount of bounds tightening or new information that is leaned by the inference step.
+
+ """
+
+ # Create (potentially) new groundings from functions
+ if not self.propositional:
+ self._ground_functions()
+
+ if kwds.get("lifted"):
+ result = self.neuron.aggregate_world(self.operands[0].world)
+ if result:
+ if self.propositional:
+ self.neuron.reset_world(self.world)
+ logging.info(
+ "↑ WORLD FREE-VARIABLE UPDATED "
+ f"TIGHTENED:{result} "
+ f"FOR:'{self.name}' "
+ f"FORMULA:{self.formula_number} "
+ )
+ else:
+ n_groundings = len(self._grounding_set)
+ input_bounds = self._upward_bounds(self.operands[0])
+ if input_bounds is None:
+ return 0.0
+ if len(self._grounding_set) > n_groundings:
+ self._set_activation(world=self.world)
+ result = self.neuron.aggregate_bounds(
+ None,
+ self.func(input_bounds.permute([1, 0])),
+ bound=(
+ (Bound.UPPER if isinstance(self, ForAll) else Bound.LOWER)
+ if not self.fully_grounded
+ else None
+ ),
+ )
+ if result:
+ logging.info(
+ "↑ BOUNDS UPDATED "
+ f"TIGHTENED:{result} "
+ f"FOR:'{self.name}' "
+ f"FORMULA:{self.formula_number} "
+ )
+ if self.is_contradiction():
+ logging.info(
+ "↑ CONTRADICTION "
+ f"FOR:'{self.name}' "
+ f"FORMULA:{self.formula_number} "
+ )
+ return result
+
+ def _set_activation(self, **kwds):
+ """Updates the neural activation according to grounding dimension size
+
+ The computation of a quantifier is implemented via one of the weighed
+ neurons, And/Or for ForAll/Exists.
+ At present, weighted quantifiers have not been well studied and is
+ therefore turned off
+ However the dimension of computation is different, computing over the
+ groundings of the input formula instead of multiple formulae, since
+ there can only be one formula to quantify over.
+ The activation is therefore required to grow according to number of
+ groundings present in the formula, which can grow as groundings
+ propagate via inference.
+
+ """
+ operator = (
+ "And"
+ if isinstance(self, ForAll)
+ else ("Or" if isinstance(self, Exists) else None)
+ )
+ kwds.setdefault("arity", len(self._grounding_set))
+ kwds.setdefault("propositional", self.propositional)
+ self.neuron = _NeuralActivation()(
+ activation={"weights_learning": False}, **kwds
+ )
+ self.func = self.neuron.activation(operator, direction=Direction.UPWARD)
+
+ @staticmethod
+ def _has_free_variables(
+ variables: Tuple["Variable", ...], operand: Formula
+ ) -> bool:
+ r"""Returns True if the quantifier contains free variables."""
+ return len(set(variables)) != len({*operand.unique_vars})
+
+ @property
+ def true_groundings(
+ self,
+ ) -> Set[Union[str, Tuple[str, ...]]]:
+ r"""Returns a set of groundings that are True."""
+ valid_groundings = [
+ g for g in self._grounding_set if self.operands[0].state(g) is Fact.TRUE
+ ]
+ return self._groundings(valid_groundings) if valid_groundings else set()
+
+ def _upward_bounds(self, operand: Formula) -> Union[torch.Tensor, None]:
+ r"""Set Quantifier grounding table and return operand tensor."""
+ operand_grounding_set = set(operand.grounding_table)
+ if len(operand_grounding_set) == 0:
+ return
+
+ self._grounding_set = set(
+ [
+ grounding
+ for grounding in operand_grounding_set
+ if _gm.is_grounding_in_bindings(self, 0, grounding)
+ ]
+ )
+ return operand.get_data(*self._grounding_set) if self._grounding_set else None
+
+ def _groundings(self, groundings=None) -> Set[Union[str, Tuple[str, ...]]]:
+ """Internal usage to extract groundings as _Grounding object"""
+ return set(map(_Grounding.eval, groundings)) if groundings else self.groundings
+
+ @property
+ def groundings(self) -> Set[Union[str, Tuple[str, ...]]]:
+ r"""returns a set of groundings as str or tuple of str"""
+ return set(map(_Grounding.eval, self._grounding_set))
+
+ def add_data(self, facts: Union[Tuple[float, float], Fact, Set]):
+ super().add_data(facts)
+ self._set_activation(world=self.world)
+
+
+class _NAryOperator(_ConnectiveFormula):
+ r"""N-ary connective operator"""
+
+ def __init__(self, *formula, **kwds):
+ super().__init__(*formula, arity=len(formula), **kwds)
+
+
+class _Grounding(_utils.MultiInstance, _utils.UniqueNameAssumption):
+ r"""Propositionalises constants for first-order logic
+
+ Returns a container for a string or a tuple of strings.
+ Follows the unique name assumption so that given constant(s) return the
+ same object
+ Decomposes multiple constants (from the tuple) by storing each str as a
+ separate grounding object but returns only the compound container.
+ This decomposition is used in grounding management to ensure that all
+ partial strings also follow the unique name assumption by returning the
+ same container
+
+ Parameters
+ ------------
+ constants : str or tuple of str
+
+ Examples
+ --------
+ ```python
+ _Grounding('person1')
+ _Grounding(('person1', 'date1'))
+ ```
+
+ Attributes
+ ----------
+ name : str
+ conversion of 'constants' param to str form
+ grounding_arity : int
+ length of the 'constants' param
+ partial_grounding : tuple(_Grounding)
+ tuple of groundings for decomposition when constants given as tuple
+
+ """
+
+ def __init__(self, constants: Union[str, Tuple[str, ...]]):
+ super().__init__(constants)
+ self.name = str(constants)
+ if isinstance(constants, tuple):
+ self.grounding_arity = len(constants)
+ self.partial_grounding = tuple(
+ map(self._partial_grounding_from_str, constants)
+ )
+ else:
+ self.grounding_arity = 1
+ self.partial_grounding = (self,)
+
+ @classmethod
+ def _partial_grounding_from_str(cls, constant: str) -> "_Grounding":
+ r"""Returns partial Grounding given grounding str"""
+ return _Grounding.instances[constant]
+
+ @classmethod
+ def ground_by_groundings(cls, *grounding: "_Grounding"):
+ r"""Reduce a tuple of groundings to a single grounding"""
+ return (
+ grounding[0]
+ if len(grounding) == 1
+ else cls.__class__(tuple(str(g) for g in grounding))
+ )
+
+ def __len__(self) -> int:
+ r"""Returns the length of the grounding arity"""
+ return self.grounding_arity
+
+ def __str__(self) -> str:
+ r"""Returns the name of the grounding"""
+ return self.name
+
+ @staticmethod
+ def eval(grounding: "_Grounding") -> Union[str, Tuple[str, ...]]:
+ r"""Returns the original constant(s) in str or tuple of str form"""
+ return eval(grounding.name) if grounding.grounding_arity > 1 else grounding.name
+
+
+class _NodeActivation:
+ def __call__(self, **kwds):
+ return getattr(
+ importlib.import_module("lnn.neural.activations.node"),
+ "_NodeActivation",
+ )(**kwds)
+
+
+class _NeuralActivation:
+ r"""Switch class, to choose a method from the correct activation class"""
+
+ def __init__(self, type=None):
+ self.neuron_type = type if type else NeuralActivation.LukasiewiczTransparent
+ _exceptions.AssertNeuronActivationType(self.neuron_type)
+ self.module = importlib.import_module(
+ f"lnn.neural.methods.{self.neuron_type.name.lower()}"
+ )
+
+ def __call__(self, **kwds):
+ return getattr(self.module, self.neuron_type.name)(**kwds)
+
+
+##
+# Public Classes
+#
+# All classes below should be exposed via the public API
+# The formulas are ordered alphabetically according to the API docs in
+# https://ibm.github.io/LNN/lnn/LNN.html
+##
+
+
+[docs]class Predicate(_LeafFormula):
+
r"""Creates a container for a predicate
+
+
Stores a table of truths, with columns specified by the arity and rows
+
indexed by the grounding
+
+
Parameters
+
----------
+
name : str
+
name of the predicate
+
arity : int, optional
+
If unspecified, assumes a unary predicate
+
+
Examples
+
--------
+
```python
+
P1 = Predicate('P1')
+
P2 = Predicate('P2', arity=2)
+
```
+
+
"""
+
+
def __init__(self, name: str, arity: int = 1, **kwds):
+
if arity is None:
+
raise Exception(f"arity expected as int > 0, received {arity}")
+
super().__init__(name=name, arity=arity, propositional=False, **kwds)
+
self._update_variables(tuple(Variable(f"?{i}") for i in range(self.arity)))
+
+
[docs] def add_data(self, facts: Union[dict, set]):
+
r"""Populate predicate with facts
+
+
Facts required in dict or set
+
- dict for grounding-based facts
+
- set for broadcasting facts across all groundings
+
requires a set of 1 item
+
dict keys for groundings and values as facts
+
tuple facts required in bounds form `(Lower, Upper)`
+
+
"""
+
super().add_data(facts)
+
+
def __call__(self, *args, **kwds):
+
r"""A called first-order logic predicate
+
+
This correctly instantiates a predicate with variables - which is required when
+
using the predicate in a compound formula. Calling the predicate allows the LNN
+
to construct the inheritance tree from subformulae.
+
+
Examples
+
--------
+
```python
+
P, Q = Predicates('P', 'Q')
+
x, y = Variables('x', 'y')
+
And(P(x), Q(y)) # calling both predicates
+
```
+
Here the conjunction inherits its variables from all subformulae, treating it as
+
an ordered unique collection (list).
+
"""
+
return super().__call__(*args, **kwds)
+
+
+[docs]def Predicates(*predicates: str, **kwds):
+
r"""Instantiates multiple predicates.
+
+
Examples
+
--------
+
```python
+
P1, P2 = Predicates("P1", "P2", arity=2)
+
```
+
+
"""
+
return utils.return1([Predicate(p, **kwds) for p in predicates])
+
+
+[docs]class Proposition(_LeafFormula):
+
r"""Creates propositional containers
+
+
Stores and retrieves single truth bounds instead of tables as in FOL case
+
+
Parameters
+
----------
+
name : str
+
name of the proposition
+
+
Examples
+
--------
+
```python
+
P = Proposition('Person')
+
```
+
+
"""
+
+
def __init__(self, name: str, **kwds):
+
super().__init__(name=name, arity=1, propositional=True, **kwds)
+
+
[docs] def add_data(self, fact: Fact):
+
"""Populate proposition with facts
+
+
Facts required in bool, tuple or None
+
None fact assumes `Unknown`
+
tuple fact required in bounds form `(Lower, Upper)`
+
+
"""
+
super().add_data(fact)
+
+
+[docs]def Propositions(*propositions: str, **kwds):
+
r"""Instantiates multiple propositions.
+
+
Examples
+
--------
+
```python
+
P1, P2 = Propositions("P1", "P2")
+
```
+
+
"""
+
return utils.return1([Proposition(p, **kwds) for p in propositions])
+
+
+[docs]class Variable:
+
r"""Free variables to quantify first-order logic formulae
+
+
Parameters
+
------------
+
name : str
+
name of the free variable
+
type : str, optional
+
constant of the type associated with the free variable
+
+
Examples
+
--------
+
```python
+
x = Variable('x', 'person')
+
```
+
+
"""
+
+
def __init__(self, name: str, type: Optional[str] = None):
+
self.name = name
+
self.type = type
+
+
def __str__(self) -> str:
+
r"""Returns the name of the free variable"""
+
return self.name
+
+
+[docs]def Variables(*variables: str, **kwds) -> Union[Variable, Tuple[Variable, ...]]:
+
"""Instantiates multiple variables.
+
+
Examples
+
--------
+
```python
+
x, y = Variables("x", "y")
+
```
+
+
"""
+
return utils.return1([Variable(v, **kwds) for v in variables])
+
+
+[docs]class And(_NAryNeuron):
+
r"""Symbolic n-ary [conjunction](https://en.wikipedia.org/wiki/Logical_conjunction).
+
+
Returns a logical conjunction where inputs can be [propositions](LNN.html#lnn.Proposition), `called` first-order logic [predicates](LNN.html#lnn.Predicate) or any other [connective formulae](LNN.html#symbolic-structure).
+
Propositional inputs yield a propositional node, whereas if any input is a predicate it will cause the connective to increase its dimension to also be a FOL node (i.e. stores a table of facts).
+
+
Parameters
+
----------
+
``*formula`` : Formula
+
A variable length argument list that accepts any number of input formulae objects as arguments.
+
name : str, optional
+
A custom name for the node to be used for identification and custom printing. If unspecified, defaults the structure of the node.
+
activation : dict, optional
+
Parameters given as a dictionary of configuration options, see the [neural configuration](../usage.html#neural-configuration) for more details
+
+
Examples
+
--------
+
```python
+
# Propositional
+
A, B, C = Propositions('A', 'B', 'C')
+
And(A, B, C)
+
```
+
```python
+
# First-order logic
+
x, y = Variables('x', 'y')
+
A, C = Predicates('A', 'C')
+
B = Predicate('B', arity=2)
+
And(A(x), B(x, y), C(y)))
+
```
+
+
"""
+
+
def __init__(self, *formula: Formula, **kwds):
+
kwds.setdefault("activation", {})
+
self.operator_str = self.get_operator_str(
+
kwds["activation"].get("type", None)
+
)
+
super().__init__(*formula, **kwds)
+
+
@staticmethod
+
def get_operator_str(type: NeuralActivation) -> str:
+
return f"{type.name[0]}∧" if type else "∧"
+
+
+[docs]class Or(_NAryNeuron):
+
r"""Symbolic n-ary [disjunction](https://en.wikipedia.org/wiki/Logical_disjunction).
+
+
Returns a logical disjunction where inputs can be [propositions](LNN.html#lnn.Proposition), `called` first-order logic [predicates](LNN.html#lnn.Predicate) or any other [connective formulae](LNN.html#symbolic-structure).
+
Propositional inputs yield a propositional node, whereas if any input is a predicate it will cause the connective to increase its dimension to also be a FOL node (i.e. stores a table of facts).
+
+
Parameters
+
----------
+
``*formula`` : Formula
+
A variable length argument list that accepts any number of input formulae objects as arguments.
+
name : str, optional
+
A custom name for the node to be used for identification and custom printing. If unspecified, defaults the structure of the node.
+
activation : dict, optional
+
Parameters given as a dictionary of configuration options, see the [neural configuration](../usage.html#neural-configuration) for more details
+
+
Examples
+
--------
+
```python
+
# Propositional
+
A, B, C = Propositions('A', 'B', 'C')
+
Or(A, B, C)
+
```
+
```python
+
# First-order logic
+
x, y = Variables('x', 'y')
+
A, C = Predicates('A', 'C')
+
B = Predicate('B', arity=2)
+
Or(A(x), B(x, y), C(y)))
+
```
+
+
"""
+
+
def __init__(self, *formula, **kwds):
+
kwds.setdefault("activation", {})
+
self.operator_str = self.get_operator_str(
+
kwds["activation"].get("type", None)
+
)
+
super().__init__(*formula, **kwds)
+
+
@staticmethod
+
def get_operator_str(type: NeuralActivation) -> str:
+
return f"{type.name[0]}∨" if type else "∨"
+
+
+[docs]class Implies(_BinaryNeuron):
+
r"""Symbolic binary [implication](https://en.wikipedia.org/wiki/Logical_implication).
+
+
Returns a logical implication node where inputs can be [propositions](LNN.html#lnn.Proposition), `called` first-order logic [predicates](LNN.html#lnn.Predicate) or any other [connective formulae](LNN.html#symbolic-structure).
+
Propositional inputs yield a propositional node, whereas if any input is a predicate it will cause the connective to increase its dimension to also be a FOL node (i.e. stores a table of facts).
+
+
Parameters
+
----------
+
lhs : Formula
+
The left-hand side formula of the binary inputs to the connective.
+
rhs : Formula
+
The right-hand side formula of the binary inputs to the connective.
+
name : str, optional
+
A custom name for the node to be used for identification and custom printing. If unspecified, defaults the structure of the node.
+
activation : dict, optional
+
Parameters given as a dictionary of configuration options, see the [neural configuration](../usage.html#neural-configuration) for more details
+
+
Examples
+
--------
+
```python
+
# Propositional
+
A, B = Propositions('A', 'B')
+
Implies(A, B)
+
```
+
```python
+
# First-order logic
+
x, y = Variables('x', 'y')
+
A = Predicate('A')
+
B = Predicate('B', arity=2)
+
Implies(A(x), B(x, y)))
+
```
+
+
"""
+
+
def __init__(self, lhs: Formula, rhs: Formula, **kwds):
+
self.operator_str = "→"
+
kwds.setdefault("activation", {})
+
kwds["activation"].setdefault("bias_learning", True)
+
super().__init__(lhs, rhs, **kwds)
+
self.operator_str = self.get_operator_str(
+
kwds["activation"].get("type", None)
+
)
+
super().__init__(lhs, rhs, **kwds)
+
+
@staticmethod
+
def get_operator_str(type: NeuralActivation) -> str:
+
# if type is NeuralActivation.Product:
+
# return "|"
+
return f"{type.name[0]}→" if type else "→"
+
+
+[docs]class Not(_UnaryOperator):
+
r"""Symbolic Negation
+
+
Returns a logical negation where inputs can be propositional,
+
first-order logic predicates or other connectives.
+
+
Parameters
+
------------
+
formula : Formula
+
accepts a unary input Formula
+
+
Examples
+
--------
+
```python
+
# Propositional
+
A = Proposition('A')
+
Not(A)
+
```
+
```python
+
# First-order logic
+
x, y = Variables('x', 'y')
+
A = Predicate('A', arity=2)
+
Not(A(x, y)))
+
```
+
+
"""
+
+
def __init__(self, formula: Formula, **kwds):
+
self.operator_str = "¬"
+
super().__init__(formula, **kwds)
+
kwds.setdefault("propositional", self.propositional)
+
self.neuron = _NodeActivation()(**kwds.get("activation", {}), **kwds)
+
+
[docs] def upward(self, **kwds) -> float:
+
r"""Upward inference from the operands to the operator.
+
+
Parameters
+
----------
+
lifted : bool, optional
+
flag that determines if lifting should be done on this node.
+
+
Returns
+
-------
+
tightened_bounds : float
+
The amount of bounds tightening or new information that is leaned by the inference step.
+
+
"""
+
+
# Create (potentially) new groundings from functions
+
if not self.propositional:
+
self._ground_functions()
+
+
if kwds.get("lifted"):
+
self.neuron.aggregate_world(
+
tuple(
+
_utils.negate_bounds(torch.tensor(self.operands[0].world)).tolist()
+
)
+
)
+
else:
+
if self.propositional:
+
groundings = {None}
+
else:
+
groundings = tuple(self.operands[0]._groundings)
+
for g in groundings:
+
if g not in self.grounding_table:
+
self._add_groundings(g)
+
bounds = self.neuron.aggregate_bounds(
+
None, _utils.negate_bounds(self.operands[0].get_data(*groundings))
+
)
+
if self.is_contradiction():
+
logging.info(
+
"↑ CONTRADICTION "
+
f"FOR:'{self.name}' "
+
f"FORMULA:{self.formula_number} "
+
)
+
return bounds
+
+
[docs] def downward(self, **kwds) -> torch.Tensor:
+
r"""Downward inference from the operator to the operands.
+
+
Parameters
+
----------
+
lifted : bool, optional
+
flag that determines if lifting should be done on this node.
+
+
Returns
+
-------
+
tightened_bounds : float
+
The amount of bounds tightening or new information that is leaned by the inference step.
+
+
"""
+
# Create (potentially) new groundings from functions
+
if not self.propositional:
+
self._ground_functions()
+
+
if kwds.get("lifted"):
+
self.operands[0].neuron.aggregate_world(
+
tuple(_utils.negate_bounds(torch.tensor(self.world)).tolist())
+
)
+
else:
+
if self.propositional:
+
groundings = {None}
+
else:
+
groundings = tuple(self._groundings)
+
for g in groundings:
+
if g not in self.operands[0]._groundings:
+
self.operands[0]._add_groundings(g)
+
bounds = self.operands[0].neuron.aggregate_bounds(
+
None, _utils.negate_bounds(self.get_data(*groundings))
+
)
+
if self.operands[0].is_contradiction():
+
logging.info(
+
"↓ CONTRADICTION "
+
f"FOR:'{self.operands[0].name}' "
+
f"FROM:'{self.name}' "
+
f"FORMULA:{self.operands[0].formula_number} "
+
f"PARENT:{self.formula_number} "
+
)
+
return bounds
+
+
+[docs]class Equivalent(_BinaryNeuron):
+
r"""Symbolic Equivalence - a bidirectional binary implication or IFF (if and only if) node.
+
+
Returns a logical bidirectional equivalence node where inputs can be [propositions](LNN.html#lnn.Proposition),
+
`called` first-order logic [predicates](LNN.html#lnn.Predicate) or any other [connective formulae](LNN.html#symbolic-structure).
+
Propositional inputs yield a propositional node, whereas if any input is a predicate it will cause the connective to increase its dimension to also be a FOL node (i.e. stores a table of facts).
+
+
Parameters
+
----------
+
lhs : Formula
+
The left-hand side formula of the binary inputs to the connective.
+
rhs : Formula
+
The right-hand side formula of the binary inputs to the connective.
+
name : str, optional
+
A custom name for the node to be used for identification and custom printing. If unspecified, defaults the structure of the node.
+
activation : dict, optional
+
parameters given as a dictionary of configuration options, see the [neural configuration](../usage.html#neural-configuration) for more details
+
+
Examples
+
--------
+
```python
+
# Propositional
+
A, B = Propositions('A', 'B')
+
Equivalent(A, B)
+
```
+
```python
+
# First-order logic
+
x, y = Variables('x', 'y')
+
A = Predicate('A')
+
B = Predicate('B', arity=2)
+
Equivalent(A(x), B(x, y)))
+
```
+
+
"""
+
+
def __init__(self, lhs: Formula, rhs: Formula, **kwds):
+
self.operator_str = "∧"
+
self.Imp1, self.Imp2 = Implies(lhs, rhs, **kwds), Implies(rhs, lhs, **kwds)
+
super().__init__(self.Imp1, self.Imp2, **kwds)
+
self.func = self.neuron.activation("And", direction=Direction.UPWARD)
+
self.func_inv = self.neuron.activation("And", direction=Direction.DOWNWARD)
+
+
[docs] def upward(
+
self, groundings: Set[Union[str, Tuple[str, ...]]] = None, **kwds
+
) -> float:
+
r"""Upward inference from the operands to the operator.
+
+
Parameters
+
----------
+
groundings : str or tuple of str
+
restrict upward inference to a specific grounding or row in the truth table
+
lifted : bool, optional
+
flag that determines if lifting should be done on this node.
+
+
Returns
+
-------
+
tightened_bounds : float
+
The amount of bounds tightening or new information that is leaned by the inference step.
+
+
"""
+
self.Imp1.upward(groundings, **kwds)
+
self.Imp2.upward(groundings, **kwds)
+
return super().upward(groundings, **kwds)
+
+
[docs] def downward(
+
self,
+
index: int = None,
+
groundings: Set[Union[str, Tuple[str, ...]]] = None,
+
**kwds,
+
) -> float:
+
r"""Downward inference from the operator to the operands.
+
+
Parameters
+
----------
+
index : int, optional
+
restricts downward inference to an operand at the specified index. If unspecified, all operands are updated.
+
groundings : str or tuple of str, optional
+
restrict upward inference to a specific grounding or row in the truth table
+
lifted : bool, optional
+
flag that determines if lifting should be done on this node.
+
+
Returns
+
-------
+
tightened_bounds : float
+
The amount of bounds tightening or new information that is leaned by the inference step.
+
+
"""
+
self.Imp1.downward(index, groundings, **kwds)
+
self.Imp2.downward(index, groundings, **kwds)
+
return super().downward(index, groundings, **kwds)
+
+
+[docs]class Exists(_Quantifier):
+
r"""Symbolic existential quantifier.
+
+
When working with belief bounds - existential operators restrict upward inference to only work with the given formula's lower bound. Downward inference behaves as usual.
+
+
Parameters
+
----------
+
``*variables`` : Variable
+
formula : Formula
+
The FOL formula to quantify over, may be a connective formula or a Predicate.
+
+
+
Examples
+
--------
+
No free variables, quantifies over all of the variables in the formula.
+
```python
+
Some_1 = Exists(birthdate(p, d)))
+
Some_2 = Exists(p, d, birthdate(p, d)))
+
```
+
+
Free variables, quantifies over a subset of variables in the formula.
+
```python
+
Some = Exists(p, birthdate(p, d)))
+
```
+
+
Warning
+
-------
+
Quantifier with free variables, not yet implemented. It is required that we quantify over all the variables given in the formula, either by specifying all the variables or but not specifying any variables - which is equivalent to quantifying over all variables.
+
+
"""
+
+
def __init__(self, *args, **kwds):
+
self.operator_str = "∃"
+
super().__init__(*args, **kwds)
+
+
+[docs]class ForAll(_Quantifier):
+
r"""Symbolic universal quantifier.
+
+
When working with belief bounds - universal operators restrict upward inference to only work with the given formula's upper bound. Downward inference behaves as usual.
+
+
Parameters
+
----------
+
``*variables`` : Variable
+
formula : Formula
+
The FOL formula to quantify over, may be a connective formula or a Predicate.
+
+
Examples
+
--------
+
No free variables, quantifies over all of the variables in the formula.
+
```python
+
All_1 = ForAll(birthdate(p, d)))
+
All_2 = ForAll(p, d, birthdate(p, d)))
+
```
+
+
Free variables, quantifies over a subset of variables in the formula.
+
```python
+
All = ForAll(p, birthdate(p, d)))
+
```
+
+
Warning
+
-------
+
Quantifier with free variables, not yet implemented. It is required that we quantify over all the variables given in the formula, either by specifying all the variables or but not specifying any variables - which is equivalent to quantifying over all variables.
+
+
"""
+
+
def __init__(self, *args, **kwds):
+
self.operator_str = "∀"
+
kwds.setdefault("world", World.AXIOM)
+
super().__init__(*args, **kwds)
+
+
[docs] def downward(self, **kwds) -> Union[torch.Tensor, None]:
+
r"""Downward inference from the operator to the operands.
+
+
Parameters
+
----------
+
lifted : bool, optional
+
flag that determines if lifting should be done on this node.
+
+
Returns
+
-------
+
tightened_bounds : float
+
The amount of bounds tightening or new information that is leaned by the inference step.
+
+
"""
+
# Create (potentially) new groundings from functions
+
if not self.propositional:
+
self._ground_functions()
+
+
if kwds.get("lifted"):
+
result = self.operands[0].neuron.aggregate_world(self.world)
+
if result:
+
logging.info(
+
"↓ WORLD FREE-VARIABLE UPDATED "
+
f"TIGHTENED:{result} "
+
f"FOR:'{self.operands[0].name}' "
+
f"FROM:'{self.name}' "
+
f"FORMULA:{self.operands[0].formula_number} "
+
f"PARENT:{self.formula_number} "
+
)
+
else:
+
if not self._grounding_set:
+
return
+
operand = self.operands[0]
+
current_bounds = self.get_data()
+
groundings = operand.grounding_table.keys()
+
result = operand.neuron.aggregate_bounds(
+
[operand.grounding_table.get(g) for g in groundings], current_bounds
+
)
+
if result:
+
logging.info(
+
"↓ BOUNDS UPDATED "
+
f"TIGHTENED:{result} "
+
f"FOR:'{self.operands[0].name}' "
+
f"FROM:'{self.name}' "
+
f"FORMULA:{self.operands[0].formula_number} "
+
f"PARENT:{self.formula_number} "
+
)
+
if operand.is_contradiction():
+
logging.info(
+
"↓ CONTRADICTION "
+
f"FOR:'{operand.name}' "
+
f"FROM:'{self.name}' "
+
f"FORMULA:{operand.formula_number} "
+
f"PARENT:{self.formula_number} "
+
)
+
return result
+
+
+[docs]class Congruent(_NAryOperator):
+
r"""Symbolic Congruency
+
+
This is used to define nodes that are symbolically equivalent to one another
+
(despite the possibility of neural differences)
+
+
"""
+
+
def __init__(self, *formulae: Formula, **kwds):
+
self.operator_str = "≅"
+
super().__init__(*formulae, **kwds)
+
kwds.setdefault("propositional", self.propositional)
+
self.neuron = _NodeActivation()(**kwds.get("activation", {}), **kwds)
+
+
def __contains__(self, item):
+
return True if item in self.congruent_nodes else False
+
+
[docs] def add_data(self, facts: Union[Fact, Tuple, Set, Dict]):
+
"""Should not be called by the user"""
+
raise AttributeError(
+
"Should not be called directly by the user, instead use "
+
"`congruent_node.upward()` to evaluate the facts from the operands"
+
)
+
+
def upward(
+
self, groundings: Set[Union[str, Tuple[str, ...]]] = None, **kwds
+
) -> float:
+
r"""Upward inference from the operands to the operator.
+
+
Parameters
+
----------
+
groundings : str or tuple of str
+
restrict upward inference to a specific grounding or row in the truth table
+
lifted : bool, optional
+
flag that determines if lifting should be done on this node.
+
+
Returns
+
-------
+
tightened_bounds : float
+
The amount of bounds tightening or new information that is leaned by the inference step.
+
+
"""
+
if kwds.get("lifted"):
+
operands_world = torch.stack(
+
[torch.tensor(op.world) for op in self.operands], dim=-1
+
)
+
result = self.neuron.aggregate_world(
+
tuple(
+
torch.stack(
+
[
+
operands_world[..., 0, :].max(),
+
operands_world[..., 1, :].min(),
+
],
+
dim=-1,
+
).tolist()
+
)
+
)
+
if result:
+
logging.info(
+
"↑ WORLD FREE-VARIABLE UPDATED "
+
f"TIGHTENED:{result} "
+
f"FOR:'{self.name}' "
+
f"FORMULA:{self.formula_number} "
+
)
+
else:
+
upward_bounds = _gm.upward_bounds(self, self.operands, groundings)
+
if upward_bounds is None: # contradiction arresting
+
return
+
input_bounds, groundings = upward_bounds
+
grounding_rows = (
+
None
+
if self.propositional
+
else (
+
self.grounding_table.values()
+
if groundings is None
+
else [self.grounding_table.get(g) for g in groundings]
+
)
+
)
+
input_bounds = torch.stack(
+
[
+
input_bounds[..., 0, :].max(-1)[0],
+
input_bounds[..., 1, :].max(-1)[0],
+
],
+
dim=-1,
+
)
+
result = self.neuron.aggregate_bounds(grounding_rows, input_bounds)
+
if result:
+
logging.info(
+
"↑ BOUNDS UPDATED "
+
f"TIGHTENED:{result} "
+
f"FOR:'{self.name}' "
+
f"FORMULA:{self.formula_number} "
+
)
+
return result
+
+
[docs] def downward(
+
self,
+
index: int = None,
+
groundings: Set[Union[str, Tuple[str, ...]]] = None,
+
**kwds,
+
) -> Union[torch.Tensor, None]:
+
r"""Downward inference from the operator to the operands.
+
+
Parameters
+
----------
+
index : int, optional
+
restricts downward inference to an operand at the specified index. If unspecified, all operands are updated.
+
groundings : str or tuple of str, optional
+
restrict upward inference to a specific grounding or row in the truth table
+
lifted : bool, optional
+
flag that determines if lifting should be done on this node.
+
+
Returns
+
-------
+
tightened_bounds : float
+
The amount of bounds tightening or new information that is leaned by the inference step.
+
+
"""
+
if kwds.get("lifted"):
+
result = 0
+
for op_idx, op in enumerate(self.operands):
+
op_aggregate = op.neuron.aggregate_world(self.world)
+
result += op_aggregate
+
if op_aggregate:
+
logging.info(
+
"↓ WORLD FREE-VARIABLE UPDATED "
+
f"TIGHTENED:{op_aggregate} "
+
f"FOR:'{op.name}' "
+
f"FROM:'{self.name}' "
+
f"FORMULA:{op.formula_number} "
+
f"PARENT:{self.formula_number} "
+
)
+
else:
+
downward_bounds = _gm.downward_bounds(self, self.operands, groundings)
+
if downward_bounds is None: # contradiction arresting
+
return
+
parent, _, groundings = downward_bounds
+
op_indices = (
+
enumerate(self.operands)
+
if index is None
+
else ([(index, self.operands[index])])
+
)
+
result = 0
+
for op_index, op in op_indices:
+
if op.propositional:
+
op_grounding_rows = None
+
else:
+
if groundings is None:
+
op_grounding_rows = op.grounding_table.values()
+
else:
+
op_grounding_rows = [None] * len(groundings)
+
for g_i, g in enumerate(groundings):
+
op_g = [
+
str(g.partial_grounding[slot])
+
for slot in self.operand_map[op_index]
+
]
+
op_g = _Grounding(tuple(op_g) if len(op_g) > 1 else op_g[0])
+
op_grounding_rows[g_i] = op.grounding_table.get(op_g)
+
op_aggregate = op.neuron.aggregate_bounds(op_grounding_rows, parent)
+
if op_aggregate:
+
logging.info(
+
"↓ BOUNDS UPDATED "
+
f"TIGHTENED:{op_aggregate} "
+
f"FOR:'{op.name}' "
+
f"FROM:'{self.name}' "
+
f"FORMULA:{op.formula_number} "
+
f"PARENT:{self.formula_number} "
+
)
+
result = result + op_aggregate
+
return result
+
+
def extract_congruency(self, *formulae):
+
for idx, formula in enumerate(formulae):
+
if self not in formula.congruent_nodes:
+
formula.congruent_nodes.append(self)
+
+
def set_congruency(self):
+
for formula in self.operands:
+
if self not in formula.congruent_nodes:
+
formula.congruent_nodes.append(self)
+
+
def upward(
+
self, groundings: Set[Union[str, Tuple[str, ...]]] = None, **kwds
+
) -> float:
+
if kwds.get("lifted"):
+
operands_world = torch.stack(
+
[torch.tensor(op.world) for op in self.operands], dim=-1
+
)
+
result = self.neuron.aggregate_world(
+
tuple(
+
torch.stack(
+
[
+
operands_world[..., 0, :].max(),
+
operands_world[..., 1, :].min(),
+
],
+
dim=-1,
+
).tolist()
+
)
+
)
+
if result:
+
logging.info(
+
"↑ WORLD FREE-VARIABLE UPDATED "
+
f"TIGHTENED:{result} "
+
f"FOR:'{self.name}' "
+
f"FORMULA:{self.formula_number} "
+
)
+
else:
+
upward_bounds = _gm.upward_bounds(self, self.operands, groundings)
+
if upward_bounds is None: # contradiction arresting
+
return
+
input_bounds, groundings = upward_bounds
+
grounding_rows = (
+
None
+
if self.propositional
+
else (
+
self.grounding_table.values()
+
if groundings is None
+
else [self.grounding_table.get(g) for g in groundings]
+
)
+
)
+
input_bounds = torch.stack(
+
[
+
input_bounds[..., 0, :].max(-1)[0],
+
input_bounds[..., 1, :].min(-1)[0],
+
],
+
dim=-1,
+
)
+
result = self.neuron.aggregate_bounds(grounding_rows, input_bounds)
+
if result:
+
logging.info(
+
"↑ BOUNDS UPDATED "
+
f"TIGHTENED:{result} "
+
f"FOR:'{self.name}' "
+
f"FORMULA:{self.formula_number} "
+
)
+
return result
+
+
+[docs]class Function:
+
r"""Creates functions in first-order logic formulae
+
+
Parameters
+
------------
+
name : str
+
name of the function
+
term : Variable, str, Grounding, or tuple of their combination
+
function arguments
+
+
Examples
+
--------
+
```python
+
plus_func = Function('plus', X, Y)
+
x, y, z = Variables('x', 'y', 'z')
+
```
+
"""
+
+
# Add output arity
+
def __init__(self, name: str = "", input_dim: int = 1):
+
self.name = name
+
# The arity is expected to be known at construction.
+
self.input_dim = input_dim
+
+
# Constants and functions seen so far for each dimension.
+
self.groundings = dict()
+
+
def __str__(self):
+
args = ""
+
for arg_pos in range(self.input_dim):
+
args += "dim_" + str(arg_pos) + ", "
+
+
return self.name + "(" + args[0:-2] + ")"
+
+
def __repr__(self):
+
return self.__str__()
+
+
def __call__(
+
self,
+
*args: Union[
+
_Grounding,
+
List[_Grounding],
+
str,
+
List[str],
+
"Variable",
+
List["Variable"],
+
List[Union[str, _Grounding, "Variable"]],
+
],
+
) -> Union[_Grounding, "Function"]:
+
r"""Calls a function with arguments.
+
+
Parameters
+
------------
+
args : List of _Grounding and/or output of called Function
+
+
Examples
+
--------
+
```python
+
y = plus_func(zero, one)
+
```
+
"""
+
+
# If no input provided, it must map all available groundings.
+
if len(args) != self.input_dim and len(args) != 0:
+
raise Exception(
+
f"expected {self.input_dim} arguments" f"Received {len(args)}"
+
)
+
+
if all(
+
[
+
True if (isinstance(g, str) or isinstance(g, _Grounding)) else False
+
for g in args
+
]
+
):
+
# Full grounding
+
ground_str = ""
+
grounding = []
+
for arg_pos, arg in enumerate(args):
+
ground_str += str(arg) + ", "
+
grounding.append(str(arg))
+
+
if len(grounding) > 1:
+
grounding = tuple(grounding)
+
else:
+
grounding = (grounding[0],)
+
ground_out = self.groundings.get(grounding)
+
if ground_out is None:
+
ground_out = _Grounding(self.name + "(" + ground_str[0:-2] + ")")
+
self.groundings[grounding] = ground_out
+
+
return str(ground_out)
+
+
if all([isinstance(g, Variable) for g in args]):
+
# All variables
+
return self
+
+
# If not all groundings or variables we have a binding.
+
bindings = {}
+
for arg_pos, arg in enumerate(args):
+
if isinstance(arg, tuple) or isinstance(arg, Function):
+
bindings[arg_pos] = arg
+
elif isinstance(arg, str) or isinstance(arg, _Grounding):
+
bindings[arg_pos] = [arg]
+
else:
+
if not isinstance(arg, Variable):
+
raise TypeError(
+
f"Expected str, _Grounding, Variable, "
+
f"tuple or Function. Got {type(arg)}"
+
)
+
return self, bindings
+
+
+[docs]def Functions(*functions: str, **kwds):
+
r"""Instantiates multiple functions.
+
+
Examples
+
--------
+
```python
+
f1, f2 = Functions("f1", "f2", input_dim=2)
+
```
+
+
"""
+
return utils.return1([Function(f, **kwds) for f in functions])
+
+
+