diff --git a/gqlalchemy/vendors/memgraph.py b/gqlalchemy/vendors/memgraph.py index 6918b19a..6e8dc1ad 100644 --- a/gqlalchemy/vendors/memgraph.py +++ b/gqlalchemy/vendors/memgraph.py @@ -40,6 +40,26 @@ __all__ = ("Memgraph",) +class MemgraphTransaction: + def __init__(self, username: str, transaction_id: str, query: list, metadata: dict): + self.username = (username,) + self.transaction_id = transaction_id + self.query = query + self.metadata = metadata + + def __repr__(self): + return f"MemgraphTransaction(username={self.username}, transaction_id={self.transaction_id}, query={self.query}, metadata={self.metadata})" + + +class MemgraphTerminatedTransaction: + def __init__(self, transaction_id: str, killed: bool): + self.transaction_id = transaction_id + self.killed = killed + + def __repr__(self): + return f"MemgraphTerminatedTransaction(transaction_id={self.transaction_id}, killed={self.killed})" + + class MemgraphConstants: CONSTRAINT_TYPE = "constraint type" EXISTS = "exists" @@ -49,6 +69,34 @@ class MemgraphConstants: UNIQUE = "unique" +def create_transaction(transaction_data) -> MemgraphTransaction: + """Create a MemgraphTransaction object from transaction data. + Args: + transaction_data (dict): A dictionary containing transaction data. + Returns: + MemgraphTransaction: A MemgraphTransaction object. + """ + return MemgraphTransaction( + username=transaction_data["username"], + transaction_id=transaction_data["transaction_id"], + query=transaction_data["query"], + metadata=transaction_data["metadata"], + ) + + +def create_terminated_transaction(transaction_data) -> MemgraphTerminatedTransaction: + """Create a MemgraphTerminatedTransaction object from transaction data. + Args: + transaction_data (dict): A dictionary containing transaction data. + Returns: + MemgraphTerminatedTransaction: A MemgraphTerminatedTransaction object. + """ + return MemgraphTerminatedTransaction( + transaction_id=transaction_data["transaction_id"], + killed=transaction_data["killed"], + ) + + class Memgraph(DatabaseClient): def __init__( self, @@ -432,3 +480,32 @@ def with_power_bi(self) -> "Memgraph": module_name = "power_bi_stream.py" return self.add_query_module(file_path=file_path, module_name=module_name) + + def get_transactions(self) -> List[MemgraphTransaction]: + """Get all transactions in the database. + Returns: + List[MemgraphTransaction]: A list of MemgraphTransaction objects. + """ + + transactions_data = self.execute_and_fetch("SHOW TRANSACTIONS;") + transactions = list(map(create_transaction, transactions_data)) + + return transactions + + def terminate_transactions(self, transaction_ids: List[str]) -> List[MemgraphTerminatedTransaction]: + """Terminate transactions in the database. + Args: + transaction_ids (List[str]): A list of transaction ids to terminate. + Returns: + List[MemgraphTerminatedTransaction]: A list of MemgraphTerminatedTransaction objects with info on their status. + """ + + query = ( + "TERMINATE TRANSACTIONS " + ", ".join([f"'{transaction_id}'" for transaction_id in transaction_ids]) + ";" + ) + + transactions_data = self.execute_and_fetch(query) + + terminated_transactions = list(map(create_terminated_transaction, transactions_data)) + + return terminated_transactions diff --git a/tests/ogm/test_transactions.py b/tests/ogm/test_transactions.py new file mode 100644 index 00000000..c814eb10 --- /dev/null +++ b/tests/ogm/test_transactions.py @@ -0,0 +1,14 @@ +def test_get_transactions(memgraph): + result = memgraph.get_transactions() + assert len(result) == 1 + assert result[0].username == ("",) + assert result[0].transaction_id != "" + assert result[0].query == ["SHOW TRANSACTIONS;"] + assert result[0].metadata == {} + + +def test_terminate_transactions(memgraph): + result = memgraph.get_transactions() + terminated_transactions = memgraph.terminate_transactions([result[0].transaction_id]) + assert terminated_transactions[0].killed is False + assert terminated_transactions[0].transaction_id == result[0].transaction_id