Skip to content

Commit

Permalink
Support google-cloud-automl >=2.1.0 (#13505)
Browse files Browse the repository at this point in the history
  • Loading branch information
mik-laj committed Jan 11, 2021
1 parent 947dbb7 commit a6f999b
Show file tree
Hide file tree
Showing 7 changed files with 134 additions and 113 deletions.
1 change: 1 addition & 0 deletions airflow/providers/google/ADDITIONAL_INFO.md
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ Details are covered in the UPDATING.md files for each library, but there are som

| Library name | Previous constraints | Current constraints | |
| --- | --- | --- | --- |
| [``google-cloud-automl``](https://pypi.org/project/google-cloud-automl/) | ``>=0.4.0,<2.0.0`` | ``>=2.1.0,<3.0.0`` | [`UPGRADING.md`](https://github.com/googleapis/python-bigquery-automl/blob/master/UPGRADING.md) |
| [``google-cloud-bigquery-datatransfer``](https://pypi.org/project/google-cloud-bigquery-datatransfer/) | ``>=0.4.0,<2.0.0`` | ``>=3.0.0,<4.0.0`` | [`UPGRADING.md`](https://github.com/googleapis/python-bigquery-datatransfer/blob/master/UPGRADING.md) |
| [``google-cloud-datacatalog``](https://pypi.org/project/google-cloud-datacatalog/) | ``>=0.5.0,<0.8`` | ``>=3.0.0,<4.0.0`` | [`UPGRADING.md`](https://github.com/googleapis/python-datacatalog/blob/master/UPGRADING.md) |
| [``google-cloud-os-login``](https://pypi.org/project/google-cloud-os-login/) | ``>=1.0.0,<2.0.0`` | ``>=2.0.0,<3.0.0`` | [`UPGRADING.md`](https://github.com/googleapis/python-oslogin/blob/master/UPGRADING.md) |
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@
GCP_AUTOML_DATASET_BUCKET = os.environ.get(
"GCP_AUTOML_DATASET_BUCKET", "gs://cloud-ml-tables-data/bank-marketing.csv"
)
TARGET = os.environ.get("GCP_AUTOML_TARGET", "Class")
TARGET = os.environ.get("GCP_AUTOML_TARGET", "Deposit")

# Example values
MODEL_ID = "TBL123456"
Expand Down Expand Up @@ -76,9 +76,9 @@ def get_target_column_spec(columns_specs: List[Dict], column_name: str) -> str:
Using column name returns spec of the column.
"""
for column in columns_specs:
if column["displayName"] == column_name:
if column["display_name"] == column_name:
return extract_object_id(column)
return ""
raise Exception(f"Unknown target column: {column_name}")


# Example DAG to create dataset, train model_id and deploy it.
Expand Down
103 changes: 53 additions & 50 deletions airflow/providers/google/cloud/hooks/automl.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,22 +20,23 @@
from typing import Dict, List, Optional, Sequence, Tuple, Union

from cached_property import cached_property
from google.api_core.operation import Operation
from google.api_core.retry import Retry
from google.cloud.automl_v1beta1 import AutoMlClient, PredictionServiceClient
from google.cloud.automl_v1beta1.types import (
from google.cloud.automl_v1beta1 import (
AutoMlClient,
BatchPredictInputConfig,
BatchPredictOutputConfig,
ColumnSpec,
Dataset,
ExamplePayload,
FieldMask,
ImageObjectDetectionModelDeploymentMetadata,
InputConfig,
Model,
Operation,
PredictionServiceClient,
PredictResponse,
TableSpec,
)
from google.protobuf.field_mask_pb2 import FieldMask

from airflow.providers.google.common.hooks.base_google import GoogleBaseHook

Expand Down Expand Up @@ -123,9 +124,9 @@ def create_model(
:return: `google.cloud.automl_v1beta1.types._OperationFuture` instance
"""
client = self.get_conn()
parent = client.location_path(project_id, location)
parent = f"projects/{project_id}/locations/{location}"
return client.create_model(
parent=parent, model=model, retry=retry, timeout=timeout, metadata=metadata
request={'parent': parent, 'model': model}, retry=retry, timeout=timeout, metadata=metadata or ()
)

@GoogleBaseHook.fallback_to_default_project_id
Expand Down Expand Up @@ -176,15 +177,17 @@ def batch_predict(
:return: `google.cloud.automl_v1beta1.types._OperationFuture` instance
"""
client = self.prediction_client
name = client.model_path(project=project_id, location=location, model=model_id)
name = f"projects/{project_id}/locations/{location}/models/{model_id}"
result = client.batch_predict(
name=name,
input_config=input_config,
output_config=output_config,
params=params,
request={
'name': name,
'input_config': input_config,
'output_config': output_config,
'params': params,
},
retry=retry,
timeout=timeout,
metadata=metadata,
metadata=metadata or (),
)
return result

Expand Down Expand Up @@ -229,14 +232,12 @@ def predict(
:return: `google.cloud.automl_v1beta1.types.PredictResponse` instance
"""
client = self.prediction_client
name = client.model_path(project=project_id, location=location, model=model_id)
name = f"projects/{project_id}/locations/{location}/models/{model_id}"
result = client.predict(
name=name,
payload=payload,
params=params,
request={'name': name, 'payload': payload, 'params': params},
retry=retry,
timeout=timeout,
metadata=metadata,
metadata=metadata or (),
)
return result

Expand Down Expand Up @@ -273,13 +274,12 @@ def create_dataset(
:return: `google.cloud.automl_v1beta1.types.Dataset` instance.
"""
client = self.get_conn()
parent = client.location_path(project=project_id, location=location)
parent = f"projects/{project_id}/locations/{location}"
result = client.create_dataset(
parent=parent,
dataset=dataset,
request={'parent': parent, 'dataset': dataset},
retry=retry,
timeout=timeout,
metadata=metadata,
metadata=metadata or (),
)
return result

Expand Down Expand Up @@ -319,13 +319,12 @@ def import_data(
:return: `google.cloud.automl_v1beta1.types._OperationFuture` instance
"""
client = self.get_conn()
name = client.dataset_path(project=project_id, location=location, dataset=dataset_id)
name = f"projects/{project_id}/locations/{location}/datasets/{dataset_id}"
result = client.import_data(
name=name,
input_config=input_config,
request={'name': name, 'input_config': input_config},
retry=retry,
timeout=timeout,
metadata=metadata,
metadata=metadata or (),
)
return result

Expand Down Expand Up @@ -385,13 +384,10 @@ def list_column_specs( # pylint: disable=too-many-arguments
table_spec=table_spec_id,
)
result = client.list_column_specs(
parent=parent,
field_mask=field_mask,
filter_=filter_,
page_size=page_size,
request={'parent': parent, 'field_mask': field_mask, 'filter': filter_, 'page_size': page_size},
retry=retry,
timeout=timeout,
metadata=metadata,
metadata=metadata or (),
)
return result

Expand Down Expand Up @@ -427,8 +423,10 @@ def get_model(
:return: `google.cloud.automl_v1beta1.types.Model` instance.
"""
client = self.get_conn()
name = client.model_path(project=project_id, location=location, model=model_id)
result = client.get_model(name=name, retry=retry, timeout=timeout, metadata=metadata)
name = f"projects/{project_id}/locations/{location}/models/{model_id}"
result = client.get_model(
request={'name': name}, retry=retry, timeout=timeout, metadata=metadata or ()
)
return result

@GoogleBaseHook.fallback_to_default_project_id
Expand Down Expand Up @@ -463,8 +461,10 @@ def delete_model(
:return: `google.cloud.automl_v1beta1.types._OperationFuture` instance.
"""
client = self.get_conn()
name = client.model_path(project=project_id, location=location, model=model_id)
result = client.delete_model(name=name, retry=retry, timeout=timeout, metadata=metadata)
name = f"projects/{project_id}/locations/{location}/models/{model_id}"
result = client.delete_model(
request={'name': name}, retry=retry, timeout=timeout, metadata=metadata or ()
)
return result

def update_dataset(
Expand Down Expand Up @@ -497,11 +497,10 @@ def update_dataset(
"""
client = self.get_conn()
result = client.update_dataset(
dataset=dataset,
update_mask=update_mask,
request={'dataset': dataset, 'update_mask': update_mask},
retry=retry,
timeout=timeout,
metadata=metadata,
metadata=metadata or (),
)
return result

Expand Down Expand Up @@ -547,13 +546,15 @@ def deploy_model(
:return: `google.cloud.automl_v1beta1.types._OperationFuture` instance.
"""
client = self.get_conn()
name = client.model_path(project=project_id, location=location, model=model_id)
name = f"projects/{project_id}/locations/{location}/models/{model_id}"
result = client.deploy_model(
name=name,
request={
'name': name,
'image_object_detection_model_deployment_metadata': image_detection_metadata,
},
retry=retry,
timeout=timeout,
metadata=metadata,
image_object_detection_model_deployment_metadata=image_detection_metadata,
metadata=metadata or (),
)
return result

Expand Down Expand Up @@ -601,14 +602,12 @@ def list_table_specs(
of the response through the `options` parameter.
"""
client = self.get_conn()
parent = client.dataset_path(project=project_id, location=location, dataset=dataset_id)
parent = f"projects/{project_id}/locations/{location}/datasets/{dataset_id}"
result = client.list_table_specs(
parent=parent,
filter_=filter_,
page_size=page_size,
request={'parent': parent, 'filter': filter_, 'page_size': page_size},
retry=retry,
timeout=timeout,
metadata=metadata,
metadata=metadata or (),
)
return result

Expand Down Expand Up @@ -644,8 +643,10 @@ def list_datasets(
of the response through the `options` parameter.
"""
client = self.get_conn()
parent = client.location_path(project=project_id, location=location)
result = client.list_datasets(parent=parent, retry=retry, timeout=timeout, metadata=metadata)
parent = f"projects/{project_id}/locations/{location}"
result = client.list_datasets(
request={'parent': parent}, retry=retry, timeout=timeout, metadata=metadata or ()
)
return result

@GoogleBaseHook.fallback_to_default_project_id
Expand Down Expand Up @@ -680,6 +681,8 @@ def delete_dataset(
:return: `google.cloud.automl_v1beta1.types._OperationFuture` instance
"""
client = self.get_conn()
name = client.dataset_path(project=project_id, location=location, dataset=dataset_id)
result = client.delete_dataset(name=name, retry=retry, timeout=timeout, metadata=metadata)
name = f"projects/{project_id}/locations/{location}/datasets/{dataset_id}"
result = client.delete_dataset(
request={'name': name}, retry=retry, timeout=timeout, metadata=metadata or ()
)
return result
36 changes: 20 additions & 16 deletions airflow/providers/google/cloud/operators/automl.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,14 @@
from typing import Dict, List, Optional, Sequence, Tuple, Union

from google.api_core.retry import Retry
from google.protobuf.json_format import MessageToDict
from google.cloud.automl_v1beta1 import (
BatchPredictResult,
ColumnSpec,
Dataset,
Model,
PredictResponse,
TableSpec,
)

from airflow.models import BaseOperator
from airflow.providers.google.cloud.hooks.automl import CloudAutoMLHook
Expand Down Expand Up @@ -113,7 +120,7 @@ def execute(self, context):
timeout=self.timeout,
metadata=self.metadata,
)
result = MessageToDict(operation.result())
result = Model.to_dict(operation.result())
model_id = hook.extract_object_id(result)
self.log.info("Model created: %s", model_id)

Expand Down Expand Up @@ -212,7 +219,7 @@ def execute(self, context):
timeout=self.timeout,
metadata=self.metadata,
)
return MessageToDict(result)
return PredictResponse.to_dict(result)


class AutoMLBatchPredictOperator(BaseOperator):
Expand Down Expand Up @@ -324,7 +331,7 @@ def execute(self, context):
timeout=self.timeout,
metadata=self.metadata,
)
result = MessageToDict(operation.result())
result = BatchPredictResult.to_dict(operation.result())
self.log.info("Batch prediction ready.")
return result

Expand Down Expand Up @@ -414,7 +421,7 @@ def execute(self, context):
timeout=self.timeout,
metadata=self.metadata,
)
result = MessageToDict(result)
result = Dataset.to_dict(result)
dataset_id = hook.extract_object_id(result)
self.log.info("Creating completed. Dataset id: %s", dataset_id)

Expand Down Expand Up @@ -513,9 +520,8 @@ def execute(self, context):
timeout=self.timeout,
metadata=self.metadata,
)
result = MessageToDict(operation.result())
operation.result()
self.log.info("Import completed")
return result


class AutoMLTablesListColumnSpecsOperator(BaseOperator):
Expand Down Expand Up @@ -627,7 +633,7 @@ def execute(self, context):
timeout=self.timeout,
metadata=self.metadata,
)
result = [MessageToDict(spec) for spec in page_iterator]
result = [ColumnSpec.to_dict(spec) for spec in page_iterator]
self.log.info("Columns specs obtained.")

return result
Expand Down Expand Up @@ -718,7 +724,7 @@ def execute(self, context):
metadata=self.metadata,
)
self.log.info("Dataset updated.")
return MessageToDict(result)
return Dataset.to_dict(result)


class AutoMLGetModelOperator(BaseOperator):
Expand Down Expand Up @@ -804,7 +810,7 @@ def execute(self, context):
timeout=self.timeout,
metadata=self.metadata,
)
return MessageToDict(result)
return Model.to_dict(result)


class AutoMLDeleteModelOperator(BaseOperator):
Expand Down Expand Up @@ -890,8 +896,7 @@ def execute(self, context):
timeout=self.timeout,
metadata=self.metadata,
)
result = MessageToDict(operation.result())
return result
operation.result()


class AutoMLDeployModelOperator(BaseOperator):
Expand Down Expand Up @@ -991,9 +996,8 @@ def execute(self, context):
timeout=self.timeout,
metadata=self.metadata,
)
result = MessageToDict(operation.result())
operation.result()
self.log.info("Model deployed.")
return result


class AutoMLTablesListTableSpecsOperator(BaseOperator):
Expand Down Expand Up @@ -1092,7 +1096,7 @@ def execute(self, context):
timeout=self.timeout,
metadata=self.metadata,
)
result = [MessageToDict(spec) for spec in page_iterator]
result = [TableSpec.to_dict(spec) for spec in page_iterator]
self.log.info(result)
self.log.info("Table specs obtained.")
return result
Expand Down Expand Up @@ -1173,7 +1177,7 @@ def execute(self, context):
timeout=self.timeout,
metadata=self.metadata,
)
result = [MessageToDict(dataset) for dataset in page_iterator]
result = [Dataset.to_dict(dataset) for dataset in page_iterator]
self.log.info("Datasets obtained.")

self.xcom_push(
Expand Down

0 comments on commit a6f999b

Please sign in to comment.