From f1dd8d28a197cde3cd06f8c5c38257acad7c3c6f Mon Sep 17 00:00:00 2001 From: Gordon Date: Sun, 3 Mar 2024 17:49:20 -0500 Subject: [PATCH 1/6] Add utility to print KB graphs Also don't apply edge rules if a property sets a value explicitly via constraints --- cmd/kb/dot.go | 138 +++++++++++ cmd/kb/main.go | 229 ++++++++++++++++++ go.mod | 1 + go.sum | 2 + pkg/dot/attributes.go | 8 +- pkg/engine/edge_targets.go | 3 +- pkg/engine/engine.go | 3 +- .../operational_eval/vertex_property.go | 46 +++- .../operational_eval/vertex_property_test.go | 115 ++++++++- pkg/engine/path_selection/path_selection.go | 22 +- pkg/engine/solution_context.go | 3 +- pkg/graph_addons/reverse.go | 12 +- pkg/graph_addons/topology.go | 137 +++++++++++ pkg/infra/cli.go | 3 +- .../iac/templates/aws/ecs_service/factory.ts | 9 +- .../aws/ecs_task_definition/factory.ts | 6 +- .../iac/templates/aws/target_group/factory.ts | 3 +- pkg/knowledgebase/kb.go | 4 + pkg/knowledgebase/resource_template.go | 10 + .../edges/ecs_task_definition-iam_role.yaml | 20 ++ .../aws/edges/target_group-ecs_service.yaml | 2 +- pkg/templates/aws/resources/ecs_service.yaml | 10 +- 22 files changed, 736 insertions(+), 50 deletions(-) create mode 100644 cmd/kb/dot.go create mode 100644 cmd/kb/main.go create mode 100644 pkg/graph_addons/topology.go diff --git a/cmd/kb/dot.go b/cmd/kb/dot.go new file mode 100644 index 000000000..064815f88 --- /dev/null +++ b/cmd/kb/dot.go @@ -0,0 +1,138 @@ +package main + +import ( + "bytes" + "errors" + "fmt" + "io" + "os" + "path/filepath" + "sort" + + "github.com/dominikbraun/graph" + "github.com/klothoplatform/klotho/pkg/dot" + "github.com/klothoplatform/klotho/pkg/graph_addons" + "github.com/klothoplatform/klotho/pkg/knowledgebase" +) + +func dotAttributes(tmpl *knowledgebase.ResourceTemplate, props graph.VertexProperties) map[string]string { + a := make(map[string]string) + for k, v := range props.Attributes { + if k != "rank" { + a[k] = v + } + } + a["label"] = tmpl.QualifiedTypeName + a["shape"] = "box" + return a +} + +func dotEdgeAttributes(e *knowledgebase.EdgeTemplate, props graph.EdgeProperties) map[string]string { + a := make(map[string]string) + for k, v := range props.Attributes { + a[k] = v + } + if e.DeploymentOrderReversed { + a["style"] = "dashed" + } + a["edgetooltip"] = fmt.Sprintf("%s -> %s", e.Source, e.Target) + return a +} + +func KbToDot(g graph.Graph[string, *knowledgebase.ResourceTemplate], out io.Writer) error { + ids, err := graph_addons.TopologicalSort(g, func(a, b string) bool { + return a < b + }) + if err != nil { + return err + } + var errs error + printf := func(s string, args ...any) { + _, err := fmt.Fprintf(out, s, args...) + errs = errors.Join(errs, err) + } + printf(`digraph { + rankdir = TB +`) + for _, id := range ids { + t, props, err := g.VertexWithProperties(id) + if err != nil { + errs = errors.Join(errs, err) + continue + } + if rank, ok := props.Attributes["rank"]; ok { + printf(" { rank = %s; %q%s; }\n", rank, id, dot.AttributesToString(dotAttributes(t, props))) + } else { + printf(" %q%s;\n", t.QualifiedTypeName, dot.AttributesToString(dotAttributes(t, props))) + } + } + + topoIndex := func(id string) int { + for i, id2 := range ids { + if id2 == id { + return i + } + } + return -1 + } + edges, err := g.Edges() + if err != nil { + return err + } + sort.Slice(edges, func(i, j int) bool { + ti, tj := topoIndex(edges[i].Source), topoIndex(edges[j].Source) + if ti != tj { + return ti < tj + } + ti, tj = topoIndex(edges[i].Target), topoIndex(edges[j].Target) + return ti < tj + }) + for _, e := range edges { + et, ok := e.Properties.Data.(*knowledgebase.EdgeTemplate) + if !ok { + errs = errors.Join(errs, fmt.Errorf("edge %q -> %q has no EdgeTemplate", e.Source, e.Target)) + continue + } + printf(" %q -> %q%s\n", e.Source, e.Target, dot.AttributesToString(dotEdgeAttributes(et, e.Properties))) + } + printf("}\n") + return errs +} + +func KbToSVG(kb knowledgebase.TemplateKB, prefix string) error { + if debugDir := os.Getenv("KLOTHO_DEBUG_DIR"); debugDir != "" { + prefix = filepath.Join(debugDir, prefix) + } + f, err := os.Create(prefix + ".gv") + if err != nil { + return err + } + defer f.Close() + + hasGraph, ok := kb.(interface { + Graph() graph.Graph[string, *knowledgebase.ResourceTemplate] + }) + if !ok { + return fmt.Errorf("knowledgebase does not have a graph") + } + g := hasGraph.Graph() + + dotContent := new(bytes.Buffer) + err = KbToDot(g, io.MultiWriter(f, dotContent)) + if err != nil { + return fmt.Errorf("could not render graph to file %s: %v", prefix+".gv", err) + } + + svgContent, err := dot.ExecPan(bytes.NewReader(dotContent.Bytes())) + if err != nil { + return fmt.Errorf("could not run 'dot' for %s: %v", prefix+".gv", err) + } + + svgFile, err := os.Create(prefix + ".gv.svg") + if err != nil { + return fmt.Errorf("could not create file %s: %v", prefix+".gv.svg", err) + } + defer svgFile.Close() + _, err = fmt.Fprint(svgFile, svgContent) + return err +} diff --git a/cmd/kb/main.go b/cmd/kb/main.go new file mode 100644 index 000000000..f844f4f85 --- /dev/null +++ b/cmd/kb/main.go @@ -0,0 +1,229 @@ +package main + +import ( + "errors" + "fmt" + + "github.com/alecthomas/kong" + "github.com/dominikbraun/graph" + "github.com/klothoplatform/klotho/pkg/construct" + "github.com/klothoplatform/klotho/pkg/engine" + "github.com/klothoplatform/klotho/pkg/engine/path_selection" + "github.com/klothoplatform/klotho/pkg/graph_addons" + "github.com/klothoplatform/klotho/pkg/knowledgebase" + "github.com/klothoplatform/klotho/pkg/knowledgebase/reader" + "github.com/klothoplatform/klotho/pkg/logging" + "github.com/klothoplatform/klotho/pkg/templates" + "go.uber.org/zap" + "go.uber.org/zap/zapcore" +) + +type Args struct { + Verbose bool `short:"v" help:"Enable verbose mode"` + Distance int `short:"d" help:"Distance from single type to display" default:"2"` + Classification string `short:"c" help:"Classification to filter for (like path expansion)"` + Source string `arg:"" optional:""` + Target string `arg:"" optional:""` +} + +func main() { + var args Args + ctx := kong.Parse(&args) + + logOpts := logging.LogOpts{ + Verbose: args.Verbose, + CategoryLogsDir: "", + DefaultLevels: map[string]zapcore.Level{ + "lsp": zap.WarnLevel, + "lsp/pylsp": zap.WarnLevel, + }, + Encoding: "pretty_console", + } + + zap.ReplaceGlobals(logOpts.NewLogger()) + defer zap.L().Sync() //nolint:errcheck + + if err := args.Run(ctx); err != nil { + panic(err) + } +} + +func (args Args) Run(ctx *kong.Context) error { + kb, err := reader.NewKBFromFs(templates.ResourceTemplates, templates.EdgeTemplates, templates.Models) + if err != nil { + return err + } + + switch { + case args.Source == "" && args.Target == "": + break + case args.Target == "": + if args.Classification != "" { + return fmt.Errorf("classification can only be used with two types (for now)") + } + kb = args.filterSingleKb(kb) + default: + if args.Classification != "" { + var edge construct.SimpleEdge + if err := edge.Source.UnmarshalText([]byte(args.Source)); err != nil { + return fmt.Errorf("could not parse source: %w", err) + } + edge.Source.Name = "source" + if err := edge.Target.UnmarshalText([]byte(args.Target)); err != nil { + return fmt.Errorf("could not parse target: %w", err) + } + edge.Target.Name = "target" + g, err := path_selection.BuildPathSelectionGraph(edge, kb, args.Classification) + if err != nil { + return err + } + return engine.GraphToSVG(kb, g, "kb_path_selection") + } + kb = args.filterPathKB(kb) + } + + return KbToSVG(kb, "knowledgebase") +} + +func (args Args) filterPathKB(kb *knowledgebase.KnowledgeBase) *knowledgebase.KnowledgeBase { + var source, target construct.ResourceId + if err := source.UnmarshalText([]byte(args.Source)); err != nil { + panic(fmt.Errorf("could not parse source: %w", err)) + } + if err := target.UnmarshalText([]byte(args.Target)); err != nil { + panic(fmt.Errorf("could not parse target: %w", err)) + } + + paths, err := kb.AllPaths(source, target) + if err != nil { + panic(err) + } + shortestPath, err := graph.ShortestPath(kb.Graph(), args.Source, args.Target) + if err != nil { + panic(err) + } + + filteredKb := knowledgebase.NewKB() + g := filteredKb.Graph() + addV := func(t *knowledgebase.ResourceTemplate) (err error) { + if t.QualifiedTypeName == args.Source || t.QualifiedTypeName == args.Target { + attribs := map[string]string{ + "color": "green", + "penwidth": "2", + } + if t.QualifiedTypeName == args.Source { + attribs["rank"] = "source" + } else { + attribs["rank"] = "sink" + } + err = g.AddVertex(t, graph.VertexAttributes(attribs)) + } else { + err = g.AddVertex(t) + } + if errors.Is(err, graph.ErrVertexAlreadyExists) { + return nil + } + return err + } + addE := func(path []*knowledgebase.ResourceTemplate, t1, t2 *knowledgebase.ResourceTemplate) error { + edge, err := kb.Graph().Edge(t1.QualifiedTypeName, t2.QualifiedTypeName) + if err != nil { + return err + } + err = g.AddEdge(t1.QualifiedTypeName, t2.QualifiedTypeName, func(ep *graph.EdgeProperties) { + *ep = edge.Properties + if len(path) == len(shortestPath) { + ep.Attributes["color"] = "green" + ep.Attributes["penwidth"] = "2" + } + }) + if errors.Is(err, graph.ErrEdgeAlreadyExists) { + return nil + } + return err + } + var errs error + for _, path := range paths { + if len(path) > len(shortestPath)*2 { + continue + } + errs = errors.Join(errs, addV(path[0])) + for i, t := range path[1:] { + errs = errors.Join( + errs, + addV(t), + addE(path, path[i], t), + ) + } + } + return filteredKb +} + +func (args Args) filterSingleKb(kb *knowledgebase.KnowledgeBase) *knowledgebase.KnowledgeBase { + filteredKb := knowledgebase.NewKB() + g := filteredKb.Graph() + + r, props, err := kb.Graph().VertexWithProperties(args.Source) + if err != nil { + panic(err) + } + err = g.AddVertex(r, func(vp *graph.VertexProperties) { + *vp = props + vp.Attributes["color"] = "green" + vp.Attributes["penwidth"] = "2" + }) + if err != nil { + panic(err) + } + + addV := func(s string) (err error) { + t, err := kb.Graph().Vertex(s) + if err != nil { + return err + } + err = g.AddVertex(t) + if errors.Is(err, graph.ErrVertexAlreadyExists) { + return nil + } + return err + } + walkFunc := func(up bool) func(p graph_addons.Path[string], nerr error) error { + edge := func(a, b string) (graph.Edge[*knowledgebase.ResourceTemplate], error) { + if up { + a, b = b, a + } + return kb.Graph().Edge(a, b) + } + + return func(p graph_addons.Path[string], nerr error) error { + last := p[len(p)-1] + if err := addV(last); err != nil { + return err + } + edge, err := edge(p[len(p)-2], last) + if err != nil { + return err + } + err = g.AddEdge(p[len(p)-2], last, func(ep *graph.EdgeProperties) { + *ep = edge.Properties + }) + if err != nil && !errors.Is(err, graph.ErrEdgeAlreadyExists) { + return err + } + if len(p) >= args.Distance { + return graph_addons.SkipPath + } + return nil + } + } + + err = errors.Join( + graph_addons.WalkUp(kb.Graph(), args.Source, walkFunc(true)), + graph_addons.WalkDown(kb.Graph(), args.Source, walkFunc(false)), + ) + if err != nil { + panic(err) + } + + return filteredKb +} diff --git a/go.mod b/go.mod index a1a91652a..d1c0dc8b3 100644 --- a/go.mod +++ b/go.mod @@ -62,6 +62,7 @@ require ( github.com/Code-Hex/dd v1.1.0 // indirect github.com/Masterminds/goutils v1.1.1 // indirect github.com/Masterminds/semver/v3 v3.2.0 // indirect + github.com/alecthomas/kong v0.8.1 // indirect github.com/go-jose/go-jose/v3 v3.0.1 // indirect github.com/gojek/valkyrie v0.0.0-20180215180059-6aee720afcdf // indirect github.com/golang/protobuf v1.5.3 // indirect diff --git a/go.sum b/go.sum index b85140b67..a7271d9c7 100644 --- a/go.sum +++ b/go.sum @@ -8,6 +8,8 @@ github.com/Masterminds/semver/v3 v3.2.0/go.mod h1:qvl/7zhW3nngYb5+80sSMF+FG2BjYr github.com/Masterminds/sprig/v3 v3.2.3 h1:eL2fZNezLomi0uOLqjQoN6BfsDD+fyLtgbJMAj9n6YA= github.com/Masterminds/sprig/v3 v3.2.3/go.mod h1:rXcFaZ2zZbLRJv/xSysmlgIM1u11eBaRMhvYXJNkGuM= github.com/afex/hystrix-go v0.0.0-20180502004556-fa1af6a1f4f5/go.mod h1:SkGFH1ia65gfNATL8TAiHDNxPzPdmEL5uirI2Uyuz6c= +github.com/alecthomas/kong v0.8.1 h1:acZdn3m4lLRobeh3Zi2S2EpnXTd1mOL6U7xVml+vfkY= +github.com/alecthomas/kong v0.8.1/go.mod h1:n1iCIO2xS46oE8ZfYCNDqdR0b0wZNrXAIAqro/2132U= github.com/alitto/pond v1.8.3 h1:ydIqygCLVPqIX/USe5EaV/aSRXTRXDEI9JwuDdu+/xs= github.com/alitto/pond v1.8.3/go.mod h1:CmvIIGd5jKLasGI3D87qDkQxjzChdKMmnXMg3fG6M6Q= github.com/benbjohnson/clock v1.1.0 h1:Q92kusRqC1XV2MjkWETPvjJVqKetz1OzxZB7mHJLju8= diff --git a/pkg/dot/attributes.go b/pkg/dot/attributes.go index 860730c9c..ce13bfed1 100644 --- a/pkg/dot/attributes.go +++ b/pkg/dot/attributes.go @@ -18,8 +18,12 @@ func AttributesToString(attribs map[string]string) string { var list []string for _, k := range keys { v := attribs[k] - v = strings.ReplaceAll(v, `"`, `\"`) - list = append(list, fmt.Sprintf(`%s="%s"`, k, v)) + if len(v) > 1 && v[0] == '<' && v[len(v)-1] == '>' { + list = append(list, fmt.Sprintf(`%s=%s`, k, v)) + } else { + v = strings.ReplaceAll(v, `"`, `\"`) + list = append(list, fmt.Sprintf(`%s="%s"`, k, v)) + } } return " [" + strings.Join(list, ", ") + "]" } diff --git a/pkg/engine/edge_targets.go b/pkg/engine/edge_targets.go index d2727e8da..aef99019b 100644 --- a/pkg/engine/edge_targets.go +++ b/pkg/engine/edge_targets.go @@ -6,6 +6,7 @@ import ( "sync" "time" + "github.com/klothoplatform/klotho/pkg/engine/constraints" "github.com/klothoplatform/klotho/pkg/engine/solution_context" "github.com/alitto/pond" @@ -159,7 +160,7 @@ func (e *Engine) GetValidEdgeTargets(context *GetPossibleEdgesContext) (map[stri if err != nil { return nil, err } - solutionCtx := NewSolutionContext(e.Kb, "") + solutionCtx := NewSolutionContext(e.Kb, "", &constraints.Constraints{}) err = solutionCtx.LoadGraph(inputGraph) if err != nil { return nil, err diff --git a/pkg/engine/engine.go b/pkg/engine/engine.go index 6d2dfa42d..a56128b1b 100644 --- a/pkg/engine/engine.go +++ b/pkg/engine/engine.go @@ -30,8 +30,7 @@ func NewEngine(kb knowledgebase.TemplateKB) *Engine { } func (e *Engine) Run(context *EngineContext) error { - solutionCtx := NewSolutionContext(e.Kb, context.GlobalTag) - solutionCtx.constraints = &context.Constraints + solutionCtx := NewSolutionContext(e.Kb, context.GlobalTag, &context.Constraints) err := solutionCtx.LoadGraph(context.InitialState) if err != nil { return err diff --git a/pkg/engine/operational_eval/vertex_property.go b/pkg/engine/operational_eval/vertex_property.go index 2ba55b4ba..5ca6e6815 100644 --- a/pkg/engine/operational_eval/vertex_property.go +++ b/pkg/engine/operational_eval/vertex_property.go @@ -52,14 +52,16 @@ func (prop *propertyVertex) Dependencies(eval *Evaluator, propCtx dependencyCapt } } - for edge, rule := range prop.EdgeRules { - edgeData := knowledgebase.DynamicValueData{ - Resource: prop.Ref.Resource, - Edge: &construct.Edge{Source: edge.Source, Target: edge.Target}, - } - for _, opRule := range rule { - if err := propCtx.ExecuteOpRule(edgeData, opRule); err != nil { - return fmt.Errorf("could not execute edge operational rule for %s: %w", prop.Ref, err) + if prop.shouldEvalEdges(eval.Solution.Constraints().Resources) { + for edge, rule := range prop.EdgeRules { + edgeData := knowledgebase.DynamicValueData{ + Resource: prop.Ref.Resource, + Edge: &construct.Edge{Source: edge.Source, Target: edge.Target}, + } + for _, opRule := range rule { + if err := propCtx.ExecuteOpRule(edgeData, opRule); err != nil { + return fmt.Errorf("could not execute edge operational rule for %s: %w", prop.Ref, err) + } } } } @@ -129,12 +131,14 @@ func (v *propertyVertex) Evaluate(eval *Evaluator) error { // // we still need to run the resource operational rules though, // to make sure dependencies exist where properties have operational rules set - if err := v.evaluateResourceOperational(res, &opCtx); err != nil { + if err := v.evaluateResourceOperational(&opCtx); err != nil { return err } - if err := v.evaluateEdgeOperational(res, &opCtx); err != nil { - return err + if v.shouldEvalEdges(eval.Solution.Constraints().Resources) { + if err := v.evaluateEdgeOperational(res, &opCtx); err != nil { + return err + } } if err := eval.UpdateId(v.Ref.Resource, res.ID); err != nil { @@ -259,7 +263,6 @@ func (v *propertyVertex) evaluateConstraints( } func (v *propertyVertex) evaluateResourceOperational( - res *construct.Resource, opCtx operational_rule.OpRuleHandler, ) error { if v.Template == nil || v.Template.Details().OperationalRule == nil { @@ -274,6 +277,25 @@ func (v *propertyVertex) evaluateResourceOperational( return nil } +// shouldEvalEdges is used as common logic for whether edges should be evaluated and is used in dependency +// calculation and in the Evaluate method. +func (v *propertyVertex) shouldEvalEdges(cs []constraints.ResourceConstraint) bool { + if knowledgebase.IsCollectionProperty(v.Template) { + return true + } + for _, c := range cs { + if c.Target != v.Ref.Resource || c.Property != v.Ref.Property { + continue + } + // NOTE(gg): does operator even matter here? If it's not a collection, + // what does an 'add' mean? Should it allow edges to overwrite? + if c.Operator == constraints.EqualsConstraintOperator { + return false + } + } + return true +} + func (v *propertyVertex) evaluateEdgeOperational( res *construct.Resource, opCtx operational_rule.OpRuleHandler, diff --git a/pkg/engine/operational_eval/vertex_property_test.go b/pkg/engine/operational_eval/vertex_property_test.go index 914c6b0e3..ead3169fc 100644 --- a/pkg/engine/operational_eval/vertex_property_test.go +++ b/pkg/engine/operational_eval/vertex_property_test.go @@ -38,8 +38,7 @@ func Test_propertyVertex_evaluateResourceOperational(t *testing.T) { Value: "test", } type args struct { - v *propertyVertex - res *construct.Resource + v *propertyVertex } tests := []struct { name string @@ -60,7 +59,6 @@ func Test_propertyVertex_evaluateResourceOperational(t *testing.T) { }, }, }, - res: &construct.Resource{ID: construct.ResourceId{Name: "test"}}, }, }, } @@ -70,7 +68,7 @@ func Test_propertyVertex_evaluateResourceOperational(t *testing.T) { ctrl := gomock.NewController(t) opctx := NewMockOpRuleHandler(ctrl) opctx.EXPECT().HandlePropertyRule(*rule).Return(nil).Times(1) - err := tt.args.v.evaluateResourceOperational(tt.args.res, opctx) + err := tt.args.v.evaluateResourceOperational(opctx) if tt.wantErr { assert.Error(err) return @@ -81,6 +79,72 @@ func Test_propertyVertex_evaluateResourceOperational(t *testing.T) { } } +func Test_propertyVertex_shouldEvalEdges(t *testing.T) { + ref := construct.PropertyRef{ + Property: "test", + Resource: construct.ResourceId{Name: "test"}, + } + tests := []struct { + name string + v *propertyVertex + constraints []constraints.ResourceConstraint + want bool + }{ + { + name: "no constraints always evals", + v: &propertyVertex{ + Ref: ref, + Template: &properties.StringProperty{}, + }, + want: true, + }, + { + name: "collection always evals", + v: &propertyVertex{ + Ref: ref, + Template: &properties.ListProperty{}, + }, + want: true, + }, + { + name: "no matching constraints evals", + v: &propertyVertex{ + Ref: ref, + Template: &properties.StringProperty{}, + }, + constraints: []constraints.ResourceConstraint{ + { + Operator: constraints.EqualsConstraintOperator, + Target: construct.ResourceId{Name: "not_test"}, + Property: "test", + }, + }, + want: true, + }, + { + name: "matching constraint does not eval", + v: &propertyVertex{ + Ref: ref, + Template: &properties.StringProperty{}, + }, + constraints: []constraints.ResourceConstraint{ + { + Operator: constraints.EqualsConstraintOperator, + Target: construct.ResourceId{Name: "test"}, + Property: "test", + }, + }, + want: false, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := tt.v.shouldEvalEdges(tt.constraints) + assert.Equal(t, tt.want, got) + }) + } +} + func Test_propertyVertex_evaluateEdgeOperational(t *testing.T) { rule := knowledgebase.OperationalRule{ If: "test", @@ -139,10 +203,11 @@ func Test_propertyVertex_evaluateEdgeOperational(t *testing.T) { func Test_propertyVertex_Dependencies(t *testing.T) { tests := []struct { - name string - v *propertyVertex - mocks func(dcap *MockdependencyCapturer, resource *construct.Resource, path construct.PropertyPath) - wantErr bool + name string + v *propertyVertex + constraints constraints.Constraints + mocks func(dcap *MockdependencyCapturer, resource *construct.Resource, path construct.PropertyPath) + wantErr bool }{ { name: "property vertex with template", @@ -199,6 +264,39 @@ func Test_propertyVertex_Dependencies(t *testing.T) { }).Return(nil) }, }, + { + name: "property vertex with edge rules not considered due to constraints", + v: &propertyVertex{ + Ref: construct.PropertyRef{ + Property: "test", + Resource: construct.ResourceId{Name: "test"}, + }, + Template: &properties.StringProperty{}, + EdgeRules: map[construct.SimpleEdge][]knowledgebase.OperationalRule{ + { + Source: construct.ResourceId{Name: "test"}, + Target: construct.ResourceId{Name: "test2"}, + }: { + { + If: "testE", + }, + }, + }, + }, + constraints: constraints.Constraints{ + Resources: []constraints.ResourceConstraint{ + { + Operator: constraints.EqualsConstraintOperator, + Target: construct.ResourceId{Name: "test"}, + Property: "test", + }, + }, + }, + mocks: func(dcap *MockdependencyCapturer, resource *construct.Resource, path construct.PropertyPath) { + // expect no calls to ExecuteOpRule due to shouldEvalEdges returning false + dcap.EXPECT().ExecuteOpRule(gomock.Any(), gomock.Any()).Times(0) + }, + }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { @@ -213,6 +311,7 @@ func Test_propertyVertex_Dependencies(t *testing.T) { } tt.mocks(dcap, resource, path) testSol := enginetesting.NewTestSolution() + testSol.Constr = tt.constraints testSol.KB.On("GetResourceTemplate", mock.Anything).Return(&knowledgebase.ResourceTemplate{}, nil) err = testSol.RawView().AddVertex(resource) if !assert.NoError(t, err) { diff --git a/pkg/engine/path_selection/path_selection.go b/pkg/engine/path_selection/path_selection.go index cca9f958b..7a3a64c78 100644 --- a/pkg/engine/path_selection/path_selection.go +++ b/pkg/engine/path_selection/path_selection.go @@ -11,7 +11,7 @@ import ( "go.uber.org/zap" ) -// PHANTOM_PREFIX deliberately uses an invalid character so if it leaks into an actualy input/output, it will +// PHANTOM_PREFIX deliberately uses an invalid character so if it leaks into an actual input/output, it will // fail to parse. const PHANTOM_PREFIX = "phantom$" const GLUE_WEIGHT = 100 @@ -75,6 +75,7 @@ func BuildPathSelectionGraph( if err != nil && !errors.Is(err, graph.ErrVertexAlreadyExists) { return nil, fmt.Errorf("failed to add target vertex to path selection graph for %s: %w", dep, err) } + satisfied_paths := 0 for _, path := range paths { resourcePath := make([]construct.ResourceId, len(path)) for i, res := range path { @@ -83,6 +84,24 @@ func BuildPathSelectionGraph( if !PathSatisfiesClassification(kb, resourcePath, classification) { continue } + // Check to see if the whole path is valid before adding phantoms and edges. + // It's a miniscule performance benefit, and is mostly done for clarity in the debug graph output. + validPath := true + for i, res := range path { + if i == 0 { + continue + } + edgeTemplate := kb.GetEdgeTemplate(path[i-1].Id(), res.Id()) + if edgeTemplate == nil || edgeTemplate.DirectEdgeOnly { + validPath = false + break + } + } + if !validPath { + continue + } + + satisfied_paths++ var prevRes construct.ResourceId for i, res := range path { id, err := makePhantom(tempGraph, res.Id()) @@ -111,6 +130,7 @@ func BuildPathSelectionGraph( prevRes = id } } + zap.S().Debugf("Found %d paths for %s :: %s", satisfied_paths, dep, classification) return tempGraph, nil } diff --git a/pkg/engine/solution_context.go b/pkg/engine/solution_context.go index b2cd5393d..afea4da7a 100644 --- a/pkg/engine/solution_context.go +++ b/pkg/engine/solution_context.go @@ -28,7 +28,7 @@ type ( } ) -func NewSolutionContext(kb knowledgebase.TemplateKB, globalTag string) *solutionContext { +func NewSolutionContext(kb knowledgebase.TemplateKB, globalTag string, constraints *constraints.Constraints) *solutionContext { ctx := &solutionContext{ KB: kb, Dataflow: graph_addons.LoggingGraph[construct.ResourceId, *construct.Resource]{ @@ -39,6 +39,7 @@ func NewSolutionContext(kb knowledgebase.TemplateKB, globalTag string) *solution Deployment: construct.NewAcyclicGraph(), decisions: &solution_context.MemoryRecord{}, mappedResources: make(map[construct.ResourceId]construct.ResourceId), + constraints: constraints, globalTag: globalTag, } ctx.propertyEval = property_eval.NewEvaluator(ctx) diff --git a/pkg/graph_addons/reverse.go b/pkg/graph_addons/reverse.go index 4a58e3766..8bb369e33 100644 --- a/pkg/graph_addons/reverse.go +++ b/pkg/graph_addons/reverse.go @@ -3,17 +3,7 @@ package graph_addons import "github.com/dominikbraun/graph" func ReverseTopologicalSort[K comparable, T any](g graph.Graph[K, T], less func(K, K) bool) ([]K, error) { - reverseLess := func(a, b K) bool { - return !less(b, a) - } - topo, err := graph.StableTopologicalSort(g, reverseLess) - if err != nil { - return nil, err - } - for i := 0; i < len(topo)/2; i++ { - topo[i], topo[len(topo)-i-1] = topo[len(topo)-i-1], topo[i] - } - return topo, nil + return toplogicalSort(g, less, true) } func ReverseGraph[K comparable, T any](g graph.Graph[K, T]) (graph.Graph[K, T], error) { diff --git a/pkg/graph_addons/topology.go b/pkg/graph_addons/topology.go new file mode 100644 index 000000000..97f0b37bb --- /dev/null +++ b/pkg/graph_addons/topology.go @@ -0,0 +1,137 @@ +package graph_addons + +import ( + "fmt" + "sort" + + "github.com/dominikbraun/graph" +) + +// TopologicalSort provides a stable topological ordering of resource IDs. +// This is a modified implementation of graph.StableTopologicalSort with the primary difference +// being any uses of the internal function `enqueueArbitrary`. +func TopologicalSort[K comparable, T any](g graph.Graph[K, T], less func(K, K) bool) ([]K, error) { + return toplogicalSort(g, less, false) +} + +func toplogicalSort[K comparable, T any](g graph.Graph[K, T], less func(K, K) bool, invertLess bool) ([]K, error) { + if !g.Traits().IsDirected { + return nil, fmt.Errorf("topological sort cannot be computed on undirected graph") + } + + predecessorMap, err := g.PredecessorMap() + if err != nil { + return nil, fmt.Errorf("failed to get predecessor map: %w", err) + } + + if len(predecessorMap) == 0 { + return nil, nil + } + + queue := make([]K, 0) + queued := make(map[K]struct{}) + enqueue := func(vs ...K) { + for _, vertex := range vs { + queue = append(queue, vertex) + queued[vertex] = struct{}{} + } + } + + for vertex, predecessors := range predecessorMap { + if len(predecessors) == 0 { + enqueue(vertex) + } + } + + // enqueueArbitrary enqueues an arbitray but deterministic id from the remaining unvisited ids. + // It should only be used if len(queue) == 0 && len(predecessorMap) > 0 + enqueueArbitrary := func() { + remainingIds := make([]K, 0, len(predecessorMap)) + for vertex := range predecessorMap { + remainingIds = append(remainingIds, vertex) + } + sort.Slice(remainingIds, func(i, j int) bool { + // Pick an arbitrary vertex to start the queue based first on the number of remaining predecessors + iPcount := len(predecessorMap[remainingIds[i]]) + jPcount := len(predecessorMap[remainingIds[j]]) + if iPcount != jPcount { + if invertLess { + return iPcount >= jPcount + } else { + return iPcount < jPcount + } + } + + // Tie-break on the ID contents themselves + if invertLess { + return !less(remainingIds[i], remainingIds[j]) + } + return less(remainingIds[i], remainingIds[j]) + }) + enqueue(remainingIds[0]) + } + + if len(queue) == 0 { + enqueueArbitrary() + } + + order := make([]K, 0, len(predecessorMap)) + visited := make(map[K]struct{}) + + if invertLess { + sort.Slice(queue, func(i, j int) bool { + return !less(queue[i], queue[j]) + }) + } else { + sort.Slice(queue, func(i, j int) bool { + return less(queue[i], queue[j]) + }) + } + + for len(queue) > 0 { + currentVertex := queue[0] + queue = queue[1:] + + if _, ok := visited[currentVertex]; ok { + continue + } + + order = append(order, currentVertex) + visited[currentVertex] = struct{}{} + delete(predecessorMap, currentVertex) + + frontier := make([]K, 0) + + for vertex, predecessors := range predecessorMap { + delete(predecessors, currentVertex) + + if len(predecessors) != 0 { + continue + } + + if _, ok := queued[vertex]; ok { + continue + } + + frontier = append(frontier, vertex) + } + + if invertLess { + sort.Slice(frontier, func(i, j int) bool { + return !less(frontier[i], frontier[j]) + }) + } else { + sort.Slice(frontier, func(i, j int) bool { + return less(frontier[i], frontier[j]) + }) + } + + enqueue(frontier...) + + if len(queue) == 0 && len(predecessorMap) > 0 { + enqueueArbitrary() + } + } + + return order, nil +} diff --git a/pkg/infra/cli.go b/pkg/infra/cli.go index 245b2b238..de4784972 100644 --- a/pkg/infra/cli.go +++ b/pkg/infra/cli.go @@ -8,6 +8,7 @@ import ( construct "github.com/klothoplatform/klotho/pkg/construct" engine "github.com/klothoplatform/klotho/pkg/engine" + "github.com/klothoplatform/klotho/pkg/engine/constraints" "github.com/klothoplatform/klotho/pkg/infra/iac" "github.com/klothoplatform/klotho/pkg/infra/kubernetes" "github.com/klothoplatform/klotho/pkg/io" @@ -107,7 +108,7 @@ func GenerateIac(cmd *cobra.Command, args []string) error { return err } - solCtx := engine.NewSolutionContext(kb, "") + solCtx := engine.NewSolutionContext(kb, "", &constraints.Constraints{}) err = solCtx.LoadGraph(input.Graph) if err != nil { return err diff --git a/pkg/infra/iac/templates/aws/ecs_service/factory.ts b/pkg/infra/iac/templates/aws/ecs_service/factory.ts index bb2c78cb0..eb3439c06 100644 --- a/pkg/infra/iac/templates/aws/ecs_service/factory.ts +++ b/pkg/infra/iac/templates/aws/ecs_service/factory.ts @@ -6,10 +6,8 @@ import { TemplateWrapper, ModelCaseWrapper } from '../../wrappers' interface Args { AssignPublicIp: Promise | OutputInstance | boolean - DeploymentCircuitBreaker: - | Promise - | OutputInstance - | awsInputs.ecs.ServiceDeploymentCircuitBreaker + DeploymentCircuitBreaker: pulumi.Input + EnableExecuteCommand: boolean ForceNewDeployment: boolean Cluster: aws.ecs.Cluster DesiredCount?: number @@ -38,6 +36,9 @@ function create(args: Args): aws.ecs.Service { //TMPL }, //TMPL {{- end }} desiredCount: args.DesiredCount, + //TMPL {{- if .EnableExecuteCommand }} + enableExecuteCommand: args.EnableExecuteCommand, + //TMPL {{- end }} forceNewDeployment: args.ForceNewDeployment, //TMPL {{- if .LoadBalancers }} loadBalancers: args.LoadBalancers, diff --git a/pkg/infra/iac/templates/aws/ecs_task_definition/factory.ts b/pkg/infra/iac/templates/aws/ecs_task_definition/factory.ts index 3f99f715b..3983d434a 100644 --- a/pkg/infra/iac/templates/aws/ecs_task_definition/factory.ts +++ b/pkg/infra/iac/templates/aws/ecs_task_definition/factory.ts @@ -6,12 +6,14 @@ import { TemplateWrapper, ModelCaseWrapper } from '../../wrappers' interface Args { Name: string + Cpu: string + Memory: string NetworkMode?: string ExecutionRole: aws.iam.Role TaskRole: aws.iam.Role RequiresCompatibilities?: string[] - EfsVolumes: TemplateWrapper - ContainerDefinitions: TemplateWrapper + EfsVolumes: TemplateWrapper + ContainerDefinitions: TemplateWrapper Tags: ModelCaseWrapper> } diff --git a/pkg/infra/iac/templates/aws/target_group/factory.ts b/pkg/infra/iac/templates/aws/target_group/factory.ts index 1dfa13e6b..6338db0eb 100644 --- a/pkg/infra/iac/templates/aws/target_group/factory.ts +++ b/pkg/infra/iac/templates/aws/target_group/factory.ts @@ -1,4 +1,5 @@ import * as aws from '@pulumi/aws' +import * as awsInputs from '@pulumi/aws/types/input' import { TemplateWrapper, ModelCaseWrapper } from '../../wrappers' interface Args { @@ -8,7 +9,7 @@ interface Args { Vpc: aws.ec2.Vpc TargetType: string Targets: { Id: string; Port: number }[] - HealthCheck: TemplateWrapper> + HealthCheck: TemplateWrapper LambdaMultiValueHeadersEnabled?: boolean Tags: ModelCaseWrapper> } diff --git a/pkg/knowledgebase/kb.go b/pkg/knowledgebase/kb.go index 5345937f0..e597f21bc 100644 --- a/pkg/knowledgebase/kb.go +++ b/pkg/knowledgebase/kb.go @@ -67,6 +67,10 @@ func NewKB() *KnowledgeBase { } } +func (kb *KnowledgeBase) Graph() graph.Graph[string, *ResourceTemplate] { + return kb.underlying +} + func (kb *KnowledgeBase) GetModel(model string) *Model { return kb.Models[model] } diff --git a/pkg/knowledgebase/resource_template.go b/pkg/knowledgebase/resource_template.go index 195cff280..268f8adbd 100644 --- a/pkg/knowledgebase/resource_template.go +++ b/pkg/knowledgebase/resource_template.go @@ -435,3 +435,13 @@ func (tmpl ResourceTemplate) LoopProperties(res *construct.Resource, addProp fun } return errs } + +func IsCollectionProperty(p Property) bool { + if _, ok := p.(CollectionProperty); ok { + return true + } + if _, ok := p.(MapProperty); ok { + return true + } + return false +} diff --git a/pkg/templates/aws/edges/ecs_task_definition-iam_role.yaml b/pkg/templates/aws/edges/ecs_task_definition-iam_role.yaml index a7488c2dc..7b28a09c7 100644 --- a/pkg/templates/aws/edges/ecs_task_definition-iam_role.yaml +++ b/pkg/templates/aws/edges/ecs_task_definition-iam_role.yaml @@ -21,3 +21,23 @@ operational_rules: field: ManagedPolicies value: - arn:aws:iam::aws:policy/service-role/AmazonECSTaskExecutionRolePolicy + - if: | + {{ (fieldValue "EnableExecuteCommand" (upstream "aws:ecs_service" .Source)) }} + configuration_rules: + - resource: | + {{ .Target }} + configuration: + field: InlinePolicies + value: + - Name: '{{ .Source.Name }}-ecs-exec' + Policy: + Version: '2012-10-17' + Statement: + - Action: + - ssmmessages:CreateControlChannel + - ssmmessages:CreateDataChannel + - ssmmessages:OpenControlChannel + - ssmmessages:OpenDataChannel + Effect: Allow + Resource: + - '*' diff --git a/pkg/templates/aws/edges/target_group-ecs_service.yaml b/pkg/templates/aws/edges/target_group-ecs_service.yaml index f447b2102..b249fa119 100644 --- a/pkg/templates/aws/edges/target_group-ecs_service.yaml +++ b/pkg/templates/aws/edges/target_group-ecs_service.yaml @@ -15,4 +15,4 @@ operational_rules: field: TargetType value: ip classification: - - service_endpoint \ No newline at end of file + - service_endpoint diff --git a/pkg/templates/aws/resources/ecs_service.yaml b/pkg/templates/aws/resources/ecs_service.yaml index 4a2cd6d0b..03abf06eb 100644 --- a/pkg/templates/aws/resources/ecs_service.yaml +++ b/pkg/templates/aws/resources/ecs_service.yaml @@ -27,6 +27,10 @@ properties: default_value: 1 description: The number of instantiations of the specified task definition to keep running on the service + EnableExecuteCommand: + type: bool + default_value: false + description: Whether to enable Amazon ECS Exec for the service. See https://docs.aws.amazon.com/AmazonECS/latest/developerguide/ecs-exec.html ForceNewDeployment: type: bool default_value: true @@ -125,6 +129,6 @@ views: dataflow: big deployment_permissions: - deploy: ["ecs:CreateService"] - tear_down: ["ecs:DeleteService"] - update: ["ecs:UpdateService"] \ No newline at end of file + deploy: ['ecs:CreateService'] + tear_down: ['ecs:DeleteService'] + update: ['ecs:UpdateService'] From 056ffd3dd9e5636574b951d7b64f9d11e422be8a Mon Sep 17 00:00:00 2001 From: Gordon Date: Sun, 3 Mar 2024 17:50:34 -0500 Subject: [PATCH 2/6] Update tests for new EnableExecuteCommand property --- pkg/engine/testdata/ecs_rds.expect.yaml | 1 + pkg/engine/testdata/idempotent_constraints.expect.yaml | 1 + 2 files changed, 2 insertions(+) diff --git a/pkg/engine/testdata/ecs_rds.expect.yaml b/pkg/engine/testdata/ecs_rds.expect.yaml index b382c2135..1c6b7bd49 100755 --- a/pkg/engine/testdata/ecs_rds.expect.yaml +++ b/pkg/engine/testdata/ecs_rds.expect.yaml @@ -21,6 +21,7 @@ resources: AssignPublicIp: false Cluster: aws:ecs_cluster:ecs_cluster-0 DesiredCount: 1 + EnableExecuteCommand: false ForceNewDeployment: true LaunchType: FARGATE SecurityGroups: diff --git a/pkg/engine/testdata/idempotent_constraints.expect.yaml b/pkg/engine/testdata/idempotent_constraints.expect.yaml index 697158e86..236f5a277 100755 --- a/pkg/engine/testdata/idempotent_constraints.expect.yaml +++ b/pkg/engine/testdata/idempotent_constraints.expect.yaml @@ -21,6 +21,7 @@ resources: AssignPublicIp: false Cluster: aws:ecs_cluster:ecs_cluster-0 DesiredCount: 1 + EnableExecuteCommand: false ForceNewDeployment: true LaunchType: FARGATE SecurityGroups: From ce173091a7467c765d73c3bc6c08b56439d40c96 Mon Sep 17 00:00:00 2001 From: Gordon Date: Sun, 3 Mar 2024 19:55:34 -0500 Subject: [PATCH 3/6] Add property names to graph edges when possible --- cmd/kb/dot.go | 78 +++++++++++++++++++++++++++++++++++++++++++-------- 1 file changed, 66 insertions(+), 12 deletions(-) diff --git a/cmd/kb/dot.go b/cmd/kb/dot.go index 064815f88..615fc912c 100644 --- a/cmd/kb/dot.go +++ b/cmd/kb/dot.go @@ -13,6 +13,7 @@ import ( "github.com/klothoplatform/klotho/pkg/dot" "github.com/klothoplatform/klotho/pkg/graph_addons" "github.com/klothoplatform/klotho/pkg/knowledgebase" + "github.com/klothoplatform/klotho/pkg/knowledgebase/properties" ) func dotAttributes(tmpl *knowledgebase.ResourceTemplate, props graph.VertexProperties) map[string]string { @@ -27,7 +28,11 @@ func dotAttributes(tmpl *knowledgebase.ResourceTemplate, props graph.VertexPrope return a } -func dotEdgeAttributes(e *knowledgebase.EdgeTemplate, props graph.EdgeProperties) map[string]string { +func dotEdgeAttributes( + kb knowledgebase.TemplateKB, + e *knowledgebase.EdgeTemplate, + props graph.EdgeProperties, +) map[string]string { a := make(map[string]string) for k, v := range props.Attributes { a[k] = v @@ -36,10 +41,67 @@ func dotEdgeAttributes(e *knowledgebase.EdgeTemplate, props graph.EdgeProperties a["style"] = "dashed" } a["edgetooltip"] = fmt.Sprintf("%s -> %s", e.Source, e.Target) + + if source, err := kb.GetResourceTemplate(e.Source); err == nil { + var isTarget func(ps knowledgebase.Properties) knowledgebase.Property + isTarget = func(ps knowledgebase.Properties) knowledgebase.Property { + for _, p := range ps { + name := p.Details().Name + if name == "" { + fmt.Print() + } + switch inst := p.(type) { + case *properties.ResourceProperty: + if inst.AllowedTypes.MatchesAny(e.Target) { + return p + } + + case knowledgebase.CollectionProperty: + if ip := inst.Item(); ip != nil { + ret := isTarget(knowledgebase.Properties{"item": ip}) + if ret != nil { + return ret + } + } + + case knowledgebase.MapProperty: + mapProps := make(knowledgebase.Properties) + if kp := inst.Key(); kp != nil { + mapProps["key"] = kp + } + if vp := inst.Value(); vp != nil { + mapProps["value"] = vp + } + ret := isTarget(mapProps) + if ret != nil { + return ret + } + } + return isTarget(p.SubProperties()) + } + return nil + } + prop := isTarget(source.Properties) + if prop != nil { + if label, ok := a["label"]; ok { + a["label"] = label + "\n" + prop.Details().Path + } else { + a["label"] = prop.Details().Path + } + } + } return a } -func KbToDot(g graph.Graph[string, *knowledgebase.ResourceTemplate], out io.Writer) error { +func KbToDot(kb knowledgebase.TemplateKB, out io.Writer) error { + hasGraph, ok := kb.(interface { + Graph() graph.Graph[string, *knowledgebase.ResourceTemplate] + }) + if !ok { + return fmt.Errorf("knowledgebase does not have a graph") + } + g := hasGraph.Graph() + ids, err := graph_addons.TopologicalSort(g, func(a, b string) bool { return a < b }) @@ -93,7 +155,7 @@ func KbToDot(g graph.Graph[string, *knowledgebase.ResourceTemplate], out io.Writ errs = errors.Join(errs, fmt.Errorf("edge %q -> %q has no EdgeTemplate", e.Source, e.Target)) continue } - printf(" %q -> %q%s\n", e.Source, e.Target, dot.AttributesToString(dotEdgeAttributes(et, e.Properties))) + printf(" %q -> %q%s\n", e.Source, e.Target, dot.AttributesToString(dotEdgeAttributes(kb, et, e.Properties))) } printf("}\n") return errs @@ -109,16 +171,8 @@ func KbToSVG(kb knowledgebase.TemplateKB, prefix string) error { } defer f.Close() - hasGraph, ok := kb.(interface { - Graph() graph.Graph[string, *knowledgebase.ResourceTemplate] - }) - if !ok { - return fmt.Errorf("knowledgebase does not have a graph") - } - g := hasGraph.Graph() - dotContent := new(bytes.Buffer) - err = KbToDot(g, io.MultiWriter(f, dotContent)) + err = KbToDot(kb, io.MultiWriter(f, dotContent)) if err != nil { return fmt.Errorf("could not render graph to file %s: %v", prefix+".gv", err) } From e09cbc45ccaba90907eed46be2a7a1e959fd231c Mon Sep 17 00:00:00 2001 From: Gordon Date: Mon, 4 Mar 2024 09:46:49 -0500 Subject: [PATCH 4/6] Remove old check which is redundant --- pkg/engine/path_selection/path_selection.go | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) diff --git a/pkg/engine/path_selection/path_selection.go b/pkg/engine/path_selection/path_selection.go index 7a3a64c78..9be9a97b5 100644 --- a/pkg/engine/path_selection/path_selection.go +++ b/pkg/engine/path_selection/path_selection.go @@ -119,12 +119,10 @@ func BuildPathSelectionGraph( return nil, err } if !prevRes.IsZero() { - edgeTemplate := kb.GetEdgeTemplate(prevRes, id) - if edgeTemplate != nil && !edgeTemplate.DirectEdgeOnly { - err := tempGraph.AddEdge(prevRes, id, graph.EdgeWeight(calculateEdgeWeight(dep, prevRes, id, 0, 0, classification, kb))) - if err != nil { - return nil, err - } + err := tempGraph.AddEdge(prevRes, id, + graph.EdgeWeight(calculateEdgeWeight(dep, prevRes, id, 0, 0, classification, kb))) + if err != nil { + return nil, err } } prevRes = id From 59377b66ab8485b6311a62e572810dda316a5be4 Mon Sep 17 00:00:00 2001 From: Gordon Date: Mon, 4 Mar 2024 10:33:13 -0500 Subject: [PATCH 5/6] Add unit test for TopologicalSort --- pkg/construct/graph_vertices.go | 123 +----------------------------- pkg/graph_addons/reverse.go | 18 ++++- pkg/graph_addons/topology.go | 105 ++++++++++++------------- pkg/graph_addons/topology_test.go | 94 +++++++++++++++++++++++ 4 files changed, 164 insertions(+), 176 deletions(-) create mode 100644 pkg/graph_addons/topology_test.go diff --git a/pkg/construct/graph_vertices.go b/pkg/construct/graph_vertices.go index cda824c71..cf42e09a4 100644 --- a/pkg/construct/graph_vertices.go +++ b/pkg/construct/graph_vertices.go @@ -2,139 +2,22 @@ package construct import ( "errors" - "fmt" - "slices" - "sort" "github.com/dominikbraun/graph" + "github.com/klothoplatform/klotho/pkg/graph_addons" ) // TopologicalSort provides a stable topological ordering of resource IDs. // This is a modified implementation of graph.StableTopologicalSort with the primary difference // being any uses of the internal function `enqueueArbitrary`. func TopologicalSort[T any](g graph.Graph[ResourceId, T]) ([]ResourceId, error) { - return toplogicalSort(g, false) + return graph_addons.TopologicalSort(g, ResourceIdLess) } // ReverseTopologicalSort is like TopologicalSort, but returns the reverse order. This is primarily useful for // IaC graphs to determine the order in which resources should be created. func ReverseTopologicalSort[T any](g graph.Graph[ResourceId, T]) ([]ResourceId, error) { - topo, err := toplogicalSort(g, true) - if err != nil { - return nil, err - } - slices.Reverse(topo) - return topo, nil -} - -func toplogicalSort[T any](g graph.Graph[ResourceId, T], invertLess bool) ([]ResourceId, error) { - if !g.Traits().IsDirected { - return nil, fmt.Errorf("topological sort cannot be computed on undirected graph") - } - - predecessorMap, err := g.PredecessorMap() - if err != nil { - return nil, fmt.Errorf("failed to get predecessor map: %w", err) - } - - if len(predecessorMap) == 0 { - return nil, nil - } - - queue := make([]ResourceId, 0) - queued := make(map[ResourceId]struct{}) - enqueue := func(vs ...ResourceId) { - for _, vertex := range vs { - queue = append(queue, vertex) - queued[vertex] = struct{}{} - } - } - - for vertex, predecessors := range predecessorMap { - if len(predecessors) == 0 { - enqueue(vertex) - } - } - - // enqueueArbitrary enqueues an arbitray but deterministic id from the remaining unvisited ids. - // It should only be used if len(queue) == 0 && len(predecessorMap) > 0 - enqueueArbitrary := func() { - remainingIds := make([]ResourceId, 0, len(predecessorMap)) - for vertex := range predecessorMap { - remainingIds = append(remainingIds, vertex) - } - sort.Slice(remainingIds, func(i, j int) bool { - // Pick an arbitrary vertex to start the queue based first on the number of remaining predecessors - iPcount := len(predecessorMap[remainingIds[i]]) - jPcount := len(predecessorMap[remainingIds[j]]) - if iPcount != jPcount { - if invertLess { - return iPcount >= jPcount - } else { - return iPcount < jPcount - } - } - - // Tie-break on the ID contents themselves - if invertLess { - return !SortedIds(remainingIds).Less(i, j) - } - return SortedIds(remainingIds).Less(i, j) - }) - enqueue(remainingIds[0]) - } - - if len(queue) == 0 { - enqueueArbitrary() - } - - order := make([]ResourceId, 0, len(predecessorMap)) - visited := make(map[ResourceId]struct{}) - - sort.Sort(SortedIds(queue)) - - for len(queue) > 0 { - currentVertex := queue[0] - queue = queue[1:] - - if _, ok := visited[currentVertex]; ok { - continue - } - - order = append(order, currentVertex) - visited[currentVertex] = struct{}{} - delete(predecessorMap, currentVertex) - - frontier := make([]ResourceId, 0) - - for vertex, predecessors := range predecessorMap { - delete(predecessors, currentVertex) - - if len(predecessors) != 0 { - continue - } - - if _, ok := queued[vertex]; ok { - continue - } - - frontier = append(frontier, vertex) - } - - if invertLess { - sort.Sort(sort.Reverse(SortedIds(frontier))) - } else { - sort.Sort(SortedIds(frontier)) - } - - enqueue(frontier...) - - if len(queue) == 0 && len(predecessorMap) > 0 { - enqueueArbitrary() - } - } - - return order, nil + return graph_addons.ReverseTopologicalSort(g, ResourceIdLess) } // WalkGraphFunc is much like `fs.WalkDirFunc` and is used in `WalkGraph` and `WalkGraphReverse` for the callback diff --git a/pkg/graph_addons/reverse.go b/pkg/graph_addons/reverse.go index 8bb369e33..84be9ade3 100644 --- a/pkg/graph_addons/reverse.go +++ b/pkg/graph_addons/reverse.go @@ -2,8 +2,24 @@ package graph_addons import "github.com/dominikbraun/graph" +// ReverseLess is a helper function that returns a new less function that reverses the order of the original less function. +func ReverseLess[K any](less func(K, K) bool) func(K, K) bool { + return func(a, b K) bool { + return less(b, a) + } +} + +// TopologicalSort provides a stable topological ordering. Note, the following is true: +// +// ReverseTopologicalSort(g, ReverseLess(less)) == reverse(TopologicalSort(g, less)) +// +// Meaning, the reverse topological sort of a graph under a reversed less is the reverse of the topological sort. func ReverseTopologicalSort[K comparable, T any](g graph.Graph[K, T], less func(K, K) bool) ([]K, error) { - return toplogicalSort(g, less, true) + adjacencyMap, err := g.AdjacencyMap() + if err != nil { + return nil, err + } + return topologicalSort(adjacencyMap, less) } func ReverseGraph[K comparable, T any](g graph.Graph[K, T]) (graph.Graph[K, T], error) { diff --git a/pkg/graph_addons/topology.go b/pkg/graph_addons/topology.go index 97f0b37bb..8ee172c55 100644 --- a/pkg/graph_addons/topology.go +++ b/pkg/graph_addons/topology.go @@ -1,33 +1,43 @@ package graph_addons import ( - "fmt" "sort" "github.com/dominikbraun/graph" ) -// TopologicalSort provides a stable topological ordering of resource IDs. -// This is a modified implementation of graph.StableTopologicalSort with the primary difference -// being any uses of the internal function `enqueueArbitrary`. +// TopologicalSort provides a stable topological ordering. func TopologicalSort[K comparable, T any](g graph.Graph[K, T], less func(K, K) bool) ([]K, error) { - return toplogicalSort(g, less, false) -} - -func toplogicalSort[K comparable, T any](g graph.Graph[K, T], less func(K, K) bool, invertLess bool) ([]K, error) { - if !g.Traits().IsDirected { - return nil, fmt.Errorf("topological sort cannot be computed on undirected graph") - } - - predecessorMap, err := g.PredecessorMap() + predecessors, err := g.PredecessorMap() if err != nil { - return nil, fmt.Errorf("failed to get predecessor map: %w", err) + return nil, err } + return topologicalSort(predecessors, less) +} - if len(predecessorMap) == 0 { +// topologicalSort performs a topological sort on a graph with the given dependencies. +// Whether the sort is regular or reverse is determined by whether the `deps` map is a PredecessorMap or AdjacencyMap. +// The `less` function is used to determine the order of vertices in the result. +// This is a modified implementation of graph.StableTopologicalSort with the primary difference +// being any uses of the internal function `enqueueArbitrary`. +func topologicalSort[K comparable](deps map[K]map[K]graph.Edge[K], less func(K, K) bool) ([]K, error) { + if len(deps) == 0 { return nil, nil } + reverseSort := false + // PredecessorMap (for regular topological sort) returns a map source -> targets, so check the first edge to see + // if the source matches the top map's key. AdjacencyMap (for reverse topological sort) returns a map + // target -> sources, so if the edge's source doesn't match the top map's key, we need to invert the less function. + for source, ts := range deps { + for _, edge := range ts { + if edge.Source != source { + reverseSort = true + } + break + } + } + queue := make([]K, 0) queued := make(map[K]struct{}) enqueue := func(vs ...K) { @@ -37,56 +47,47 @@ func toplogicalSort[K comparable, T any](g graph.Graph[K, T], less func(K, K) bo } } - for vertex, predecessors := range predecessorMap { - if len(predecessors) == 0 { + for vertex, vdeps := range deps { + if len(vdeps) == 0 { enqueue(vertex) } } // enqueueArbitrary enqueues an arbitray but deterministic id from the remaining unvisited ids. - // It should only be used if len(queue) == 0 && len(predecessorMap) > 0 + // It should only be used if len(queue) == 0 && len(deps) > 0 enqueueArbitrary := func() { - remainingIds := make([]K, 0, len(predecessorMap)) - for vertex := range predecessorMap { - remainingIds = append(remainingIds, vertex) + remaining := make([]K, 0, len(deps)) + for vertex := range deps { + remaining = append(remaining, vertex) } - sort.Slice(remainingIds, func(i, j int) bool { - // Pick an arbitrary vertex to start the queue based first on the number of remaining predecessors - iPcount := len(predecessorMap[remainingIds[i]]) - jPcount := len(predecessorMap[remainingIds[j]]) + sort.Slice(remaining, func(i, j int) bool { + // Pick an arbitrary vertex to start the queue based first on the number of remaining deps + iPcount := len(deps[remaining[i]]) + jPcount := len(deps[remaining[j]]) if iPcount != jPcount { - if invertLess { - return iPcount >= jPcount + if reverseSort { + return jPcount < iPcount } else { return iPcount < jPcount } } - // Tie-break on the ID contents themselves - if invertLess { - return !less(remainingIds[i], remainingIds[j]) - } - return less(remainingIds[i], remainingIds[j]) + // Tie-break using the less function on contents themselves + return less(remaining[i], remaining[j]) }) - enqueue(remainingIds[0]) + enqueue(remaining[0]) } if len(queue) == 0 { enqueueArbitrary() } - order := make([]K, 0, len(predecessorMap)) + order := make([]K, 0, len(deps)) visited := make(map[K]struct{}) - if invertLess { - sort.Slice(queue, func(i, j int) bool { - return !less(queue[i], queue[j]) - }) - } else { - sort.Slice(queue, func(i, j int) bool { - return less(queue[i], queue[j]) - }) - } + sort.Slice(queue, func(i, j int) bool { + return less(queue[i], queue[j]) + }) for len(queue) > 0 { currentVertex := queue[0] @@ -98,11 +99,11 @@ func toplogicalSort[K comparable, T any](g graph.Graph[K, T], less func(K, K) bo order = append(order, currentVertex) visited[currentVertex] = struct{}{} - delete(predecessorMap, currentVertex) + delete(deps, currentVertex) frontier := make([]K, 0) - for vertex, predecessors := range predecessorMap { + for vertex, predecessors := range deps { delete(predecessors, currentVertex) if len(predecessors) != 0 { @@ -116,19 +117,13 @@ func toplogicalSort[K comparable, T any](g graph.Graph[K, T], less func(K, K) bo frontier = append(frontier, vertex) } - if invertLess { - sort.Slice(frontier, func(i, j int) bool { - return !less(frontier[i], frontier[j]) - }) - } else { - sort.Slice(frontier, func(i, j int) bool { - return less(frontier[i], frontier[j]) - }) - } + sort.Slice(frontier, func(i, j int) bool { + return less(frontier[i], frontier[j]) + }) enqueue(frontier...) - if len(queue) == 0 && len(predecessorMap) > 0 { + if len(queue) == 0 && len(deps) > 0 { enqueueArbitrary() } } diff --git a/pkg/graph_addons/topology_test.go b/pkg/graph_addons/topology_test.go new file mode 100644 index 000000000..366585539 --- /dev/null +++ b/pkg/graph_addons/topology_test.go @@ -0,0 +1,94 @@ +package graph_addons + +import ( + "slices" + "testing" + + "github.com/dominikbraun/graph" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func Test_topologicalSort(t *testing.T) { + type Edge = graph.Edge[int] + less := func(a, b int) bool { + return a < b + } + + tests := map[string]struct { + vertices []int + edges []Edge + expectedOrder []int + shouldFail bool + }{ + "graph with 5 vertices": { + vertices: []int{1, 2, 3, 4, 5}, + edges: []Edge{ + {Source: 1, Target: 2}, + {Source: 1, Target: 3}, + {Source: 2, Target: 3}, + {Source: 2, Target: 4}, + {Source: 2, Target: 5}, + {Source: 3, Target: 4}, + {Source: 4, Target: 5}, + }, + expectedOrder: []int{1, 2, 3, 4, 5}, + }, + "graph with many possible topological orders": { + vertices: []int{1, 2, 3, 4, 5, 6, 10, 20, 30, 40, 50, 60}, + edges: []Edge{ + {Source: 1, Target: 10}, + {Source: 2, Target: 20}, + {Source: 3, Target: 30}, + {Source: 4, Target: 40}, + {Source: 5, Target: 50}, + {Source: 6, Target: 60}, + }, + expectedOrder: []int{1, 2, 3, 4, 5, 6, 10, 20, 30, 40, 50, 60}, + }, + "graph with cycle": { + vertices: []int{1, 2, 3, 4}, + edges: []Edge{ // 1 -> 3 -> 2 -> 4 ↺ + {Source: 1, Target: 3}, + {Source: 3, Target: 2}, + {Source: 2, Target: 4}, + {Source: 4, Target: 1}, + }, + expectedOrder: []int{1, 3, 2, 4}, + }, + } + + for name, test := range tests { + t.Run(name, func(t *testing.T) { + require, assert := require.New(t), assert.New(t) + + g := graph.New(graph.IntHash, graph.Directed()) + + for _, vertex := range test.vertices { + _ = g.AddVertex(vertex) + } + + for _, edge := range test.edges { + require.NoError( + g.AddEdge(edge.Source, edge.Target, func(ep *graph.EdgeProperties) { *ep = edge.Properties }), + ) + } + + order, err := TopologicalSort(g, less) + + if test.shouldFail { + require.Error(err) + return + } + require.NoError(err) + + assert.Equal(test.expectedOrder, order, "regular order doesn't match") + + reverse, err := ReverseTopologicalSort(g, ReverseLess(less)) + require.NoError(err) + + slices.Reverse(test.expectedOrder) + assert.Equal(test.expectedOrder, reverse, "reverse order doesn't match") + }) + } +} From 58f2aa3123b033b8b0111ad799a0f8a406869e7d Mon Sep 17 00:00:00 2001 From: Gordon Date: Mon, 4 Mar 2024 10:50:30 -0500 Subject: [PATCH 6/6] Improve topologicalSort cycle breaking --- pkg/graph_addons/reverse.go | 6 +----- pkg/graph_addons/topology.go | 28 ++++++---------------------- pkg/graph_addons/topology_test.go | 28 +++++++++++++++++++--------- 3 files changed, 26 insertions(+), 36 deletions(-) diff --git a/pkg/graph_addons/reverse.go b/pkg/graph_addons/reverse.go index 84be9ade3..2d6a7e0dc 100644 --- a/pkg/graph_addons/reverse.go +++ b/pkg/graph_addons/reverse.go @@ -9,11 +9,7 @@ func ReverseLess[K any](less func(K, K) bool) func(K, K) bool { } } -// TopologicalSort provides a stable topological ordering. Note, the following is true: -// -// ReverseTopologicalSort(g, ReverseLess(less)) == reverse(TopologicalSort(g, less)) -// -// Meaning, the reverse topological sort of a graph under a reversed less is the reverse of the topological sort. +// TopologicalSort provides a stable topological ordering. func ReverseTopologicalSort[K comparable, T any](g graph.Graph[K, T], less func(K, K) bool) ([]K, error) { adjacencyMap, err := g.AdjacencyMap() if err != nil { diff --git a/pkg/graph_addons/topology.go b/pkg/graph_addons/topology.go index 8ee172c55..17a805343 100644 --- a/pkg/graph_addons/topology.go +++ b/pkg/graph_addons/topology.go @@ -25,19 +25,6 @@ func topologicalSort[K comparable](deps map[K]map[K]graph.Edge[K], less func(K, return nil, nil } - reverseSort := false - // PredecessorMap (for regular topological sort) returns a map source -> targets, so check the first edge to see - // if the source matches the top map's key. AdjacencyMap (for reverse topological sort) returns a map - // target -> sources, so if the edge's source doesn't match the top map's key, we need to invert the less function. - for source, ts := range deps { - for _, edge := range ts { - if edge.Source != source { - reverseSort = true - } - break - } - } - queue := make([]K, 0) queued := make(map[K]struct{}) enqueue := func(vs ...K) { @@ -61,15 +48,12 @@ func topologicalSort[K comparable](deps map[K]map[K]graph.Edge[K], less func(K, remaining = append(remaining, vertex) } sort.Slice(remaining, func(i, j int) bool { - // Pick an arbitrary vertex to start the queue based first on the number of remaining deps - iPcount := len(deps[remaining[i]]) - jPcount := len(deps[remaining[j]]) - if iPcount != jPcount { - if reverseSort { - return jPcount < iPcount - } else { - return iPcount < jPcount - } + // Start based first on the number of remaining deps, prioritizing vertices with fewer deps + // to make it most likely to break any cycles, reducing the amount of arbitrary choices. + ic := len(deps[remaining[i]]) + jc := len(deps[remaining[j]]) + if ic != jc { + return ic < jc } // Tie-break using the less function on contents themselves diff --git a/pkg/graph_addons/topology_test.go b/pkg/graph_addons/topology_test.go index 366585539..dab132121 100644 --- a/pkg/graph_addons/topology_test.go +++ b/pkg/graph_addons/topology_test.go @@ -19,6 +19,7 @@ func Test_topologicalSort(t *testing.T) { vertices []int edges []Edge expectedOrder []int + reverseOrder []int // defaults to reverse(expectedOrder) shouldFail bool }{ "graph with 5 vertices": { @@ -47,14 +48,19 @@ func Test_topologicalSort(t *testing.T) { expectedOrder: []int{1, 2, 3, 4, 5, 6, 10, 20, 30, 40, 50, 60}, }, "graph with cycle": { - vertices: []int{1, 2, 3, 4}, - edges: []Edge{ // 1 -> 3 -> 2 -> 4 ↺ - {Source: 1, Target: 3}, - {Source: 3, Target: 2}, - {Source: 2, Target: 4}, - {Source: 4, Target: 1}, + vertices: []int{1, 2, 3, 4, 5}, + edges: []Edge{ + {Source: 5, Target: 1}, + + // 1 -> 2 -> 3 ↺ 1 + {Source: 1, Target: 2}, + {Source: 2, Target: 3}, + {Source: 3, Target: 1}, + + {Source: 3, Target: 4}, }, - expectedOrder: []int{1, 3, 2, 4}, + expectedOrder: []int{5, 1, 2, 3, 4}, + reverseOrder: []int{4, 5, 3, 2, 1}, }, } @@ -87,8 +93,12 @@ func Test_topologicalSort(t *testing.T) { reverse, err := ReverseTopologicalSort(g, ReverseLess(less)) require.NoError(err) - slices.Reverse(test.expectedOrder) - assert.Equal(test.expectedOrder, reverse, "reverse order doesn't match") + if test.reverseOrder == nil { + test.reverseOrder = make([]int, len(test.expectedOrder)) + copy(test.reverseOrder, test.expectedOrder) + slices.Reverse(test.reverseOrder) + } + assert.Equal(test.reverseOrder, reverse, "reverse order doesn't match") }) } }