Skip to content

Commit

Permalink
Add type annotations for mlengine_operator_utils (#10297)
Browse files Browse the repository at this point in the history
Add type annotations, including a few changes to ensure the right types
are passed through. Specifically, if region is not given, it must be
provided in the DAG's default_args.
  • Loading branch information
coopergillan committed Aug 16, 2020
1 parent 382c101 commit e195a98
Showing 1 changed file with 23 additions and 16 deletions.
39 changes: 23 additions & 16 deletions airflow/providers/google/cloud/utils/mlengine_operator_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,31 +24,35 @@
import json
import os
import re
from typing import Callable, Dict, Iterable, List, Optional, Tuple, TypeVar
from urllib.parse import urlsplit

import dill

from airflow import DAG
from airflow.exceptions import AirflowException
from airflow.operators.python import PythonOperator
from airflow.providers.google.cloud.hooks.gcs import GCSHook
from airflow.providers.google.cloud.operators.dataflow import DataflowCreatePythonJobOperator
from airflow.providers.google.cloud.operators.mlengine import MLEngineStartBatchPredictionJobOperator


def create_evaluate_ops(task_prefix, # pylint: disable=too-many-arguments
data_format,
input_paths,
prediction_path,
metric_fn_and_keys,
validate_fn,
batch_prediction_job_id=None,
project_id=None,
region=None,
dataflow_options=None,
model_uri=None,
model_name=None,
version_name=None,
dag=None,
T = TypeVar("T", bound=Callable) # pylint: disable=invalid-name


def create_evaluate_ops(task_prefix: str, # pylint: disable=too-many-arguments
data_format: str,
input_paths: List[str],
prediction_path: str,
metric_fn_and_keys: Tuple[T, Iterable[str]],
validate_fn: T,
batch_prediction_job_id: Optional[str] = None,
region: Optional[str] = None,
project_id: Optional[str] = None,
dataflow_options: Optional[Dict] = None,
model_uri: Optional[str] = None,
model_name: Optional[str] = None,
version_name: Optional[str] = None,
dag: Optional[DAG] = None,
py_interpreter="python3"):
"""
Creates Operators needed for model evaluation and returns.
Expand Down Expand Up @@ -186,6 +190,9 @@ def validate_err_and_count(summary):
:rtype: tuple(DataFlowPythonOperator, DataFlowPythonOperator,
PythonOperator)
"""
batch_prediction_job_id = batch_prediction_job_id or ""
dataflow_options = dataflow_options or {}
region = region or ""

# Verify that task_prefix doesn't have any special characters except hyphen
# '-', which is the only allowed non-alphanumeric character by Dataflow.
Expand All @@ -203,7 +210,7 @@ def validate_err_and_count(summary):
if dag is not None and dag.default_args is not None:
default_args = dag.default_args
project_id = project_id or default_args.get('project_id')
region = region or default_args.get('region')
region = region or default_args['region']
model_name = model_name or default_args.get('model_name')
version_name = version_name or default_args.get('version_name')
dataflow_options = dataflow_options or \
Expand Down

0 comments on commit e195a98

Please sign in to comment.