Skip to content

Commit

Permalink
Allow and prefer non-prefixed extra fields for dataprep hook (#27039)
Browse files Browse the repository at this point in the history
* No extra prefix required for dataprep hook

From 2.3 we no longer need this convention for web UI custom fields.  Just cleaning up the codebase to generally not use this pattern.
  • Loading branch information
dstandish committed Oct 28, 2022
1 parent afa5ce4 commit 3d5f34c
Show file tree
Hide file tree
Showing 4 changed files with 39 additions and 12 deletions.
20 changes: 17 additions & 3 deletions airflow/providers/google/cloud/hooks/dataprep.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,20 @@
from airflow.hooks.base import BaseHook


def _get_field(extras: dict, field_name: str):
"""Get field from extra, first checking short name, then for backcompat we check for prefixed name."""
backcompat_prefix = "extra__dataprep__"
if field_name.startswith("extra__"):
raise ValueError(
f"Got prefixed name {field_name}; please remove the '{backcompat_prefix}' prefix "
"when using this method."
)
if field_name in extras:
return extras[field_name] or None
prefixed_name = f"{backcompat_prefix}{field_name}"
return extras.get(prefixed_name) or None


class GoogleDataprepHook(BaseHook):
"""
Hook for connection with Dataprep API.
Expand All @@ -48,9 +62,9 @@ def __init__(self, dataprep_conn_id: str = default_conn_name) -> None:
super().__init__()
self.dataprep_conn_id = dataprep_conn_id
conn = self.get_connection(self.dataprep_conn_id)
extra_dejson = conn.extra_dejson
self._token = extra_dejson.get("extra__dataprep__token")
self._base_url = extra_dejson.get("extra__dataprep__base_url", "https://api.clouddataprep.com")
extras = conn.extra_dejson
self._token = _get_field(extras, "token")
self._base_url = _get_field(extras, "base_url") or "https://api.clouddataprep.com"

@property
def _headers(self) -> dict[str, str]:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,8 +39,7 @@ Set values for these fields:
.. code-block::
Connection Id: "your_conn_id"
Extra: {"extra__dataprep__token": "TOKEN",
"extra__dataprep__base_url": "https://api.clouddataprep.com"}
Extra: {"token": "TOKEN", "base_url": "https://api.clouddataprep.com"}
Prerequisite Tasks
^^^^^^^^^^^^^^^^^^
Expand Down
26 changes: 20 additions & 6 deletions tests/providers/google/cloud/hooks/test_dataprep.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,31 +18,32 @@
from __future__ import annotations

import json
import unittest
import os
from unittest import mock
from unittest.mock import patch

import pytest
from pytest import param
from requests import HTTPError
from tenacity import RetryError

from airflow.providers.google.cloud.hooks import dataprep
from airflow.providers.google.cloud.hooks.dataprep import GoogleDataprepHook

JOB_ID = 1234567
RECIPE_ID = 1234567
TOKEN = "1111"
EXTRA = {"extra__dataprep__token": TOKEN}
EXTRA = {"token": TOKEN}
EMBED = ""
INCLUDE_DELETED = False
DATA = json.dumps({"wrangledDataset": {"id": RECIPE_ID}})
URL = "https://api.clouddataprep.com/v4/jobGroups"


class TestGoogleDataprepHook(unittest.TestCase):
def setUp(self):
class TestGoogleDataprepHook:
def setup(self):
with mock.patch("airflow.hooks.base.BaseHook.get_connection") as conn:
conn.return_value.extra_dejson = EXTRA
self.hook = dataprep.GoogleDataprepHook(dataprep_conn_id="dataprep_default")
self.hook = GoogleDataprepHook(dataprep_conn_id="dataprep_default")

@patch("airflow.providers.google.cloud.hooks.dataprep.requests.get")
def test_get_jobs_for_job_group_should_be_called_once_with_params(self, mock_get_request):
Expand Down Expand Up @@ -204,3 +205,16 @@ def test_run_job_group_raise_error_after_five_calls(self, mock_get_request):
self.hook.run_job_group(body_request=DATA)
assert "HTTPError" in str(ctx.value)
assert mock_get_request.call_count == 5

@pytest.mark.parametrize(
"uri",
[
param("a://?extra__dataprep__token=abc&extra__dataprep__base_url=abc", id="prefix"),
param("a://?token=abc&base_url=abc", id="no-prefix"),
],
)
def test_conn_extra_backcompat_prefix(self, uri):
with patch.dict(os.environ, {"AIRFLOW_CONN_MY_CONN": uri}):
hook = GoogleDataprepHook("my_conn")
assert hook._token == "abc"
assert hook._base_url == "abc"
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
from tests.test_utils.gcp_system_helpers import CLOUD_DAG_FOLDER, GoogleSystemTest

TOKEN = environ.get("DATAPREP_TOKEN")
EXTRA = {"extra__dataprep__token": TOKEN}
EXTRA = {"token": TOKEN}


@pytest.mark.skipif(TOKEN is None, reason="Dataprep token not present")
Expand Down

0 comments on commit 3d5f34c

Please sign in to comment.