Skip to content

Commit

Permalink
add new rules, fix some bugs
Browse files Browse the repository at this point in the history
  • Loading branch information
LakshSingla committed Sep 28, 2023
1 parent 7e60702 commit a6ee632
Show file tree
Hide file tree
Showing 5 changed files with 492 additions and 6 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -20,17 +20,29 @@
package org.apache.druid.msq.test;

import com.fasterxml.jackson.databind.ObjectMapper;
import com.google.common.collect.ImmutableList;
import com.google.inject.Injector;
import com.google.inject.Module;
import org.apache.druid.common.config.NullHandling;
import org.apache.druid.guice.DruidInjectorBuilder;
import org.apache.druid.java.util.common.granularity.Granularities;
import org.apache.druid.msq.exec.WorkerMemoryParameters;
import org.apache.druid.msq.sql.MSQTaskSqlEngine;
import org.apache.druid.query.QueryDataSource;
import org.apache.druid.query.TableDataSource;
import org.apache.druid.query.UnionDataSource;
import org.apache.druid.query.aggregation.CountAggregatorFactory;
import org.apache.druid.query.aggregation.LongSumAggregatorFactory;
import org.apache.druid.query.dimension.DefaultDimensionSpec;
import org.apache.druid.query.groupby.GroupByQuery;
import org.apache.druid.query.groupby.TestGroupByBuffers;
import org.apache.druid.server.QueryLifecycleFactory;
import org.apache.druid.sql.calcite.BaseCalciteQueryTest;
import org.apache.druid.sql.calcite.CalciteUnionQueryTest;
import org.apache.druid.sql.calcite.QueryTestBuilder;
import org.apache.druid.sql.calcite.filtration.Filtration;
import org.apache.druid.sql.calcite.run.SqlEngine;
import org.apache.druid.sql.calcite.util.CalciteTests;
import org.junit.After;
import org.junit.Before;
import org.junit.Test;
Expand Down Expand Up @@ -109,4 +121,57 @@ public void testUnionIsUnplannable()
{

}

@Test
public void testUnionOnSubqueries()
{
testQuery(
"SELECT\n"
+ " SUM(cnt),\n"
+ " COUNT(*)\n"
+ "FROM (\n"
+ " (SELECT dim2, SUM(cnt) AS cnt FROM druid.foo GROUP BY dim2)\n"
+ " UNION ALL\n"
+ " (SELECT dim2, SUM(cnt) AS cnt FROM druid.foo GROUP BY dim2)\n"
+ ")",
ImmutableList.of(
GroupByQuery.builder()
.setDataSource(
new QueryDataSource(
GroupByQuery.builder()
.setDataSource(
new UnionDataSource(
ImmutableList.of(
new TableDataSource(CalciteTests.DATASOURCE1),
new TableDataSource(CalciteTests.DATASOURCE1)
)
)
)
.setInterval(querySegmentSpec(Filtration.eternity()))
.setGranularity(Granularities.ALL)
.setDimensions(dimensions(new DefaultDimensionSpec("dim2", "d0")))
.setAggregatorSpecs(aggregators(new LongSumAggregatorFactory("a0", "cnt")))
.setContext(QUERY_CONTEXT_DEFAULT)
.build()
)
)
.setInterval(querySegmentSpec(Filtration.eternity()))
.setGranularity(Granularities.ALL)
.setAggregatorSpecs(aggregators(
new LongSumAggregatorFactory("_a0", "a0"),
new CountAggregatorFactory("_a1")
))
.setContext(QUERY_CONTEXT_DEFAULT)
.build()
),
NullHandling.replaceWithDefault() ?
ImmutableList.of(
new Object[]{12L, 3L}
) :
ImmutableList.of(
new Object[]{12L, 4L}
)
);
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,8 @@
*/
public class UnionDataSource implements DataSource
{
@JsonProperty

@JsonProperty("dataSources")
private final List<DataSource> dataSources;

@JsonCreator
Expand Down Expand Up @@ -76,7 +77,6 @@ public Set<String> getTableNames()
}

// TODO: native only method
@JsonProperty
public List<TableDataSource> getDataSourcesAsTableDataSources()
{
return dataSources.stream()
Expand All @@ -102,9 +102,7 @@ public DataSource withChildren(List<DataSource> children)
throw new IAE("Expected [%d] children, got [%d]", dataSources.size(), children.size());
}

return new UnionDataSource(
children.stream().map(dataSource -> (TableDataSource) dataSource).collect(Collectors.toList())
);
return new UnionDataSource(children);
}

@Override
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,272 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

package org.apache.druid.sql.calcite.rel;

import com.fasterxml.jackson.core.JsonProcessingException;
import org.apache.calcite.plan.RelOptCluster;
import org.apache.calcite.plan.RelOptCost;
import org.apache.calcite.plan.RelOptPlanner;
import org.apache.calcite.plan.RelOptRule;
import org.apache.calcite.plan.RelTraitSet;
import org.apache.calcite.rel.RelNode;
import org.apache.calcite.rel.RelWriter;
import org.apache.calcite.rel.core.Union;
import org.apache.calcite.rel.metadata.RelMetadataQuery;
import org.apache.calcite.rel.type.RelDataType;
import org.apache.druid.java.util.common.StringUtils;
import org.apache.druid.query.DataSource;
import org.apache.druid.query.QueryDataSource;
import org.apache.druid.query.TableDataSource;
import org.apache.druid.query.UnionDataSource;
import org.apache.druid.segment.column.RowSignature;
import org.apache.druid.sql.calcite.planner.PlannerContext;
import org.apache.druid.sql.calcite.table.RowSignatures;

import java.util.ArrayList;
import java.util.HashSet;
import java.util.List;
import java.util.Set;
import java.util.stream.Collectors;

/**
* Represents a query on top of a {@link UnionDataSource}. This is used to represent a "UNION ALL" of regular table
* datasources.
* <p>
* See {@link DruidUnionRel} for a version that can union any set of queries together (not just regular tables),
* but also must be the outermost rel of a query plan. In the future we expect that {@link UnionDataSource} will gain
* the ability to union query datasources together, and then this class could replace {@link DruidUnionRel}.
*/
public class DruidFreeUnionDataSourceRel extends DruidRel<DruidFreeUnionDataSourceRel>
{
private static final TableDataSource DUMMY_DATA_SOURCE = new TableDataSource("__union__");

private final Union unionRel;
private final List<String> unionColumnNames;
private final PartialDruidQuery partialQuery;

private DruidFreeUnionDataSourceRel(
final RelOptCluster cluster,
final RelTraitSet traitSet,
final Union unionRel,
final List<String> unionColumnNames,
final PartialDruidQuery partialQuery,
final PlannerContext plannerContext
)
{
super(cluster, traitSet, plannerContext);
this.unionRel = unionRel;
this.unionColumnNames = unionColumnNames;
this.partialQuery = partialQuery;
}

public static DruidFreeUnionDataSourceRel create(
final Union unionRel,
final List<String> unionColumnNames,
final PlannerContext plannerContext
)
{
return new DruidFreeUnionDataSourceRel(
unionRel.getCluster(),
unionRel.getTraitSet(),
unionRel,
unionColumnNames,
PartialDruidQuery.create(unionRel),
plannerContext
);
}

public List<String> getUnionColumnNames()
{
return unionColumnNames;
}

@Override
public PartialDruidQuery getPartialDruidQuery()
{
return partialQuery;
}

@Override
public DruidFreeUnionDataSourceRel withPartialQuery(final PartialDruidQuery newQueryBuilder)
{
return new DruidFreeUnionDataSourceRel(
getCluster(),
newQueryBuilder.getTraitSet(getConvention()),
unionRel,
unionColumnNames,
newQueryBuilder,
getPlannerContext()
);
}

@Override
public DruidQuery toDruidQuery(final boolean finalizeAggregations)
{
final List<DataSource> dataSources = new ArrayList<>();
RowSignature signature = null;

for (final RelNode relNode : unionRel.getInputs()) {
final DruidRel<?> druidRel = (DruidRel<?>) relNode;

final DruidQuery query = druidRel.toDruidQuery(false);
final DataSource dataSource;
if (druidRel instanceof DruidQueryRel) {
dataSource = query.getDataSource();
} else {
dataSource = new QueryDataSource(query.getQuery());
}

if (signature == null) {
signature = query.getOutputRowSignature();
}

if (signature.getColumnNames().equals(query.getOutputRowSignature().getColumnNames())) {
dataSources.add(dataSource);
} else {
getPlannerContext().setPlanningError(
"There is a mismatch between the output row signature of input tables and the row signature of union output.");
throw new CannotBuildQueryException(druidRel);
}
}

if (signature == null) {
// No inputs.
throw new CannotBuildQueryException(unionRel);
}

// Sanity check: the columns we think we're building off must equal the "unionColumnNames" registered at
// creation time.
if (!signature.getColumnNames().equals(unionColumnNames)) {
throw new CannotBuildQueryException(unionRel);
}

return partialQuery.build(
new UnionDataSource(dataSources),
signature,
getPlannerContext(),
getCluster().getRexBuilder(),
finalizeAggregations
);
}

@Override
public DruidQuery toDruidQueryForExplaining()
{
return partialQuery.build(
DUMMY_DATA_SOURCE,
RowSignatures.fromRelDataType(
unionRel.getRowType().getFieldNames(),
unionRel.getRowType()
),
getPlannerContext(),
getCluster().getRexBuilder(),
false
);
}

@Override
public DruidFreeUnionDataSourceRel asDruidConvention()
{
return new DruidFreeUnionDataSourceRel(
getCluster(),
getTraitSet().replace(DruidConvention.instance()),
(Union) unionRel.copy(
unionRel.getTraitSet(),
unionRel.getInputs()
.stream()
.map(input -> RelOptRule.convert(input, DruidConvention.instance()))
.collect(Collectors.toList())
),
unionColumnNames,
partialQuery,
getPlannerContext()
);
}

@Override
public List<RelNode> getInputs()
{
return unionRel.getInputs();
}

@Override
public void replaceInput(int ordinalInParent, RelNode p)
{
unionRel.replaceInput(ordinalInParent, p);
}

@Override
public RelNode copy(final RelTraitSet traitSet, final List<RelNode> inputs)
{
return new DruidFreeUnionDataSourceRel(
getCluster(),
traitSet,
(Union) unionRel.copy(unionRel.getTraitSet(), inputs),
unionColumnNames,
partialQuery,
getPlannerContext()
);
}

@Override
public Set<String> getDataSourceNames()
{
final Set<String> retVal = new HashSet<>();

for (final RelNode input : unionRel.getInputs()) {
retVal.addAll(((DruidRel<?>) input).getDataSourceNames());
}

return retVal;
}

@Override
public RelWriter explainTerms(RelWriter pw)
{
final String queryString;
final DruidQuery druidQuery = toDruidQueryForExplaining();

try {
queryString = getPlannerContext().getJsonMapper().writeValueAsString(druidQuery.getQuery());
}
catch (JsonProcessingException e) {
throw new RuntimeException(e);
}

for (int i = 0; i < unionRel.getInputs().size(); i++) {
pw.input(StringUtils.format("input#%d", i), unionRel.getInputs().get(i));
}

return pw.item("query", queryString)
.item("signature", druidQuery.getOutputRowSignature());
}

@Override
protected RelDataType deriveRowType()
{
return partialQuery.getRowType();
}

@Override
public RelOptCost computeSelfCost(final RelOptPlanner planner, final RelMetadataQuery mq)
{
return planner.getCostFactory().makeZeroCost();
}
}
Loading

0 comments on commit a6ee632

Please sign in to comment.