Skip to content

Commit

Permalink
Move cloud_sql_binary_path from connection to Hook (#29499)
Browse files Browse the repository at this point in the history
Specifying cloud_sql_binary_path in connection should never be needed,
This is at most property of the Hook (and Operator by transition) if you
want to override it, rather than extra in the connection.
  • Loading branch information
potiuk committed Feb 13, 2023
1 parent 8e24387 commit 32c571e
Show file tree
Hide file tree
Showing 3 changed files with 14 additions and 6 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -205,8 +205,7 @@ def get_absolute_path(path):
"location={location}&"
"instance={instance}&"
"use_proxy=True&"
"sql_proxy_binary_path={sql_proxy_binary_path}&"
"sql_proxy_use_tcp=False".format(sql_proxy_binary_path=quote_plus(sql_proxy_binary_path), **mysql_kwargs)
"sql_proxy_use_tcp=False".format(**mysql_kwargs)
)

# MySQL: connect directly via TCP (non-SSL)
Expand Down Expand Up @@ -279,7 +278,10 @@ def get_absolute_path(path):

for connection_name in connection_names:
task = CloudSQLExecuteQueryOperator(
gcp_cloudsql_conn_id=connection_name, task_id="example_gcp_sql_task_" + connection_name, sql=SQL
gcp_cloudsql_conn_id=connection_name,
task_id="example_gcp_sql_task_" + connection_name,
sql=SQL,
sql_proxy_binary_path=sql_proxy_binary_path,
)
tasks.append(task)
if prev_task:
Expand Down
7 changes: 4 additions & 3 deletions airflow/providers/google/cloud/hooks/cloud_sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -668,6 +668,8 @@ class CloudSQLDatabaseHook(BaseHook):
* **public_ip** - IP to connect to for public connection (from host of the URI).
* **public_port** - Port to connect to for public connection (from port of the URI).
* **database** - Database to connect to (from schema of the URI).
* **sql_proxy_binary_path** - Optional path to Cloud SQL Proxy binary. If the binary
is not specified or the binary is not present, it is automatically downloaded.
Remaining parameters are retrieved from the extras (URI query parameters):
Expand All @@ -682,8 +684,6 @@ class CloudSQLDatabaseHook(BaseHook):
You cannot use proxy and SSL together.
* **sql_proxy_use_tcp** - (default False) If set to true, TCP is used to connect via
proxy, otherwise UNIX sockets are used.
* **sql_proxy_binary_path** - Optional path to Cloud SQL Proxy binary. If the binary
is not specified or the binary is not present, it is automatically downloaded.
* **sql_proxy_version** - Specific version of the proxy to download (for example
v1.13). If not specified, the latest version is downloaded.
* **sslcert** - Path to client certificate to authenticate when SSL is used.
Expand All @@ -707,6 +707,7 @@ def __init__(
gcp_cloudsql_conn_id: str = "google_cloud_sql_default",
gcp_conn_id: str = "google_cloud_default",
default_gcp_project_id: str | None = None,
sql_proxy_binary_path: str | None = None,
) -> None:
super().__init__()
self.gcp_conn_id = gcp_conn_id
Expand All @@ -722,7 +723,7 @@ def __init__(
self.use_ssl = self._get_bool(self.extras.get("use_ssl", "False"))
self.sql_proxy_use_tcp = self._get_bool(self.extras.get("sql_proxy_use_tcp", "False"))
self.sql_proxy_version = self.extras.get("sql_proxy_version")
self.sql_proxy_binary_path = self.extras.get("sql_proxy_binary_path")
self.sql_proxy_binary_path = sql_proxy_binary_path
self.user = self.cloudsql_connection.login
self.password = self.cloudsql_connection.password
self.public_ip = self.cloudsql_connection.host
Expand Down
5 changes: 5 additions & 0 deletions airflow/providers/google/cloud/operators/cloud_sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -1045,6 +1045,8 @@ class CloudSQLExecuteQueryOperator(BaseOperator):
its schema should be gcpcloudsql://.
See :class:`~airflow.providers.google.cloud.hooks.cloud_sql.CloudSQLDatabaseHook` for
details on how to define ``gcpcloudsql://`` connection.
:param sql_proxy_binary_path: (optional) Path to the cloud-sql-proxy binary.
is not specified or the binary is not present, it is automatically downloaded.
"""

# [START gcp_sql_query_template_fields]
Expand All @@ -1062,6 +1064,7 @@ def __init__(
parameters: Iterable | Mapping | None = None,
gcp_conn_id: str = "google_cloud_default",
gcp_cloudsql_conn_id: str = "google_cloud_sql_default",
sql_proxy_binary_path: str | None = None,
**kwargs,
) -> None:
super().__init__(**kwargs)
Expand All @@ -1071,6 +1074,7 @@ def __init__(
self.autocommit = autocommit
self.parameters = parameters
self.gcp_connection: Connection | None = None
self.sql_proxy_binary_path = sql_proxy_binary_path

def _execute_query(self, hook: CloudSQLDatabaseHook, database_hook: PostgresHook | MySqlHook) -> None:
cloud_sql_proxy_runner = None
Expand All @@ -1094,6 +1098,7 @@ def execute(self, context: Context):
gcp_cloudsql_conn_id=self.gcp_cloudsql_conn_id,
gcp_conn_id=self.gcp_conn_id,
default_gcp_project_id=get_field(self.gcp_connection.extra_dejson, "project"),
sql_proxy_binary_path=self.sql_proxy_binary_path,
)
hook.validate_ssl_certs()
connection = hook.create_connection()
Expand Down

0 comments on commit 32c571e

Please sign in to comment.