Skip to content

Commit

Permalink
try to save core graph state
Browse files Browse the repository at this point in the history
  • Loading branch information
sc420 committed Feb 18, 2024
1 parent 9604b86 commit 60a4bf3
Show file tree
Hide file tree
Showing 10 changed files with 87 additions and 6 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,12 @@ test("should trigger the event when clicking the load button", async () => {
const contents = `\
{
"coreGraphAdapterState": {
"coreGraphState": {
"nodeIdToNodes": {},
"differentiationMode": "REVERSE",
"targetNodeId": null,
"nodeIdToDerivatives": {}
},
"nodeIdToNames": {},
"dummyInputNodeIdToNodeIds": {}
},
Expand Down
7 changes: 7 additions & 0 deletions interactive-computational-graph/src/core/ConstantNode.ts
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import type CoreNodeState from "../states/CoreNodeState";
import type CoreNode from "./CoreNode";
import NodeRelationship from "./NodeRelationship";
import type NodeType from "./NodeType";
Expand Down Expand Up @@ -44,6 +45,12 @@ class ConstantNode implements CoreNode {
getRelationship(): NodeRelationship {
return this.nodeRelationship;
}

save(): CoreNodeState {
return {
nodeType: "CONSTANT",
};
}
}

export default ConstantNode;
3 changes: 3 additions & 0 deletions interactive-computational-graph/src/core/CoreNode.ts
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import type CoreNodeState from "../states/CoreNodeState";
import type NodeRelationship from "./NodeRelationship";
import type NodeType from "./NodeType";

Expand Down Expand Up @@ -25,6 +26,8 @@ interface CoreNode {
calculateDfdx: (x: CoreNode) => string;

getRelationship: () => NodeRelationship;

save: () => CoreNodeState;
}

export default CoreNode;
28 changes: 27 additions & 1 deletion interactive-computational-graph/src/core/Graph.ts
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import type CoreGraphState from "../states/CoreGraphState";
import type CoreNodeState from "../states/CoreNodeState";
import type ChainRuleTerm from "./ChainRuleTerm";
import {
CycleError,
Expand Down Expand Up @@ -27,7 +29,7 @@ class Graph {
* the current target node. For forward mode, it propagates from left to
* right. For reverse mode, it propagates from right to left.
*/
private readonly nodeIdToDerivatives = new Map<string, string>();
private nodeIdToDerivatives = new Map<string, string>();

getNodes(): CoreNode[] {
return Array.from(this.nodeIdToNodes.values());
Expand Down Expand Up @@ -306,6 +308,30 @@ multiple edges`,
});
}

save(): CoreGraphState {
const nodeIdToNodes: Record<string, CoreNodeState> = {};
this.nodeIdToNodes.forEach((node, nodeId) => {
nodeIdToNodes[nodeId] = node.save();
});

return {
nodeIdToNodes,
differentiationMode: this.differentiationMode,
targetNodeId: this.targetNodeId,
nodeIdToDerivatives: Object.fromEntries(this.nodeIdToDerivatives),
};
}

load(state: CoreGraphState): void {
// TODO(sc420): Build the core nodes

this.differentiationMode = state.differentiationMode;
this.targetNodeId = state.targetNodeId;
this.nodeIdToDerivatives = new Map(
Object.entries(state.nodeIdToDerivatives),
);
}

/**
* Checks if we can visit the node y from node x in an acyclic graph.
*
Expand Down
9 changes: 9 additions & 0 deletions interactive-computational-graph/src/core/OperationNode.ts
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import type CoreNodeState from "../states/CoreNodeState";
import type CoreNode from "./CoreNode";
import NodeRelationship from "./NodeRelationship";
import type NodeType from "./NodeType";
Expand Down Expand Up @@ -90,6 +91,14 @@ class OperationNode implements CoreNode {
});
return fInputNodeToValues;
}

save(): CoreNodeState {
return {
nodeType: "OPERATION",
value: this.value,
operationId: "", // TODO(sc420): Return operation ID
};
}
}

export default OperationNode;
8 changes: 8 additions & 0 deletions interactive-computational-graph/src/core/VariableNode.ts
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import type CoreNodeState from "../states/CoreNodeState";
import type CoreNode from "./CoreNode";
import NodeRelationship from "./NodeRelationship";
import type NodeType from "./NodeType";
Expand Down Expand Up @@ -50,6 +51,13 @@ class VariableNode implements CoreNode {
getRelationship(): NodeRelationship {
return this.nodeRelationship;
}

save(): CoreNodeState {
return {
nodeType: "VARIABLE",
value: this.value,
};
}
}

export default VariableNode;
Original file line number Diff line number Diff line change
Expand Up @@ -548,6 +548,7 @@ cycle`;

save(): CoreGraphAdapterState {
return {
coreGraphState: this.graph.save(),
nodeIdToNames: Object.fromEntries(this.nodeIdToNames),
dummyInputNodeIdToNodeIds: Object.fromEntries(
this.dummyInputNodeIdToNodeIds,
Expand All @@ -556,6 +557,7 @@ cycle`;
}

load(state: CoreGraphAdapterState): void {
this.graph.load(state.coreGraphState);
this.nodeIdToNames = new Map(Object.entries(state.nodeIdToNames));
this.dummyInputNodeIdToNodeIds = new Map(
Object.entries(state.dummyInputNodeIdToNodeIds),
Expand Down
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
// import type CoreGraphState from "./CoreGraphState";
import type CoreGraphState from "./CoreGraphState";

interface CoreGraphAdapterState {
// TODO(sc420): Uncomment
// coreGraphState: CoreGraphState;
coreGraphState: CoreGraphState;
nodeIdToNames: Record<string, string>;
dummyInputNodeIdToNodeIds: Record<string, string>;
}
Expand Down
5 changes: 3 additions & 2 deletions interactive-computational-graph/src/states/CoreGraphState.ts
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
import type DifferentiationMode from "../core/DifferentiationMode";
import type CoreNodeState from "./CoreNodeState";

interface CoreGraphState {
// TODO(sc420): Add nodeIdToNodes state
nodeIdToNodes: Record<string, CoreNodeState>;
differentiationMode: DifferentiationMode;
targetNodeId: string | null;
nodeIdToDerivatives: Map<string, string>;
nodeIdToDerivatives: Record<string, string>;
}

export default CoreGraphState;
20 changes: 20 additions & 0 deletions interactive-computational-graph/src/states/CoreNodeState.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
// TODO(sc420): Add nodeRelationship states

interface ConstantNodeState {
nodeType: "CONSTANT";
}

interface VariableNodeState {
nodeType: "VARIABLE";
value: string;
}

interface OperationNodeState {
nodeType: "OPERATION";
value: string;
operationId: string;
}

type CoreNodeState = ConstantNodeState | VariableNodeState | OperationNodeState;

export default CoreNodeState;

0 comments on commit 60a4bf3

Please sign in to comment.