From 60a4bf397aea60a87d814e5d47100160707cc477 Mon Sep 17 00:00:00 2001 From: Shawn Chang Date: Sun, 18 Feb 2024 21:58:18 +0800 Subject: [PATCH] try to save core graph state --- .../src/components/SaveLoadPanel.test.tsx | 6 ++++ .../src/core/ConstantNode.ts | 7 +++++ .../src/core/CoreNode.ts | 3 ++ .../src/core/Graph.ts | 28 ++++++++++++++++++- .../src/core/OperationNode.ts | 9 ++++++ .../src/core/VariableNode.ts | 8 ++++++ .../src/features/CoreGraphAdapter.ts | 2 ++ .../src/states/CoreGraphAdapterState.ts | 5 ++-- .../src/states/CoreGraphState.ts | 5 ++-- .../src/states/CoreNodeState.ts | 20 +++++++++++++ 10 files changed, 87 insertions(+), 6 deletions(-) create mode 100644 interactive-computational-graph/src/states/CoreNodeState.ts diff --git a/interactive-computational-graph/src/components/SaveLoadPanel.test.tsx b/interactive-computational-graph/src/components/SaveLoadPanel.test.tsx index 8172de0..f43e6c0 100644 --- a/interactive-computational-graph/src/components/SaveLoadPanel.test.tsx +++ b/interactive-computational-graph/src/components/SaveLoadPanel.test.tsx @@ -26,6 +26,12 @@ test("should trigger the event when clicking the load button", async () => { const contents = `\ { "coreGraphAdapterState": { + "coreGraphState": { + "nodeIdToNodes": {}, + "differentiationMode": "REVERSE", + "targetNodeId": null, + "nodeIdToDerivatives": {} + }, "nodeIdToNames": {}, "dummyInputNodeIdToNodeIds": {} }, diff --git a/interactive-computational-graph/src/core/ConstantNode.ts b/interactive-computational-graph/src/core/ConstantNode.ts index 6729542..aacc653 100644 --- a/interactive-computational-graph/src/core/ConstantNode.ts +++ b/interactive-computational-graph/src/core/ConstantNode.ts @@ -1,3 +1,4 @@ +import type CoreNodeState from "../states/CoreNodeState"; import type CoreNode from "./CoreNode"; import NodeRelationship from "./NodeRelationship"; import type NodeType from "./NodeType"; @@ -44,6 +45,12 @@ class ConstantNode implements CoreNode { getRelationship(): NodeRelationship { return this.nodeRelationship; } + + save(): CoreNodeState { + return { + nodeType: "CONSTANT", + }; + } } export default ConstantNode; diff --git a/interactive-computational-graph/src/core/CoreNode.ts b/interactive-computational-graph/src/core/CoreNode.ts index f0a495a..b0a0116 100644 --- a/interactive-computational-graph/src/core/CoreNode.ts +++ b/interactive-computational-graph/src/core/CoreNode.ts @@ -1,3 +1,4 @@ +import type CoreNodeState from "../states/CoreNodeState"; import type NodeRelationship from "./NodeRelationship"; import type NodeType from "./NodeType"; @@ -25,6 +26,8 @@ interface CoreNode { calculateDfdx: (x: CoreNode) => string; getRelationship: () => NodeRelationship; + + save: () => CoreNodeState; } export default CoreNode; diff --git a/interactive-computational-graph/src/core/Graph.ts b/interactive-computational-graph/src/core/Graph.ts index 56d8efd..f4cd998 100644 --- a/interactive-computational-graph/src/core/Graph.ts +++ b/interactive-computational-graph/src/core/Graph.ts @@ -1,3 +1,5 @@ +import type CoreGraphState from "../states/CoreGraphState"; +import type CoreNodeState from "../states/CoreNodeState"; import type ChainRuleTerm from "./ChainRuleTerm"; import { CycleError, @@ -27,7 +29,7 @@ class Graph { * the current target node. For forward mode, it propagates from left to * right. For reverse mode, it propagates from right to left. */ - private readonly nodeIdToDerivatives = new Map(); + private nodeIdToDerivatives = new Map(); getNodes(): CoreNode[] { return Array.from(this.nodeIdToNodes.values()); @@ -306,6 +308,30 @@ multiple edges`, }); } + save(): CoreGraphState { + const nodeIdToNodes: Record = {}; + this.nodeIdToNodes.forEach((node, nodeId) => { + nodeIdToNodes[nodeId] = node.save(); + }); + + return { + nodeIdToNodes, + differentiationMode: this.differentiationMode, + targetNodeId: this.targetNodeId, + nodeIdToDerivatives: Object.fromEntries(this.nodeIdToDerivatives), + }; + } + + load(state: CoreGraphState): void { + // TODO(sc420): Build the core nodes + + this.differentiationMode = state.differentiationMode; + this.targetNodeId = state.targetNodeId; + this.nodeIdToDerivatives = new Map( + Object.entries(state.nodeIdToDerivatives), + ); + } + /** * Checks if we can visit the node y from node x in an acyclic graph. * diff --git a/interactive-computational-graph/src/core/OperationNode.ts b/interactive-computational-graph/src/core/OperationNode.ts index a4acf48..b0fcb41 100644 --- a/interactive-computational-graph/src/core/OperationNode.ts +++ b/interactive-computational-graph/src/core/OperationNode.ts @@ -1,3 +1,4 @@ +import type CoreNodeState from "../states/CoreNodeState"; import type CoreNode from "./CoreNode"; import NodeRelationship from "./NodeRelationship"; import type NodeType from "./NodeType"; @@ -90,6 +91,14 @@ class OperationNode implements CoreNode { }); return fInputNodeToValues; } + + save(): CoreNodeState { + return { + nodeType: "OPERATION", + value: this.value, + operationId: "", // TODO(sc420): Return operation ID + }; + } } export default OperationNode; diff --git a/interactive-computational-graph/src/core/VariableNode.ts b/interactive-computational-graph/src/core/VariableNode.ts index 608940a..01afa48 100644 --- a/interactive-computational-graph/src/core/VariableNode.ts +++ b/interactive-computational-graph/src/core/VariableNode.ts @@ -1,3 +1,4 @@ +import type CoreNodeState from "../states/CoreNodeState"; import type CoreNode from "./CoreNode"; import NodeRelationship from "./NodeRelationship"; import type NodeType from "./NodeType"; @@ -50,6 +51,13 @@ class VariableNode implements CoreNode { getRelationship(): NodeRelationship { return this.nodeRelationship; } + + save(): CoreNodeState { + return { + nodeType: "VARIABLE", + value: this.value, + }; + } } export default VariableNode; diff --git a/interactive-computational-graph/src/features/CoreGraphAdapter.ts b/interactive-computational-graph/src/features/CoreGraphAdapter.ts index 6496952..8207a77 100644 --- a/interactive-computational-graph/src/features/CoreGraphAdapter.ts +++ b/interactive-computational-graph/src/features/CoreGraphAdapter.ts @@ -548,6 +548,7 @@ cycle`; save(): CoreGraphAdapterState { return { + coreGraphState: this.graph.save(), nodeIdToNames: Object.fromEntries(this.nodeIdToNames), dummyInputNodeIdToNodeIds: Object.fromEntries( this.dummyInputNodeIdToNodeIds, @@ -556,6 +557,7 @@ cycle`; } load(state: CoreGraphAdapterState): void { + this.graph.load(state.coreGraphState); this.nodeIdToNames = new Map(Object.entries(state.nodeIdToNames)); this.dummyInputNodeIdToNodeIds = new Map( Object.entries(state.dummyInputNodeIdToNodeIds), diff --git a/interactive-computational-graph/src/states/CoreGraphAdapterState.ts b/interactive-computational-graph/src/states/CoreGraphAdapterState.ts index 027adaf..af53e9a 100644 --- a/interactive-computational-graph/src/states/CoreGraphAdapterState.ts +++ b/interactive-computational-graph/src/states/CoreGraphAdapterState.ts @@ -1,8 +1,7 @@ -// import type CoreGraphState from "./CoreGraphState"; +import type CoreGraphState from "./CoreGraphState"; interface CoreGraphAdapterState { - // TODO(sc420): Uncomment - // coreGraphState: CoreGraphState; + coreGraphState: CoreGraphState; nodeIdToNames: Record; dummyInputNodeIdToNodeIds: Record; } diff --git a/interactive-computational-graph/src/states/CoreGraphState.ts b/interactive-computational-graph/src/states/CoreGraphState.ts index b4e0421..f679985 100644 --- a/interactive-computational-graph/src/states/CoreGraphState.ts +++ b/interactive-computational-graph/src/states/CoreGraphState.ts @@ -1,10 +1,11 @@ import type DifferentiationMode from "../core/DifferentiationMode"; +import type CoreNodeState from "./CoreNodeState"; interface CoreGraphState { - // TODO(sc420): Add nodeIdToNodes state + nodeIdToNodes: Record; differentiationMode: DifferentiationMode; targetNodeId: string | null; - nodeIdToDerivatives: Map; + nodeIdToDerivatives: Record; } export default CoreGraphState; diff --git a/interactive-computational-graph/src/states/CoreNodeState.ts b/interactive-computational-graph/src/states/CoreNodeState.ts new file mode 100644 index 0000000..a883526 --- /dev/null +++ b/interactive-computational-graph/src/states/CoreNodeState.ts @@ -0,0 +1,20 @@ +// TODO(sc420): Add nodeRelationship states + +interface ConstantNodeState { + nodeType: "CONSTANT"; +} + +interface VariableNodeState { + nodeType: "VARIABLE"; + value: string; +} + +interface OperationNodeState { + nodeType: "OPERATION"; + value: string; + operationId: string; +} + +type CoreNodeState = ConstantNodeState | VariableNodeState | OperationNodeState; + +export default CoreNodeState;