Skip to content

Commit

Permalink
respect soft_fail argument when exception is raised for google sensors (
Browse files Browse the repository at this point in the history
  • Loading branch information
Lee-W committed Sep 26, 2023
1 parent f99e65b commit 20b7cfc
Show file tree
Hide file tree
Showing 25 changed files with 461 additions and 163 deletions.
23 changes: 20 additions & 3 deletions airflow/providers/google/cloud/sensors/bigquery.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
from typing import TYPE_CHECKING, Any, Sequence

from airflow.configuration import conf
from airflow.exceptions import AirflowException, AirflowProviderDeprecationWarning
from airflow.exceptions import AirflowException, AirflowProviderDeprecationWarning, AirflowSkipException
from airflow.providers.google.cloud.hooks.bigquery import BigQueryHook
from airflow.providers.google.cloud.triggers.bigquery import (
BigQueryTableExistenceTrigger,
Expand Down Expand Up @@ -141,8 +141,16 @@ def execute_complete(self, context: dict[str, Any], event: dict[str, str] | None
if event:
if event["status"] == "success":
return event["message"]
# TODO: remove this if check when min_airflow_version is set to higher than 2.7.1
if self.soft_fail:
raise AirflowSkipException(event["message"])
raise AirflowException(event["message"])
raise AirflowException("No event received in trigger callback")

# TODO: remove this if check when min_airflow_version is set to higher than 2.7.1
message = "No event received in trigger callback"
if self.soft_fail:
raise AirflowSkipException(message)
raise AirflowException(message)


class BigQueryTablePartitionExistenceSensor(BaseSensorOperator):
Expand Down Expand Up @@ -248,8 +256,17 @@ def execute_complete(self, context: dict[str, Any], event: dict[str, str] | None
if event:
if event["status"] == "success":
return event["message"]

# TODO: remove this if check when min_airflow_version is set to higher than 2.7.1
if self.soft_fail:
raise AirflowSkipException(event["message"])
raise AirflowException(event["message"])
raise AirflowException("No event received in trigger callback")

# TODO: remove this if check when min_airflow_version is set to higher than 2.7.1
message = "No event received in trigger callback"
if self.soft_fail:
raise AirflowSkipException(message)
raise AirflowException(message)


class BigQueryTableExistenceAsyncSensor(BigQueryTableExistenceSensor):
Expand Down
8 changes: 6 additions & 2 deletions airflow/providers/google/cloud/sensors/bigquery_dts.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
from google.api_core.gapic_v1.method import DEFAULT, _MethodDefault
from google.cloud.bigquery_datatransfer_v1 import TransferState

from airflow.exceptions import AirflowException
from airflow.exceptions import AirflowException, AirflowSkipException
from airflow.providers.google.cloud.hooks.bigquery_dts import BiqQueryDataTransferServiceHook
from airflow.sensors.base import BaseSensorOperator

Expand Down Expand Up @@ -140,5 +140,9 @@ def poke(self, context: Context) -> bool:
self.log.info("Status of %s run: %s", self.run_id, str(run.state))

if run.state in (TransferState.FAILED, TransferState.CANCELLED):
raise AirflowException(f"Transfer {self.run_id} did not succeed")
# TODO: remove this if check when min_airflow_version is set to higher than 2.7.1
message = f"Transfer {self.run_id} did not succeed"
if self.soft_fail:
raise AirflowSkipException(message)
raise AirflowException(message)
return run.state in self.expected_statuses
13 changes: 11 additions & 2 deletions airflow/providers/google/cloud/sensors/cloud_composer.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@

from typing import TYPE_CHECKING, Any, Sequence

from airflow.exceptions import AirflowException
from airflow.exceptions import AirflowException, AirflowSkipException
from airflow.providers.google.cloud.triggers.cloud_composer import CloudComposerExecutionTrigger
from airflow.sensors.base import BaseSensorOperator

Expand Down Expand Up @@ -90,5 +90,14 @@ def execute_complete(self, context: dict[str, Any], event: dict[str, str] | None
if event:
if event.get("operation_done"):
return event["operation_done"]

# TODO: remove this if check when min_airflow_version is set to higher than 2.7.1
if self.soft_fail:
raise AirflowSkipException(event["message"])
raise AirflowException(event["message"])
raise AirflowException("No event received in trigger callback")

# TODO: remove this if check when min_airflow_version is set to higher than 2.7.1
message = "No event received in trigger callback"
if self.soft_fail:
raise AirflowSkipException(message)
raise AirflowException(message)
32 changes: 21 additions & 11 deletions airflow/providers/google/cloud/sensors/dataflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@

from typing import TYPE_CHECKING, Callable, Sequence

from airflow.exceptions import AirflowException
from airflow.exceptions import AirflowException, AirflowSkipException
from airflow.providers.google.cloud.hooks.dataflow import (
DEFAULT_DATAFLOW_LOCATION,
DataflowHook,
Expand Down Expand Up @@ -106,7 +106,11 @@ def poke(self, context: Context) -> bool:
if job_status in self.expected_statuses:
return True
elif job_status in DataflowJobStatus.TERMINAL_STATES:
raise AirflowException(f"Job with id '{self.job_id}' is already in terminal state: {job_status}")
# TODO: remove this if check when min_airflow_version is set to higher than 2.7.1
message = f"Job with id '{self.job_id}' is already in terminal state: {job_status}"
if self.soft_fail:
raise AirflowSkipException(message)
raise AirflowException(message)

return False

Expand Down Expand Up @@ -178,9 +182,11 @@ def poke(self, context: Context) -> bool:
)
job_status = job["currentState"]
if job_status in DataflowJobStatus.TERMINAL_STATES:
raise AirflowException(
f"Job with id '{self.job_id}' is already in terminal state: {job_status}"
)
# TODO: remove this if check when min_airflow_version is set to higher than 2.7.1
message = f"Job with id '{self.job_id}' is already in terminal state: {job_status}"
if self.soft_fail:
raise AirflowSkipException(message)
raise AirflowException(message)

result = self.hook.fetch_job_metrics_by_id(
job_id=self.job_id,
Expand Down Expand Up @@ -257,9 +263,11 @@ def poke(self, context: Context) -> bool:
)
job_status = job["currentState"]
if job_status in DataflowJobStatus.TERMINAL_STATES:
raise AirflowException(
f"Job with id '{self.job_id}' is already in terminal state: {job_status}"
)
# TODO: remove this if check when min_airflow_version is set to higher than 2.7.1
message = f"Job with id '{self.job_id}' is already in terminal state: {job_status}"
if self.soft_fail:
raise AirflowSkipException(message)
raise AirflowException(message)

result = self.hook.fetch_job_messages_by_id(
job_id=self.job_id,
Expand Down Expand Up @@ -336,9 +344,11 @@ def poke(self, context: Context) -> bool:
)
job_status = job["currentState"]
if job_status in DataflowJobStatus.TERMINAL_STATES:
raise AirflowException(
f"Job with id '{self.job_id}' is already in terminal state: {job_status}"
)
# TODO: remove this if check when min_airflow_version is set to higher than 2.7.1
message = f"Job with id '{self.job_id}' is already in terminal state: {job_status}"
if self.soft_fail:
raise AirflowSkipException(message)
raise AirflowException(message)

result = self.hook.fetch_job_autoscaling_events_by_id(
job_id=self.job_id,
Expand Down
8 changes: 6 additions & 2 deletions airflow/providers/google/cloud/sensors/dataform.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@

from typing import TYPE_CHECKING, Iterable, Sequence

from airflow.exceptions import AirflowException
from airflow.exceptions import AirflowException, AirflowSkipException
from airflow.providers.google.cloud.hooks.dataform import DataformHook
from airflow.sensors.base import BaseSensorOperator

Expand Down Expand Up @@ -95,9 +95,13 @@ def poke(self, context: Context) -> bool:
workflow_status = workflow_invocation.state
if workflow_status is not None:
if self.failure_statuses and workflow_status in self.failure_statuses:
raise AirflowException(
# TODO: remove this if check when min_airflow_version is set to higher than 2.7.1
message = (
f"Workflow Invocation with id '{self.workflow_invocation_id}' "
f"state is: {workflow_status}. Terminating sensor..."
)
if self.soft_fail:
raise AirflowSkipException(message)
raise AirflowException(message)

return workflow_status in self.expected_statuses
14 changes: 11 additions & 3 deletions airflow/providers/google/cloud/sensors/datafusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@

from typing import TYPE_CHECKING, Iterable, Sequence

from airflow.exceptions import AirflowException, AirflowNotFoundException
from airflow.exceptions import AirflowException, AirflowNotFoundException, AirflowSkipException
from airflow.providers.google.cloud.hooks.datafusion import DataFusionHook
from airflow.sensors.base import BaseSensorOperator

Expand Down Expand Up @@ -109,15 +109,23 @@ def poke(self, context: Context) -> bool:
)
pipeline_status = pipeline_workflow["status"]
except AirflowNotFoundException:
raise AirflowException("Specified Pipeline ID was not found.")
# TODO: remove this if check when min_airflow_version is set to higher than 2.7.1
message = "Specified Pipeline ID was not found."
if self.soft_fail:
raise AirflowSkipException(message)
raise AirflowException(message)
except AirflowException:
pass # Because the pipeline may not be visible in system yet
if pipeline_status is not None:
if self.failure_statuses and pipeline_status in self.failure_statuses:
raise AirflowException(
# TODO: remove this if check when min_airflow_version is set to higher than 2.7.1
message = (
f"Pipeline with id '{self.pipeline_id}' state is: {pipeline_status}. "
f"Terminating sensor..."
)
if self.soft_fail:
raise AirflowSkipException(message)
raise AirflowException(message)

self.log.debug(
"Current status of the pipeline workflow for %s: %s.", self.pipeline_id, pipeline_status
Expand Down
41 changes: 32 additions & 9 deletions airflow/providers/google/cloud/sensors/dataplex.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,11 +24,12 @@
from google.api_core.retry import Retry

from airflow.utils.context import Context

from google.api_core.exceptions import GoogleAPICallError
from google.api_core.gapic_v1.method import DEFAULT, _MethodDefault
from google.cloud.dataplex_v1.types import DataScanJob

from airflow.exceptions import AirflowException
from airflow.exceptions import AirflowException, AirflowSkipException
from airflow.providers.google.cloud.hooks.dataplex import (
AirflowDataQualityScanException,
AirflowDataQualityScanResultTimeoutException,
Expand Down Expand Up @@ -116,7 +117,11 @@ def poke(self, context: Context) -> bool:
task_status = task.state

if task_status == TaskState.DELETING:
raise AirflowException(f"Task is going to be deleted {self.dataplex_task_id}")
# TODO: remove this if check when min_airflow_version is set to higher than 2.7.1
message = f"Task is going to be deleted {self.dataplex_task_id}"
if self.soft_fail:
raise AirflowSkipException(message)
raise AirflowException(message)

self.log.info("Current status of the Dataplex task %s => %s", self.dataplex_task_id, task_status)

Expand Down Expand Up @@ -196,9 +201,13 @@ def poke(self, context: Context) -> bool:
if self.result_timeout:
duration = self._duration()
if duration > self.result_timeout:
raise AirflowDataQualityScanResultTimeoutException(
# TODO: remove this if check when min_airflow_version is set to higher than 2.7.1
message = (
f"Timeout: Data Quality scan {self.job_id} is not ready after {self.result_timeout}s"
)
if self.soft_fail:
raise AirflowSkipException(message)
raise AirflowDataQualityScanResultTimeoutException(message)

hook = DataplexHook(
gcp_conn_id=self.gcp_conn_id,
Expand All @@ -217,22 +226,36 @@ def poke(self, context: Context) -> bool:
metadata=self.metadata,
)
except GoogleAPICallError as e:
raise AirflowException(
f"Error occurred when trying to retrieve Data Quality scan job: {self.data_scan_id}", e
)
# TODO: remove this if check when min_airflow_version is set to higher than 2.7.1
message = f"Error occurred when trying to retrieve Data Quality scan job: {self.data_scan_id}"
if self.soft_fail:
raise AirflowSkipException(message, e)
raise AirflowException(message, e)

job_status = job.state
self.log.info(
"Current status of the Dataplex Data Quality scan job %s => %s", self.job_id, job_status
)
if job_status == DataScanJob.State.FAILED:
raise AirflowException(f"Data Quality scan job failed: {self.job_id}")
# TODO: remove this if check when min_airflow_version is set to higher than 2.7.1
message = f"Data Quality scan job failed: {self.job_id}"
if self.soft_fail:
raise AirflowSkipException(message)
raise AirflowException(message)
if job_status == DataScanJob.State.CANCELLED:
raise AirflowException(f"Data Quality scan job cancelled: {self.job_id}")
# TODO: remove this if check when min_airflow_version is set to higher than 2.7.1
message = f"Data Quality scan job cancelled: {self.job_id}"
if self.soft_fail:
raise AirflowSkipException(message)
raise AirflowException(message)
if self.fail_on_dq_failure:
if job_status == DataScanJob.State.SUCCEEDED and not job.data_quality_result.passed:
raise AirflowDataQualityScanException(
# TODO: remove this if check when min_airflow_version is set to higher than 2.7.1
message = (
f"Data Quality job {self.job_id} execution failed due to failure of its scanning "
f"rules: {self.data_scan_id}"
)
if self.soft_fail:
raise AirflowSkipException(message)
raise AirflowDataQualityScanException(message)
return job_status == DataScanJob.State.SUCCEEDED
32 changes: 26 additions & 6 deletions airflow/providers/google/cloud/sensors/dataproc.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
from google.api_core.exceptions import ServerError
from google.cloud.dataproc_v1.types import Batch, JobStatus

from airflow.exceptions import AirflowException
from airflow.exceptions import AirflowException, AirflowSkipException
from airflow.providers.google.cloud.hooks.dataproc import DataprocHook
from airflow.sensors.base import BaseSensorOperator

Expand Down Expand Up @@ -83,24 +83,36 @@ def poke(self, context: Context) -> bool:
duration = self._duration()
self.log.info("DURATION RUN: %f", duration)
if duration > self.wait_timeout:
raise AirflowException(
# TODO: remove this if check when min_airflow_version is set to higher than 2.7.1
message = (
f"Timeout: dataproc job {self.dataproc_job_id} "
f"is not ready after {self.wait_timeout}s"
)
if self.soft_fail:
raise AirflowSkipException(message)
raise AirflowException(message)
self.log.info("Retrying. Dataproc API returned server error when waiting for job: %s", err)
return False
else:
job = hook.get_job(job_id=self.dataproc_job_id, region=self.region, project_id=self.project_id)

state = job.status.state
if state == JobStatus.State.ERROR:
raise AirflowException(f"Job failed:\n{job}")
# TODO: remove this if check when min_airflow_version is set to higher than 2.7.1
message = f"Job failed:\n{job}"
if self.soft_fail:
raise AirflowSkipException(message)
raise AirflowException(message)
elif state in {
JobStatus.State.CANCELLED,
JobStatus.State.CANCEL_PENDING,
JobStatus.State.CANCEL_STARTED,
}:
raise AirflowException(f"Job was cancelled:\n{job}")
# TODO: remove this if check when min_airflow_version is set to higher than 2.7.1
message = f"Job was cancelled:\n{job}"
if self.soft_fail:
raise AirflowSkipException(message)
raise AirflowException(message)
elif JobStatus.State.DONE == state:
self.log.debug("Job %s completed successfully.", self.dataproc_job_id)
return True
Expand Down Expand Up @@ -171,12 +183,20 @@ def poke(self, context: Context) -> bool:

state = batch.state
if state == Batch.State.FAILED:
raise AirflowException("Batch failed")
# TODO: remove this if check when min_airflow_version is set to higher than 2.7.1
message = "Batch failed"
if self.soft_fail:
raise AirflowSkipException(message)
raise AirflowException(message)
elif state in {
Batch.State.CANCELLED,
Batch.State.CANCELLING,
}:
raise AirflowException("Batch was cancelled.")
# TODO: remove this if check when min_airflow_version is set to higher than 2.7.1
message = "Batch was cancelled."
if self.soft_fail:
raise AirflowSkipException(message)
raise AirflowException(message)
elif state == Batch.State.SUCCEEDED:
self.log.debug("Batch %s completed successfully.", self.batch_id)
return True
Expand Down

0 comments on commit 20b7cfc

Please sign in to comment.