Skip to content

Commit

Permalink
feat: add default LoadJobConfig to Client (#1526)
Browse files Browse the repository at this point in the history
  • Loading branch information
chelsea-lin committed Mar 17, 2023
1 parent aa0fa02 commit a2520ca
Show file tree
Hide file tree
Showing 5 changed files with 621 additions and 56 deletions.
121 changes: 71 additions & 50 deletions google/cloud/bigquery/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,6 +210,9 @@ class Client(ClientWithProject):
default_query_job_config (Optional[google.cloud.bigquery.job.QueryJobConfig]):
Default ``QueryJobConfig``.
Will be merged into job configs passed into the ``query`` method.
default_load_job_config (Optional[google.cloud.bigquery.job.LoadJobConfig]):
Default ``LoadJobConfig``.
Will be merged into job configs passed into the ``load_table_*`` methods.
client_info (Optional[google.api_core.client_info.ClientInfo]):
The client info used to send a user-agent string along with API
requests. If ``None``, then default info will be used. Generally,
Expand All @@ -235,6 +238,7 @@ def __init__(
_http=None,
location=None,
default_query_job_config=None,
default_load_job_config=None,
client_info=None,
client_options=None,
) -> None:
Expand All @@ -260,6 +264,7 @@ def __init__(
self._connection = Connection(self, **kw_args)
self._location = location
self._default_query_job_config = copy.deepcopy(default_query_job_config)
self._default_load_job_config = copy.deepcopy(default_load_job_config)

@property
def location(self):
Expand All @@ -277,6 +282,17 @@ def default_query_job_config(self):
def default_query_job_config(self, value: QueryJobConfig):
self._default_query_job_config = copy.deepcopy(value)

@property
def default_load_job_config(self):
"""Default ``LoadJobConfig``.
Will be merged into job configs passed into the ``load_table_*`` methods.
"""
return self._default_load_job_config

@default_load_job_config.setter
def default_load_job_config(self, value: LoadJobConfig):
self._default_load_job_config = copy.deepcopy(value)

def close(self):
"""Close the underlying transport objects, releasing system resources.
Expand Down Expand Up @@ -2330,8 +2346,8 @@ def load_table_from_uri(
Raises:
TypeError:
If ``job_config`` is not an instance of :class:`~google.cloud.bigquery.job.LoadJobConfig`
class.
If ``job_config`` is not an instance of
:class:`~google.cloud.bigquery.job.LoadJobConfig` class.
"""
job_id = _make_job_id(job_id, job_id_prefix)

Expand All @@ -2348,11 +2364,14 @@ def load_table_from_uri(

destination = _table_arg_to_table_ref(destination, default_project=self.project)

if job_config:
job_config = copy.deepcopy(job_config)
_verify_job_config_type(job_config, google.cloud.bigquery.job.LoadJobConfig)
if job_config is not None:
_verify_job_config_type(job_config, LoadJobConfig)
else:
job_config = job.LoadJobConfig()

load_job = job.LoadJob(job_ref, source_uris, destination, self, job_config)
new_job_config = job_config._fill_from_default(self._default_load_job_config)

load_job = job.LoadJob(job_ref, source_uris, destination, self, new_job_config)
load_job._begin(retry=retry, timeout=timeout)

return load_job
Expand Down Expand Up @@ -2424,8 +2443,8 @@ def load_table_from_file(
mode.
TypeError:
If ``job_config`` is not an instance of :class:`~google.cloud.bigquery.job.LoadJobConfig`
class.
If ``job_config`` is not an instance of
:class:`~google.cloud.bigquery.job.LoadJobConfig` class.
"""
job_id = _make_job_id(job_id, job_id_prefix)

Expand All @@ -2437,10 +2456,15 @@ def load_table_from_file(

destination = _table_arg_to_table_ref(destination, default_project=self.project)
job_ref = job._JobReference(job_id, project=project, location=location)
if job_config:
job_config = copy.deepcopy(job_config)
_verify_job_config_type(job_config, google.cloud.bigquery.job.LoadJobConfig)
load_job = job.LoadJob(job_ref, None, destination, self, job_config)

if job_config is not None:
_verify_job_config_type(job_config, LoadJobConfig)
else:
job_config = job.LoadJobConfig()

new_job_config = job_config._fill_from_default(self._default_load_job_config)

load_job = job.LoadJob(job_ref, None, destination, self, new_job_config)
job_resource = load_job.to_api_repr()

if rewind:
Expand Down Expand Up @@ -2564,43 +2588,40 @@ def load_table_from_dataframe(
If a usable parquet engine cannot be found. This method
requires :mod:`pyarrow` to be installed.
TypeError:
If ``job_config`` is not an instance of :class:`~google.cloud.bigquery.job.LoadJobConfig`
class.
If ``job_config`` is not an instance of
:class:`~google.cloud.bigquery.job.LoadJobConfig` class.
"""
job_id = _make_job_id(job_id, job_id_prefix)

if job_config:
_verify_job_config_type(job_config, google.cloud.bigquery.job.LoadJobConfig)
# Make a copy so that the job config isn't modified in-place.
job_config_properties = copy.deepcopy(job_config._properties)
job_config = job.LoadJobConfig()
job_config._properties = job_config_properties

if job_config is not None:
_verify_job_config_type(job_config, LoadJobConfig)
else:
job_config = job.LoadJobConfig()

new_job_config = job_config._fill_from_default(self._default_load_job_config)

supported_formats = {job.SourceFormat.CSV, job.SourceFormat.PARQUET}
if job_config.source_format is None:
if new_job_config.source_format is None:
# default value
job_config.source_format = job.SourceFormat.PARQUET
new_job_config.source_format = job.SourceFormat.PARQUET

if (
job_config.source_format == job.SourceFormat.PARQUET
and job_config.parquet_options is None
new_job_config.source_format == job.SourceFormat.PARQUET
and new_job_config.parquet_options is None
):
parquet_options = ParquetOptions()
# default value
parquet_options.enable_list_inference = True
job_config.parquet_options = parquet_options
new_job_config.parquet_options = parquet_options

if job_config.source_format not in supported_formats:
if new_job_config.source_format not in supported_formats:
raise ValueError(
"Got unexpected source_format: '{}'. Currently, only PARQUET and CSV are supported".format(
job_config.source_format
new_job_config.source_format
)
)

if pyarrow is None and job_config.source_format == job.SourceFormat.PARQUET:
if pyarrow is None and new_job_config.source_format == job.SourceFormat.PARQUET:
# pyarrow is now the only supported parquet engine.
raise ValueError("This method requires pyarrow to be installed")

Expand All @@ -2611,8 +2632,8 @@ def load_table_from_dataframe(
# schema, and check if dataframe schema is compatible with it - except
# for WRITE_TRUNCATE jobs, the existing schema does not matter then.
if (
not job_config.schema
and job_config.write_disposition != job.WriteDisposition.WRITE_TRUNCATE
not new_job_config.schema
and new_job_config.write_disposition != job.WriteDisposition.WRITE_TRUNCATE
):
try:
table = self.get_table(destination)
Expand All @@ -2623,7 +2644,7 @@ def load_table_from_dataframe(
name
for name, _ in _pandas_helpers.list_columns_and_indexes(dataframe)
)
job_config.schema = [
new_job_config.schema = [
# Field description and policy tags are not needed to
# serialize a data frame.
SchemaField(
Expand All @@ -2637,11 +2658,11 @@ def load_table_from_dataframe(
if field.name in columns_and_indexes
]

job_config.schema = _pandas_helpers.dataframe_to_bq_schema(
dataframe, job_config.schema
new_job_config.schema = _pandas_helpers.dataframe_to_bq_schema(
dataframe, new_job_config.schema
)

if not job_config.schema:
if not new_job_config.schema:
# the schema could not be fully detected
warnings.warn(
"Schema could not be detected for all columns. Loading from a "
Expand All @@ -2652,13 +2673,13 @@ def load_table_from_dataframe(
)

tmpfd, tmppath = tempfile.mkstemp(
suffix="_job_{}.{}".format(job_id[:8], job_config.source_format.lower())
suffix="_job_{}.{}".format(job_id[:8], new_job_config.source_format.lower())
)
os.close(tmpfd)

try:

if job_config.source_format == job.SourceFormat.PARQUET:
if new_job_config.source_format == job.SourceFormat.PARQUET:
if _PYARROW_VERSION in _PYARROW_BAD_VERSIONS:
msg = (
"Loading dataframe data in PARQUET format with pyarrow "
Expand All @@ -2669,13 +2690,13 @@ def load_table_from_dataframe(
)
warnings.warn(msg, category=RuntimeWarning)

if job_config.schema:
if new_job_config.schema:
if parquet_compression == "snappy": # adjust the default value
parquet_compression = parquet_compression.upper()

_pandas_helpers.dataframe_to_parquet(
dataframe,
job_config.schema,
new_job_config.schema,
tmppath,
parquet_compression=parquet_compression,
parquet_use_compliant_nested_type=True,
Expand Down Expand Up @@ -2715,7 +2736,7 @@ def load_table_from_dataframe(
job_id_prefix=job_id_prefix,
location=location,
project=project,
job_config=job_config,
job_config=new_job_config,
timeout=timeout,
)

Expand Down Expand Up @@ -2791,22 +2812,22 @@ def load_table_from_json(
Raises:
TypeError:
If ``job_config`` is not an instance of :class:`~google.cloud.bigquery.job.LoadJobConfig`
class.
If ``job_config`` is not an instance of
:class:`~google.cloud.bigquery.job.LoadJobConfig` class.
"""
job_id = _make_job_id(job_id, job_id_prefix)

if job_config:
_verify_job_config_type(job_config, google.cloud.bigquery.job.LoadJobConfig)
# Make a copy so that the job config isn't modified in-place.
job_config = copy.deepcopy(job_config)
if job_config is not None:
_verify_job_config_type(job_config, LoadJobConfig)
else:
job_config = job.LoadJobConfig()

job_config.source_format = job.SourceFormat.NEWLINE_DELIMITED_JSON
new_job_config = job_config._fill_from_default(self._default_load_job_config)

new_job_config.source_format = job.SourceFormat.NEWLINE_DELIMITED_JSON

if job_config.schema is None:
job_config.autodetect = True
if new_job_config.schema is None:
new_job_config.autodetect = True

if project is None:
project = self.project
Expand All @@ -2828,7 +2849,7 @@ def load_table_from_json(
job_id_prefix=job_id_prefix,
location=location,
project=project,
job_config=job_config,
job_config=new_job_config,
timeout=timeout,
)

Expand Down
6 changes: 5 additions & 1 deletion google/cloud/bigquery/job/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -269,7 +269,7 @@ def to_api_repr(self) -> dict:
"""
return copy.deepcopy(self._properties)

def _fill_from_default(self, default_job_config):
def _fill_from_default(self, default_job_config=None):
"""Merge this job config with a default job config.
The keys in this object take precedence over the keys in the default
Expand All @@ -283,6 +283,10 @@ def _fill_from_default(self, default_job_config):
Returns:
google.cloud.bigquery.job._JobConfig: A new (merged) job config.
"""
if not default_job_config:
new_job_config = copy.deepcopy(self)
return new_job_config

if self._job_type != default_job_config._job_type:
raise TypeError(
"attempted to merge two incompatible job types: "
Expand Down
8 changes: 4 additions & 4 deletions tests/system/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -2319,7 +2319,7 @@ def _table_exists(t):
return False


def test_dbapi_create_view(dataset_id):
def test_dbapi_create_view(dataset_id: str):

query = f"""
CREATE VIEW {dataset_id}.dbapi_create_view
Expand All @@ -2332,7 +2332,7 @@ def test_dbapi_create_view(dataset_id):
assert Config.CURSOR.rowcount == 0, "expected 0 rows"


def test_parameterized_types_round_trip(dataset_id):
def test_parameterized_types_round_trip(dataset_id: str):
client = Config.CLIENT
table_id = f"{dataset_id}.test_parameterized_types_round_trip"
fields = (
Expand All @@ -2358,7 +2358,7 @@ def test_parameterized_types_round_trip(dataset_id):
assert tuple(s._key()[:2] for s in table2.schema) == fields


def test_table_snapshots(dataset_id):
def test_table_snapshots(dataset_id: str):
from google.cloud.bigquery import CopyJobConfig
from google.cloud.bigquery import OperationType

Expand Down Expand Up @@ -2429,7 +2429,7 @@ def test_table_snapshots(dataset_id):
assert rows == [(1, "one"), (2, "two")]


def test_table_clones(dataset_id):
def test_table_clones(dataset_id: str):
from google.cloud.bigquery import CopyJobConfig
from google.cloud.bigquery import OperationType

Expand Down
29 changes: 28 additions & 1 deletion tests/unit/job/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -1104,7 +1104,7 @@ def test_ctor_with_unknown_property_raises_error(self):
config = self._make_one()
config.wrong_name = None

def test_fill_from_default(self):
def test_fill_query_job_config_from_default(self):
from google.cloud.bigquery import QueryJobConfig

job_config = QueryJobConfig()
Expand All @@ -1120,6 +1120,22 @@ def test_fill_from_default(self):
self.assertTrue(final_job_config.use_query_cache)
self.assertEqual(final_job_config.maximum_bytes_billed, 1000)

def test_fill_load_job_from_default(self):
from google.cloud.bigquery import LoadJobConfig

job_config = LoadJobConfig()
job_config.create_session = True
job_config.encoding = "UTF-8"

default_job_config = LoadJobConfig()
default_job_config.ignore_unknown_values = True
default_job_config.encoding = "ISO-8859-1"

final_job_config = job_config._fill_from_default(default_job_config)
self.assertTrue(final_job_config.create_session)
self.assertTrue(final_job_config.ignore_unknown_values)
self.assertEqual(final_job_config.encoding, "UTF-8")

def test_fill_from_default_conflict(self):
from google.cloud.bigquery import QueryJobConfig

Expand All @@ -1132,6 +1148,17 @@ def test_fill_from_default_conflict(self):
with self.assertRaises(TypeError):
basic_job_config._fill_from_default(conflicting_job_config)

def test_fill_from_empty_default_conflict(self):
from google.cloud.bigquery import QueryJobConfig

job_config = QueryJobConfig()
job_config.dry_run = True
job_config.maximum_bytes_billed = 1000

final_job_config = job_config._fill_from_default(default_job_config=None)
self.assertTrue(final_job_config.dry_run)
self.assertEqual(final_job_config.maximum_bytes_billed, 1000)

@mock.patch("google.cloud.bigquery._helpers._get_sub_prop")
def test__get_sub_prop_wo_default(self, _get_sub_prop):
job_config = self._make_one()
Expand Down

0 comments on commit a2520ca

Please sign in to comment.