Skip to content

Commit

Permalink
save/load core graph adapter state
Browse files Browse the repository at this point in the history
  • Loading branch information
sc420 committed Feb 15, 2024
1 parent f2c8dfc commit 9604b86
Show file tree
Hide file tree
Showing 5 changed files with 44 additions and 16 deletions.
16 changes: 11 additions & 5 deletions interactive-computational-graph/src/components/GraphContainer.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -492,16 +492,22 @@ const GraphContainer: FunctionComponent<GraphContainerProps> = ({
);

const handleSave = useCallback((): GraphContainerState => {
const coreGraphAdapterState = coreGraphAdapter.save();
return {
coreGraphAdapterState,
isReverseMode,
derivativeTarget,
};
}, [derivativeTarget, isReverseMode]);
}, [coreGraphAdapter, derivativeTarget, isReverseMode]);

const handleLoad = useCallback((graphContainerState: GraphContainerState) => {
setReverseMode(graphContainerState.isReverseMode);
setDerivativeTarget(graphContainerState.derivativeTarget);
}, []);
const handleLoad = useCallback(
(graphContainerState: GraphContainerState) => {
coreGraphAdapter.load(graphContainerState.coreGraphAdapterState);
setReverseMode(graphContainerState.isReverseMode);
setDerivativeTarget(graphContainerState.derivativeTarget);
},
[coreGraphAdapter],
);

const handleReverseModeChange = useCallback(
(isReverseMode: boolean) => {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,12 @@ test("should trigger the event when clicking the load button", async () => {

const contents = `\
{
"isReverseMode": true,
"derivativeTarget": null
"coreGraphAdapterState": {
"nodeIdToNames": {},
"dummyInputNodeIdToNodeIds": {}
},
"isReverseMode": true,
"derivativeTarget": null
}
`;
const file = new File([contents], "graph.json", { type: "text/plain" });
Expand Down
23 changes: 20 additions & 3 deletions interactive-computational-graph/src/features/CoreGraphAdapter.ts
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,13 @@ import type DifferentiationMode from "../core/DifferentiationMode";
import Graph from "../core/Graph";
import OperationNode from "../core/OperationNode";
import VariableNode from "../core/VariableNode";
import type CoreGraphAdapterState from "../states/CoreGraphAdapterState";
import type ExplainDerivativeBuildOptions from "./ExplainDerivativeBuildOptions";
import { buildExplainDerivativeItems } from "./ExplainDerivativeController";
import type ExplainDerivativeData from "./ExplainDerivativeData";
import type ExplainDerivativeType from "./ExplainDerivativeType";
import type FeatureNodeType from "./FeatureNodeType";
import type FeatureOperation from "./FeatureOperation";
import type ExplainDerivativeBuildOptions from "./ExplainDerivativeBuildOptions";

type ConnectionAddedCallback = (connection: Connection) => void;

Expand Down Expand Up @@ -51,8 +52,8 @@ type ExplainDerivativeDataUpdatedCallback = (
class CoreGraphAdapter {
private readonly graph = new Graph();

private readonly nodeIdToNames = new Map<string, string>();
private readonly dummyInputNodeIdToNodeIds = new Map<string, string>();
private nodeIdToNames = new Map<string, string>();
private dummyInputNodeIdToNodeIds = new Map<string, string>();
private selectedNodeIds: string[] = [];

private connectionAddedCallbacks: ConnectionAddedCallback[] = [];
Expand Down Expand Up @@ -545,6 +546,22 @@ cycle`;
return this.graph.getNodeValue(nodeId);
}

save(): CoreGraphAdapterState {
return {
nodeIdToNames: Object.fromEntries(this.nodeIdToNames),
dummyInputNodeIdToNodeIds: Object.fromEntries(
this.dummyInputNodeIdToNodeIds,
),
};
}

load(state: CoreGraphAdapterState): void {
this.nodeIdToNames = new Map(Object.entries(state.nodeIdToNames));
this.dummyInputNodeIdToNodeIds = new Map(
Object.entries(state.dummyInputNodeIdToNodeIds),
);
}

private connectDummyInputNode(nodeId: string, portId: string): void {
const dummyInputNodeId = this.getDummyInputNodeId(nodeId, portId);
this.graph.connect(dummyInputNodeId, nodeId, portId);
Expand Down
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
import type CoreGraphState from "./CoreGraphState";
// import type CoreGraphState from "./CoreGraphState";

interface CoreGraphAdapterState {
coreGraphState: CoreGraphState;
nodeIdToNames: Map<string, string>;
dummyInputNodeIdToNodeIds: Map<string, string>;
// TODO(sc420): Uncomment
// coreGraphState: CoreGraphState;
nodeIdToNames: Record<string, string>;
dummyInputNodeIdToNodeIds: Record<string, string>;
}

export default CoreGraphAdapterState;
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
// import type FeatureOperation from "../features/FeatureOperation";
// import type CoreGraphAdapterState from "./CoreGraphAdapterState";
import type CoreGraphAdapterState from "./CoreGraphAdapterState";
// import type NodeNameBuilderState from "./NodeNameBuilderState";

interface GraphContainerState {
// TODO(sc420): Uncomment other states
// Core graph
// coreGraphAdapterState: CoreGraphAdapterState;
coreGraphAdapterState: CoreGraphAdapterState;

// Graph state
isReverseMode: boolean;
Expand Down

0 comments on commit 9604b86

Please sign in to comment.