Skip to content

Commit

Permalink
Support query timeout as an argument in CassandraToGCSOperator (#18927)
Browse files Browse the repository at this point in the history
Support query timeout as an argument in CassandraToGCSOperator (#18927)
  • Loading branch information
xuan616 committed Oct 28, 2021
1 parent fd569e7 commit 55abc2f
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 2 deletions.
18 changes: 16 additions & 2 deletions airflow/providers/google/cloud/transfers/cassandra_to_gcs.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
from datetime import datetime
from decimal import Decimal
from tempfile import NamedTemporaryFile
from typing import Any, Dict, Iterable, List, Optional, Sequence, Tuple, Union
from typing import Any, Dict, Iterable, List, NewType, Optional, Sequence, Tuple, Union
from uuid import UUID

from cassandra.util import Date, OrderedMapSerializedKey, SortedSet, Time
Expand All @@ -36,6 +36,9 @@
from airflow.providers.apache.cassandra.hooks.cassandra import CassandraHook
from airflow.providers.google.cloud.hooks.gcs import GCSHook

NotSetType = NewType('NotSetType', object)
NOT_SET = NotSetType(object())


class CassandraToGCSOperator(BaseOperator):
"""
Expand Down Expand Up @@ -84,6 +87,10 @@ class CassandraToGCSOperator(BaseOperator):
Service Account Token Creator IAM role to the directly preceding identity, with first
account from the list granting this role to the originating account (templated).
:type impersonation_chain: Union[str, Sequence[str]]
:param query_timeout: (Optional) The amount of time, in seconds, used to execute the Cassandra query.
If not set, the timeout value will be set in Session.execute() by Cassandra driver.
If set to None, there is no timeout.
:type query_timeout: float | None
"""

template_fields = (
Expand All @@ -110,6 +117,7 @@ def __init__(
google_cloud_storage_conn_id: Optional[str] = None,
delegate_to: Optional[str] = None,
impersonation_chain: Optional[Union[str, Sequence[str]]] = None,
query_timeout: Union[float, None, NotSetType] = NOT_SET,
**kwargs,
) -> None:
super().__init__(**kwargs)
Expand All @@ -133,6 +141,7 @@ def __init__(
self.delegate_to = delegate_to
self.gzip = gzip
self.impersonation_chain = impersonation_chain
self.query_timeout = query_timeout

# Default Cassandra to BigQuery type mapping
CQL_TYPE_MAP = {
Expand Down Expand Up @@ -162,7 +171,12 @@ def __init__(

def execute(self, context: Dict[str, str]):
hook = CassandraHook(cassandra_conn_id=self.cassandra_conn_id)
cursor = hook.get_conn().execute(self.cql)

query_extra = {}
if self.query_timeout is not NOT_SET:
query_extra['timeout'] = self.query_timeout

cursor = hook.get_conn().execute(self.cql, **query_extra)

files_to_upload = self._write_local_data_files(cursor)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ def test_execute(self, mock_hook, mock_upload, mock_tempfile):
schema = "schema.json"
filename = "data.json"
gzip = True
query_timeout = 20
mock_tempfile.return_value.name = TMP_FILE_NAME

operator = CassandraToGCSOperator(
Expand All @@ -43,9 +44,14 @@ def test_execute(self, mock_hook, mock_upload, mock_tempfile):
filename=filename,
schema_filename=schema,
gzip=gzip,
query_timeout=query_timeout,
)
operator.execute(None)
mock_hook.return_value.get_conn.assert_called_once_with()
mock_hook.return_value.get_conn.return_value.execute.assert_called_once_with(
"select * from keyspace1.table1",
timeout=20,
)

call_schema = call(
bucket_name=test_bucket,
Expand Down

0 comments on commit 55abc2f

Please sign in to comment.