Skip to content

Commit

Permalink
Use a single statement with multiple contexts instead of nested state…
Browse files Browse the repository at this point in the history
…ments in providers (#33768)
  • Loading branch information
hussein-awala committed Aug 26, 2023
1 parent 4bae275 commit 6d182be
Show file tree
Hide file tree
Showing 11 changed files with 213 additions and 228 deletions.
124 changes: 61 additions & 63 deletions airflow/providers/apache/hive/hooks/hive.py
Original file line number Diff line number Diff line change
Expand Up @@ -236,58 +236,55 @@ def run_cli(
if schema:
hql = f"USE {schema};\n{hql}"

with TemporaryDirectory(prefix="airflow_hiveop_") as tmp_dir:
with NamedTemporaryFile(dir=tmp_dir) as f:
hql += "\n"
f.write(hql.encode("UTF-8"))
f.flush()
hive_cmd = self._prepare_cli_cmd()
env_context = get_context_from_env_var()
# Only extend the hive_conf if it is defined.
if hive_conf:
env_context.update(hive_conf)
hive_conf_params = self._prepare_hiveconf(env_context)
if self.mapred_queue:
hive_conf_params.extend(
[
"-hiveconf",
f"mapreduce.job.queuename={self.mapred_queue}",
"-hiveconf",
f"mapred.job.queue.name={self.mapred_queue}",
"-hiveconf",
f"tez.queue.name={self.mapred_queue}",
]
)
with TemporaryDirectory(prefix="airflow_hiveop_") as tmp_dir, NamedTemporaryFile(dir=tmp_dir) as f:
hql += "\n"
f.write(hql.encode("UTF-8"))
f.flush()
hive_cmd = self._prepare_cli_cmd()
env_context = get_context_from_env_var()
# Only extend the hive_conf if it is defined.
if hive_conf:
env_context.update(hive_conf)
hive_conf_params = self._prepare_hiveconf(env_context)
if self.mapred_queue:
hive_conf_params.extend(
[
"-hiveconf",
f"mapreduce.job.queuename={self.mapred_queue}",
"-hiveconf",
f"mapred.job.queue.name={self.mapred_queue}",
"-hiveconf",
f"tez.queue.name={self.mapred_queue}",
]
)

if self.mapred_queue_priority:
hive_conf_params.extend(
["-hiveconf", f"mapreduce.job.priority={self.mapred_queue_priority}"]
)
if self.mapred_queue_priority:
hive_conf_params.extend(["-hiveconf", f"mapreduce.job.priority={self.mapred_queue_priority}"])

if self.mapred_job_name:
hive_conf_params.extend(["-hiveconf", f"mapred.job.name={self.mapred_job_name}"])
if self.mapred_job_name:
hive_conf_params.extend(["-hiveconf", f"mapred.job.name={self.mapred_job_name}"])

hive_cmd.extend(hive_conf_params)
hive_cmd.extend(["-f", f.name])
hive_cmd.extend(hive_conf_params)
hive_cmd.extend(["-f", f.name])

if verbose:
self.log.info("%s", " ".join(hive_cmd))
sub_process: Any = subprocess.Popen(
hive_cmd, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, cwd=tmp_dir, close_fds=True
)
self.sub_process = sub_process
stdout = ""
for line in iter(sub_process.stdout.readline, b""):
line = line.decode()
stdout += line
if verbose:
self.log.info("%s", " ".join(hive_cmd))
sub_process: Any = subprocess.Popen(
hive_cmd, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, cwd=tmp_dir, close_fds=True
)
self.sub_process = sub_process
stdout = ""
for line in iter(sub_process.stdout.readline, b""):
line = line.decode()
stdout += line
if verbose:
self.log.info(line.strip())
sub_process.wait()
self.log.info(line.strip())
sub_process.wait()

if sub_process.returncode:
raise AirflowException(stdout)
if sub_process.returncode:
raise AirflowException(stdout)

return stdout
return stdout

def test_hql(self, hql: str) -> None:
"""Test an hql statement using the hive cli and EXPLAIN."""
Expand Down Expand Up @@ -376,25 +373,26 @@ def _infer_field_types_from_df(df: pd.DataFrame) -> dict[Any, Any]:
if pandas_kwargs is None:
pandas_kwargs = {}

with TemporaryDirectory(prefix="airflow_hiveop_") as tmp_dir:
with NamedTemporaryFile(dir=tmp_dir, mode="w") as f:
if field_dict is None:
field_dict = _infer_field_types_from_df(df)

df.to_csv(
path_or_buf=f,
sep=delimiter,
header=False,
index=False,
encoding=encoding,
date_format="%Y-%m-%d %H:%M:%S",
**pandas_kwargs,
)
f.flush()
with TemporaryDirectory(prefix="airflow_hiveop_") as tmp_dir, NamedTemporaryFile(
dir=tmp_dir, mode="w"
) as f:
if field_dict is None:
field_dict = _infer_field_types_from_df(df)

df.to_csv(
path_or_buf=f,
sep=delimiter,
header=False,
index=False,
encoding=encoding,
date_format="%Y-%m-%d %H:%M:%S",
**pandas_kwargs,
)
f.flush()

return self.load_file(
filepath=f.name, table=table, delimiter=delimiter, field_dict=field_dict, **kwargs
)
return self.load_file(
filepath=f.name, table=table, delimiter=delimiter, field_dict=field_dict, **kwargs
)

def load_file(
self,
Expand Down
29 changes: 14 additions & 15 deletions airflow/providers/apache/hive/transfers/mysql_to_hive.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,21 +136,20 @@ def execute(self, context: Context):
mysql = MySqlHook(mysql_conn_id=self.mysql_conn_id)
self.log.info("Dumping MySQL query results to local file")
with NamedTemporaryFile(mode="w", encoding="utf-8") as f:
with closing(mysql.get_conn()) as conn:
with closing(conn.cursor()) as cursor:
cursor.execute(self.sql)
csv_writer = csv.writer(
f,
delimiter=self.delimiter,
quoting=self.quoting,
quotechar=self.quotechar if self.quoting != csv.QUOTE_NONE else None,
escapechar=self.escapechar,
)
field_dict = {}
if cursor.description is not None:
for field in cursor.description:
field_dict[field[0]] = self.type_map(field[1])
csv_writer.writerows(cursor) # type: ignore[arg-type]
with closing(mysql.get_conn()) as conn, closing(conn.cursor()) as cursor:
cursor.execute(self.sql)
csv_writer = csv.writer(
f,
delimiter=self.delimiter,
quoting=self.quoting,
quotechar=self.quotechar if self.quoting != csv.QUOTE_NONE else None,
escapechar=self.escapechar,
)
field_dict = {}
if cursor.description is not None:
for field in cursor.description:
field_dict[field[0]] = self.type_map(field[1])
csv_writer.writerows(cursor) # type: ignore[arg-type]
f.flush()
self.log.info("Loading file into Hive")
hive.load_file(
Expand Down
65 changes: 32 additions & 33 deletions airflow/providers/apache/pig/hooks/pig.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,41 +64,40 @@ def run_cli(self, pig: str, pig_opts: str | None = None, verbose: bool = True) -
>>> ("hdfs://" in result)
True
"""
with TemporaryDirectory(prefix="airflow_pigop_") as tmp_dir:
with NamedTemporaryFile(dir=tmp_dir) as f:
f.write(pig.encode("utf-8"))
f.flush()
fname = f.name
pig_bin = "pig"
cmd_extra: list[str] = []

pig_cmd = [pig_bin]

if self.pig_properties:
pig_cmd.extend(self.pig_properties)
if pig_opts:
pig_opts_list = pig_opts.split()
pig_cmd.extend(pig_opts_list)
with TemporaryDirectory(prefix="airflow_pigop_") as tmp_dir, NamedTemporaryFile(dir=tmp_dir) as f:
f.write(pig.encode("utf-8"))
f.flush()
fname = f.name
pig_bin = "pig"
cmd_extra: list[str] = []

pig_cmd = [pig_bin]

if self.pig_properties:
pig_cmd.extend(self.pig_properties)
if pig_opts:
pig_opts_list = pig_opts.split()
pig_cmd.extend(pig_opts_list)

pig_cmd.extend(["-f", fname] + cmd_extra)

if verbose:
self.log.info("%s", " ".join(pig_cmd))
sub_process: Any = subprocess.Popen(
pig_cmd, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, cwd=tmp_dir, close_fds=True
)
self.sub_process = sub_process
stdout = ""
for line in iter(sub_process.stdout.readline, b""):
stdout += line.decode("utf-8")
if verbose:
self.log.info(line.strip())
sub_process.wait()

pig_cmd.extend(["-f", fname] + cmd_extra)
if sub_process.returncode:
raise AirflowException(stdout)

if verbose:
self.log.info("%s", " ".join(pig_cmd))
sub_process: Any = subprocess.Popen(
pig_cmd, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, cwd=tmp_dir, close_fds=True
)
self.sub_process = sub_process
stdout = ""
for line in iter(sub_process.stdout.readline, b""):
stdout += line.decode("utf-8")
if verbose:
self.log.info(line.strip())
sub_process.wait()

if sub_process.returncode:
raise AirflowException(stdout)

return stdout
return stdout

def kill(self) -> None:
"""Kill Pig job."""
Expand Down
15 changes: 8 additions & 7 deletions airflow/providers/dbt/cloud/hooks/dbt.py
Original file line number Diff line number Diff line change
Expand Up @@ -234,13 +234,14 @@ async def get_job_details(
endpoint = f"{account_id}/runs/{run_id}/"
headers, tenant = await self.get_headers_tenants_from_connection()
url, params = self.get_request_url_params(tenant, endpoint, include_related)
async with aiohttp.ClientSession(headers=headers) as session:
async with session.get(url, params=params) as response:
try:
response.raise_for_status()
return await response.json()
except ClientResponseError as e:
raise AirflowException(str(e.status) + ":" + e.message)
async with aiohttp.ClientSession(headers=headers) as session, session.get(
url, params=params
) as response:
try:
response.raise_for_status()
return await response.json()
except ClientResponseError as e:
raise AirflowException(f"{e.status}:{e.message}")

async def get_job_status(
self, run_id: int, account_id: int | None = None, include_related: list[str] | None = None
Expand Down
10 changes: 4 additions & 6 deletions airflow/providers/exasol/hooks/exasol.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,9 +97,8 @@ def get_records(
sql statements to execute
:param parameters: The parameters to render the SQL query with.
"""
with closing(self.get_conn()) as conn:
with closing(conn.execute(sql, parameters)) as cur:
return cur.fetchall()
with closing(self.get_conn()) as conn, closing(conn.execute(sql, parameters)) as cur:
return cur.fetchall()

def get_first(self, sql: str | list[str], parameters: Iterable | Mapping[str, Any] | None = None) -> Any:
"""Execute the SQL and return the first resulting row.
Expand All @@ -108,9 +107,8 @@ def get_first(self, sql: str | list[str], parameters: Iterable | Mapping[str, An
sql statements to execute
:param parameters: The parameters to render the SQL query with.
"""
with closing(self.get_conn()) as conn:
with closing(conn.execute(sql, parameters)) as cur:
return cur.fetchone()
with closing(self.get_conn()) as conn, closing(conn.execute(sql, parameters)) as cur:
return cur.fetchone()

def export_to_file(
self,
Expand Down
7 changes: 3 additions & 4 deletions airflow/providers/google/cloud/hooks/gcs.py
Original file line number Diff line number Diff line change
Expand Up @@ -550,10 +550,9 @@ def _call_with_retry(f: Callable[[], None]) -> None:
if gzip:
filename_gz = filename + ".gz"

with open(filename, "rb") as f_in:
with gz.open(filename_gz, "wb") as f_out:
shutil.copyfileobj(f_in, f_out)
filename = filename_gz
with open(filename, "rb") as f_in, gz.open(filename_gz, "wb") as f_out:
shutil.copyfileobj(f_in, f_out)
filename = filename_gz

_call_with_retry(
partial(blob.upload_from_filename, filename=filename, content_type=mime_type, timeout=timeout)
Expand Down
58 changes: 28 additions & 30 deletions airflow/providers/google/cloud/hooks/kubernetes_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -493,19 +493,18 @@ async def delete_pod(self, name: str, namespace: str):
:param name: Name of the pod.
:param namespace: Name of the pod's namespace.
"""
async with Token(scopes=self.scopes) as token:
async with self.get_conn(token) as connection:
try:
v1_api = async_client.CoreV1Api(connection)
await v1_api.delete_namespaced_pod(
name=name,
namespace=namespace,
body=client.V1DeleteOptions(),
)
except async_client.ApiException as e:
# If the pod is already deleted
if e.status != 404:
raise
async with Token(scopes=self.scopes) as token, self.get_conn(token) as connection:
try:
v1_api = async_client.CoreV1Api(connection)
await v1_api.delete_namespaced_pod(
name=name,
namespace=namespace,
body=client.V1DeleteOptions(),
)
except async_client.ApiException as e:
# If the pod is already deleted
if e.status != 404:
raise

async def read_logs(self, name: str, namespace: str):
"""Read logs inside the pod while starting containers inside.
Expand All @@ -518,20 +517,19 @@ async def read_logs(self, name: str, namespace: str):
:param name: Name of the pod.
:param namespace: Name of the pod's namespace.
"""
async with Token(scopes=self.scopes) as token:
async with self.get_conn(token) as connection:
try:
v1_api = async_client.CoreV1Api(connection)
logs = await v1_api.read_namespaced_pod_log(
name=name,
namespace=namespace,
follow=False,
timestamps=True,
)
logs = logs.splitlines()
for line in logs:
self.log.info("Container logs from %s", line)
return logs
except HTTPError:
self.log.exception("There was an error reading the kubernetes API.")
raise
async with Token(scopes=self.scopes) as token, self.get_conn(token) as connection:
try:
v1_api = async_client.CoreV1Api(connection)
logs = await v1_api.read_namespaced_pod_log(
name=name,
namespace=namespace,
follow=False,
timestamps=True,
)
logs = logs.splitlines()
for line in logs:
self.log.info("Container logs from %s", line)
return logs
except HTTPError:
self.log.exception("There was an error reading the kubernetes API.")
raise

0 comments on commit 6d182be

Please sign in to comment.