Skip to content

Commit

Permalink
Fix incorrect result when aggregating count BigQuery view
Browse files Browse the repository at this point in the history
  • Loading branch information
ebyhr committed Dec 21, 2020
1 parent 320a9bd commit 3b1c23e
Show file tree
Hide file tree
Showing 2 changed files with 44 additions and 2 deletions.
Expand Up @@ -15,6 +15,7 @@

import com.google.cloud.bigquery.BigQueryException;
import com.google.cloud.bigquery.TableId;
import com.google.cloud.bigquery.TableInfo;
import com.google.cloud.bigquery.TableResult;
import com.google.cloud.bigquery.storage.v1beta1.Storage.ReadSession;
import com.google.common.collect.ImmutableList;
Expand All @@ -37,8 +38,11 @@
import java.util.Optional;
import java.util.OptionalInt;

import static com.google.cloud.bigquery.TableDefinition.Type.TABLE;
import static com.google.cloud.bigquery.TableDefinition.Type.VIEW;
import static com.google.common.collect.ImmutableList.toImmutableList;
import static io.prestosql.plugin.bigquery.BigQueryErrorCode.BIGQUERY_FAILED_TO_EXECUTE_QUERY;
import static io.prestosql.spi.StandardErrorCode.NOT_SUPPORTED;
import static java.util.Objects.requireNonNull;
import static java.util.stream.Collectors.toList;
import static java.util.stream.IntStream.range;
Expand Down Expand Up @@ -124,8 +128,19 @@ private List<BigQuerySplit> createEmptyProjection(TableId tableId, int actualPar
numberOfRows = result.iterateAll().iterator().next().get(0).getLongValue();
}
else {
// no filters, so we can take the value from the table info
numberOfRows = bigQueryClient.getTable(tableId).getNumRows().longValue();
// no filters, so we can take the value from the table info when the object is TABLE
TableInfo tableInfo = bigQueryClient.getTable(tableId);
if (tableInfo.getDefinition().getType() == TABLE) {
numberOfRows = tableInfo.getNumRows().longValue();
}
else if (tableInfo.getDefinition().getType() == VIEW) {
String sql = bigQueryClient.selectSql(tableId, "COUNT(*)");
TableResult result = bigQueryClient.query(sql);
numberOfRows = result.iterateAll().iterator().next().get(0).getLongValue();
}
else {
throw new PrestoException(NOT_SUPPORTED, "Unsupported table type: " + tableInfo.getDefinition().getType());
}
}

long rowsPerSplit = numberOfRows / actualParallelism;
Expand Down
Expand Up @@ -105,6 +105,33 @@ public void testPredicatePushdownPrunnedColumns()
"VALUES (1)");
}

@Test(description = "regression test for https://github.com/prestosql/presto/issues/5635")
public void testCountAggregationView()
{
BigQuery client = createBigQueryClient();

String tableName = "test.count_aggregation_table";
String viewName = "test.count_aggregation_view";

executeBigQuerySql(client, "DROP TABLE IF EXISTS " + tableName);
executeBigQuerySql(client, "DROP VIEW IF EXISTS " + viewName);
executeBigQuerySql(client, "CREATE TABLE " + tableName + " (a INT64, b INT64, c INT64)");
executeBigQuerySql(client, "INSERT INTO " + tableName + " VALUES (1, 2, 3), (4, 5, 6)");
executeBigQuerySql(client, "CREATE VIEW " + viewName + " AS SELECT * FROM " + tableName);

assertQuery(
"SELECT count(*) FROM " + viewName,
"VALUES (2)");

assertQuery(
"SELECT count(*) FROM " + viewName + " WHERE a = 1",
"VALUES (1)");

assertQuery(
"SELECT count(a) FROM " + viewName + " WHERE b = 2",
"VALUES (1)");
}

private static void executeBigQuerySql(BigQuery bigquery, String query)
{
QueryJobConfiguration queryConfig = QueryJobConfiguration.newBuilder(query)
Expand Down

0 comments on commit 3b1c23e

Please sign in to comment.