Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix(postgresql): set query result column nullable if selected column contains json operator or null value in case expression #3739

Open
wants to merge 11 commits into
base: main
Choose a base branch
from
Open
184 changes: 135 additions & 49 deletions internal/compiler/output_columns.go
Original file line number Diff line number Diff line change
Expand Up @@ -132,33 +132,18 @@ func (c *Compiler) outputColumns(qc *QueryCatalog, node ast.Node) ([]*Column, er
if res.Name != nil {
name = *res.Name
}
switch n.Val.(type) {
case *ast.String:
cols = append(cols, &Column{Name: name, DataType: "text", NotNull: true})
case *ast.Integer:
cols = append(cols, &Column{Name: name, DataType: "int", NotNull: true})
case *ast.Float:
cols = append(cols, &Column{Name: name, DataType: "float", NotNull: true})
case *ast.Boolean:
cols = append(cols, &Column{Name: name, DataType: "bool", NotNull: true})
default:
cols = append(cols, &Column{Name: name, DataType: "any", NotNull: false})
col := convertAConstToColumn(n, name)
if col.DataType == "null" {
col.DataType = "any"
}
cols = append(cols, col)

case *ast.A_Expr:
name := ""
if res.Name != nil {
name = *res.Name
}
switch op := astutils.Join(n.Name, ""); {
case lang.IsComparisonOperator(op):
// TODO: Generate a name for these operations
cols = append(cols, &Column{Name: name, DataType: "bool", NotNull: true})
case lang.IsMathematicalOperator(op):
cols = append(cols, &Column{Name: name, DataType: "int", NotNull: true})
default:
cols = append(cols, &Column{Name: name, DataType: "any", NotNull: false})
}
cols = append(cols, convertAExprToColumn(n, name))

case *ast.BoolExpr:
name := ""
Expand All @@ -183,44 +168,70 @@ func (c *Compiler) outputColumns(qc *QueryCatalog, node ast.Node) ([]*Column, er
cols = append(cols, &Column{Name: name, DataType: "bool", NotNull: notNull})

case *ast.CaseExpr:
name := ""
var name string
if res.Name != nil {
name = *res.Name
}
// TODO: The TypeCase and A_Const code has been copied from below. Instead, we
// need a recurse function to get the type of a node.
if tc, ok := n.Defresult.(*ast.TypeCast); ok {
if tc.TypeName == nil {
return nil, errors.New("no type name type cast")

chosenType := ""
chosenNullable := false

for _, i := range n.Args.Items {
cw := i.(*ast.CaseWhen)
col, err := convertCaseExprCondToColumn(cw.Result, &name)
if err != nil {
return nil, err
}
if col.DataType == "null" {
// we don't choose type from this column if its value is null, only choose nullability
chosenNullable = true
continue
}
name := ""
if ref, ok := tc.Arg.(*ast.ColumnRef); ok {
name = astutils.Join(ref.Fields, "_")
if col.DataType != chosenType {
if chosenType == "" {
chosenType = col.DataType
} else {
chosenType = "any"
}
}
if res.Name != nil {
name = *res.Name
if !col.NotNull {
chosenNullable = true
}
// TODO Validate column names
col := toColumn(tc.TypeName)
col.Name = name
cols = append(cols, col)
} else if aconst, ok := n.Defresult.(*ast.A_Const); ok {
switch aconst.Val.(type) {
case *ast.String:
cols = append(cols, &Column{Name: name, DataType: "text", NotNull: true})
case *ast.Integer:
cols = append(cols, &Column{Name: name, DataType: "int", NotNull: true})
case *ast.Float:
cols = append(cols, &Column{Name: name, DataType: "float", NotNull: true})
case *ast.Boolean:
cols = append(cols, &Column{Name: name, DataType: "bool", NotNull: true})
default:
cols = append(cols, &Column{Name: name, DataType: "any", NotNull: false})
}

var defaultCol *Column
if n.Defresult.Pos() != 0 {
defaultCol, err = convertCaseExprCondToColumn(n.Defresult, &name)
if err != nil {
return nil, err
}
} else {
cols = append(cols, &Column{Name: name, DataType: "any", NotNull: false})
defaultCol = &Column{Name: name, DataType: "null", NotNull: false}
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

missing ELSE clause on CaseExpr is considered to return null value

}

if defaultCol.DataType == "null" {
// we don't choose type from this column if its value is null, only choose nullability
chosenNullable = true
} else {
if defaultCol.DataType != chosenType {
if chosenType == "" {
chosenType = defaultCol.DataType
} else {
chosenType = "any"
}
}
if !defaultCol.NotNull {
chosenNullable = true
}
}

if chosenType == "" {
chosenType = "any"
}

chosenColumn := &Column{Name: name, DataType: chosenType, NotNull: !chosenNullable}
cols = append(cols, chosenColumn)

case *ast.CoalesceExpr:
name := "coalesce"
if res.Name != nil {
Expand Down Expand Up @@ -256,7 +267,6 @@ func (c *Compiler) outputColumns(qc *QueryCatalog, node ast.Node) ([]*Column, er

case *ast.ColumnRef:
if hasStarRef(n) {

// add a column with a reference to an embedded table
if embed, ok := qc.embeds.Find(n); ok {
cols = append(cols, &Column{
Expand Down Expand Up @@ -366,6 +376,11 @@ func (c *Compiler) outputColumns(qc *QueryCatalog, node ast.Node) ([]*Column, er
col.NotNull = false
}
}
if expr, ok := n.Arg.(*ast.A_Expr); ok {
if op := astutils.Join(expr.Name, ""); lang.IsJSONOperator(op) {
col.NotNull = false
}
}
cols = append(cols, col)

case *ast.SelectStmt:
Expand Down Expand Up @@ -764,3 +779,74 @@ func findColumnForRef(ref *ast.ColumnRef, tables []*Table, targetList *ast.List)

return nil
}

func convertCaseExprCondToColumn(n ast.Node, resTargetName *string) (*Column, error) {
var col *Column
name := ""
if resTargetName != nil {
name = *resTargetName
}

if tc, ok := n.(*ast.TypeCast); ok {
if tc.TypeName == nil {
return nil, errors.New("no type name type cast")
}
if ref, ok := tc.Arg.(*ast.ColumnRef); ok {
name = astutils.Join(ref.Fields, "_")
}
// TODO Validate column names
col = toColumn(tc.TypeName)

if x, ok := tc.Arg.(*ast.A_Const); ok {
if _, ok := x.Val.(*ast.Null); ok {
col.NotNull = false
}
}
col.Name = name

} else if aconst, ok := n.(*ast.A_Const); ok {
col = convertAConstToColumn(aconst, name)
} else if aexpr, ok := n.(*ast.A_Expr); ok {
col = convertAExprToColumn(aexpr, name)
} else {
col = &Column{Name: name, DataType: "any", NotNull: false}
}

return col, nil
}

func convertAExprToColumn(aexpr *ast.A_Expr, name string) *Column {
var col *Column
switch op := astutils.Join(aexpr.Name, ""); {
case lang.IsComparisonOperator(op):
// TODO: Generate a name for these operations
col = &Column{Name: name, DataType: "bool", NotNull: true}
case lang.IsMathematicalOperator(op):
col = &Column{Name: name, DataType: "int", NotNull: true}
case lang.IsJSONOperator(op) && lang.IsJSONResultAsText(op):
col = &Column{Name: name, DataType: "text", NotNull: false}
default:
col = &Column{Name: name, DataType: "any", NotNull: false}
}

return col
}

func convertAConstToColumn(aconst *ast.A_Const, name string) (*Column) {
var col *Column
switch aconst.Val.(type) {
case *ast.String:
col = &Column{Name: name, DataType: "text", NotNull: true}
case *ast.Integer:
col = &Column{Name: name, DataType: "int", NotNull: true}
case *ast.Float:
col = &Column{Name: name, DataType: "float", NotNull: true}
case *ast.Boolean:
col = &Column{Name: name, DataType: "bool", NotNull: true}
case *ast.Null:
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is new addition to handle null const value. I need to distinguish between any and null, because null in CaseExpr would determine query column result nullability, but does not make it included in column type decision.

for ast.A_Const condition, i have converted DataType null back to any (line 136, outputColumns function) because we don't know its type and to preserve backward compatibility

col = &Column{Name: name, DataType: "null", NotNull: false}
default:
col = &Column{Name: name, DataType: "any", NotNull: false}
}
return col
}
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
https://github.com/sqlc-dev/sqlc/issues/3710
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
{
"contexts": ["base"]
}

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Loading
Loading