Skip to content

Commit

Permalink
Migrating Google AutoML example_dags to sys tests (#32368)
Browse files Browse the repository at this point in the history

---------

Co-authored-by: Amogh Desai <[email protected]>
  • Loading branch information
amoghrajesh and Amogh Desai committed Jul 7, 2023
1 parent 3a14e84 commit 6c854dc
Show file tree
Hide file tree
Showing 6 changed files with 93 additions and 31 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -39,13 +39,9 @@
GCP_AUTOML_LOCATION = os.environ.get("GCP_AUTOML_LOCATION", "us-central1")
GCP_AUTOML_TEXT_CLS_BUCKET = os.environ.get("GCP_AUTOML_TEXT_CLS_BUCKET", "gs://INVALID BUCKET NAME")

# Example values
DATASET_ID = ""

# Example model
MODEL = {
"display_name": "auto_model_1",
"dataset_id": DATASET_ID,
"text_classification_model_metadata": {},
}

Expand All @@ -55,6 +51,7 @@
"text_classification_dataset_metadata": {"classification_type": "MULTICLASS"},
}


IMPORT_INPUT_CONFIG = {"gcs_source": {"input_uris": [GCP_AUTOML_TEXT_CLS_BUCKET]}}

extract_object_id = CloudAutoMLHook.extract_object_id
Expand All @@ -65,24 +62,23 @@
start_date=datetime(2021, 1, 1),
catchup=False,
tags=["example"],
) as example_dag:
) as dag:
create_dataset_task = AutoMLCreateDatasetOperator(
task_id="create_dataset_task", dataset=DATASET, location=GCP_AUTOML_LOCATION
)

dataset_id = cast(str, XComArg(create_dataset_task, key="dataset_id"))
MODEL["dataset_id"] = dataset_id

import_dataset_task = AutoMLImportDataOperator(
task_id="import_dataset_task",
dataset_id=dataset_id,
location=GCP_AUTOML_LOCATION,
input_config=IMPORT_INPUT_CONFIG,
)

MODEL["dataset_id"] = dataset_id

create_model = AutoMLTrainModelOperator(task_id="create_model", model=MODEL, location=GCP_AUTOML_LOCATION)

model_id = cast(str, XComArg(create_model, key="model_id"))

delete_model_task = AutoMLDeleteModelOperator(
Expand All @@ -99,10 +95,23 @@
project_id=GCP_PROJECT_ID,
)

# TEST BODY
import_dataset_task >> create_model
# TEST TEARDOWN
delete_model_task >> delete_datasets_task

# Task dependencies created via `XComArgs`:
# create_dataset_task >> import_dataset_task
# create_dataset_task >> create_model
# create_dataset_task >> delete_datasets_task

from tests.system.utils.watcher import watcher

# This test needs watcher in order to properly mark success/failure
# when "tearDown" task with trigger rule is part of the DAG
list(dag.tasks) >> watcher()

from tests.system.utils import get_test_run # noqa: E402

# Needed to run the example DAG with pytest (see: tests/system/README.md#run_via_pytest)
test_run = get_test_run(dag)
Original file line number Diff line number Diff line change
Expand Up @@ -39,13 +39,9 @@
GCP_AUTOML_LOCATION = os.environ.get("GCP_AUTOML_LOCATION", "us-central1")
GCP_AUTOML_SENTIMENT_BUCKET = os.environ.get("GCP_AUTOML_SENTIMENT_BUCKET", "gs://INVALID BUCKET NAME")

# Example values
DATASET_ID = ""

# Example model
MODEL = {
"display_name": "auto_model_1",
"dataset_id": DATASET_ID,
"text_sentiment_model_metadata": {},
}

Expand All @@ -66,12 +62,13 @@
catchup=False,
user_defined_macros={"extract_object_id": extract_object_id},
tags=["example"],
) as example_dag:
) as dag:
create_dataset_task = AutoMLCreateDatasetOperator(
task_id="create_dataset_task", dataset=DATASET, location=GCP_AUTOML_LOCATION
)

dataset_id = cast(str, XComArg(create_dataset_task, key="dataset_id"))
MODEL["dataset_id"] = dataset_id

import_dataset_task = AutoMLImportDataOperator(
task_id="import_dataset_task",
Expand Down Expand Up @@ -100,11 +97,24 @@
project_id=GCP_PROJECT_ID,
)

# TEST BODY
import_dataset_task >> create_model
# TEST TEARDOWN
delete_model_task >> delete_datasets_task

# Task dependencies created via `XComArgs`:
# create_dataset_task >> import_dataset_task
# create_dataset_task >> create_model
# create_model >> delete_model_task
# create_dataset_task >> delete_datasets_task

from tests.system.utils.watcher import watcher

# This test needs watcher in order to properly mark success/failure
# when "tearDown" task with trigger rule is part of the DAG
list(dag.tasks) >> watcher()

from tests.system.utils import get_test_run # noqa: E402

# Needed to run the example DAG with pytest (see: tests/system/README.md#run_via_pytest)
test_run = get_test_run(dag)
Original file line number Diff line number Diff line change
Expand Up @@ -41,13 +41,9 @@
"GCP_AUTOML_TRANSLATION_BUCKET", "gs://INVALID BUCKET NAME/file"
)

# Example values
DATASET_ID = "TRL123456789"

# Example model
MODEL = {
"display_name": "auto_model_1",
"dataset_id": DATASET_ID,
"translation_model_metadata": {},
}

Expand All @@ -60,6 +56,7 @@
},
}


IMPORT_INPUT_CONFIG = {"gcs_source": {"input_uris": [GCP_AUTOML_TRANSLATION_BUCKET]}}

extract_object_id = CloudAutoMLHook.extract_object_id
Expand All @@ -69,10 +66,11 @@
with models.DAG(
"example_automl_translation",
start_date=datetime(2021, 1, 1),
schedule="@once",
catchup=False,
user_defined_macros={"extract_object_id": extract_object_id},
tags=["example"],
) as example_dag:
) as dag:
create_dataset_task = AutoMLCreateDatasetOperator(
task_id="create_dataset_task", dataset=DATASET, location=GCP_AUTOML_LOCATION
)
Expand Down Expand Up @@ -106,11 +104,25 @@
project_id=GCP_PROJECT_ID,
)

# TEST BODY
import_dataset_task >> create_model
# TEST TEARDOWN
delete_model_task >> delete_datasets_task

# Task dependencies created via `XComArgs`:
# create_dataset_task >> import_dataset_task
# create_dataset_task >> create_model
# create_model >> delete_model_task
# create_dataset_task >> delete_datasets_task

from tests.system.utils.watcher import watcher

# This test needs watcher in order to properly mark success/failure
# when "tearDown" task with trigger rule is part of the DAG
list(dag.tasks) >> watcher()


from tests.system.utils import get_test_run # noqa: E402

# Needed to run the example DAG with pytest (see: tests/system/README.md#run_via_pytest)
test_run = get_test_run(dag)
Original file line number Diff line number Diff line change
Expand Up @@ -41,13 +41,9 @@
"GCP_AUTOML_VIDEO_BUCKET", "gs://INVALID BUCKET NAME/hmdb_split1.csv"
)

# Example values
DATASET_ID = "VCN123455678"

# Example model
MODEL = {
"display_name": "auto_model_1",
"dataset_id": DATASET_ID,
"video_classification_model_metadata": {},
}

Expand All @@ -69,12 +65,13 @@
catchup=False,
user_defined_macros={"extract_object_id": extract_object_id},
tags=["example"],
) as example_dag:
) as dag:
create_dataset_task = AutoMLCreateDatasetOperator(
task_id="create_dataset_task", dataset=DATASET, location=GCP_AUTOML_LOCATION
)

dataset_id = cast(str, XComArg(create_dataset_task, key="dataset_id"))
MODEL["dataset_id"] = dataset_id

import_dataset_task = AutoMLImportDataOperator(
task_id="import_dataset_task",
Expand Down Expand Up @@ -103,11 +100,24 @@
project_id=GCP_PROJECT_ID,
)

# TEST BODY
import_dataset_task >> create_model
# TEST TEARDOWN
delete_model_task >> delete_datasets_task

# Task dependencies created via `XComArgs`:
# create_dataset_task >> import_dataset_task
# create_dataset_task >> create_model
# create_model >> delete_model_task
# create_dataset_task >> delete_datasets_task

from tests.system.utils.watcher import watcher

# This test needs watcher in order to properly mark success/failure
# when "tearDown" task with trigger rule is part of the DAG
list(dag.tasks) >> watcher()

from tests.system.utils import get_test_run # noqa: E402

# Needed to run the example DAG with pytest (see: tests/system/README.md#run_via_pytest)
test_run = get_test_run(dag)
Original file line number Diff line number Diff line change
Expand Up @@ -42,13 +42,10 @@
"gs://INVALID BUCKET NAME/youtube_8m_videos_animal_tiny.csv",
)

# Example values
DATASET_ID = "VOT123456789"

# Example model
MODEL = {
"display_name": "auto_model_1",
"dataset_id": DATASET_ID,
"video_object_tracking_model_metadata": {},
}

Expand All @@ -70,12 +67,13 @@
catchup=False,
user_defined_macros={"extract_object_id": extract_object_id},
tags=["example"],
) as example_dag:
) as dag:
create_dataset_task = AutoMLCreateDatasetOperator(
task_id="create_dataset_task", dataset=DATASET, location=GCP_AUTOML_LOCATION
)

dataset_id = cast(str, XComArg(create_dataset_task, key="dataset_id"))
MODEL["dataset_id"] = dataset_id

import_dataset_task = AutoMLImportDataOperator(
task_id="import_dataset_task",
Expand Down Expand Up @@ -104,11 +102,24 @@
project_id=GCP_PROJECT_ID,
)

# TEST BODY
import_dataset_task >> create_model
# TEST TEARDOWN
delete_model_task >> delete_datasets_task

# Task dependencies created via `XComArgs`:
# create_dataset_task >> import_dataset_task
# create_dataset_task >> create_model
# create_model >> delete_model_task
# create_dataset_task >> delete_datasets_task

from tests.system.utils.watcher import watcher

# This test needs watcher in order to properly mark success/failure
# when "tearDown" task with trigger rule is part of the DAG
list(dag.tasks) >> watcher()

from tests.system.utils import get_test_run # noqa: E402

# Needed to run the example DAG with pytest (see: tests/system/README.md#run_via_pytest)
test_run = get_test_run(dag)
Original file line number Diff line number Diff line change
Expand Up @@ -41,13 +41,9 @@
"GCP_AUTOML_DETECTION_BUCKET", "gs://INVALID BUCKET NAME/img/openimage/csv/salads_ml_use.csv"
)

# Example values
DATASET_ID = ""

# Example model
MODEL = {
"display_name": "auto_model",
"dataset_id": DATASET_ID,
"image_object_detection_model_metadata": {},
}

Expand All @@ -69,12 +65,13 @@
catchup=False,
user_defined_macros={"extract_object_id": extract_object_id},
tags=["example"],
) as example_dag:
) as dag:
create_dataset_task = AutoMLCreateDatasetOperator(
task_id="create_dataset_task", dataset=DATASET, location=GCP_AUTOML_LOCATION
)

dataset_id = cast(str, XComArg(create_dataset_task, key="dataset_id"))
MODEL["dataset_id"] = dataset_id

import_dataset_task = AutoMLImportDataOperator(
task_id="import_dataset_task",
Expand Down Expand Up @@ -103,11 +100,24 @@
project_id=GCP_PROJECT_ID,
)

# TEST BODY
import_dataset_task >> create_model
# TEST TEARDOWN
delete_model_task >> delete_datasets_task

# Task dependencies created via `XComArgs`:
# create_dataset_task >> import_dataset_task
# create_dataset_task >> create_model
# create_model >> delete_model_task
# create_dataset_task >> delete_datasets_task

from tests.system.utils.watcher import watcher

# This test needs watcher in order to properly mark success/failure
# when "tearDown" task with trigger rule is part of the DAG
list(dag.tasks) >> watcher()

from tests.system.utils import get_test_run # noqa: E402

# Needed to run the example DAG with pytest (see: tests/system/README.md#run_via_pytest)
test_run = get_test_run(dag)

0 comments on commit 6c854dc

Please sign in to comment.