Skip to content

Commit

Permalink
Standardize SecretBackend class names (#7846)
Browse files Browse the repository at this point in the history
- AwsSsmSecretsBackend -> SystemsManagerParameterStoreBackend
- CloudSecretsManagerSecretsBackend -> CloudSecretsManagerBackend
- VaultSecrets -> VaultBackend
- EnvironmentVariablesSecretsBackend -> EnvironmentVariablesBackend
- MetastoreSecretsBackend -> MetastoreBackend
  • Loading branch information
kaxil committed Mar 25, 2020
1 parent 02b71f9 commit 686d7d5
Show file tree
Hide file tree
Showing 12 changed files with 51 additions and 47 deletions.
Expand Up @@ -27,7 +27,7 @@
from airflow.utils.log.logging_mixin import LoggingMixin


class AwsSsmSecretsBackend(BaseSecretsBackend, LoggingMixin):
class SystemsManagerParameterStoreBackend(BaseSecretsBackend, LoggingMixin):
"""
Retrieves Connection object from AWS SSM Parameter Store
Expand All @@ -36,7 +36,7 @@ class AwsSsmSecretsBackend(BaseSecretsBackend, LoggingMixin):
.. code-block:: ini
[secrets]
backend = airflow.providers.amazon.aws.secrets.ssm.AwsSsmSecretsBackend
backend = airflow.providers.amazon.aws.secrets.systems_manager.SystemsManagerParameterStoreBackend
backend_kwargs = {"connections_prefix": "/airflow/connections", "profile_name": null}
For example, if ssm path is ``/airflow/connections/smtp_default``, this would be accessible
Expand Down
4 changes: 2 additions & 2 deletions airflow/providers/google/cloud/secrets/secrets_manager.py
Expand Up @@ -33,7 +33,7 @@
from airflow.utils.log.logging_mixin import LoggingMixin


class CloudSecretsManagerSecretsBackend(BaseSecretsBackend, LoggingMixin):
class CloudSecretsManagerBackend(BaseSecretsBackend, LoggingMixin):
"""
Retrieves Connection object from GCP Secrets Manager
Expand All @@ -42,7 +42,7 @@ class CloudSecretsManagerSecretsBackend(BaseSecretsBackend, LoggingMixin):
.. code-block:: ini
[secrets]
backend = airflow.providers.google.cloud.secrets.secrets_manager.CloudSecretsManagerSecretsBackend
backend = airflow.providers.google.cloud.secrets.secrets_manager.CloudSecretsManagerBackend
backend_kwargs = {"connections_prefix": "airflow/connections"}
For example, if secret id is ``airflow/connections/smtp_default``, this would be accessible
Expand Down
4 changes: 2 additions & 2 deletions airflow/providers/hashicorp/secrets/vault.py
Expand Up @@ -29,7 +29,7 @@
from airflow.utils.log.logging_mixin import LoggingMixin


class VaultSecrets(BaseSecretsBackend, LoggingMixin):
class VaultBackend(BaseSecretsBackend, LoggingMixin):
"""
Retrieves Connection object from Hashicorp Vault
Expand All @@ -38,7 +38,7 @@ class VaultSecrets(BaseSecretsBackend, LoggingMixin):
.. code-block:: ini
[secrets]
backend = airflow.providers.hashicorp.secrets.vault.VaultSecrets
backend = airflow.providers.hashicorp.secrets.vault.VaultBackend
backend_kwargs = {
"connections_path": "connections",
"url": "http://127.0.0.1:8200",
Expand Down
4 changes: 2 additions & 2 deletions airflow/secrets/__init__.py
Expand Up @@ -36,8 +36,8 @@

CONFIG_SECTION = "secrets"
DEFAULT_SECRETS_SEARCH_PATH = [
"airflow.secrets.environment_variables.EnvironmentVariablesSecretsBackend",
"airflow.secrets.metastore.MetastoreSecretsBackend",
"airflow.secrets.environment_variables.EnvironmentVariablesBackend",
"airflow.secrets.metastore.MetastoreBackend",
]


Expand Down
2 changes: 1 addition & 1 deletion airflow/secrets/environment_variables.py
Expand Up @@ -27,7 +27,7 @@
CONN_ENV_PREFIX = "AIRFLOW_CONN_"


class EnvironmentVariablesSecretsBackend(BaseSecretsBackend):
class EnvironmentVariablesBackend(BaseSecretsBackend):
"""
Retrieves Connection object from environment variable.
"""
Expand Down
2 changes: 1 addition & 1 deletion airflow/secrets/metastore.py
Expand Up @@ -26,7 +26,7 @@
from airflow.utils.session import provide_session


class MetastoreSecretsBackend(BaseSecretsBackend):
class MetastoreBackend(BaseSecretsBackend):
"""
Retrieves Connection object from airflow metastore database.
"""
Expand Down
12 changes: 6 additions & 6 deletions docs/howto/use-alternative-secrets-backend.rst
Expand Up @@ -57,15 +57,15 @@ See :ref:`AWS SSM Parameter Store <ssm_parameter_store_secrets>` for an example
AWS SSM Parameter Store Secrets Backend
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

To enable SSM parameter store, specify :py:class:`~airflow.providers.amazon.aws.secrets.ssm.AwsSsmSecretsBackend`
To enable SSM parameter store, specify :py:class:`~airflow.providers.amazon.aws.secrets.systems_manager.SystemsManagerParameterStoreBackend`
as the ``backend`` in ``[secrets]`` section of ``airflow.cfg``.

Here is a sample configuration:

.. code-block:: ini
[secrets]
backend = airflow.providers.amazon.aws.secrets.ssm.AwsSsmSecretsBackend
backend = airflow.providers.amazon.aws.secrets.systems_manager.SystemsManagerParameterStoreBackend
backend_kwargs = {"connections_prefix": "/airflow/connections", "profile_name": "default"}
If you have set ``connections_prefix`` as ``/airflow/connections``, then for a connection id of ``smtp_default``,
Expand All @@ -81,15 +81,15 @@ of the connection object.
Hashicorp Vault Secrets Backend
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

To enable Hashicorp vault to retrieve connection, specify :py:class:`~airflow.providers.hashicorp.secrets.vault.VaultSecrets`
To enable Hashicorp vault to retrieve connection, specify :py:class:`~airflow.providers.hashicorp.secrets.vault.VaultBackend`
as the ``backend`` in ``[secrets]`` section of ``airflow.cfg``.

Here is a sample configuration:

.. code-block:: ini
[secrets]
backend = airflow.providers.hashicorp.secrets.vault.VaultSecrets
backend = airflow.providers.hashicorp.secrets.vault.VaultBackend
backend_kwargs = {"connections_path": "connections", "mount_point": "airflow", "url": "http://127.0.0.1:8200"}
The default KV version engine is ``2``, pass ``kv_engine_version: 1`` in ``backend_kwargs`` if you use
Expand Down Expand Up @@ -147,7 +147,7 @@ of the connection object.
GCP Secrets Manager Backend
^^^^^^^^^^^^^^^^^^^^^^^^^^^

To enable GCP Secrets Manager to retrieve connection, specify :py:class:`~airflow.providers.google.cloud.secrets.secrets_manager.CloudSecretsManagerSecretsBackend`
To enable GCP Secrets Manager to retrieve connection, specify :py:class:`~airflow.providers.google.cloud.secrets.secrets_manager.CloudSecretsManagerBackend`
as the ``backend`` in ``[secrets]`` section of ``airflow.cfg``.

Available parameters to ``backend_kwargs``:
Expand All @@ -161,7 +161,7 @@ Here is a sample configuration:
.. code-block:: ini
[secrets]
backend = airflow.providers.google.cloud.secrets.secrets_manager.CloudSecretsManagerSecretsBackend
backend = airflow.providers.google.cloud.secrets.secrets_manager.CloudSecretsManagerBackend
backend_kwargs = {"connections_prefix": "airflow/connections"}
When ``gcp_key_path`` is not provided, it will use the Application Default Credentials in the current environment. You can set up the credentials with:
Expand Down
Expand Up @@ -19,14 +19,15 @@

from moto import mock_ssm

from airflow.providers.amazon.aws.secrets.ssm import AwsSsmSecretsBackend
from airflow.providers.amazon.aws.secrets.systems_manager import SystemsManagerParameterStoreBackend


class TestSsmSecrets(TestCase):
@mock.patch("airflow.providers.amazon.aws.secrets.ssm.AwsSsmSecretsBackend.get_conn_uri")
@mock.patch("airflow.providers.amazon.aws.secrets.systems_manager."
"SystemsManagerParameterStoreBackend.get_conn_uri")
def test_aws_ssm_get_connections(self, mock_get_uri):
mock_get_uri.return_value = "scheme://user:pass@host:100"
conn_list = AwsSsmSecretsBackend().get_connections("fake_conn")
conn_list = SystemsManagerParameterStoreBackend().get_connections("fake_conn")
conn = conn_list[0]
assert conn.host == 'host'

Expand All @@ -38,7 +39,7 @@ def test_get_conn_uri(self):
'Value': 'postgresql://airflow:airflow@host:5432/airflow'
}

ssm_backend = AwsSsmSecretsBackend()
ssm_backend = SystemsManagerParameterStoreBackend()
ssm_backend.client.put_parameter(**param)

returned_uri = ssm_backend.get_conn_uri(conn_id="test_postgres")
Expand All @@ -48,7 +49,7 @@ def test_get_conn_uri(self):
def test_get_conn_uri_non_existent_key(self):
"""
Test that if the key with connection ID is not present in SSM,
AwsSsmSecretsBackend.get_connections should return None
SystemsManagerParameterStoreBackend.get_connections should return None
"""
conn_id = "test_mysql"
param = {
Expand All @@ -57,7 +58,7 @@ def test_get_conn_uri_non_existent_key(self):
'Value': 'postgresql://airflow:airflow@host:5432/airflow'
}

ssm_backend = AwsSsmSecretsBackend()
ssm_backend = SystemsManagerParameterStoreBackend()
ssm_backend.client.put_parameter(**param)

self.assertIsNone(ssm_backend.get_conn_uri(conn_id=conn_id))
Expand Down
10 changes: 5 additions & 5 deletions tests/providers/google/cloud/secrets/test_secrets_manager.py
Expand Up @@ -22,7 +22,7 @@
from parameterized import parameterized

from airflow.models import Connection
from airflow.providers.google.cloud.secrets.secrets_manager import CloudSecretsManagerSecretsBackend
from airflow.providers.google.cloud.secrets.secrets_manager import CloudSecretsManagerBackend

CREDENTIALS = 'test-creds'
KEY_FILE = 'test-file.json'
Expand Down Expand Up @@ -50,19 +50,19 @@ def test_get_conn_uri(self, connections_prefix, mock_client_callable, mock_get_c
test_response.payload.data = CONN_URI.encode("UTF-8")
mock_client.access_secret_version.return_value = test_response

secrets_manager_backend = CloudSecretsManagerSecretsBackend(connections_prefix=connections_prefix)
secrets_manager_backend = CloudSecretsManagerBackend(connections_prefix=connections_prefix)
returned_uri = secrets_manager_backend.get_conn_uri(conn_id=CONN_ID)
self.assertEqual(CONN_URI, returned_uri)
mock_client.secret_version_path.assert_called_once_with(
PROJECT_ID, f"{connections_prefix}/{CONN_ID}", "latest"
)

@mock.patch(MODULE_NAME + ".get_credentials_and_project_id")
@mock.patch(MODULE_NAME + ".CloudSecretsManagerSecretsBackend.get_conn_uri")
@mock.patch(MODULE_NAME + ".CloudSecretsManagerBackend.get_conn_uri")
def test_get_connections(self, mock_get_uri, mock_get_creds):
mock_get_creds.return_value = CREDENTIALS, PROJECT_ID
mock_get_uri.return_value = CONN_URI
conns = CloudSecretsManagerSecretsBackend().get_connections(conn_id=CONN_ID)
conns = CloudSecretsManagerBackend().get_connections(conn_id=CONN_ID)
self.assertIsInstance(conns, list)
self.assertIsInstance(conns[0], Connection)

Expand All @@ -77,7 +77,7 @@ def test_get_conn_uri_non_existent_key(self, mock_client_callable, mock_get_cred

connections_prefix = "airflow/connections"

secrets_manager_backend = CloudSecretsManagerSecretsBackend(connections_prefix=connections_prefix)
secrets_manager_backend = CloudSecretsManagerBackend(connections_prefix=connections_prefix)
with self.assertLogs(secrets_manager_backend.log, level="ERROR") as log_output:
self.assertIsNone(secrets_manager_backend.get_conn_uri(conn_id=CONN_ID))
self.assertEqual([], secrets_manager_backend.get_connections(conn_id=CONN_ID))
Expand Down
12 changes: 6 additions & 6 deletions tests/providers/hashicorp/secrets/test_vault.py
Expand Up @@ -19,7 +19,7 @@

from hvac.exceptions import InvalidPath, VaultError

from airflow.providers.hashicorp.secrets.vault import VaultSecrets
from airflow.providers.hashicorp.secrets.vault import VaultBackend


class TestVaultSecrets(TestCase):
Expand Down Expand Up @@ -52,7 +52,7 @@ def test_get_conn_uri(self, mock_hvac):
"token": "s.7AU0I51yv1Q1lxOIg1F3ZRAS"
}

test_client = VaultSecrets(**kwargs)
test_client = VaultBackend(**kwargs)
returned_uri = test_client.get_conn_uri(conn_id="test_postgres")
self.assertEqual('postgresql://airflow:airflow@host:5432/airflow', returned_uri)

Expand All @@ -79,7 +79,7 @@ def test_get_conn_uri_engine_version_1(self, mock_hvac):
"kv_engine_version": 1
}

test_client = VaultSecrets(**kwargs)
test_client = VaultBackend(**kwargs)
returned_uri = test_client.get_conn_uri(conn_id="test_postgres")
mock_client.secrets.kv.v1.read_secret.assert_called_once_with(
mount_point='airflow', path='connections/test_postgres')
Expand Down Expand Up @@ -107,7 +107,7 @@ def test_get_conn_uri_non_existent_key(self, mock_hvac):
"token": "s.7AU0I51yv1Q1lxOIg1F3ZRAS"
}

test_client = VaultSecrets(**kwargs)
test_client = VaultBackend(**kwargs)
self.assertIsNone(test_client.get_conn_uri(conn_id="test_mysql"))
mock_client.secrets.kv.v2.read_secret_version.assert_called_once_with(
mount_point='airflow', path='connections/test_mysql')
Expand All @@ -128,7 +128,7 @@ def test_auth_failure_raises_error(self, mock_hvac):
}

with self.assertRaisesRegex(VaultError, "Vault Authentication Error!"):
VaultSecrets(**kwargs).get_connections(conn_id='test')
VaultBackend(**kwargs).get_connections(conn_id='test')

@mock.patch("airflow.providers.hashicorp.secrets.vault.hvac")
def test_empty_token_raises_error(self, mock_hvac):
Expand All @@ -143,4 +143,4 @@ def test_empty_token_raises_error(self, mock_hvac):
}

with self.assertRaisesRegex(VaultError, "token cannot be None for auth_type='token'"):
VaultSecrets(**kwargs).get_connections(conn_id='test')
VaultBackend(**kwargs).get_connections(conn_id='test')
23 changes: 13 additions & 10 deletions tests/secrets/test_secrets.py
Expand Up @@ -24,51 +24,54 @@


class TestSecrets(unittest.TestCase):
@mock.patch("airflow.secrets.metastore.MetastoreSecretsBackend.get_connections")
@mock.patch("airflow.secrets.environment_variables.EnvironmentVariablesSecretsBackend.get_connections")
@mock.patch("airflow.secrets.metastore.MetastoreBackend.get_connections")
@mock.patch("airflow.secrets.environment_variables.EnvironmentVariablesBackend.get_connections")
def test_get_connections_second_try(self, mock_env_get, mock_meta_get):
mock_env_get.side_effect = [[]] # return empty list
get_connections("fake_conn_id")
mock_meta_get.assert_called_once_with(conn_id="fake_conn_id")
mock_env_get.assert_called_once_with(conn_id="fake_conn_id")

@mock.patch("airflow.secrets.metastore.MetastoreSecretsBackend.get_connections")
@mock.patch("airflow.secrets.environment_variables.EnvironmentVariablesSecretsBackend.get_connections")
@mock.patch("airflow.secrets.metastore.MetastoreBackend.get_connections")
@mock.patch("airflow.secrets.environment_variables.EnvironmentVariablesBackend.get_connections")
def test_get_connections_first_try(self, mock_env_get, mock_meta_get):
mock_env_get.side_effect = [["something"]] # returns nonempty list
get_connections("fake_conn_id")
mock_env_get.assert_called_once_with(conn_id="fake_conn_id")
mock_meta_get.not_called()

@conf_vars({
("secrets", "backend"): "airflow.providers.amazon.aws.secrets.ssm.AwsSsmSecretsBackend",
("secrets", "backend"):
"airflow.providers.amazon.aws.secrets.systems_manager.SystemsManagerParameterStoreBackend",
("secrets", "backend_kwargs"): '{"connections_prefix": "/airflow", "profile_name": null}',
})
def test_initialize_secrets_backends(self):
backends = initialize_secrets_backends()
backend_classes = [backend.__class__.__name__ for backend in backends]

self.assertEqual(3, len(backends))
self.assertIn('AwsSsmSecretsBackend', backend_classes)
self.assertIn('SystemsManagerParameterStoreBackend', backend_classes)

@conf_vars({
("secrets", "backend"): "airflow.providers.amazon.aws.secrets.ssm.AwsSsmSecretsBackend",
("secrets", "backend"):
"airflow.providers.amazon.aws.secrets.systems_manager.SystemsManagerParameterStoreBackend",
("secrets", "backend_kwargs"): '{"connections_prefix": "/airflow", "profile_name": null}',
})
@mock.patch.dict('os.environ', {
'AIRFLOW_CONN_TEST_MYSQL': 'mysql://airflow:airflow@host:5432/airflow',
})
@mock.patch("airflow.providers.amazon.aws.secrets.ssm.AwsSsmSecretsBackend.get_conn_uri")
@mock.patch("airflow.providers.amazon.aws.secrets.systems_manager."
"SystemsManagerParameterStoreBackend.get_conn_uri")
def test_backend_fallback_to_env_var(self, mock_get_uri):
mock_get_uri.return_value = None

backends = ensure_secrets_loaded()
backend_classes = [backend.__class__.__name__ for backend in backends]
self.assertIn('AwsSsmSecretsBackend', backend_classes)
self.assertIn('SystemsManagerParameterStoreBackend', backend_classes)

uri = get_connections(conn_id="test_mysql")

# Assert that AwsSsmSecretsBackend.get_conn_uri was called
# Assert that SystemsManagerParameterStoreBackend.get_conn_uri was called
mock_get_uri.assert_called_once_with(conn_id='test_mysql')

self.assertEqual('mysql://airflow:airflow@host:5432/airflow', uri[0].get_uri())
Expand Down
8 changes: 4 additions & 4 deletions tests/secrets/test_secrets_backends.py
Expand Up @@ -20,8 +20,8 @@
import unittest

from airflow.models import Connection
from airflow.secrets.environment_variables import EnvironmentVariablesSecretsBackend
from airflow.secrets.metastore import MetastoreSecretsBackend
from airflow.secrets.environment_variables import EnvironmentVariablesBackend
from airflow.secrets.metastore import MetastoreBackend
from airflow.utils.session import create_session


Expand All @@ -39,7 +39,7 @@ def __init__(self, conn_id, variation: str):
class TestBaseSecretsBackend(unittest.TestCase):
def test_env_secrets_backend(self):
sample_conn_1 = SampleConn("sample_1", "A")
env_secrets_backend = EnvironmentVariablesSecretsBackend()
env_secrets_backend = EnvironmentVariablesBackend()
os.environ[sample_conn_1.var_name] = sample_conn_1.conn_uri
conn_list = env_secrets_backend.get_connections(sample_conn_1.conn_id)
self.assertEqual(1, len(conn_list))
Expand All @@ -55,7 +55,7 @@ def test_metastore_secrets_backend(self):
session.add(sample_conn_2a.conn)
session.add(sample_conn_2b.conn)
session.commit()
metastore_backend = MetastoreSecretsBackend()
metastore_backend = MetastoreBackend()
conn_list = metastore_backend.get_connections("sample_2")
host_list = {x.host for x in conn_list}
self.assertEqual(
Expand Down

0 comments on commit 686d7d5

Please sign in to comment.