Skip to content

Commit

Permalink
fix: Allow matching on multiple patterns, check target nodes in relat…
Browse files Browse the repository at this point in the history
…ionships. [translator]

This change allows first and foremost that the translator can now not only match on chains of relationships, but also on multiple, indirectly connected patterns.

This sets the groundwork to be able to create an equivalent match for a chain of join statements that are not really a chain, but starts with a driving table that joins to several other tables.

Fixes #814.

Signed-off-by: Michael Simons <[email protected]>
  • Loading branch information
michael-simons committed Jan 7, 2025
1 parent 1c1562b commit 498a74f
Show file tree
Hide file tree
Showing 2 changed files with 95 additions and 26 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -356,6 +356,49 @@ void innerJoinWrongColumn() throws SQLException, IOException {

}

@Test // GH-814
void multipleHops() throws SQLException {
try (var connection = getConnection(true, false)) {
var createGraph = """
/*+ NEO4J FORCE_CYPHER */
CREATE (p:Person {name: 'A'})
CREATE (pr:Project {serial: '1'})
CREATE (c:Company {name: 'B'})
CREATE (p)-[:PARTICIPATES_IN]->(pr)
CREATE (p)-[:WORKS_FOR]->(c);
""";

var sql = """
SELECT a.*,c.*,e.*
FROM Person a
JOIN Person_PARTICIPATES_IN_Project b
ON b."v$person_id" = a."v$id"
JOIN Project c
ON c."v$id" = b."v$project_id"
JOIN Person_WORKS_FOR_Company d
ON d."v$person_id"= a."v$id"
JOIN Company e
ON e."v$id" = d."v$company_id"
""";

try (var statement = connection.createStatement()) {
statement.executeUpdate(createGraph);
}

var cypher = connection.nativeSQL(sql);
assertThat(cypher).isEqualTo(
"MATCH (a:Person)-[b:PARTICIPATES_IN]->(c:Project), (a)-[d:WORKS_FOR]->(e:Company) RETURN elementId(a) AS `v$id`, a.name AS name, elementId(c) AS `v$id1`, c.serial AS serial, elementId(e) AS `v$id2`, e.name AS name1");
try (var statement = connection.createStatement(); var result = statement.executeQuery(sql)) {
assertThat(result.next()).isTrue();
assertThat(result.getString("name")).isEqualTo("A");
assertThat(result.getString("name1")).isEqualTo("B");
assertThat(result.next()).isFalse();
}

}

}

record PersonAndTitle(String name, String title) {
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -310,7 +310,7 @@ private Statement statement(QOM.Delete<?> d) {
this.tables.clear();
this.tables.add(d.$from());

Node e = (Node) resolveTableOrJoin(this.tables.get(0));
Node e = (Node) resolveTableOrJoinf(this.tables.get(0)).get(0);
OngoingReadingWithoutWhere m1 = Cypher.match(e);
OngoingReadingWithWhere m2 = (d.$where() != null) ? m1.where(condition(d.$where()))
: (OngoingReadingWithWhere) m1;
Expand All @@ -321,7 +321,7 @@ private Statement statement(QOM.Truncate<?> t) {
this.tables.clear();
this.tables.addAll(t.$table());

Node e = (Node) resolveTableOrJoin(this.tables.get(0));
Node e = (Node) resolveTableOrJoinf(this.tables.get(0)).get(0);
return Cypher.match(e).detachDelete(e.asExpression()).build();
}

Expand Down Expand Up @@ -350,7 +350,8 @@ private ResultStatement statement(Select<?> incoming) {
return Cypher.returning(resultColumnsSupplier.get()).build();
}

OngoingReadingWithoutWhere m1 = Cypher.match(x.$from().stream().map(this::resolveTableOrJoin).toList());
OngoingReadingWithoutWhere m1 = Cypher
.match(x.$from().stream().flatMap(t -> resolveTableOrJoinf(t).stream()).toList());
OngoingReadingWithWhere m2 = (x.$where() != null) ? m1.where(condition(x.$where()))
: (OngoingReadingWithWhere) m1;

Expand All @@ -377,7 +378,7 @@ private Statement statement(QOM.Insert<?> insert, List<? extends SelectFieldOrAs
this.tables.clear();
this.tables.add(insert.$into());

var node = (Node) this.resolveTableOrJoin(this.tables.get(0));
var node = (Node) this.resolveTableOrJoinf(this.tables.get(0)).get(0);
var rows = insert.$values();

var hasMergeProperties = !insert.$onConflict().isEmpty();
Expand Down Expand Up @@ -513,7 +514,7 @@ private Statement statement(QOM.Update<?> update) {
this.tables.clear();
this.tables.add(update.$table());

var node = (Node) this.resolveTableOrJoin(this.tables.get(0));
var node = (Node) this.resolveTableOrJoinf(this.tables.get(0)).get(0);
var updates = new ArrayList<Expression>();
update.$set().forEach((c, v) -> {
updates.add(node.property(((Field<?>) c).getName()));
Expand Down Expand Up @@ -555,7 +556,7 @@ else if (t instanceof Asterisk) {
}
return properties.stream();
}
else if (t instanceof QualifiedAsterisk q && resolveTableOrJoin(q.$table()) instanceof Node node) {
else if (t instanceof QualifiedAsterisk q && resolveTableOrJoinf(q.$table()).get(0) instanceof Node node) {

var properties = new ArrayList<Expression>();
for (var table : this.tables) {
Expand Down Expand Up @@ -586,7 +587,7 @@ private List<Expression> projectAllColumns(List<Table<?>> from) {
return properties;
}
for (Table<?> table : from) {
var pc = (PropertyContainer) resolveTableOrJoin(table);
var pc = (PropertyContainer) resolveTableOrJoinf(table).get(0);
var tableName = labelOrType(table);
if (!(pc instanceof Relationship rel)) {
properties.addAll(findProperties(tableName, pc));
Expand Down Expand Up @@ -696,7 +697,7 @@ private SortItem expression(SortField<?> s) {
private Expression findTableFieldInTables(TableField<?, ?> tf, boolean fallbackToFieldName) {
Expression col = null;
if (this.tables.size() == 1) {
var propertyContainer = (PropertyContainer) resolveTableOrJoin(this.tables.get(0));
var propertyContainer = (PropertyContainer) resolveTableOrJoinf(this.tables.get(0)).get(0);
if (isElementId(tf)) {
return makeId(propertyContainer, tf.getName());
}
Expand All @@ -721,19 +722,19 @@ else if (this.databaseMetaData != null) {

// Figure out virtual columns
if (isId) {
var pc = (PropertyContainer) resolveTableOrJoin(table);
var pc = (PropertyContainer) resolveTableOrJoinf(table).get(0);
return makeId(pc, tf.getName());
}
if (tableName.equalsIgnoreCase(prefix)) {
var pc = (PropertyContainer) resolveTableOrJoin(table);
var pc = (PropertyContainer) resolveTableOrJoinf(table).get(0);
return makeId(pc, tf.getName());
}

try (var columns = this.databaseMetaData.getColumns(null, null, tableName, null)) {
while (columns.next()) {
var columnName = columns.getString("COLUMN_NAME");
if (columnName.equals(tf.getName())) {
var pc = (PropertyContainer) resolveTableOrJoin(table);
var pc = (PropertyContainer) resolveTableOrJoinf(table).get(0);
col = pc.property(tf.getName());
}
}
Expand Down Expand Up @@ -783,7 +784,7 @@ else if (f instanceof TableField<?, ?> tf) {
return tableField;
}

var pe = resolveTableOrJoin(tf.getTable());
var pe = resolveTableOrJoinf(tf.getTable()).get(0);
if (pe instanceof PropertyContainer pc) {
var m = ELEMENT_ID_PATTERN.matcher(tf.getName());
if (m.matches()) {
Expand Down Expand Up @@ -1225,27 +1226,28 @@ private Condition rowCondition(Row r1, Row r2,
return result;
}

private PatternElement resolveTableOrJoin(Table<?> t) {
private List<PatternElement> resolveTableOrJoinf(Table<?> t) {
var relationship = this.resolvedRelationships.get(Cypher.name(t.getName()));
if (relationship != null) {
return relationship;
return List.of(relationship);
}

if (t instanceof QOM.JoinTable<?, ? extends Table<?>> joinTable) {
return resolveJoin(joinTable);
}

if (t instanceof TableAlias<?> ta) {
var resolved = resolveTableOrJoin(ta.$aliased());
var patternElements = resolveTableOrJoinf(ta.$aliased());
var resolved = (patternElements.size() == 1) ? patternElements.get(0) : null;
if ((resolved instanceof Node || resolved instanceof Relationship) && !ta.$alias().empty()) {
return nodeOrPattern(ta.$aliased(), ta.$alias().last());
return List.of(nodeOrPattern(ta.$aliased(), ta.$alias().last()));
}
else {
throw unsupported(ta);
}
}
else {
return nodeOrPattern(t, t.getName());
return List.of(nodeOrPattern(t, t.getName()));
}
}

Expand Down Expand Up @@ -1277,7 +1279,7 @@ private PatternElement nodeOrPattern(Table<?> t, String name) {
return Cypher.node(primaryLabel).named(symbolicName);
}

private RelationshipPattern resolveJoin(QOM.JoinTable<?, ? extends Table<?>> joinTable) {
private List<PatternElement> resolveJoin(QOM.JoinTable<?, ? extends Table<?>> joinTable) {
var join = JoinDetails.of(joinTable);

String relType = null;
Expand All @@ -1287,9 +1289,9 @@ private RelationshipPattern resolveJoin(QOM.JoinTable<?, ? extends Table<?>> joi

Table<?> t1 = joinTable.$table1();
if (t1 instanceof QOM.JoinTable<?, ? extends Table<?>> lhsJoin) {
lhs = resolveTableOrJoin(lhsJoin.$table1());
lhs = resolveTableOrJoinf(lhsJoin.$table1()).get(0);
var eqJoin2 = JoinDetails.of(lhsJoin);
var relationship = tryToIntegrateNodeAndVirtualTable(lhs, resolveTableOrJoin(lhsJoin.$table2()),
var relationship = tryToIntegrateNodeAndVirtualTable(lhs, resolveTableOrJoinf(lhsJoin.$table2()).get(0),
eqJoin2.eq);
if (relationship != null) {
lhs = relationship;
Expand All @@ -1302,27 +1304,50 @@ private RelationshipPattern resolveJoin(QOM.JoinTable<?, ? extends Table<?>> joi
}
}
else if (join.eq != null) {
lhs = resolveTableOrJoin(t1);
lhs = resolveTableOrJoinf(t1).get(0);
relType = type(t1, join.eq.$arg2());
}
else if (join.join != null && join.join.$using().isEmpty()) {
throw unsupported(joinTable);
}
else {
lhs = resolveTableOrJoin(t1);
lhs = resolveTableOrJoinf(t1).get(0);
relType = (join.join != null) ? type(t1, join.join.$using().get(0)) : null;
}

if (relSymbolicName == null && relType != null) {
relSymbolicName = symbolicName(relType);
}

PatternElement rhs = resolveTableOrJoin(joinTable.$table2());
PatternElement rhs = resolveTableOrJoinf(joinTable.$table2()).get(0);

if (lhs instanceof ExposesRelationships<?> from && rhs instanceof Node to) {
var relationship = tryToIntegrateNodeAndVirtualTable(lhs, rhs, join.eq);
if (relationship != null) {
return relationship;
return List.of(relationship);
}

List<PatternElement> resolved = new ArrayList<>();

// Figure out the left most driving table of the join and check if it's in
// the relationship
Table<?> leftMost = joinTable.$table1();
while (leftMost instanceof QOM.JoinTable<?, ?> tab) {
leftMost = tab.$table1();
}
var hlp = resolveTableOrJoinf(leftMost).get(0);
// We have one single previous relationship, the left most node matching
// the leftmost table and the previous table is a join table.
// Future safety check might actually be comparing the equals operators,
// too.
if (from instanceof Relationship r && hlp instanceof Node leftMostNode
&& r.getLeft().getRequiredSymbolicName().equals(leftMostNode.getRequiredSymbolicName())
&& joinTable.$table1() instanceof QOM.JoinTable<?, ?> previousJoinTable
&& nodeOrPattern(previousJoinTable.$table2(),
"ignored") instanceof Relationship targetRelationship) {
resolved.add(lhs);
from = leftMostNode;
relType = targetRelationship.getDetails().getTypes().get(0);
}

var direction = Relationship.Direction.LTR;
Expand All @@ -1341,12 +1366,13 @@ else if (relationship instanceof RelationshipChain r) {
}
this.resolvedRelationships.put(relSymbolicName, relationship);
}
return relationship;
resolved.add(relationship);
return resolved;
}
else {
var relationship = tryToIntegrateNodeAndVirtualTable(lhs, rhs, join.eq);
if (relationship != null) {
return relationship;
return List.of(relationship);
}
}

Expand Down

0 comments on commit 498a74f

Please sign in to comment.