Skip to content

Commit

Permalink
Handle overlapping edges better in edge overlay.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 686520719
  • Loading branch information
Google AI Edge authored and copybara-github committed Oct 16, 2024
1 parent f76c30a commit bad7362
Show file tree
Hide file tree
Showing 2 changed files with 54 additions and 2 deletions.
11 changes: 9 additions & 2 deletions src/ui/src/components/visualizer/common/utils.ts
Original file line number Diff line number Diff line change
Expand Up @@ -983,12 +983,19 @@ export function getMultiLineLabelExtraHeight(label: string): number {
* Calculates the closest intersection points of a line (L) connecting
* the centers of two rectangles (rect1 and rect2) with the sides of these
* rectangles.
*
* xOffsetFactor is used to shift the center of the rectangle to the left or
* right by a certain factor of the width of the rectangle.
*/
export function getIntersectionPoints(rect1: Rect, rect2: Rect) {
export function getIntersectionPoints(
rect1: Rect,
rect2: Rect,
xOffsetFactor = 0,
) {
// Function to calculate the center of a rectangle
function getCenter(rect: Rect) {
return {
x: rect.x + rect.width / 2,
x: rect.x + rect.width / 2 + xOffsetFactor * rect.width,
y: rect.y + rect.height / 2,
};
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,23 @@ export class WebglRendererEdgeOverlaysService {
return;
}

// Keep track of number of edges for a given pair of nodes. If there are
// more than 1 edges, we will shift the edges to avoid overlapping.
//
// From sorted edge key (nodeId1->nodeId2) to the number of edges for that
// pair.
const seenEdgePairs: Record<string, number> = {};
const totalEdgePairs: Record<string, number> = {};

// Populate totalEdgePairs.
for (let i = 0; i < this.curOverlays.length; i++) {
const subgraph = this.curOverlays[i];
for (const {sourceNodeId, targetNodeId, label} of subgraph.edges) {
this.addToEdgePairs(sourceNodeId, targetNodeId, totalEdgePairs);
}
}
console.log(totalEdgePairs);

for (let i = 0; i < this.curOverlays.length; i++) {
const subgraph = this.curOverlays[i];
const edgeWidth = subgraph.edgeWidth ?? DEFAULT_EDGE_WIDTH;
Expand All @@ -98,9 +115,18 @@ export class WebglRendererEdgeOverlaysService {
const targetNode = this.webglRenderer.curModelGraph.nodesById[
targetNodeId
] as OpNode;
const curEdgesCount = this.addToEdgePairs(
sourceNodeId,
targetNodeId,
seenEdgePairs,
);
const totalEdgesCount =
totalEdgePairs[this.getEdgeKey(sourceNodeId, targetNodeId)];
const xOffsetFactor = (1 / (totalEdgesCount + 1)) * curEdgesCount - 0.5;
const {intersection1, intersection2} = getIntersectionPoints(
this.webglRenderer.getNodeRect(sourceNode),
this.webglRenderer.getNodeRect(targetNode),
xOffsetFactor,
);
// Edge.
edges.push({
Expand Down Expand Up @@ -190,4 +216,23 @@ export class WebglRendererEdgeOverlaysService {
}
return [...ids];
}

private addToEdgePairs(
nodeId1: string,
nodeId2: string,
pairs: Record<string, number>,
): number {
const key = this.getEdgeKey(nodeId1, nodeId2);
if (pairs[key] === undefined) {
pairs[key] = 0;
}
pairs[key]++;
return pairs[key];
}

private getEdgeKey(nodeId1: string, nodeId2: string): string {
return nodeId1.localeCompare(nodeId2) < 0
? `${nodeId1}___${nodeId2}`
: `${nodeId2}___${nodeId1}`;
}
}

0 comments on commit bad7362

Please sign in to comment.