Skip to content

Commit

Permalink
Implement case insensitive name matching for BigQuery
Browse files Browse the repository at this point in the history
  • Loading branch information
hashhar authored and losipiuk committed Feb 15, 2021
1 parent 281e7b1 commit cda6386
Show file tree
Hide file tree
Showing 13 changed files with 693 additions and 170 deletions.
7 changes: 7 additions & 0 deletions .github/workflows/ci.yml
Expand Up @@ -295,6 +295,13 @@ jobs:
if [ "${BIGQUERY_CREDENTIALS_KEY}" != "" ]; then
./mvnw test ${MAVEN_TEST} -pl :trino-bigquery -Pcloud-tests -Dbigquery.credentials-key="${BIGQUERY_CREDENTIALS_KEY}"
fi
- name: Cloud BigQuery Case Insensitive Mapping Tests
env:
BIGQUERY_CREDENTIALS_KEY: ${{ secrets.BIGQUERY_CREDENTIALS_KEY }}
run: |
if [ "${BIGQUERY_CREDENTIALS_KEY}" != "" ]; then
./mvnw test ${MAVEN_TEST} -pl :trino-bigquery -Pcloud-tests-case-insensitive-mapping -Dbigquery.credentials-key="${BIGQUERY_CREDENTIALS_KEY}"
fi
pt:
runs-on: ubuntu-latest
Expand Down
27 changes: 27 additions & 0 deletions plugin/trino-bigquery/pom.xml
Expand Up @@ -79,6 +79,11 @@
<artifactId>log-manager</artifactId>
</dependency>

<dependency>
<groupId>io.airlift</groupId>
<artifactId>units</artifactId>
</dependency>

<dependency>
<groupId>com.google.api</groupId>
<artifactId>gax</artifactId>
Expand Down Expand Up @@ -287,6 +292,7 @@
<excludes>
<!-- If you are adding entry here also add an entry to cloud-tests profile below -->
<exclude>**/TestBigQueryIntegrationSmokeTest.java</exclude>
<exclude>**/TestBigQueryCaseInsensitiveMapping.java</exclude>
</excludes>
</configuration>
</plugin>
Expand All @@ -313,5 +319,26 @@
</plugins>
</build>
</profile>

<!-- Separate profile for TestBigQueryCaseInsensitiveMapping until we can fully isolate it -->
<profile>
<id>cloud-tests-case-insensitive-mapping</id>
<activation>
<activeByDefault>false</activeByDefault>
</activation>
<build>
<plugins>
<plugin>
<groupId>org.apache.maven.plugins</groupId>
<artifactId>maven-surefire-plugin</artifactId>
<configuration>
<includes>
<include>**/TestBigQueryCaseInsensitiveMapping.java</include>
</includes>
</configuration>
</plugin>
</plugins>
</build>
</profile>
</profiles>
</project>
Expand Up @@ -17,64 +17,152 @@
import com.google.cloud.bigquery.BigQueryException;
import com.google.cloud.bigquery.Dataset;
import com.google.cloud.bigquery.DatasetId;
import com.google.cloud.bigquery.DatasetInfo;
import com.google.cloud.bigquery.Job;
import com.google.cloud.bigquery.JobInfo;
import com.google.cloud.bigquery.QueryJobConfiguration;
import com.google.cloud.bigquery.Schema;
import com.google.cloud.bigquery.Table;
import com.google.cloud.bigquery.TableDefinition;
import com.google.cloud.bigquery.TableId;
import com.google.cloud.bigquery.TableInfo;
import com.google.cloud.bigquery.TableResult;
import com.google.cloud.http.BaseHttpServiceException;
import com.google.common.cache.Cache;
import com.google.common.cache.CacheBuilder;
import com.google.common.collect.ImmutableSet;
import com.google.common.collect.Iterators;
import io.airlift.units.Duration;
import io.trino.spi.TrinoException;
import io.trino.spi.connector.TableNotFoundException;
import org.checkerframework.checker.nullness.qual.Nullable;

import java.util.Iterator;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.Set;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ConcurrentMap;
import java.util.stream.StreamSupport;

import static com.google.cloud.bigquery.TableDefinition.Type.TABLE;
import static com.google.cloud.bigquery.TableDefinition.Type.VIEW;
import static com.google.common.base.Verify.verify;
import static com.google.common.collect.ImmutableList.toImmutableList;
import static com.google.common.collect.Iterables.getOnlyElement;
import static com.google.common.collect.Streams.stream;
import static io.trino.plugin.bigquery.BigQueryErrorCode.BIGQUERY_AMBIGUOUS_OBJECT_NAME;
import static java.lang.String.format;
import static java.util.Locale.ENGLISH;
import static java.util.Objects.requireNonNull;
import static java.util.UUID.randomUUID;
import static java.util.concurrent.TimeUnit.MILLISECONDS;
import static java.util.stream.Collectors.joining;

// holds caches and mappings
// Trino converts the dataset and table names to lower case, while BigQuery is case sensitive
// the mappings here keep the mappings
class BigQueryClient
{
private final BigQuery bigQuery;
private final Optional<String> viewMaterializationProject;
private final Optional<String> viewMaterializationDataset;
private final ConcurrentMap<TableId, TableId> tableIds = new ConcurrentHashMap<>();
private final ConcurrentMap<DatasetId, DatasetId> datasetIds = new ConcurrentHashMap<>();
private final boolean caseInsensitiveNameMatching;
private final Cache<String, Optional<RemoteDatabaseObject>> remoteDatasets;
private final Cache<TableId, Optional<RemoteDatabaseObject>> remoteTables;

BigQueryClient(BigQuery bigQuery, BigQueryConfig config)
{
this.bigQuery = bigQuery;
this.viewMaterializationProject = config.getViewMaterializationProject();
this.viewMaterializationDataset = config.getViewMaterializationDataset();
Duration caseInsensitiveNameMatchingCacheTtl = requireNonNull(config.getCaseInsensitiveNameMatchingCacheTtl(), "caseInsensitiveNameMatchingCacheTtl is null");

this.caseInsensitiveNameMatching = config.isCaseInsensitiveNameMatching();
CacheBuilder<Object, Object> remoteNamesCacheBuilder = CacheBuilder.newBuilder()
.expireAfterWrite(caseInsensitiveNameMatchingCacheTtl.toMillis(), MILLISECONDS);
this.remoteDatasets = remoteNamesCacheBuilder.build();
this.remoteTables = remoteNamesCacheBuilder.build();
}

Optional<RemoteDatabaseObject> toRemoteDataset(String projectId, String datasetName)
{
requireNonNull(projectId, "projectId is null");
requireNonNull(datasetName, "datasetName is null");
verify(datasetName.codePoints().noneMatch(Character::isUpperCase), "Expected schema name from internal metadata to be lowercase: %s", datasetName);
if (!caseInsensitiveNameMatching) {
return Optional.of(RemoteDatabaseObject.of(datasetName));
}

@Nullable Optional<RemoteDatabaseObject> remoteDataset = remoteDatasets.getIfPresent(datasetName);
if (remoteDataset != null) {
return remoteDataset;
}

// cache miss, reload the cache
Map<String, Optional<RemoteDatabaseObject>> mapping = new HashMap<>();
for (Dataset dataset : listDatasets(projectId)) {
mapping.merge(
dataset.getDatasetId().getDataset().toLowerCase(ENGLISH),
Optional.of(RemoteDatabaseObject.of(dataset.getDatasetId().getDataset())),
(currentValue, collision) -> currentValue.map(current -> current.registerCollision(collision.get().getOnlyRemoteName())));
}

// explicitly cache the information if the requested dataset doesn't exist
if (!mapping.containsKey(datasetName)) {
mapping.put(datasetName, Optional.empty());
}

verify(mapping.containsKey(datasetName));
return mapping.get(datasetName);
}

TableInfo getTable(TableId tableId)
Optional<RemoteDatabaseObject> toRemoteTable(String projectId, String remoteDatasetName, String tableName)
{
TableId bigQueryTableId = tableIds.get(tableId);
Table table = bigQuery.getTable(bigQueryTableId != null ? bigQueryTableId : tableId);
if (table != null) {
tableIds.putIfAbsent(tableId, table.getTableId());
datasetIds.putIfAbsent(toDatasetId(tableId), toDatasetId(table.getTableId()));
requireNonNull(projectId, "projectId is null");
requireNonNull(remoteDatasetName, "remoteDatasetName is null");
requireNonNull(tableName, "tableName is null");
verify(tableName.codePoints().noneMatch(Character::isUpperCase), "Expected table name from internal metadata to be lowercase: %s", tableName);
if (!caseInsensitiveNameMatching) {
return Optional.of(RemoteDatabaseObject.of(tableName));
}

TableId cacheKey = TableId.of(projectId, remoteDatasetName, tableName);
@Nullable Optional<RemoteDatabaseObject> remoteTable = remoteTables.getIfPresent(cacheKey);
if (remoteTable != null) {
return remoteTable;
}
return table;

// cache miss, reload the cache
Map<TableId, Optional<RemoteDatabaseObject>> mapping = new HashMap<>();
for (Table table : listTables(DatasetId.of(projectId, remoteDatasetName), TABLE, VIEW)) {
mapping.merge(
tableIdToLowerCase(table.getTableId()),
Optional.of(RemoteDatabaseObject.of(table.getTableId().getTable())),
(currentValue, collision) -> currentValue.map(current -> current.registerCollision(collision.get().getOnlyRemoteName())));
}

// explicitly cache the information if the requested table doesn't exist
if (!mapping.containsKey(cacheKey)) {
mapping.put(cacheKey, Optional.empty());
}

verify(mapping.containsKey(cacheKey));
return mapping.get(cacheKey);
}

private static TableId tableIdToLowerCase(TableId tableId)
{
return TableId.of(
tableId.getProject(),
tableId.getDataset(),
tableId.getTable().toLowerCase(ENGLISH));
}

DatasetInfo getDataset(DatasetId datasetId)
{
return bigQuery.getDataset(datasetId);
}

DatasetId toDatasetId(TableId tableId)
TableInfo getTable(TableId remoteTableId)
{
return DatasetId.of(tableId.getProject(), tableId.getDataset());
// TODO: Return Optional and make callers handle missing value
return bigQuery.getTable(remoteTableId);
}

String getProjectId()
Expand All @@ -84,43 +172,32 @@ String getProjectId()

Iterable<Dataset> listDatasets(String projectId)
{
Iterator<Dataset> datasets = bigQuery.listDatasets(projectId).iterateAll().iterator();
return () -> Iterators.transform(datasets, this::addDataSetMappingIfNeeded);
return bigQuery.listDatasets(projectId).iterateAll();
}

Iterable<Table> listTables(DatasetId datasetId, TableDefinition.Type... types)
Iterable<Table> listTables(DatasetId remoteDatasetId, TableDefinition.Type... types)
{
Set<TableDefinition.Type> allowedTypes = ImmutableSet.copyOf(types);
DatasetId bigQueryDatasetId = datasetIds.getOrDefault(datasetId, datasetId);
Iterable<Table> allTables = bigQuery.listTables(bigQueryDatasetId).iterateAll();
return StreamSupport.stream(allTables.spliterator(), false)
Iterable<Table> allTables = bigQuery.listTables(remoteDatasetId).iterateAll();
return stream(allTables)
.filter(table -> allowedTypes.contains(table.getDefinition().getType()))
.collect(toImmutableList());
}

private Dataset addDataSetMappingIfNeeded(Dataset dataset)
{
DatasetId bigQueryDatasetId = dataset.getDatasetId();
DatasetId trinoDatasetId = DatasetId.of(bigQueryDatasetId.getProject(), bigQueryDatasetId.getDataset().toLowerCase(ENGLISH));
datasetIds.putIfAbsent(trinoDatasetId, bigQueryDatasetId);
return dataset;
}

TableId createDestinationTable(TableId tableId)
{
String project = viewMaterializationProject.orElse(tableId.getProject());
String dataset = viewMaterializationDataset.orElse(tableId.getDataset());
DatasetId datasetId = mapIfNeeded(project, dataset);

String remoteDatasetName = toRemoteDataset(project, dataset)
.map(RemoteDatabaseObject::getOnlyRemoteName)
.orElse(dataset);

DatasetId datasetId = DatasetId.of(project, remoteDatasetName);
String name = format("_pbc_%s", randomUUID().toString().toLowerCase(ENGLISH).replace("-", ""));
return TableId.of(datasetId.getProject(), datasetId.getDataset(), name);
}

private DatasetId mapIfNeeded(String project, String dataset)
{
DatasetId datasetId = DatasetId.of(project, dataset);
return datasetIds.getOrDefault(datasetId, datasetId);
}

Table update(TableInfo table)
{
return bigQuery.update(table);
Expand Down Expand Up @@ -159,7 +236,75 @@ String selectSql(TableId table, String formattedColumns)

private String fullTableName(TableId tableId)
{
tableId = tableIds.getOrDefault(tableId, tableId);
String remoteSchemaName = toRemoteDataset(tableId.getProject(), tableId.getDataset())
.map(RemoteDatabaseObject::getOnlyRemoteName)
.orElse(tableId.getDataset());
String remoteTableName = toRemoteTable(tableId.getProject(), remoteSchemaName, tableId.getTable())
.map(RemoteDatabaseObject::getOnlyRemoteName)
.orElse(tableId.getTable());
tableId = TableId.of(tableId.getProject(), remoteSchemaName, remoteTableName);
return format("%s.%s.%s", tableId.getProject(), tableId.getDataset(), tableId.getTable());
}

List<BigQueryColumnHandle> getColumns(BigQueryTableHandle tableHandle)
{
TableInfo tableInfo = getTable(tableHandle.getRemoteTableName().toTableId());
if (tableInfo == null) {
throw new TableNotFoundException(
tableHandle.getSchemaTableName(),
format("Table '%s' not found", tableHandle.getSchemaTableName()));
}
@Nullable Schema schema = tableInfo.getDefinition().getSchema();
if (schema == null) {
throw new TableNotFoundException(
tableHandle.getSchemaTableName(),
format("Table '%s' has no schema", tableHandle.getSchemaTableName()));
}
return schema.getFields()
.stream()
.map(Conversions::toColumnHandle)
.collect(toImmutableList());
}

static final class RemoteDatabaseObject
{
private final Set<String> remoteNames;

private RemoteDatabaseObject(Set<String> remoteNames)
{
this.remoteNames = ImmutableSet.copyOf(remoteNames);
}

public static RemoteDatabaseObject of(String remoteName)
{
return new RemoteDatabaseObject(ImmutableSet.of(remoteName));
}

public RemoteDatabaseObject registerCollision(String ambiguousName)
{
return new RemoteDatabaseObject(ImmutableSet.<String>builderWithExpectedSize(remoteNames.size() + 1)
.addAll(remoteNames)
.add(ambiguousName)
.build());
}

public String getAnyRemoteName()
{
return Collections.min(remoteNames);
}

public String getOnlyRemoteName()
{
if (!isAmbiguous()) {
return getOnlyElement(remoteNames);
}

throw new TrinoException(BIGQUERY_AMBIGUOUS_OBJECT_NAME, "Found ambiguous names in BigQuery when looking up '" + getAnyRemoteName().toLowerCase(ENGLISH) + "': " + remoteNames);
}

public boolean isAmbiguous()
{
return remoteNames.size() > 1;
}
}
}

0 comments on commit cda6386

Please sign in to comment.