Skip to content

Commit

Permalink
Speech To Text assets & system tests migration (AIP-47) (#23643)
Browse files Browse the repository at this point in the history
Co-authored-by: Wojciech Januszek <[email protected]>
  • Loading branch information
wojsamjan and Wojciech Januszek committed May 17, 2022
1 parent 64d0d9c commit d3b0880
Show file tree
Hide file tree
Showing 5 changed files with 57 additions and 56 deletions.
11 changes: 11 additions & 0 deletions airflow/providers/google/cloud/operators/speech_to_text.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
from airflow.exceptions import AirflowException
from airflow.models import BaseOperator
from airflow.providers.google.cloud.hooks.speech_to_text import CloudSpeechToTextHook, RecognitionAudio
from airflow.providers.google.common.links.storage import FileDetailsLink

if TYPE_CHECKING:
from airflow.utils.context import Context
Expand Down Expand Up @@ -72,6 +73,7 @@ class CloudSpeechToTextRecognizeSpeechOperator(BaseOperator):
"impersonation_chain",
)
# [END gcp_speech_to_text_synthesize_template_fields]
operator_extra_links = (FileDetailsLink(),)

def __init__(
self,
Expand Down Expand Up @@ -106,6 +108,15 @@ def execute(self, context: 'Context'):
gcp_conn_id=self.gcp_conn_id,
impersonation_chain=self.impersonation_chain,
)

FileDetailsLink.persist(
context=context,
task_instance=self,
# Slice from: "gs://{BUCKET_NAME}/{FILE_NAME}" to: "{BUCKET_NAME}/{FILE_NAME}"
uri=self.audio["uri"][5:],
project_id=self.project_id or hook.project_id,
)

response = hook.recognize_speech(
config=self.config, audio=self.audio, retry=self.retry, timeout=self.timeout
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,22 +42,22 @@ google.cloud.speech_v1.types module

for more information, see: https://googleapis.github.io/google-cloud-python/latest/speech/gapic/v1/api.html#google.cloud.speech_v1.SpeechClient.recognize

.. exampleinclude:: /../../airflow/providers/google/cloud/example_dags/example_speech_to_text.py
.. exampleinclude:: /../../tests/system/providers/google/speech_to_text/example_speech_to_text.py
:language: python
:start-after: [START howto_operator_text_to_speech_api_arguments]
:end-before: [END howto_operator_text_to_speech_api_arguments]

filename is a simple string argument:

.. exampleinclude:: /../../airflow/providers/google/cloud/example_dags/example_speech_to_text.py
.. exampleinclude:: /../../tests/system/providers/google/speech_to_text/example_speech_to_text.py
:language: python
:start-after: [START howto_operator_speech_to_text_api_arguments]
:end-before: [END howto_operator_speech_to_text_api_arguments]

Using the operator
""""""""""""""""""

.. exampleinclude:: /../../airflow/providers/google/cloud/example_dags/example_speech_to_text.py
.. exampleinclude:: /../../tests/system/providers/google/speech_to_text/example_speech_to_text.py
:language: python
:dedent: 4
:start-after: [START howto_operator_speech_to_text_recognize]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ def test_recognize_speech_green_path(self, mock_hook):
audio=AUDIO,
task_id="id",
impersonation_chain=IMPERSONATION_CHAIN,
).execute(context={"task_instance": Mock()})
).execute(context=MagicMock())

mock_hook.assert_called_once_with(
gcp_conn_id=GCP_CONN_ID,
Expand Down

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -20,14 +20,19 @@
from datetime import datetime

from airflow import models
from airflow.providers.google.cloud.operators.gcs import GCSCreateBucketOperator, GCSDeleteBucketOperator
from airflow.providers.google.cloud.operators.speech_to_text import CloudSpeechToTextRecognizeSpeechOperator
from airflow.providers.google.cloud.operators.text_to_speech import CloudTextToSpeechSynthesizeOperator
from airflow.utils.trigger_rule import TriggerRule

GCP_PROJECT_ID = os.environ.get("GCP_PROJECT_ID", "example-project")
BUCKET_NAME = os.environ.get("GCP_SPEECH_TO_TEXT_TEST_BUCKET", "INVALID BUCKET NAME")
ENV_ID = os.environ.get("SYSTEM_TESTS_ENV_ID")
PROJECT_ID = os.environ.get("SYSTEM_TESTS_GCP_PROJECT")
DAG_ID = "speech_to_text"

BUCKET_NAME = f"bucket_{DAG_ID}_{ENV_ID}"

# [START howto_operator_speech_to_text_gcp_filename]
FILENAME = "gcp-speech-test-file"
FILE_NAME = f"test-audio-file-{DAG_ID}-{ENV_ID}"
# [END howto_operator_speech_to_text_gcp_filename]

# [START howto_operator_text_to_speech_api_arguments]
Expand All @@ -38,29 +43,55 @@

# [START howto_operator_speech_to_text_api_arguments]
CONFIG = {"encoding": "LINEAR16", "language_code": "en_US"}
AUDIO = {"uri": f"gs://{BUCKET_NAME}/{FILENAME}"}
AUDIO = {"uri": f"gs://{BUCKET_NAME}/{FILE_NAME}"}
# [END howto_operator_speech_to_text_api_arguments]

with models.DAG(
"example_gcp_speech_to_text",
schedule_interval='@once', # Override to match your needs
DAG_ID,
schedule_interval=None,
start_date=datetime(2021, 1, 1),
catchup=False,
tags=['example'],
tags=["example", "speech_to_text"],
) as dag:
create_bucket = GCSCreateBucketOperator(task_id="create_bucket", bucket_name=BUCKET_NAME)

text_to_speech_synthesize_task = CloudTextToSpeechSynthesizeOperator(
project_id=GCP_PROJECT_ID,
project_id=PROJECT_ID,
input_data=INPUT,
voice=VOICE,
audio_config=AUDIO_CONFIG,
target_bucket_name=BUCKET_NAME,
target_filename=FILENAME,
target_filename=FILE_NAME,
task_id="text_to_speech_synthesize_task",
)
# [START howto_operator_speech_to_text_recognize]
speech_to_text_recognize_task2 = CloudSpeechToTextRecognizeSpeechOperator(
speech_to_text_recognize_task = CloudSpeechToTextRecognizeSpeechOperator(
config=CONFIG, audio=AUDIO, task_id="speech_to_text_recognize_task"
)
# [END howto_operator_speech_to_text_recognize]

text_to_speech_synthesize_task >> speech_to_text_recognize_task2
delete_bucket = GCSDeleteBucketOperator(
task_id="delete_bucket", bucket_name=BUCKET_NAME, trigger_rule=TriggerRule.ALL_DONE
)

(
# TEST SETUP
create_bucket
# TEST BODY
>> text_to_speech_synthesize_task
>> speech_to_text_recognize_task
# TEST TEARDOWN
>> delete_bucket
)

from tests.system.utils.watcher import watcher

# This test needs watcher in order to properly mark success/failure
# when "tearDown" task with trigger rule is part of the DAG
list(dag.tasks) >> watcher()


from tests.system.utils import get_test_run # noqa: E402

# Needed to run the example DAG with pytest (see: tests/system/README.md#run_via_pytest)
test_run = get_test_run(dag)

0 comments on commit d3b0880

Please sign in to comment.