Skip to content

Commit

Permalink
Refactor regex in providers (#33898)
Browse files Browse the repository at this point in the history
* Refactor regex in providers

* Satisfy Mypy's optional check

---------

Co-authored-by: Tzu-ping Chung <[email protected]>
  • Loading branch information
eumiro and uranusjr committed Sep 5, 2023
1 parent bb5e186 commit a7310f9
Show file tree
Hide file tree
Showing 17 changed files with 45 additions and 44 deletions.
2 changes: 1 addition & 1 deletion airflow/providers/amazon/aws/hooks/chime.py
Expand Up @@ -70,7 +70,7 @@ def _get_webhook_endpoint(self, conn_id: str) -> str:
url = conn.schema + "://" + conn.host
endpoint = url + token
# Check to make sure the endpoint matches what Chime expects
if not re.match(r"^[a-zA-Z0-9_-]+\?token=[a-zA-Z0-9_-]+$", token):
if not re.fullmatch(r"[a-zA-Z0-9_-]+\?token=[a-zA-Z0-9_-]+", token):
raise AirflowException(
"Expected Chime webhook token in the form of '{webhook.id}?token={webhook.token}'."
)
Expand Down
6 changes: 3 additions & 3 deletions airflow/providers/amazon/aws/hooks/s3.py
Expand Up @@ -468,7 +468,7 @@ async def get_file_metadata_async(self, client: AioBaseClient, bucket_name: str,
:param bucket_name: the name of the bucket
:param key: the path to the key
"""
prefix = re.split(r"[\[\*\?]", key, 1)[0]
prefix = re.split(r"[\[*?]", key, 1)[0]
delimiter = ""
paginator = client.get_paginator("list_objects_v2")
response = paginator.paginate(Bucket=bucket_name, Prefix=prefix, Delimiter=delimiter)
Expand Down Expand Up @@ -570,7 +570,7 @@ async def get_files_async(
for key in bucket_keys:
prefix = key
if wildcard_match:
prefix = re.split(r"[\[\*\?]", key, 1)[0]
prefix = re.split(r"[\[*?]", key, 1)[0]

paginator = client.get_paginator("list_objects_v2")
response = paginator.paginate(Bucket=bucket, Prefix=prefix, Delimiter=delimiter)
Expand Down Expand Up @@ -1015,7 +1015,7 @@ def get_wildcard_key(
:param delimiter: the delimiter marks key hierarchy
:return: the key object from the bucket or None if none has been found.
"""
prefix = re.split(r"[\[\*\?]", wildcard_key, 1)[0]
prefix = re.split(r"[\[*?]", wildcard_key, 1)[0]
key_list = self.list_keys(bucket_name, prefix=prefix, delimiter=delimiter)
key_matches = [k for k in key_list if fnmatch.fnmatch(k, wildcard_key)]
if key_matches:
Expand Down
3 changes: 1 addition & 2 deletions airflow/providers/amazon/aws/hooks/sagemaker.py
Expand Up @@ -979,8 +979,7 @@ def _name_matches_pattern(
found_name: str,
job_name_suffix: str | None = None,
) -> bool:
pattern = re.compile(f"^{processing_job_name}({job_name_suffix})?$")
return pattern.fullmatch(found_name) is not None
return re.fullmatch(f"{processing_job_name}({job_name_suffix})?", found_name) is not None

def count_processing_jobs_by_name(
self,
Expand Down
2 changes: 1 addition & 1 deletion airflow/providers/amazon/aws/sensors/s3.py
Expand Up @@ -113,7 +113,7 @@ def _check_key(self, key):
}]
"""
if self.wildcard_match:
prefix = re.split(r"[\[\*\?]", key, 1)[0]
prefix = re.split(r"[\[*?]", key, 1)[0]
keys = self.hook.get_file_metadata(prefix, bucket_name)
key_matches = [k for k in keys if fnmatch.fnmatch(k["Key"], key)]
if not key_matches:
Expand Down
6 changes: 4 additions & 2 deletions airflow/providers/amazon/aws/utils/__init__.py
Expand Up @@ -66,8 +66,10 @@ def datetime_to_epoch_us(date_time: datetime) -> int:


def get_airflow_version() -> tuple[int, ...]:
val = re.sub(r"(\d+\.\d+\.\d+).*", lambda x: x.group(1), version)
return tuple(int(x) for x in val.split("."))
match = re.match(r"(\d+)\.(\d+)\.(\d+)", version)
if match is None: # Not theoratically possible.
raise RuntimeError(f"Broken Airflow version: {version}")
return tuple(int(x) for x in match.groups())


class _StringCompareEnum(Enum):
Expand Down
2 changes: 1 addition & 1 deletion airflow/providers/apache/hive/hooks/hive.py
Expand Up @@ -315,7 +315,7 @@ def test_hql(self, hql: str) -> None:
message = e.args[0].splitlines()[-2]
self.log.info(message)
error_loc = re.search(r"(\d+):(\d+)", message)
if error_loc and error_loc.group(1).isdigit():
if error_loc:
lst = int(error_loc.group(1))
begin = max(lst - 2, 0)
end = min(lst + 3, len(query.splitlines()))
Expand Down
4 changes: 2 additions & 2 deletions airflow/providers/apache/livy/hooks/livy.py
Expand Up @@ -419,7 +419,7 @@ def _validate_size_format(size: str) -> bool:
:param size: size value
:return: true if valid format
"""
if size and not (isinstance(size, str) and re.match(r"^\d+[kmgt]b?$", size, re.IGNORECASE)):
if size and not (isinstance(size, str) and re.fullmatch(r"\d+[kmgt]b?", size, re.IGNORECASE)):
raise ValueError(f"Invalid java size format for string'{size}'")
return True

Expand Down Expand Up @@ -800,7 +800,7 @@ def _validate_size_format(size: str) -> bool:
:param size: size value
:return: true if valid format
"""
if size and not (isinstance(size, str) and re.match(r"^\d+[kmgt]b?$", size, re.IGNORECASE)):
if size and not (isinstance(size, str) and re.fullmatch(r"\d+[kmgt]b?", size, re.IGNORECASE)):
raise ValueError(f"Invalid java size format for string'{size}'")
return True

Expand Down
12 changes: 6 additions & 6 deletions airflow/providers/apache/spark/hooks/spark_submit.py
Expand Up @@ -461,31 +461,31 @@ def _process_spark_submit_log(self, itr: Iterator[Any]) -> None:
# If we run yarn cluster mode, we want to extract the application id from
# the logs so we can kill the application when we stop it unexpectedly
if self._is_yarn and self._connection["deploy_mode"] == "cluster":
match = re.search("(application[0-9_]+)", line)
match = re.search("application[0-9_]+", line)
if match:
self._yarn_application_id = match.groups()[0]
self._yarn_application_id = match.group(0)
self.log.info("Identified spark driver id: %s", self._yarn_application_id)

# If we run Kubernetes cluster mode, we want to extract the driver pod id
# from the logs so we can kill the application when we stop it unexpectedly
elif self._is_kubernetes:
match = re.search(r"\s*pod name: ((.+?)-([a-z0-9]+)-driver)", line)
if match:
self._kubernetes_driver_pod = match.groups()[0]
self._kubernetes_driver_pod = match.group(1)
self.log.info("Identified spark driver pod: %s", self._kubernetes_driver_pod)

# Store the Spark Exit code
match_exit_code = re.search(r"\s*[eE]xit code: (\d+)", line)
if match_exit_code:
self._spark_exit_code = int(match_exit_code.groups()[0])
self._spark_exit_code = int(match_exit_code.group(1))

# if we run in standalone cluster mode and we want to track the driver status
# we need to extract the driver id from the logs. This allows us to poll for
# the status using the driver id. Also, we can kill the driver when needed.
elif self._should_track_driver_status and not self._driver_id:
match_driver_id = re.search(r"(driver-[0-9\-]+)", line)
match_driver_id = re.search(r"driver-[0-9\-]+", line)
if match_driver_id:
self._driver_id = match_driver_id.groups()[0]
self._driver_id = match_driver_id.group(0)
self.log.info("identified spark driver id: %s", self._driver_id)

self.log.info(line)
Expand Down
2 changes: 1 addition & 1 deletion airflow/providers/common/sql/operators/sql.py
Expand Up @@ -84,7 +84,7 @@ def _get_failed_checks(checks, col=None):
"""


_PROVIDERS_MATCHER = re.compile(r"airflow\.providers\.(.*)\.hooks.*")
_PROVIDERS_MATCHER = re.compile(r"airflow\.providers\.(.*?)\.hooks.*")

_MIN_SUPPORTED_PROVIDERS_VERSION = {
"amazon": "4.1.0",
Expand Down
2 changes: 1 addition & 1 deletion airflow/providers/discord/hooks/discord_webhook.py
Expand Up @@ -95,7 +95,7 @@ def _get_webhook_endpoint(self, http_conn_id: str | None, webhook_endpoint: str
)

# make sure endpoint matches the expected Discord webhook format
if not re.match("^webhooks/[0-9]+/[a-zA-Z0-9_-]+$", endpoint):
if not re.fullmatch("webhooks/[0-9]+/[a-zA-Z0-9_-]+", endpoint):
raise AirflowException(
'Expected Discord webhook endpoint in the form of "webhooks/{webhook.id}/{webhook.token}".'
)
Expand Down
5 changes: 3 additions & 2 deletions airflow/providers/ftp/sensors/ftp.py
Expand Up @@ -44,7 +44,7 @@ class FTPSensor(BaseSensorOperator):
"""Errors that are transient in nature, and where action can be retried"""
transient_errors = [421, 425, 426, 434, 450, 451, 452]

error_code_pattern = re.compile(r"([\d]+)")
error_code_pattern = re.compile(r"\d+")

def __init__(
self, *, path: str, ftp_conn_id: str = "ftp_default", fail_on_transient_errors: bool = True, **kwargs
Expand All @@ -64,9 +64,10 @@ def _get_error_code(self, e):
try:
matches = self.error_code_pattern.match(str(e))
code = int(matches.group(0))
return code
except ValueError:
return e
else:
return code

def poke(self, context: Context) -> bool:
with self._create_hook() as hook:
Expand Down
3 changes: 1 addition & 2 deletions airflow/providers/google/cloud/hooks/cloud_sql.py
Expand Up @@ -670,8 +670,7 @@ def get_proxy_version(self) -> str | None:
command_to_run.extend(["--version"])
command_to_run.extend(self._get_credential_parameters())
result = subprocess.check_output(command_to_run).decode("utf-8")
pattern = re.compile("^.*[V|v]ersion ([^;]*);.*$")
matched = pattern.match(result)
matched = re.search("[Vv]ersion (.*?);", result)
if matched:
return matched.group(1)
else:
Expand Down
2 changes: 1 addition & 1 deletion airflow/providers/google/cloud/hooks/dataflow.py
Expand Up @@ -867,7 +867,7 @@ def build_dataflow_job_name(job_name: str, append_job_name: bool = True) -> str:
"""Builds Dataflow job name."""
base_job_name = str(job_name).replace("_", "-")

if not re.match(r"^[a-z]([-a-z0-9]*[a-z0-9])?$", base_job_name):
if not re.fullmatch(r"[a-z]([-a-z0-9]*[a-z0-9])?", base_job_name):
raise ValueError(
f"Invalid job_name ({base_job_name}); the name must consist of only the characters "
f"[-a-z0-9], starting with a letter and ending with a letter or number "
Expand Down
28 changes: 14 additions & 14 deletions airflow/providers/google/cloud/operators/dataproc.py
Expand Up @@ -245,12 +245,13 @@ def _set_preemptibility_type(self, preemptibility: str):
return PreemptibilityType(preemptibility.upper())

def _get_init_action_timeout(self) -> dict:
match = re.match(r"^(\d+)([sm])$", self.init_action_timeout)
match = re.fullmatch(r"(\d+)([sm])", self.init_action_timeout)
if match:
val = float(match.group(1))
if match.group(2) == "s":
return {"seconds": int(val)}
elif match.group(2) == "m":
val = int(match.group(1))
unit = match.group(2)
if unit == "s":
return {"seconds": val}
elif unit == "m":
return {"seconds": int(timedelta(minutes=val).total_seconds())}

raise AirflowException(
Expand Down Expand Up @@ -811,18 +812,17 @@ def _graceful_decommission_timeout_object(self) -> dict[str, int] | None:
return None

timeout = None
match = re.match(r"^(\d+)([smdh])$", self.graceful_decommission_timeout)
match = re.fullmatch(r"(\d+)([smdh])", self.graceful_decommission_timeout)
if match:
if match.group(2) == "s":
timeout = int(match.group(1))
elif match.group(2) == "m":
val = float(match.group(1))
val = int(match.group(1))
unit = match.group(2)
if unit == "s":
timeout = val
elif unit == "m":
timeout = int(timedelta(minutes=val).total_seconds())
elif match.group(2) == "h":
val = float(match.group(1))
elif unit == "h":
timeout = int(timedelta(hours=val).total_seconds())
elif match.group(2) == "d":
val = float(match.group(1))
elif unit == "d":
timeout = int(timedelta(days=val).total_seconds())

if not timeout:
Expand Down
4 changes: 2 additions & 2 deletions airflow/providers/google/cloud/secrets/secret_manager.py
Expand Up @@ -36,8 +36,8 @@


def _parse_version(val):
val = re.sub(r"(\d+\.\d+\.\d+).*", lambda x: x.group(1), val)
return tuple(int(x) for x in val.split("."))
match = re.search(r"(\d+)\.(\d+)\.(\d+)", val)
return tuple(int(x) for x in match.groups())


class CloudSecretManagerBackend(BaseSecretsBackend, LoggingMixin):
Expand Down
Expand Up @@ -193,7 +193,7 @@ def validate_err_and_count(summary):

# Verify that task_prefix doesn't have any special characters except hyphen
# '-', which is the only allowed non-alphanumeric character by Dataflow.
if not re.match(r"^[a-zA-Z][-A-Za-z0-9]*$", task_prefix):
if not re.fullmatch(r"[a-zA-Z][-A-Za-z0-9]*", task_prefix):
raise AirflowException(
"Malformed task_id for DataFlowPythonOperator (only alphanumeric "
"and hyphens are allowed but got: " + task_prefix
Expand Down
4 changes: 2 additions & 2 deletions airflow/providers/microsoft/azure/secrets/key_vault.py
Expand Up @@ -33,8 +33,8 @@


def _parse_version(val):
val = re.sub(r"(\d+\.\d+\.\d+).*", lambda x: x.group(1), val)
return tuple(int(x) for x in val.split("."))
match = re.search(r"(\d+)\.(\d+)\.(\d+)", val)
return tuple(int(x) for x in match.groups())


class AzureKeyVaultBackend(BaseSecretsBackend, LoggingMixin):
Expand Down

0 comments on commit a7310f9

Please sign in to comment.