Skip to content
This repository has been archived by the owner on Jan 28, 2021. It is now read-only.

Commit

Permalink
Merge pull request #574 from erizocosmico/fix/column-prune-fix-fields
Browse files Browse the repository at this point in the history
analyzer: fix fields of subqueries too on prune_columns
  • Loading branch information
ajnavarro authored Dec 17, 2018
2 parents 47cd1f8 + ec6bee6 commit ac59802
Show file tree
Hide file tree
Showing 2 changed files with 45 additions and 60 deletions.
103 changes: 44 additions & 59 deletions sql/analyzer/prune_columns.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,12 +29,7 @@ func pruneColumns(ctx *sql.Context, a *Analyzer, n sql.Node) (sql.Node, error) {

findUsedColumns(columns, n)

n, err := addSubqueryBarriers(n)
if err != nil {
return nil, err
}

n, err = pruneUnusedColumns(n, columns)
n, err := pruneUnusedColumns(n, columns)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -81,12 +76,7 @@ func pruneSubqueryColumns(

findUsedColumns(columns, n.Child)

node, err := addSubqueryBarriers(n.Child)
if err != nil {
return nil, err
}

node, err = pruneUnusedColumns(node, columns)
node, err := pruneUnusedColumns(n.Child, columns)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -126,30 +116,19 @@ func findUsedColumns(columns usedColumns, n sql.Node) {
})
}

func addSubqueryBarriers(n sql.Node) (sql.Node, error) {
return n.TransformUp(func(n sql.Node) (sql.Node, error) {
sq, ok := n.(*plan.SubqueryAlias)
if !ok {
return n, nil
}

return &subqueryBarrier{sq}, nil
})
}

func pruneSubqueries(
ctx *sql.Context,
a *Analyzer,
n sql.Node,
parentColumns usedColumns,
) (sql.Node, error) {
return n.TransformUp(func(n sql.Node) (sql.Node, error) {
barrier, ok := n.(*subqueryBarrier)
subq, ok := n.(*plan.SubqueryAlias)
if !ok {
return n, nil
}

return pruneSubqueryColumns(ctx, a, barrier.SubqueryAlias, parentColumns)
return pruneSubqueryColumns(ctx, a, subq, parentColumns)
})
}

Expand All @@ -173,39 +152,53 @@ type tableColumnPair struct {

func fixRemainingFieldsIndexes(n sql.Node) (sql.Node, error) {
return n.TransformUp(func(n sql.Node) (sql.Node, error) {
exp, ok := n.(sql.Expressioner)
if !ok {
return n, nil
}

var schema sql.Schema
for _, c := range n.Children() {
schema = append(schema, c.Schema()...)
}
switch n := n.(type) {
case *plan.SubqueryAlias:
child, err := fixRemainingFieldsIndexes(n.Child)
if err != nil {
return nil, err
}

if len(schema) == 0 {
return n, nil
}
return plan.NewSubqueryAlias(n.Name(), child), nil
default:
exp, ok := n.(sql.Expressioner)
if !ok {
return n, nil
}

indexes := make(map[tableColumnPair]int)
for i, col := range schema {
indexes[tableColumnPair{col.Source, col.Name}] = i
}
var schema sql.Schema
for _, c := range n.Children() {
schema = append(schema, c.Schema()...)
}

return exp.TransformExpressions(func(e sql.Expression) (sql.Expression, error) {
gf, ok := e.(*expression.GetField)
if !ok {
return e, nil
if len(schema) == 0 {
return n, nil
}

idx, ok := indexes[tableColumnPair{gf.Table(), gf.Name()}]
if !ok {
return nil, fmt.Errorf("unable to find column %q of table %q", gf.Name(), gf.Table())
indexes := make(map[tableColumnPair]int)
for i, col := range schema {
indexes[tableColumnPair{col.Source, col.Name}] = i
}

ngf := *gf
return ngf.WithIndex(idx), nil
})
return exp.TransformExpressions(func(e sql.Expression) (sql.Expression, error) {
gf, ok := e.(*expression.GetField)
if !ok {
return e, nil
}

idx, ok := indexes[tableColumnPair{gf.Table(), gf.Name()}]
if !ok {
return nil, fmt.Errorf("unable to find column %q of table %q", gf.Name(), gf.Table())
}

if idx == gf.Index() {
return gf, nil
}

ngf := *gf
return ngf.WithIndex(idx), nil
})
}
})
}

Expand Down Expand Up @@ -290,11 +283,3 @@ func shouldPruneExpr(e sql.Expression, cols usedColumns) bool {

return true
}

type subqueryBarrier struct {
*plan.SubqueryAlias
}

func (b *subqueryBarrier) TransformUp(f sql.TransformNodeFunc) (sql.Node, error) {
return f(b)
}
2 changes: 1 addition & 1 deletion sql/analyzer/prune_columns_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -263,7 +263,7 @@ func TestPruneColumns(t *testing.T) {
),
expression.NewEquals(
gf(0, "t1", "foo"),
gf(3, "t2", "foo"),
gf(1, "t2", "foo"),
),
),
),
Expand Down

0 comments on commit ac59802

Please sign in to comment.