Skip to content

Commit

Permalink
Check that cloud sql provider version is valid (#29497)
Browse files Browse the repository at this point in the history
Additional chek on cloud sql version should be done to avoid
downloading non-existing binary.
  • Loading branch information
potiuk committed Feb 13, 2023
1 parent cf81455 commit 5e6f8eb
Show file tree
Hide file tree
Showing 2 changed files with 91 additions and 15 deletions.
31 changes: 21 additions & 10 deletions airflow/providers/google/cloud/hooks/cloud_sql.py
Expand Up @@ -59,6 +59,8 @@
# Time to sleep between active checks of the operation results
TIME_TO_SLEEP_IN_SECONDS = 20

CLOUD_SQL_PROXY_VERSION_REGEX = re.compile(r"^v?(\d+\.\d+\.\d+)(-\w*.?\d?)?$")


class CloudSqlOperationStatus:
"""Helper class with operation statuses."""
Expand Down Expand Up @@ -449,16 +451,7 @@ def _download_sql_proxy_if_needed(self) -> None:
if os.path.isfile(self.sql_proxy_path):
self.log.info("cloud-sql-proxy is already present")
return
system = platform.system().lower()
processor = os.uname().machine
if processor == "x86_64":
processor = "amd64"
if not self.sql_proxy_version:
download_url = CLOUD_SQL_PROXY_DOWNLOAD_URL.format(system, processor)
else:
download_url = CLOUD_SQL_PROXY_VERSION_DOWNLOAD_URL.format(
self.sql_proxy_version, system, processor
)
download_url = self._get_sql_proxy_download_url()
proxy_path_tmp = self.sql_proxy_path + ".tmp"
self.log.info("Downloading cloud_sql_proxy from %s to %s", download_url, proxy_path_tmp)
# httpx has a breaking API change (follow_redirects vs allow_redirects)
Expand All @@ -482,6 +475,24 @@ def _download_sql_proxy_if_needed(self) -> None:
os.chmod(self.sql_proxy_path, 0o744) # Set executable bit
self.sql_proxy_was_downloaded = True

def _get_sql_proxy_download_url(self):
system = platform.system().lower()
processor = os.uname().machine
if processor == "x86_64":
processor = "amd64"
if not self.sql_proxy_version:
download_url = CLOUD_SQL_PROXY_DOWNLOAD_URL.format(system, processor)
else:
if not CLOUD_SQL_PROXY_VERSION_REGEX.match(self.sql_proxy_version):
raise ValueError(
"The sql_proxy_version should match the regular expression "
f"{CLOUD_SQL_PROXY_VERSION_REGEX.pattern}"
)
download_url = CLOUD_SQL_PROXY_VERSION_DOWNLOAD_URL.format(
self.sql_proxy_version, system, processor
)
return download_url

def _get_credential_parameters(self) -> list[str]:
extras = GoogleBaseHook.get_connection(conn_id=self.gcp_conn_id).extra_dejson
key_path = get_field(extras, "key_path")
Expand Down
75 changes: 70 additions & 5 deletions tests/providers/google/cloud/hooks/test_cloud_sql.py
Expand Up @@ -18,6 +18,9 @@
from __future__ import annotations

import json
import os
import platform
import tempfile
from unittest import mock
from unittest.mock import PropertyMock

Expand All @@ -27,7 +30,11 @@

from airflow.exceptions import AirflowException
from airflow.models import Connection
from airflow.providers.google.cloud.hooks.cloud_sql import CloudSQLDatabaseHook, CloudSQLHook
from airflow.providers.google.cloud.hooks.cloud_sql import (
CloudSQLDatabaseHook,
CloudSQLHook,
CloudSqlProxyRunner,
)
from tests.providers.google.cloud.utils.base_gcp_mock import (
mock_base_gcp_hook_default_project_id,
mock_base_gcp_hook_no_default_project_id,
Expand Down Expand Up @@ -847,8 +854,12 @@ def test_cloudsql_database_hook_validate_ssl_certs_with_ssl_files_not_readable(
err = ctx.value
assert "must be a readable file" in str(err)

@mock.patch("airflow.providers.google.cloud.hooks.cloud_sql.gettempdir")
@mock.patch("airflow.providers.google.cloud.hooks.cloud_sql.CloudSQLDatabaseHook.get_connection")
def test_cloudsql_database_hook_validate_socket_path_length_too_long(self, get_connection):
def test_cloudsql_database_hook_validate_socket_path_length_too_long(
self, get_connection, gettempdir_mock
):
gettempdir_mock.return_value = "/tmp"
connection = Connection()
connection.set_extra(
json.dumps(
Expand All @@ -870,8 +881,12 @@ def test_cloudsql_database_hook_validate_socket_path_length_too_long(self, get_c
err = ctx.value
assert "The UNIX socket path length cannot exceed" in str(err)

@mock.patch("airflow.providers.google.cloud.hooks.cloud_sql.gettempdir")
@mock.patch("airflow.providers.google.cloud.hooks.cloud_sql.CloudSQLDatabaseHook.get_connection")
def test_cloudsql_database_hook_validate_socket_path_length_not_too_long(self, get_connection):
def test_cloudsql_database_hook_validate_socket_path_length_not_too_long(
self, get_connection, gettempdir_mock
):
gettempdir_mock.return_value = "/tmp"
connection = Connection()
connection.set_extra(
json.dumps(
Expand Down Expand Up @@ -1093,7 +1108,7 @@ def test_hook_with_correct_parameters_postgres_proxy_socket(self, get_connection
hook = CloudSQLDatabaseHook()
connection = hook.create_connection()
assert "postgres" == connection.conn_type
assert "/tmp" in connection.host
assert tempfile.gettempdir() in connection.host
assert "example-project:europe-west1:testdb" in connection.host
assert connection.port is None
assert "testdb" == connection.schema
Expand Down Expand Up @@ -1166,7 +1181,7 @@ def test_hook_with_correct_parameters_mysql_proxy_socket(self, get_connection):
connection = hook.create_connection()
assert "mysql" == connection.conn_type
assert "localhost" == connection.host
assert "/tmp" in connection.extra_dejson["unix_socket"]
assert tempfile.gettempdir() in connection.extra_dejson["unix_socket"]
assert "example-project:europe-west1:testdb" in connection.extra_dejson["unix_socket"]
assert connection.port is None
assert "testdb" == connection.schema
Expand All @@ -1185,3 +1200,53 @@ def test_hook_with_correct_parameters_mysql_tcp(self, get_connection):
assert "127.0.0.1" == connection.host
assert 3200 != connection.port
assert "testdb" == connection.schema


def get_processor():
processor = os.uname().machine
if processor == "x86_64":
processor = "amd64"
return processor


class TestCloudSqlProxyRunner:
@pytest.mark.parametrize(
["version", "download_url"],
[
(
"v1.23.0",
"https://storage.googleapis.com/cloudsql-proxy/v1.23.0/cloud_sql_proxy."
f"{platform.system().lower()}.{get_processor()}",
),
(
"v1.23.0-preview.1",
"https://storage.googleapis.com/cloudsql-proxy/v1.23.0-preview.1/cloud_sql_proxy."
f"{platform.system().lower()}.{get_processor()}",
),
],
)
def test_cloud_sql_proxy_runner_version_ok(self, version, download_url):
runner = CloudSqlProxyRunner(
path_prefix="12345678",
instance_specification="project:us-east-1:instance",
sql_proxy_version=version,
)
assert runner._get_sql_proxy_download_url() == download_url

@pytest.mark.parametrize(
"version",
[
"v1.23.",
"v1.23.0..",
"v1.23.0\\",
"\\",
],
)
def test_cloud_sql_proxy_runner_version_nok(self, version):
runner = CloudSqlProxyRunner(
path_prefix="12345678",
instance_specification="project:us-east-1:instance",
sql_proxy_version=version,
)
with pytest.raises(ValueError, match="The sql_proxy_version should match the regular expression"):
runner._get_sql_proxy_download_url()

0 comments on commit 5e6f8eb

Please sign in to comment.