Skip to content

Commit

Permalink
Inject a GetOrCreateMixin to Node and Relationship. (memgraph#244)
Browse files Browse the repository at this point in the history
* Align develop and main branch (memgraph#237)

Co-authored-by: Bruno Sačarić <[email protected]>

* Inject a GetOrCreateMixin to Node and Relationship.

* Use the imported GQLAlchemyError instead of a fully qualified path.

* Use the imported Tuple instead of fully qualified typing.Tuple

* Run black.

* suppress flake8's f821 on get_or_create definition

* Add tests for get_or_create and refactor the mixin into separate implementations for improved docs.

Signed-off-by: Aalekh Patel <[email protected]>

* Apply black.

Signed-off-by: Aalekh Patel <[email protected]>

* Apply black to tests as well.

Signed-off-by: Aalekh Patel <[email protected]>

* typo fix in function signature.

Signed-off-by: Aalekh Patel <[email protected]>

* Assert the counts of nodes and relationships in the test.

Signed-off-by: Aalekh Patel <[email protected]>

* Apply black.

Signed-off-by: Aalekh Patel <[email protected]>

* Add the .execute() to the query builder in failing tests and provide "name" to the node instantiation because it is a required field.

* another attempt to fix the query builder usage in tests.

* apply black fixes.

* Fix tests by relying on the database identifier `_id` instead of the user-defined `id`.

Signed-off-by: Aalekh Patel <[email protected]>

* Remove unused test functions.

Signed-off-by: Aalekh Patel <[email protected]>

---------

Signed-off-by: Aalekh Patel <[email protected]>
Co-authored-by: Katarina Supe <[email protected]>
Co-authored-by: Bruno Sačarić <[email protected]>
Co-authored-by: katarinasupe <[email protected]>
Co-authored-by: Aalekh Patel <[email protected]>
  • Loading branch information
5 people authored Sep 20, 2023
1 parent 066690f commit 0fe2368
Show file tree
Hide file tree
Showing 2 changed files with 120 additions and 0 deletions.
32 changes: 32 additions & 0 deletions gqlalchemy/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -620,6 +620,22 @@ def load(self, db: "Database") -> "Node": # noqa F821
self._id = node._id
return self

def get_or_create(self, db: "Database") -> Tuple["Node", bool]: # noqa F821
"""Return the node and a flag for whether it was created in the database.
Args:
db: The database instance to operate on.
Returns:
A tuple with the first component being the created graph node,
and the second being a boolean that is True if the node
was created in the database, and False if it was loaded instead.
"""
try:
return self.load(db=db), False
except GQLAlchemyError:
return self.save(db=db), True


class RelationshipMetaclass(BaseModel.__class__):
def __new__(mcs, name, bases, namespace, **kwargs): # noqa C901
Expand Down Expand Up @@ -693,6 +709,22 @@ def load(self, db: "Database") -> "Relationship": # noqa F821
self._id = relationship._id
return self

def get_or_create(self, db: "Database") -> Tuple["Relationship", bool]: # noqa F821
"""Return the relationship and a flag for whether it was created in the database.
Args:
db: The database instance to operate on.
Returns:
A tuple with the first component being the created graph relationship,
and the second being a boolean that is True if the relationship
was created in the database, and False if it was loaded instead.
"""
try:
return self.load(db=db), False
except GQLAlchemyError:
return self.save(db=db), True


class Path(GraphObject):
_nodes: Iterable[Node] = PrivateAttr()
Expand Down
88 changes: 88 additions & 0 deletions tests/ogm/test_get_or_create.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
# Copyright (c) 2016-2022 Memgraph Ltd. [https://memgraph.com]
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import pytest

from gqlalchemy import Node, Field, Relationship, GQLAlchemyError


@pytest.mark.parametrize("database", ["neo4j", "memgraph"], indirect=True)
def test_get_or_create_node(database):
class User(Node):
name: str = Field(unique=True, db=database)

class Streamer(User):
name: str = Field(unique=True, db=database)
id: str = Field(index=True, db=database)
followers: int = Field()
totalViewCount: int = Field()

# Assert that loading a node that doesn't yet exist raises GQLAlchemyError.
non_existent_streamer = Streamer(name="Mislav", id="7", followers=777, totalViewCount=7777)
with pytest.raises(GQLAlchemyError):
database.load_node(non_existent_streamer)

streamer, created = non_existent_streamer.get_or_create(database)
assert created is True, "Node.get_or_create should create this node since it doesn't yet exist."
assert streamer.name == "Mislav"
assert streamer.id == "7"
assert streamer.followers == 777
assert streamer.totalViewCount == 7777
assert streamer._labels == {"Streamer", "User"}

assert streamer._id is not None, "Since the streamer was created, it should not have a None _id."

streamer_other, created = non_existent_streamer.get_or_create(database)
assert created is False, "Node.get_or_create should not create this node but load it instead."
assert streamer_other.name == "Mislav"
assert streamer_other.id == "7"
assert streamer_other.followers == 777
assert streamer_other.totalViewCount == 7777
assert streamer_other._labels == {"Streamer", "User"}

assert (
streamer_other._id == streamer._id
), "Since the other streamer wasn't created, it should have the same underlying _id property."


@pytest.mark.parametrize("database", ["neo4j", "memgraph"], indirect=True)
def test_get_or_create_relationship(database):
class User(Node):
name: str = Field(unique=True, db=database)

class Follows(Relationship):
_type = "FOLLOWS"

node_from, created = User(name="foo").get_or_create(database)
assert created is True
assert node_from.name == "foo"

node_to, created = User(name="bar").get_or_create(database)
assert created is True
assert node_to.name == "bar"

assert node_from._id != node_to._id, "Since a new node was created, it should have a different id."

# Assert that loading a relationship that doesn't yet exist raises GQLAlchemyError.
non_existent_relationship = Follows(_start_node_id=node_from._id, _end_node_id=node_to._id)
with pytest.raises(GQLAlchemyError):
database.load_relationship(non_existent_relationship)

relationship, created = non_existent_relationship.get_or_create(database)
assert created is True, "Relationship.get_or_create should create this relationship since it doesn't yet exist."
assert relationship._id is not None
created_id = relationship._id

relationship_loaded, created = non_existent_relationship.get_or_create(database)
assert created is False, "Relationship.get_or_create should not create this relationship but load it instead."
assert relationship_loaded._id == created_id

0 comments on commit 0fe2368

Please sign in to comment.