forked from canonical/sqlair-prototype
-
Notifications
You must be signed in to change notification settings - Fork 0
/
statement.go
130 lines (109 loc) · 3.45 KB
/
statement.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
package sqlair
import (
"github.com/canonical/sqlair/internal/parse"
sqlairreflect "github.com/canonical/sqlair/internal/reflect"
)
// typeMap is a convenience type alias for reflection
// information indexed by type name.
type typeMap = map[string]sqlairreflect.Info
// Statement represents a prepared Sqlair DSL statement
// that can be executed by the database.
type Statement struct {
// expression is the parsed expression tree for this statement.
expression parse.Expression
// argTypes holds the reflection info for types used in this statement.
argTypes typeMap
}
// Prepare accepts a raw DSL string and optionally,
// objects from which to infer type information.
// - The string is parsed into an expression tree.
// - Any input objects have their reflection information retrieved/generated.
// - The reflection information is matched with the parser output to generate
// a Statement that can be passed to the database for execution.
func Prepare(stmt string, args ...any) (*Statement, error) {
lex := parse.NewLexer(stmt)
parser := parse.NewParser(lex)
exp, err := parser.Run()
if err != nil {
return nil, err
}
argTypes, err := typesForStatement(args)
if err != nil {
return nil, err
}
if err := validateExpressionTypes(exp, argTypes); err != nil {
return nil, err
}
return &Statement{
expression: exp,
argTypes: argTypes,
}, nil
}
// typesForStatement returns reflection information for the input arguments.
// The reflected type name of each argument must be unique in the list,
// which means declaring new local types to avoid ambiguity.
//
// Example:
//
// type Person struct{}
// type Manager Person
//
// stmt, err := sqlair.Prepare(`
// SELECT p.* AS &Person.*,
// m.* AS &Manager.*
// FROM person AS p
// JOIN person AS m
// ON p.manager_id = m.id
// WHERE p.name = 'Fred'`, Person{}, Manager{})
//
func typesForStatement(args []any) (typeMap, error) {
c := sqlairreflect.Cache()
argTypes := make(typeMap)
for _, arg := range args {
reflected, err := c.Reflect(arg)
if err != nil {
return nil, err
}
name := reflected.Name()
if _, ok := argTypes[name]; ok {
return nil, NewErrTypeNameNotUnique(name)
}
argTypes[name] = reflected
}
return argTypes, nil
}
// validateExpressionTypes walks the input expression tree to ensure:
// - Each input/output target in expression has type information in argTypes.
// - All type information is actually required by the input/output targets.
func validateExpressionTypes(statementExp parse.Expression, argTypes typeMap) error {
var err error
seen := make(map[string]bool)
visit := func(exp parse.Expression) bool {
if t := exp.Type(); t != parse.OutputTarget && t != parse.InputSource {
return true
}
// Select the first identity, such as "Person"
// in the case of "$Person.id".
// Ensure that there is type information for it.
typeName := exp.Expressions()[1].String()
if _, ok := argTypes[typeName]; !ok {
err = NewErrTypeInfoNotPresent(typeName)
return false
}
seen[typeName] = true
return true
}
// If we did not complete the walk through the tree,
// return the error that we encountered.
if !parse.Walk(statementExp, visit) {
return err
}
// Now compare the type names that we saw against what we have information
// for. If unused types were supplied, it is an error condition.
for name := range argTypes {
if _, ok := seen[name]; !ok {
return NewErrSuperfluousType(name)
}
}
return nil
}