Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: Allow matching on multiple patterns, check target nodes in relationships. [translator] #819

Merged
merged 1 commit into from
Jan 7, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -356,6 +356,49 @@ void innerJoinWrongColumn() throws SQLException, IOException {

}

@Test // GH-814
void multipleHops() throws SQLException {
try (var connection = getConnection(true, false)) {
var createGraph = """
/*+ NEO4J FORCE_CYPHER */
CREATE (p:Person {name: 'A'})
CREATE (pr:Project {serial: '1'})
CREATE (c:Company {name: 'B'})
CREATE (p)-[:PARTICIPATES_IN]->(pr)
CREATE (p)-[:WORKS_FOR]->(c);
""";

var sql = """
SELECT a.*,c.*,e.*
FROM Person a
JOIN Person_PARTICIPATES_IN_Project b
ON b."v$person_id" = a."v$id"
JOIN Project c
ON c."v$id" = b."v$project_id"
JOIN Person_WORKS_FOR_Company d
ON d."v$person_id"= a."v$id"
JOIN Company e
ON e."v$id" = d."v$company_id"
""";

try (var statement = connection.createStatement()) {
statement.executeUpdate(createGraph);
}

var cypher = connection.nativeSQL(sql);
assertThat(cypher).isEqualTo(
"MATCH (a:Person)-[b:PARTICIPATES_IN]->(c:Project), (a)-[d:WORKS_FOR]->(e:Company) RETURN elementId(a) AS `v$id`, a.name AS name, elementId(c) AS `v$id1`, c.serial AS serial, elementId(e) AS `v$id2`, e.name AS name1");
try (var statement = connection.createStatement(); var result = statement.executeQuery(sql)) {
assertThat(result.next()).isTrue();
assertThat(result.getString("name")).isEqualTo("A");
assertThat(result.getString("name1")).isEqualTo("B");
assertThat(result.next()).isFalse();
}

}

}

record PersonAndTitle(String name, String title) {
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -310,7 +310,7 @@ private Statement statement(QOM.Delete<?> d) {
this.tables.clear();
this.tables.add(d.$from());

Node e = (Node) resolveTableOrJoin(this.tables.get(0));
Node e = (Node) resolveTableOrJoinf(this.tables.get(0)).get(0);
OngoingReadingWithoutWhere m1 = Cypher.match(e);
OngoingReadingWithWhere m2 = (d.$where() != null) ? m1.where(condition(d.$where()))
: (OngoingReadingWithWhere) m1;
Expand All @@ -321,7 +321,7 @@ private Statement statement(QOM.Truncate<?> t) {
this.tables.clear();
this.tables.addAll(t.$table());

Node e = (Node) resolveTableOrJoin(this.tables.get(0));
Node e = (Node) resolveTableOrJoinf(this.tables.get(0)).get(0);
return Cypher.match(e).detachDelete(e.asExpression()).build();
}

Expand Down Expand Up @@ -350,7 +350,8 @@ private ResultStatement statement(Select<?> incoming) {
return Cypher.returning(resultColumnsSupplier.get()).build();
}

OngoingReadingWithoutWhere m1 = Cypher.match(x.$from().stream().map(this::resolveTableOrJoin).toList());
OngoingReadingWithoutWhere m1 = Cypher
.match(x.$from().stream().flatMap(t -> resolveTableOrJoinf(t).stream()).toList());
OngoingReadingWithWhere m2 = (x.$where() != null) ? m1.where(condition(x.$where()))
: (OngoingReadingWithWhere) m1;

Expand All @@ -377,7 +378,7 @@ private Statement statement(QOM.Insert<?> insert, List<? extends SelectFieldOrAs
this.tables.clear();
this.tables.add(insert.$into());

var node = (Node) this.resolveTableOrJoin(this.tables.get(0));
var node = (Node) this.resolveTableOrJoinf(this.tables.get(0)).get(0);
var rows = insert.$values();

var hasMergeProperties = !insert.$onConflict().isEmpty();
Expand Down Expand Up @@ -513,7 +514,7 @@ private Statement statement(QOM.Update<?> update) {
this.tables.clear();
this.tables.add(update.$table());

var node = (Node) this.resolveTableOrJoin(this.tables.get(0));
var node = (Node) this.resolveTableOrJoinf(this.tables.get(0)).get(0);
var updates = new ArrayList<Expression>();
update.$set().forEach((c, v) -> {
updates.add(node.property(((Field<?>) c).getName()));
Expand Down Expand Up @@ -555,7 +556,7 @@ else if (t instanceof Asterisk) {
}
return properties.stream();
}
else if (t instanceof QualifiedAsterisk q && resolveTableOrJoin(q.$table()) instanceof Node node) {
else if (t instanceof QualifiedAsterisk q && resolveTableOrJoinf(q.$table()).get(0) instanceof Node node) {

var properties = new ArrayList<Expression>();
for (var table : this.tables) {
Expand Down Expand Up @@ -586,7 +587,7 @@ private List<Expression> projectAllColumns(List<Table<?>> from) {
return properties;
}
for (Table<?> table : from) {
var pc = (PropertyContainer) resolveTableOrJoin(table);
var pc = (PropertyContainer) resolveTableOrJoinf(table).get(0);
var tableName = labelOrType(table);
if (!(pc instanceof Relationship rel)) {
properties.addAll(findProperties(tableName, pc));
Expand Down Expand Up @@ -696,7 +697,7 @@ private SortItem expression(SortField<?> s) {
private Expression findTableFieldInTables(TableField<?, ?> tf, boolean fallbackToFieldName) {
Expression col = null;
if (this.tables.size() == 1) {
var propertyContainer = (PropertyContainer) resolveTableOrJoin(this.tables.get(0));
var propertyContainer = (PropertyContainer) resolveTableOrJoinf(this.tables.get(0)).get(0);
if (isElementId(tf)) {
return makeId(propertyContainer, tf.getName());
}
Expand All @@ -721,19 +722,19 @@ else if (this.databaseMetaData != null) {

// Figure out virtual columns
if (isId) {
var pc = (PropertyContainer) resolveTableOrJoin(table);
var pc = (PropertyContainer) resolveTableOrJoinf(table).get(0);
return makeId(pc, tf.getName());
}
if (tableName.equalsIgnoreCase(prefix)) {
var pc = (PropertyContainer) resolveTableOrJoin(table);
var pc = (PropertyContainer) resolveTableOrJoinf(table).get(0);
return makeId(pc, tf.getName());
}

try (var columns = this.databaseMetaData.getColumns(null, null, tableName, null)) {
while (columns.next()) {
var columnName = columns.getString("COLUMN_NAME");
if (columnName.equals(tf.getName())) {
var pc = (PropertyContainer) resolveTableOrJoin(table);
var pc = (PropertyContainer) resolveTableOrJoinf(table).get(0);
col = pc.property(tf.getName());
}
}
Expand Down Expand Up @@ -783,7 +784,7 @@ else if (f instanceof TableField<?, ?> tf) {
return tableField;
}

var pe = resolveTableOrJoin(tf.getTable());
var pe = resolveTableOrJoinf(tf.getTable()).get(0);
if (pe instanceof PropertyContainer pc) {
var m = ELEMENT_ID_PATTERN.matcher(tf.getName());
if (m.matches()) {
Expand Down Expand Up @@ -1225,27 +1226,28 @@ private Condition rowCondition(Row r1, Row r2,
return result;
}

private PatternElement resolveTableOrJoin(Table<?> t) {
private List<PatternElement> resolveTableOrJoinf(Table<?> t) {
var relationship = this.resolvedRelationships.get(Cypher.name(t.getName()));
if (relationship != null) {
return relationship;
return List.of(relationship);
}

if (t instanceof QOM.JoinTable<?, ? extends Table<?>> joinTable) {
return resolveJoin(joinTable);
}

if (t instanceof TableAlias<?> ta) {
var resolved = resolveTableOrJoin(ta.$aliased());
var patternElements = resolveTableOrJoinf(ta.$aliased());
var resolved = (patternElements.size() == 1) ? patternElements.get(0) : null;
if ((resolved instanceof Node || resolved instanceof Relationship) && !ta.$alias().empty()) {
return nodeOrPattern(ta.$aliased(), ta.$alias().last());
return List.of(nodeOrPattern(ta.$aliased(), ta.$alias().last()));
}
else {
throw unsupported(ta);
}
}
else {
return nodeOrPattern(t, t.getName());
return List.of(nodeOrPattern(t, t.getName()));
}
}

Expand Down Expand Up @@ -1277,7 +1279,7 @@ private PatternElement nodeOrPattern(Table<?> t, String name) {
return Cypher.node(primaryLabel).named(symbolicName);
}

private RelationshipPattern resolveJoin(QOM.JoinTable<?, ? extends Table<?>> joinTable) {
private List<PatternElement> resolveJoin(QOM.JoinTable<?, ? extends Table<?>> joinTable) {
var join = JoinDetails.of(joinTable);

String relType = null;
Expand All @@ -1287,9 +1289,9 @@ private RelationshipPattern resolveJoin(QOM.JoinTable<?, ? extends Table<?>> joi

Table<?> t1 = joinTable.$table1();
if (t1 instanceof QOM.JoinTable<?, ? extends Table<?>> lhsJoin) {
lhs = resolveTableOrJoin(lhsJoin.$table1());
lhs = resolveTableOrJoinf(lhsJoin.$table1()).get(0);
var eqJoin2 = JoinDetails.of(lhsJoin);
var relationship = tryToIntegrateNodeAndVirtualTable(lhs, resolveTableOrJoin(lhsJoin.$table2()),
var relationship = tryToIntegrateNodeAndVirtualTable(lhs, resolveTableOrJoinf(lhsJoin.$table2()).get(0),
eqJoin2.eq);
if (relationship != null) {
lhs = relationship;
Expand All @@ -1302,27 +1304,50 @@ private RelationshipPattern resolveJoin(QOM.JoinTable<?, ? extends Table<?>> joi
}
}
else if (join.eq != null) {
lhs = resolveTableOrJoin(t1);
lhs = resolveTableOrJoinf(t1).get(0);
relType = type(t1, join.eq.$arg2());
}
else if (join.join != null && join.join.$using().isEmpty()) {
throw unsupported(joinTable);
}
else {
lhs = resolveTableOrJoin(t1);
lhs = resolveTableOrJoinf(t1).get(0);
relType = (join.join != null) ? type(t1, join.join.$using().get(0)) : null;
}

if (relSymbolicName == null && relType != null) {
relSymbolicName = symbolicName(relType);
}

PatternElement rhs = resolveTableOrJoin(joinTable.$table2());
PatternElement rhs = resolveTableOrJoinf(joinTable.$table2()).get(0);

if (lhs instanceof ExposesRelationships<?> from && rhs instanceof Node to) {
var relationship = tryToIntegrateNodeAndVirtualTable(lhs, rhs, join.eq);
if (relationship != null) {
return relationship;
return List.of(relationship);
}

List<PatternElement> resolved = new ArrayList<>();

// Figure out the left most driving table of the join and check if it's in
// the relationship
Table<?> leftMost = joinTable.$table1();
while (leftMost instanceof QOM.JoinTable<?, ?> tab) {
leftMost = tab.$table1();
}
var hlp = resolveTableOrJoinf(leftMost).get(0);
// We have one single previous relationship, the left most node matching
// the leftmost table and the previous table is a join table.
// Future safety check might actually be comparing the equals operators,
// too.
if (from instanceof Relationship r && hlp instanceof Node leftMostNode
&& r.getLeft().getRequiredSymbolicName().equals(leftMostNode.getRequiredSymbolicName())
&& joinTable.$table1() instanceof QOM.JoinTable<?, ?> previousJoinTable
&& nodeOrPattern(previousJoinTable.$table2(),
"ignored") instanceof Relationship targetRelationship) {
resolved.add(lhs);
from = leftMostNode;
relType = targetRelationship.getDetails().getTypes().get(0);
}

var direction = Relationship.Direction.LTR;
Expand All @@ -1341,12 +1366,13 @@ else if (relationship instanceof RelationshipChain r) {
}
this.resolvedRelationships.put(relSymbolicName, relationship);
}
return relationship;
resolved.add(relationship);
return resolved;
}
else {
var relationship = tryToIntegrateNodeAndVirtualTable(lhs, rhs, join.eq);
if (relationship != null) {
return relationship;
return List.of(relationship);
}
}

Expand Down
Loading