Skip to content

Commit

Permalink
Remove redundant None provided as default to dict.get() (#11448)
Browse files Browse the repository at this point in the history
  • Loading branch information
kaxil committed Oct 11, 2020
1 parent d8d13fa commit d305876
Show file tree
Hide file tree
Showing 26 changed files with 45 additions and 45 deletions.
2 changes: 1 addition & 1 deletion airflow/api/common/experimental/trigger_dag.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ def _trigger_dag(
state=State.RUNNING,
conf=run_conf,
external_trigger=True,
dag_hash=dag_bag.dags_hash.get(dag_id, None),
dag_hash=dag_bag.dags_hash.get(dag_id),
)

triggers.append(trigger)
Expand Down
2 changes: 1 addition & 1 deletion airflow/api_connexion/endpoints/connection_endpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ def patch_connection(connection_id, session, update_mask=None):
"Connection not found",
detail=f"The Connection with connection_id: `{connection_id}` was not found",
)
if data.get('conn_id', None) and connection.conn_id != data['conn_id']:
if data.get('conn_id') and connection.conn_id != data['conn_id']:
raise BadRequest(detail="The connection_id cannot be updated.")
if update_mask:
update_mask = [i.strip() for i in update_mask]
Expand Down
2 changes: 1 addition & 1 deletion airflow/api_connexion/endpoints/log_endpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ def get_log(session, dag_id, dag_run_id, task_id, task_try_number, full_content=
except BadSignature:
raise BadRequest("Bad Signature. Please use only the tokens provided by the API.")

if metadata.get('download_logs', None) and metadata['download_logs']:
if metadata.get('download_logs') and metadata['download_logs']:
full_content = True

if full_content:
Expand Down
2 changes: 1 addition & 1 deletion airflow/api_connexion/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ def common_error_handler(exception):
"""
if isinstance(exception, ProblemException):

link = EXCEPTIONS_LINK_MAP.get(exception.status, None)
link = EXCEPTIONS_LINK_MAP.get(exception.status)
if link:
response = problem(
status=exception.status,
Expand Down
2 changes: 1 addition & 1 deletion airflow/api_connexion/security.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ def requires_access_decorator(func: T):
def decorated(*args, **kwargs):

check_authentication()
check_authorization(permissions, kwargs.get('dag_id', None))
check_authorization(permissions, kwargs.get('dag_id'))

return func(*args, **kwargs)

Expand Down
2 changes: 1 addition & 1 deletion airflow/jobs/scheduler_job.py
Original file line number Diff line number Diff line change
Expand Up @@ -1555,7 +1555,7 @@ def _create_dag_runs(self, dag_models: Iterable[DagModel], session: Session) ->
"""
for dag_model in dag_models:
dag = self.dagbag.get_dag(dag_model.dag_id, session=session)
dag_hash = self.dagbag.dags_hash.get(dag.dag_id, None)
dag_hash = self.dagbag.dags_hash.get(dag.dag_id)
dag.create_dagrun(
run_type=DagRunType.SCHEDULED,
execution_date=dag_model.next_dagrun,
Expand Down
2 changes: 1 addition & 1 deletion airflow/operators/python.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,7 +187,7 @@ def __init__(
multiple_outputs: bool = False,
**kwargs
) -> None:
kwargs['task_id'] = self._get_unique_task_id(task_id, kwargs.get('dag', None))
kwargs['task_id'] = self._get_unique_task_id(task_id, kwargs.get('dag'))
super().__init__(**kwargs)
self.python_callable = python_callable

Expand Down
2 changes: 1 addition & 1 deletion airflow/providers/amazon/aws/hooks/step_function.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ def start_execution(
self.log.info('Executing Step Function State Machine: %s', state_machine_arn)

response = self.conn.start_execution(**execution_args)
return response.get('executionArn', None)
return response.get('executionArn')

def describe_execution(self, execution_arn: str) -> dict:
"""
Expand Down
6 changes: 3 additions & 3 deletions airflow/providers/apache/spark/hooks/spark_submit.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,9 +209,9 @@ def _resolve_connection(self) -> Dict[str, Any]:

# Determine optional yarn queue from the extra field
extra = conn.extra_dejson
conn_data['queue'] = extra.get('queue', None)
conn_data['deploy_mode'] = extra.get('deploy-mode', None)
conn_data['spark_home'] = extra.get('spark-home', None)
conn_data['queue'] = extra.get('queue')
conn_data['deploy_mode'] = extra.get('deploy-mode')
conn_data['spark_home'] = extra.get('spark-home')
conn_data['spark_binary'] = self._spark_binary or extra.get('spark-binary', "spark-submit")
conn_data['namespace'] = extra.get('namespace')
except AirflowException:
Expand Down
4 changes: 2 additions & 2 deletions airflow/providers/google/cloud/operators/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -480,6 +480,6 @@ def execute(self, context: Dict):
location=self.location,
project_id=self.project_id,
)
self.log.info('Function called successfully. Execution id %s', result.get('executionId', None))
self.xcom_push(context=context, key='execution_id', value=result.get('executionId', None))
self.log.info('Function called successfully. Execution id %s', result.get('executionId'))
self.xcom_push(context=context, key='execution_id', value=result.get('executionId'))
return result
6 changes: 3 additions & 3 deletions airflow/providers/google/cloud/operators/mlengine.py
Original file line number Diff line number Diff line change
Expand Up @@ -283,7 +283,7 @@ def execute(self, context):
# Helper method to check if the existing job's prediction input is the
# same as the request we get here.
def check_existing_job(existing_job):
return existing_job.get('predictionInput', None) == prediction_request['predictionInput']
return existing_job.get('predictionInput') == prediction_request['predictionInput']

finished_prediction_job = hook.create_job(
project_id=self._project_id, job=prediction_request, use_existing_job_fn=check_existing_job
Expand Down Expand Up @@ -1264,12 +1264,12 @@ def execute(self, context):
# Helper method to check if the existing job's training input is the
# same as the request we get here.
def check_existing_job(existing_job):
existing_training_input = existing_job.get('trainingInput', None)
existing_training_input = existing_job.get('trainingInput')
requested_training_input = training_request['trainingInput']
if 'scaleTier' not in existing_training_input:
existing_training_input['scaleTier'] = None

existing_training_input['args'] = existing_training_input.get('args', None)
existing_training_input['args'] = existing_training_input.get('args')
requested_training_input["args"] = (
requested_training_input['args'] if requested_training_input["args"] else None
)
Expand Down
8 changes: 4 additions & 4 deletions airflow/providers/oracle/hooks/oracle.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,11 +57,11 @@ def get_conn(self) -> 'OracleHook':
self.oracle_conn_id # type: ignore[attr-defined] # pylint: disable=no-member
)
conn_config = {'user': conn.login, 'password': conn.password}
dsn = conn.extra_dejson.get('dsn', None)
sid = conn.extra_dejson.get('sid', None)
mod = conn.extra_dejson.get('module', None)
dsn = conn.extra_dejson.get('dsn')
sid = conn.extra_dejson.get('sid')
mod = conn.extra_dejson.get('module')

service_name = conn.extra_dejson.get('service_name', None)
service_name = conn.extra_dejson.get('service_name')
port = conn.port if conn.port else 1521
if dsn and sid and not service_name:
conn_config['dsn'] = cx_Oracle.makedsn(dsn, port, sid)
Expand Down
2 changes: 1 addition & 1 deletion airflow/providers/postgres/hooks/postgres.py
Original file line number Diff line number Diff line change
Expand Up @@ -223,7 +223,7 @@ def _generate_insert_sql(
placeholders = [
"%s",
] * len(values)
replace_index = kwargs.get("replace_index", None)
replace_index = kwargs.get("replace_index")

if target_fields:
target_fields_fragment = ", ".join(target_fields)
Expand Down
2 changes: 1 addition & 1 deletion airflow/providers/redis/hooks/redis.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ def get_conn(self):
self.host = conn.host
self.port = conn.port
self.password = None if str(conn.password).lower() in ['none', 'false', ''] else conn.password
self.db = conn.extra_dejson.get('db', None)
self.db = conn.extra_dejson.get('db')

# check for ssl parameters in conn.extra
ssl_arg_names = [
Expand Down
2 changes: 1 addition & 1 deletion airflow/providers/salesforce/hooks/salesforce.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ def get_conn(self) -> api.Salesforce:
password=connection.password,
security_token=extras['security_token'],
instance_url=connection.host,
domain=extras.get('domain', None),
domain=extras.get('domain'),
)
return self.conn

Expand Down
2 changes: 1 addition & 1 deletion airflow/providers/sendgrid/utils/emailer.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ def send_email(
personalization.add_bcc(Email(bcc_address))

# Add custom_args to personalization if present
pers_custom_args = kwargs.get('personalization_custom_args', None)
pers_custom_args = kwargs.get('personalization_custom_args')
if isinstance(pers_custom_args, dict):
for key in pers_custom_args.keys():
personalization.add_custom_arg(CustomArg(key, pers_custom_args[key]))
Expand Down
2 changes: 1 addition & 1 deletion airflow/providers/sftp/hooks/sftp.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ def __init__(self, ftp_conn_id: str = 'sftp_default', *args, **kwargs) -> None:
if conn.extra is not None:
extra_options = conn.extra_dejson
if 'private_key_pass' in extra_options:
self.private_key_pass = extra_options.get('private_key_pass', None)
self.private_key_pass = extra_options.get('private_key_pass')

# For backward compatibility
# TODO: remove in Airflow 2.1
Expand Down
2 changes: 1 addition & 1 deletion airflow/providers/snowflake/hooks/snowflake.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ def _get_conn_params(self) -> Dict[str, Optional[str]]:
# passphrase for the private key. If your private key file is not encrypted (not recommended), then
# leave the password empty.

private_key_file = conn.extra_dejson.get('private_key_file', None)
private_key_file = conn.extra_dejson.get('private_key_file')
if private_key_file:
with open(private_key_file, "rb") as key:
passphrase = None
Expand Down
12 changes: 6 additions & 6 deletions airflow/sensors/smart_sensor_operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -499,18 +499,18 @@ def email_alert(task_instance, error_info):
sensor_work.log.exception(e, exc_info=True)

def handle_failure(sensor_work, ti):
if sensor_work.execution_context.get('retries', None) and \
if sensor_work.execution_context.get('retries') and \
ti.try_number <= ti.max_tries:
# retry
ti.state = State.UP_FOR_RETRY
if sensor_work.execution_context.get('email_on_retry', None) and \
sensor_work.execution_context.get('email', None):
if sensor_work.execution_context.get('email_on_retry') and \
sensor_work.execution_context.get('email'):
sensor_work.log.info("%s sending email alert for retry", sensor_work.ti_key)
email_alert(ti, error)
else:
ti.state = State.FAILED
if sensor_work.execution_context.get('email_on_failure', None) and \
sensor_work.execution_context.get('email', None):
if sensor_work.execution_context.get('email_on_failure') and \
sensor_work.execution_context.get('email'):
sensor_work.log.info("%s sending email alert for failure", sensor_work.ti_key)
email_alert(ti, error)

Expand Down Expand Up @@ -566,7 +566,7 @@ def _check_and_handle_ti_timeout(self, sensor_work):
:param sensor_work: SensorWork
"""
task_timeout = sensor_work.execution_context.get('timeout', self.timeout)
task_execution_timeout = sensor_work.execution_context.get('execution_timeout', None)
task_execution_timeout = sensor_work.execution_context.get('execution_timeout')
if task_execution_timeout:
task_timeout = min(task_timeout, task_execution_timeout)

Expand Down
2 changes: 1 addition & 1 deletion airflow/utils/decorators.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ def wrapper(*args: Any, **kwargs: Any) -> Any:
dag_args: Dict[str, Any] = {}
dag_params: Dict[str, Any] = {}

dag = kwargs.get('dag', None) or DagContext.get_current_dag()
dag = kwargs.get('dag') or DagContext.get_current_dag()
if dag:
dag_args = copy(dag.default_args) or {}
dag_params = copy(dag.params) or {}
Expand Down
8 changes: 4 additions & 4 deletions airflow/www/views.py
Original file line number Diff line number Diff line change
Expand Up @@ -410,9 +410,9 @@ def get_int_arg(value, default=0):
return default

arg_current_page = request.args.get('page', '0')
arg_search_query = request.args.get('search', None)
arg_tags_filter = request.args.getlist('tags', None)
arg_status_filter = request.args.get('status', None)
arg_search_query = request.args.get('search')
arg_tags_filter = request.args.getlist('tags')
arg_status_filter = request.args.get('status')

if request.args.get('reset_tags') is not None:
flask_session[FILTER_TAGS_COOKIE] = None
Expand Down Expand Up @@ -1282,7 +1282,7 @@ def trigger(self, session=None):
state=State.RUNNING,
conf=run_conf,
external_trigger=True,
dag_hash=current_app.dag_bag.dags_hash.get(dag_id, None),
dag_hash=current_app.dag_bag.dags_hash.get(dag_id),
)

flash(
Expand Down
4 changes: 2 additions & 2 deletions tests/providers/amazon/aws/hooks/test_s3.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ def test_create_bucket_us_standard_region(self, monkeypatch):
hook.create_bucket(bucket_name='new_bucket', region_name='us-east-1')
bucket = hook.get_bucket('new_bucket')
assert bucket is not None
region = bucket.meta.client.get_bucket_location(Bucket=bucket.name).get('LocationConstraint', None)
region = bucket.meta.client.get_bucket_location(Bucket=bucket.name).get('LocationConstraint')
# https://github.com/spulec/moto/pull/1961
# If location is "us-east-1", LocationConstraint should be None
assert region is None
Expand All @@ -90,7 +90,7 @@ def test_create_bucket_other_region(self):
hook.create_bucket(bucket_name='new_bucket', region_name='us-east-2')
bucket = hook.get_bucket('new_bucket')
assert bucket is not None
region = bucket.meta.client.get_bucket_location(Bucket=bucket.name).get('LocationConstraint', None)
region = bucket.meta.client.get_bucket_location(Bucket=bucket.name).get('LocationConstraint')
assert region == 'us-east-2'

def test_check_for_prefix(self, s3_bucket):
Expand Down
4 changes: 2 additions & 2 deletions tests/providers/amazon/aws/hooks/test_step_function.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ def test_start_execution(self):
name='pseudo-state-machine', definition='{}', roleArn='arn:aws:iam::000000000000:role/Role'
)

state_machine_arn = state_machine.get('stateMachineArn', None)
state_machine_arn = state_machine.get('stateMachineArn')

execution_arn = hook.start_execution(
state_machine_arn=state_machine_arn, name=None, state_machine_input={}
Expand All @@ -56,7 +56,7 @@ def test_describe_execution(self):
name='pseudo-state-machine', definition='{}', roleArn='arn:aws:iam::000000000000:role/Role'
)

state_machine_arn = state_machine.get('stateMachineArn', None)
state_machine_arn = state_machine.get('stateMachineArn')

execution_arn = hook.start_execution(
state_machine_arn=state_machine_arn, name=None, state_machine_input={}
Expand Down
4 changes: 2 additions & 2 deletions tests/providers/google/cloud/hooks/test_mlengine.py
Original file line number Diff line number Diff line change
Expand Up @@ -714,7 +714,7 @@ def test_create_mlengine_job_check_existing_job_failed(self, mock_get_conn):

# fmt: on
def check_input(existing_job):
return existing_job.get('someInput', None) == my_job['someInput']
return existing_job.get('someInput') == my_job['someInput']

with self.assertRaises(HttpError):
self.hook.create_job(project_id=project_id, job=my_job, use_existing_job_fn=check_input)
Expand Down Expand Up @@ -748,7 +748,7 @@ def test_create_mlengine_job_check_existing_job_success(self, mock_get_conn):

# fmt: on
def check_input(existing_job):
return existing_job.get('someInput', None) == my_job['someInput']
return existing_job.get('someInput') == my_job['someInput']

create_job_response = self.hook.create_job(
project_id=project_id, job=my_job, use_existing_job_fn=check_input
Expand Down
2 changes: 1 addition & 1 deletion tests/providers/salesforce/hooks/test_salesforce.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ def test_get_conn(self, mock_salesforce, mock_get_connection):
password=mock_get_connection.return_value.password,
security_token=mock_get_connection.return_value.extra_dejson["security_token"],
instance_url=mock_get_connection.return_value.host,
domain=mock_get_connection.return_value.extra_dejson.get("domain", None),
domain=mock_get_connection.return_value.extra_dejson.get("domain"),
)

@patch("airflow.providers.salesforce.hooks.salesforce.Salesforce")
Expand Down
2 changes: 1 addition & 1 deletion tests/test_utils/mock_plugins.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ def mock_plugin_manager(**kwargs):
with ExitStack() as exit_stack:
for attr in PLUGINS_MANAGER_NULLABLE_ATTRIBUTES:
exit_stack.enter_context( # pylint: disable=no-member
mock.patch(f"airflow.plugins_manager.{attr}", kwargs.get(attr, None))
mock.patch(f"airflow.plugins_manager.{attr}", kwargs.get(attr))
)
exit_stack.enter_context( # pylint: disable=no-member
mock.patch("airflow.plugins_manager.import_errors", kwargs.get("import_errors", {}))
Expand Down

0 comments on commit d305876

Please sign in to comment.