diff --git a/interactive-computational-graph/src/core/ConstantNode.test.ts b/interactive-computational-graph/src/core/ConstantNode.test.ts index 3ce44aa..ad0f3fb 100644 --- a/interactive-computational-graph/src/core/ConstantNode.test.ts +++ b/interactive-computational-graph/src/core/ConstantNode.test.ts @@ -36,3 +36,15 @@ test("can get relationship", () => { const constNode = new ConstantNode("c1"); expect(constNode.getRelationship()).toBeInstanceOf(NodeRelationship); }); + +test("can save the state", () => { + const constNode = new ConstantNode("c1"); + constNode.setValue("3"); + const state = constNode.save(); + expect(state).toEqual( + expect.objectContaining({ + nodeType: "CONSTANT", + value: "3", + }), + ); +}); diff --git a/interactive-computational-graph/src/core/Graph.test.ts b/interactive-computational-graph/src/core/Graph.test.ts index 3ab2e8b..f319314 100644 --- a/interactive-computational-graph/src/core/Graph.test.ts +++ b/interactive-computational-graph/src/core/Graph.test.ts @@ -968,6 +968,69 @@ describe("explaining chain rule", () => { }); }); +describe("saving graph state", () => { + test("should save the empty graph", () => { + const graph = new Graph(); + + const state = graph.save(); + expect(state).toEqual({ + nodeIdToNodes: {}, + differentiationMode: "REVERSE", + targetNodeId: null, + }); + }); + + test("should save the small graph", () => { + const graph = buildSmallGraph(); + graph.setDifferentiationMode("FORWARD"); + graph.setTargetNode("sum1"); + + const state = graph.save(); + expect(state).toEqual({ + nodeIdToNodes: { + v1: { + nodeType: "VARIABLE", + value: "2", + relationship: { + inputPortIdToNodeIds: {}, + }, + }, + v2: { + nodeType: "VARIABLE", + value: "1", + relationship: { + inputPortIdToNodeIds: {}, + }, + }, + sum1: { + nodeType: "OPERATION", + operationId: "sum", + relationship: { + inputPortIdToNodeIds: { + x_i: ["v1", "v2"], + }, + }, + }, + }, + differentiationMode: "FORWARD", + targetNodeId: "sum1", + }); + }); +}); + +describe("clearing graph state", () => { + test("should clear the small graph", () => { + const graph = buildSmallGraph(); + graph.updateFValues(); + graph.setDifferentiationMode("REVERSE"); + graph.setTargetNode("v1"); + graph.clear(); + + expect(graph.getNodes()).toHaveLength(0); + expect(parseFloat(graph.getNodeDerivative("v1"))).toBeCloseTo(0); + }); +}); + function buildSmallGraph(): Graph { const graph = new Graph(); diff --git a/interactive-computational-graph/src/core/NodeRelationship.test.ts b/interactive-computational-graph/src/core/NodeRelationship.test.ts index 5543afc..7c4d0af 100644 --- a/interactive-computational-graph/src/core/NodeRelationship.test.ts +++ b/interactive-computational-graph/src/core/NodeRelationship.test.ts @@ -239,14 +239,29 @@ describe("output behavior", () => { expect(nodeRelationship.getOutputNodes()).toEqual([]); }); +}); - function getDummyOperationNode(id: string): OperationNode { - const op = new Operation("", ""); - return new OperationNode(id, [new Port("in1", false)], "dummy", op); - } +describe("state behavior", () => { + test("should save the relationship", () => { + const nodeRelationship = buildTwoPortsNodeRelationship(); + const opNode1 = getDummyOperationNode("op1"); + nodeRelationship.addInputNodeByPort("a", opNode1); + + expect(nodeRelationship.save()).toEqual({ + inputPortIdToNodeIds: { + a: ["op1"], + b: [], + }, + }); + }); }); function buildTwoPortsNodeRelationship(): NodeRelationship { const ports = [new Port("a", false), new Port("b", true)]; return new NodeRelationship(ports); } + +function getDummyOperationNode(id: string): OperationNode { + const op = new Operation("", ""); + return new OperationNode(id, [new Port("in1", false)], "dummy", op); +} diff --git a/interactive-computational-graph/src/core/OperationNode.test.ts b/interactive-computational-graph/src/core/OperationNode.test.ts index ae9ff18..ef312d2 100644 --- a/interactive-computational-graph/src/core/OperationNode.test.ts +++ b/interactive-computational-graph/src/core/OperationNode.test.ts @@ -57,6 +57,17 @@ test("can get relationship", () => { expect(sumNode.getRelationship()).toBeInstanceOf(NodeRelationship); }); +test("can save the state", () => { + const sumNode = buildSumNode(); + const state = sumNode.save(); + expect(state).toEqual( + expect.objectContaining({ + nodeType: "OPERATION", + operationId: "sum", + }), + ); +}); + function buildSumNode(): OperationNode { const ports: Port[] = [new Port("x_i", true)]; const operation = new Operation(SUM_F_CODE, SUM_DFDX_CODE); diff --git a/interactive-computational-graph/src/core/VariableNode.test.ts b/interactive-computational-graph/src/core/VariableNode.test.ts index c409a5f..12e9944 100644 --- a/interactive-computational-graph/src/core/VariableNode.test.ts +++ b/interactive-computational-graph/src/core/VariableNode.test.ts @@ -28,3 +28,15 @@ test("can get relationship", () => { const varNode = new VariableNode("v1"); expect(varNode.getRelationship()).toBeInstanceOf(NodeRelationship); }); + +test("can save the state", () => { + const varNode = new VariableNode("v1"); + varNode.setValue("3"); + const state = varNode.save(); + expect(state).toEqual( + expect.objectContaining({ + nodeType: "VARIABLE", + value: "3", + }), + ); +}); diff --git a/interactive-computational-graph/src/features/CoreGraphAdapter.test.ts b/interactive-computational-graph/src/features/CoreGraphAdapter.test.ts index b22f4b8..623b07a 100644 --- a/interactive-computational-graph/src/features/CoreGraphAdapter.test.ts +++ b/interactive-computational-graph/src/features/CoreGraphAdapter.test.ts @@ -562,6 +562,44 @@ describe("behavior", () => { adapter.getNodeNameById("1"); }).toThrow(); }); + + test("should save the state", () => { + const adapter = new CoreGraphAdapter(); + + addConstantNode(adapter, "1", "c_1"); + addAddNode(adapter, "2", "a_1"); + addConnection(adapter, "1", "2", "a"); + + const state = adapter.save(); + expect(state).toEqual( + expect.objectContaining({ + nodeIdToNames: { + "1": "c_1", + "2": "a_1", + "dummy-input-node-2-a": "a_1.a", + "dummy-input-node-2-b": "a_1.b", + }, + dummyInputNodeIdToNodeIds: { + "dummy-input-node-2-a": "2", + "dummy-input-node-2-b": "2", + }, + }), + ); + }); + + test("should load from the saved state correctly", () => { + const adapter = new CoreGraphAdapter(); + + addConstantNode(adapter, "1", "c_1"); + addAddNode(adapter, "2", "a_1"); + addConnection(adapter, "1", "2", "a"); + + const state = adapter.save(); + adapter.load(state, [featureOperation]); + + expect(adapter.getNodeNameById("1")).toBe("c_1"); + expect(adapter.getNodeNameById("2")).toBe("a_1"); + }); }); const featureOperation: FeatureOperation = { diff --git a/interactive-computational-graph/src/features/NodeNameBuilder.test.ts b/interactive-computational-graph/src/features/NodeNameBuilder.test.ts index 63d7d90..f571e0e 100644 --- a/interactive-computational-graph/src/features/NodeNameBuilder.test.ts +++ b/interactive-computational-graph/src/features/NodeNameBuilder.test.ts @@ -1,5 +1,6 @@ import Operation from "../core/Operation"; import Port from "../core/Port"; +import type NodeNameBuilderState from "../states/NodeNameBuilderState"; import { ADD_DFDX_CODE, ADD_F_CODE } from "./BuiltInCode"; import type FeatureNodeType from "./FeatureNodeType"; import type FeatureOperation from "./FeatureOperation"; @@ -20,6 +21,39 @@ test("should build names with interleaving node types", () => { expect(builder.buildName(getVariableNodeType(), null)).toBe("v_4"); }); +test("should save the state", () => { + const builder = new NodeNameBuilder(); + + builder.buildName(getConstantNodeType(), null); + builder.buildName(getAddNodeType(), getAddOperation()); + builder.buildName(getVariableNodeType(), null); + + expect(builder.save()).toEqual({ + constantCounter: 2, + variableCounter: 2, + operationIdToCounter: { + add: 2, + }, + }); +}); + +test("should have correct state after loading", () => { + const builder = new NodeNameBuilder(); + const state: NodeNameBuilderState = { + constantCounter: 2, + variableCounter: 2, + operationIdToCounter: { + add: 2, + }, + }; + + builder.load(state); + + expect(builder.buildName(getConstantNodeType(), null)).toBe("c_2"); + expect(builder.buildName(getAddNodeType(), getAddOperation())).toBe("a_2"); + expect(builder.buildName(getVariableNodeType(), null)).toBe("v_2"); +}); + const getConstantNodeType = (): FeatureNodeType => { return { nodeType: "CONSTANT",