Skip to content

Commit

Permalink
test save/load APIs
Browse files Browse the repository at this point in the history
  • Loading branch information
sc420 committed Mar 3, 2024
1 parent e5f2eba commit 2eb8486
Show file tree
Hide file tree
Showing 7 changed files with 189 additions and 4 deletions.
12 changes: 12 additions & 0 deletions interactive-computational-graph/src/core/ConstantNode.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -36,3 +36,15 @@ test("can get relationship", () => {
const constNode = new ConstantNode("c1");
expect(constNode.getRelationship()).toBeInstanceOf(NodeRelationship);
});

test("can save the state", () => {
const constNode = new ConstantNode("c1");
constNode.setValue("3");
const state = constNode.save();
expect(state).toEqual(
expect.objectContaining({
nodeType: "CONSTANT",
value: "3",
}),
);
});
63 changes: 63 additions & 0 deletions interactive-computational-graph/src/core/Graph.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -968,6 +968,69 @@ describe("explaining chain rule", () => {
});
});

describe("saving graph state", () => {
test("should save the empty graph", () => {
const graph = new Graph();

const state = graph.save();
expect(state).toEqual({
nodeIdToNodes: {},
differentiationMode: "REVERSE",
targetNodeId: null,
});
});

test("should save the small graph", () => {
const graph = buildSmallGraph();
graph.setDifferentiationMode("FORWARD");
graph.setTargetNode("sum1");

const state = graph.save();
expect(state).toEqual({
nodeIdToNodes: {
v1: {
nodeType: "VARIABLE",
value: "2",
relationship: {
inputPortIdToNodeIds: {},
},
},
v2: {
nodeType: "VARIABLE",
value: "1",
relationship: {
inputPortIdToNodeIds: {},
},
},
sum1: {
nodeType: "OPERATION",
operationId: "sum",
relationship: {
inputPortIdToNodeIds: {
x_i: ["v1", "v2"],
},
},
},
},
differentiationMode: "FORWARD",
targetNodeId: "sum1",
});
});
});

describe("clearing graph state", () => {
test("should clear the small graph", () => {
const graph = buildSmallGraph();
graph.updateFValues();
graph.setDifferentiationMode("REVERSE");
graph.setTargetNode("v1");
graph.clear();

expect(graph.getNodes()).toHaveLength(0);
expect(parseFloat(graph.getNodeDerivative("v1"))).toBeCloseTo(0);
});
});

function buildSmallGraph(): Graph {
const graph = new Graph();

Expand Down
23 changes: 19 additions & 4 deletions interactive-computational-graph/src/core/NodeRelationship.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -239,14 +239,29 @@ describe("output behavior", () => {

expect(nodeRelationship.getOutputNodes()).toEqual([]);
});
});

function getDummyOperationNode(id: string): OperationNode {
const op = new Operation("", "");
return new OperationNode(id, [new Port("in1", false)], "dummy", op);
}
describe("state behavior", () => {
test("should save the relationship", () => {
const nodeRelationship = buildTwoPortsNodeRelationship();
const opNode1 = getDummyOperationNode("op1");
nodeRelationship.addInputNodeByPort("a", opNode1);

expect(nodeRelationship.save()).toEqual({
inputPortIdToNodeIds: {
a: ["op1"],
b: [],
},
});
});
});

function buildTwoPortsNodeRelationship(): NodeRelationship {
const ports = [new Port("a", false), new Port("b", true)];
return new NodeRelationship(ports);
}

function getDummyOperationNode(id: string): OperationNode {
const op = new Operation("", "");
return new OperationNode(id, [new Port("in1", false)], "dummy", op);
}
11 changes: 11 additions & 0 deletions interactive-computational-graph/src/core/OperationNode.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,17 @@ test("can get relationship", () => {
expect(sumNode.getRelationship()).toBeInstanceOf(NodeRelationship);
});

test("can save the state", () => {
const sumNode = buildSumNode();
const state = sumNode.save();
expect(state).toEqual(
expect.objectContaining({
nodeType: "OPERATION",
operationId: "sum",
}),
);
});

function buildSumNode(): OperationNode {
const ports: Port[] = [new Port("x_i", true)];
const operation = new Operation(SUM_F_CODE, SUM_DFDX_CODE);
Expand Down
12 changes: 12 additions & 0 deletions interactive-computational-graph/src/core/VariableNode.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -28,3 +28,15 @@ test("can get relationship", () => {
const varNode = new VariableNode("v1");
expect(varNode.getRelationship()).toBeInstanceOf(NodeRelationship);
});

test("can save the state", () => {
const varNode = new VariableNode("v1");
varNode.setValue("3");
const state = varNode.save();
expect(state).toEqual(
expect.objectContaining({
nodeType: "VARIABLE",
value: "3",
}),
);
});
Original file line number Diff line number Diff line change
Expand Up @@ -562,6 +562,44 @@ describe("behavior", () => {
adapter.getNodeNameById("1");
}).toThrow();
});

test("should save the state", () => {
const adapter = new CoreGraphAdapter();

addConstantNode(adapter, "1", "c_1");
addAddNode(adapter, "2", "a_1");
addConnection(adapter, "1", "2", "a");

const state = adapter.save();
expect(state).toEqual(
expect.objectContaining({
nodeIdToNames: {
"1": "c_1",
"2": "a_1",
"dummy-input-node-2-a": "a_1.a",
"dummy-input-node-2-b": "a_1.b",
},
dummyInputNodeIdToNodeIds: {
"dummy-input-node-2-a": "2",
"dummy-input-node-2-b": "2",
},
}),
);
});

test("should load from the saved state correctly", () => {
const adapter = new CoreGraphAdapter();

addConstantNode(adapter, "1", "c_1");
addAddNode(adapter, "2", "a_1");
addConnection(adapter, "1", "2", "a");

const state = adapter.save();
adapter.load(state, [featureOperation]);

expect(adapter.getNodeNameById("1")).toBe("c_1");
expect(adapter.getNodeNameById("2")).toBe("a_1");
});
});

const featureOperation: FeatureOperation = {
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import Operation from "../core/Operation";
import Port from "../core/Port";
import type NodeNameBuilderState from "../states/NodeNameBuilderState";
import { ADD_DFDX_CODE, ADD_F_CODE } from "./BuiltInCode";
import type FeatureNodeType from "./FeatureNodeType";
import type FeatureOperation from "./FeatureOperation";
Expand All @@ -20,6 +21,39 @@ test("should build names with interleaving node types", () => {
expect(builder.buildName(getVariableNodeType(), null)).toBe("v_4");
});

test("should save the state", () => {
const builder = new NodeNameBuilder();

builder.buildName(getConstantNodeType(), null);
builder.buildName(getAddNodeType(), getAddOperation());
builder.buildName(getVariableNodeType(), null);

expect(builder.save()).toEqual({
constantCounter: 2,
variableCounter: 2,
operationIdToCounter: {
add: 2,
},
});
});

test("should have correct state after loading", () => {
const builder = new NodeNameBuilder();
const state: NodeNameBuilderState = {
constantCounter: 2,
variableCounter: 2,
operationIdToCounter: {
add: 2,
},
};

builder.load(state);

expect(builder.buildName(getConstantNodeType(), null)).toBe("c_2");
expect(builder.buildName(getAddNodeType(), getAddOperation())).toBe("a_2");
expect(builder.buildName(getVariableNodeType(), null)).toBe("v_2");
});

const getConstantNodeType = (): FeatureNodeType => {
return {
nodeType: "CONSTANT",
Expand Down

0 comments on commit 2eb8486

Please sign in to comment.