Skip to content

Commit

Permalink
Use literal dict instead of calling dict() in providers (#33761)
Browse files Browse the repository at this point in the history
  • Loading branch information
hussein-awala committed Aug 26, 2023
1 parent 1e81ed1 commit b115257
Show file tree
Hide file tree
Showing 23 changed files with 491 additions and 484 deletions.
2 changes: 1 addition & 1 deletion airflow/providers/amazon/aws/hooks/redshift_data.py
Expand Up @@ -188,7 +188,7 @@ def get_table_primary_key(
pk_columns = []
token = ""
while True:
kwargs = dict(Id=stmt_id)
kwargs = {"Id": stmt_id}
if token:
kwargs["NextToken"] = token
response = self.conn.get_statement_result(**kwargs)
Expand Down
Expand Up @@ -246,7 +246,7 @@ def clear_not_launched_queued_tasks(self, session: Session = NEW_SESSION) -> Non
if ti.map_index >= 0:
# Old tasks _couldn't_ be mapped, so we don't have to worry about compat
base_label_selector += f",map_index={ti.map_index}"
kwargs = dict(label_selector=base_label_selector)
kwargs = {"label_selector": base_label_selector}
if self.kube_config.kube_client_request_args:
kwargs.update(**self.kube_config.kube_client_request_args)

Expand Down
8 changes: 4 additions & 4 deletions airflow/providers/cncf/kubernetes/operators/pod.py
Expand Up @@ -852,10 +852,10 @@ def patch_already_checked(self, pod: k8s.V1Pod, *, reraise=True):
def on_kill(self) -> None:
if self.pod:
pod = self.pod
kwargs = dict(
name=pod.metadata.name,
namespace=pod.metadata.namespace,
)
kwargs = {
"name": pod.metadata.name,
"namespace": pod.metadata.namespace,
}
if self.termination_grace_period is not None:
kwargs.update(grace_period_seconds=self.termination_grace_period)
self.client.delete_namespaced_pod(**kwargs)
Expand Down
12 changes: 6 additions & 6 deletions airflow/providers/databricks/hooks/databricks_base.py
Expand Up @@ -121,12 +121,12 @@ def my_after_func(retry_state):
self.retry_args["retry"] = retry_if_exception(self._retryable_error)
self.retry_args["after"] = my_after_func
else:
self.retry_args = dict(
stop=stop_after_attempt(self.retry_limit),
wait=wait_exponential(min=self.retry_delay, max=(2**retry_limit)),
retry=retry_if_exception(self._retryable_error),
after=my_after_func,
)
self.retry_args = {
"stop": stop_after_attempt(self.retry_limit),
"wait": wait_exponential(min=self.retry_delay, max=(2**retry_limit)),
"retry": retry_if_exception(self._retryable_error),
"after": my_after_func,
}

@cached_property
def databricks_conn(self) -> Connection:
Expand Down
18 changes: 9 additions & 9 deletions airflow/providers/docker/decorators/docker.py
Expand Up @@ -112,15 +112,15 @@ def execute(self, context: Context):
self.pickling_library.dump({"args": self.op_args, "kwargs": self.op_kwargs}, file)
py_source = self.get_python_source()
write_python_script(
jinja_context=dict(
op_args=self.op_args,
op_kwargs=self.op_kwargs,
pickling_library=self.pickling_library.__name__,
python_callable=self.python_callable.__name__,
python_callable_source=py_source,
expect_airflow=self.expect_airflow,
string_args_global=False,
),
jinja_context={
"op_args": self.op_args,
"op_kwargs": self.op_kwargs,
"pickling_library": self.pickling_library.__name__,
"python_callable": self.python_callable.__name__,
"python_callable_source": py_source,
"expect_airflow": self.expect_airflow,
"string_args_global": False,
},
filename=script_filename,
)

Expand Down
14 changes: 7 additions & 7 deletions airflow/providers/elasticsearch/hooks/elasticsearch.py
Expand Up @@ -92,13 +92,13 @@ def get_conn(self) -> ESConnection:
conn_id = getattr(self, self.conn_name_attr)
conn = self.connection or self.get_connection(conn_id)

conn_args = dict(
host=conn.host,
port=conn.port,
user=conn.login or None,
password=conn.password or None,
scheme=conn.schema or "http",
)
conn_args = {
"host": conn.host,
"port": conn.port,
"user": conn.login or None,
"password": conn.password or None,
"scheme": conn.schema or "http",
}

if conn.extra_dejson.get("http_compress", False):
conn_args["http_compress"] = bool(["http_compress"])
Expand Down
12 changes: 6 additions & 6 deletions airflow/providers/exasol/hooks/exasol.py
Expand Up @@ -56,12 +56,12 @@ def __init__(self, *args, **kwargs) -> None:
def get_conn(self) -> ExaConnection:
conn_id = getattr(self, self.conn_name_attr)
conn = self.get_connection(conn_id)
conn_args = dict(
dsn=f"{conn.host}:{conn.port}",
user=conn.login,
password=conn.password,
schema=self.schema or conn.schema,
)
conn_args = {
"dsn": f"{conn.host}:{conn.port}",
"user": conn.login,
"password": conn.password,
"schema": self.schema or conn.schema,
}
# check for parameters in conn.extra
for arg_name, arg_val in conn.extra_dejson.items():
if arg_name in ["compression", "encryption", "json_lib", "client_name"]:
Expand Down
4 changes: 3 additions & 1 deletion airflow/providers/google/ads/hooks/ads.py
Expand Up @@ -224,7 +224,9 @@ def _search(

iterators = []
for client_id in client_ids:
iterator = service.search(request=dict(customer_id=client_id, query=query, page_size=page_size))
iterator = service.search(
request={"customer_id": client_id, "query": query, "page_size": page_size}
)
iterators.append(iterator)

self.log.info("Fetched Google Ads Iterators")
Expand Down
Expand Up @@ -100,19 +100,19 @@ def get_absolute_path(path):
return os.path.join(HOME_DIR, path)


postgres_kwargs = dict(
user=quote_plus(GCSQL_POSTGRES_USER),
password=quote_plus(GCSQL_POSTGRES_PASSWORD),
public_port=GCSQL_POSTGRES_PUBLIC_PORT,
public_ip=quote_plus(GCSQL_POSTGRES_PUBLIC_IP),
project_id=quote_plus(GCP_PROJECT_ID),
location=quote_plus(GCP_REGION),
instance=quote_plus(GCSQL_POSTGRES_INSTANCE_NAME_QUERY),
database=quote_plus(GCSQL_POSTGRES_DATABASE_NAME),
client_cert_file=quote_plus(get_absolute_path(GCSQL_POSTGRES_CLIENT_CERT_FILE)),
client_key_file=quote_plus(get_absolute_path(GCSQL_POSTGRES_CLIENT_KEY_FILE)),
server_ca_file=quote_plus(get_absolute_path(GCSQL_POSTGRES_SERVER_CA_FILE)),
)
postgres_kwargs = {
"user": quote_plus(GCSQL_POSTGRES_USER),
"password": quote_plus(GCSQL_POSTGRES_PASSWORD),
"public_port": GCSQL_POSTGRES_PUBLIC_PORT,
"public_ip": quote_plus(GCSQL_POSTGRES_PUBLIC_IP),
"project_id": quote_plus(GCP_PROJECT_ID),
"location": quote_plus(GCP_REGION),
"instance": quote_plus(GCSQL_POSTGRES_INSTANCE_NAME_QUERY),
"database": quote_plus(GCSQL_POSTGRES_DATABASE_NAME),
"client_cert_file": quote_plus(get_absolute_path(GCSQL_POSTGRES_CLIENT_CERT_FILE)),
"client_key_file": quote_plus(get_absolute_path(GCSQL_POSTGRES_CLIENT_KEY_FILE)),
"server_ca_file": quote_plus(get_absolute_path(GCSQL_POSTGRES_SERVER_CA_FILE)),
}

# The connections below are created using one of the standard approaches - via environment
# variables named AIRFLOW_CONN_* . The connections can also be created in the database
Expand Down Expand Up @@ -166,19 +166,19 @@ def get_absolute_path(path):
"sslrootcert={server_ca_file}".format(**postgres_kwargs)
)

mysql_kwargs = dict(
user=quote_plus(GCSQL_MYSQL_USER),
password=quote_plus(GCSQL_MYSQL_PASSWORD),
public_port=GCSQL_MYSQL_PUBLIC_PORT,
public_ip=quote_plus(GCSQL_MYSQL_PUBLIC_IP),
project_id=quote_plus(GCP_PROJECT_ID),
location=quote_plus(GCP_REGION),
instance=quote_plus(GCSQL_MYSQL_INSTANCE_NAME_QUERY),
database=quote_plus(GCSQL_MYSQL_DATABASE_NAME),
client_cert_file=quote_plus(get_absolute_path(GCSQL_MYSQL_CLIENT_CERT_FILE)),
client_key_file=quote_plus(get_absolute_path(GCSQL_MYSQL_CLIENT_KEY_FILE)),
server_ca_file=quote_plus(get_absolute_path(GCSQL_MYSQL_SERVER_CA_FILE)),
)
mysql_kwargs = {
"user": quote_plus(GCSQL_MYSQL_USER),
"password": quote_plus(GCSQL_MYSQL_PASSWORD),
"public_port": GCSQL_MYSQL_PUBLIC_PORT,
"public_ip": quote_plus(GCSQL_MYSQL_PUBLIC_IP),
"project_id": quote_plus(GCP_PROJECT_ID),
"location": quote_plus(GCP_REGION),
"instance": quote_plus(GCSQL_MYSQL_INSTANCE_NAME_QUERY),
"database": quote_plus(GCSQL_MYSQL_DATABASE_NAME),
"client_cert_file": quote_plus(get_absolute_path(GCSQL_MYSQL_CLIENT_CERT_FILE)),
"client_key_file": quote_plus(get_absolute_path(GCSQL_MYSQL_CLIENT_KEY_FILE)),
"server_ca_file": quote_plus(get_absolute_path(GCSQL_MYSQL_SERVER_CA_FILE)),
}

# MySQL: connect via proxy over TCP (specific proxy version)
os.environ["AIRFLOW_CONN_PROXY_MYSQL_TCP"] = (
Expand Down
10 changes: 5 additions & 5 deletions airflow/providers/google/cloud/hooks/bigtable.py
Expand Up @@ -148,11 +148,11 @@ def create_instance(
instance_labels,
)

cluster_kwargs = dict(
cluster_id=main_cluster_id,
location_id=main_cluster_zone,
default_storage_type=cluster_storage_type,
)
cluster_kwargs = {
"cluster_id": main_cluster_id,
"location_id": main_cluster_zone,
"default_storage_type": cluster_storage_type,
}
if instance_type != enums.Instance.Type.DEVELOPMENT and cluster_nodes:
cluster_kwargs["serve_nodes"] = cluster_nodes
clusters = [instance.cluster(**cluster_kwargs)]
Expand Down
Expand Up @@ -517,7 +517,7 @@ async def get_jobs(self, job_names: list[str]) -> ListTransferJobsAsyncPager:
"""
client = self.get_conn()
jobs_list_request = ListTransferJobsRequest(
filter=json.dumps(dict(project_id=self.project_id, job_names=job_names))
filter=json.dumps({"project_id": self.project_id, "job_names": job_names})
)
return await client.list_transfer_jobs(request=jobs_list_request)

Expand Down
2 changes: 1 addition & 1 deletion airflow/providers/google/cloud/hooks/compute_ssh.py
Expand Up @@ -314,7 +314,7 @@ def _authorize_compute_engine_instance_metadata(self, pubkey):
item["value"] = keys
break
else:
new_dict = dict(key="ssh-keys", value=keys)
new_dict = {"key": "ssh-keys", "value": keys}
metadata["items"] = [new_dict]

self._compute_hook.set_instance_metadata(
Expand Down
12 changes: 6 additions & 6 deletions airflow/providers/google/cloud/hooks/dataflow.py
Expand Up @@ -1236,12 +1236,12 @@ async def get_job(
client = await self.initialize_client(JobsV1Beta3AsyncClient)

request = GetJobRequest(
dict(
project_id=project_id,
job_id=job_id,
view=job_view,
location=location,
)
{
"project_id": project_id,
"job_id": job_id,
"view": job_view,
"location": location,
}
)

job = await client.get_job(
Expand Down

0 comments on commit b115257

Please sign in to comment.