Skip to content

Default split_statements to True for Presto and Trino Hook #47186

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 8 commits into
base: main
Choose a base branch
from

Conversation

willshen99
Copy link

closes #47167

Copy link

boring-cyborg bot commented Feb 28, 2025

Congratulations on your first Pull Request and welcome to the Apache Airflow community! If you have any issues or are unsure about any anything please check our Contributors' Guide (https://github.com/apache/airflow/blob/main/contributing-docs/README.rst)
Here are some useful points:

  • Pay attention to the quality of your code (ruff, mypy and type annotations). Our pre-commits will help you with that.
  • In case of a new feature add useful documentation (in docstrings or in docs/ directory). Adding a new operator? Check this short guide Consider adding an example DAG that shows how users should use it.
  • Consider using Breeze environment for testing locally, it's a heavy docker but it ships with a working Airflow and a lot of integrations.
  • Be patient and persistent. It might take some time to get a review or get the final approval from Committers.
  • Please follow ASF Code of Conduct for all communication including (but not limited to) comments on Pull Requests, Mailing list and Slack.
  • Be sure to read the Airflow Coding style.
  • Always keep your Pull Requests rebased, otherwise your build might fail due to changes not related to your commits.
    Apache Airflow is a community-driven project and together we are making it better 🚀.
    In case of doubts contact the developers at:
    Mailing List: [email protected]
    Slack: https://s.apache.org/airflow-slack

Copy link
Member

@jason810496 jason810496 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good! Would you mind adding a test for this change? Thanks!

@willshen99
Copy link
Author

@jason810496 Thank you for the review! Added test to check default split_statements in run method. Not sure if any other tests are needed?

Copy link
Contributor

@bugraoz93 bugraoz93 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the changes! We should include a test case per overload. Because there are major differences for example one is taking Callable while the other one is None for handler. Something like initialising the hook with these and testing the run method. I think we should follow this for both hooks.

@willshen99
Copy link
Author

Thank you @bugraoz93 for the direction! This makes total sense! I'm new to contributing to Airflow -- is there any provider example I can follow to add tests?

@bugraoz93
Copy link
Contributor

Thank you @bugraoz93 for the direction! This makes total sense! I'm new to contributing to Airflow -- is there any provider example I can follow to add tests?

The tests depend on the case and in this case, I believe it should be under providers/presto/tests/unit/presto/hooks/test_presto.py.

If you would like specific examples, please take a look at them here. I believe there are examples where file_name is None or filled means overload in this context. So in your case, those overloads should be covered in tests according to their behaviours in the code
Code

@overload
def download(
self,
bucket_name: str,
object_name: str,
filename: None = None,
chunk_size: int | None = None,
timeout: int | None = DEFAULT_TIMEOUT,
num_max_attempts: int | None = 1,
user_project: str | None = None,
) -> bytes: ...
@overload
def download(
self,
bucket_name: str,
object_name: str,
filename: str,
chunk_size: int | None = None,
timeout: int | None = DEFAULT_TIMEOUT,
num_max_attempts: int | None = 1,
user_project: str | None = None,
) -> str: ...
def download(
self,
bucket_name: str,
object_name: str,
filename: str | None = None,
chunk_size: int | None = None,
timeout: int | None = DEFAULT_TIMEOUT,
num_max_attempts: int | None = 1,
user_project: str | None = None,
) -> str | bytes:
"""
Download a file from Google Cloud Storage.
When no filename is supplied, the operator loads the file into memory and returns its
content. When a filename is supplied, it writes the file to the specified location and
returns the location. For file sizes that exceed the available memory it is recommended
to write to a file.
:param bucket_name: The bucket to fetch from.
:param object_name: The object to fetch.
:param filename: If set, a local file path where the file should be written to.
:param chunk_size: Blob chunk size.
:param timeout: Request timeout in seconds.
:param num_max_attempts: Number of attempts to download the file.
:param user_project: The identifier of the Google Cloud project to bill for the request.
Required for Requester Pays buckets.
"""
# TODO: future improvement check file size before downloading,
# to check for local space availability
if num_max_attempts is None:
num_max_attempts = 3
for attempt in range(num_max_attempts):
if attempt:
# Wait with exponential backoff scheme before retrying.
timeout_seconds = 2**attempt
time.sleep(timeout_seconds)
try:
client = self.get_conn()
bucket = client.bucket(bucket_name, user_project=user_project)
blob = bucket.blob(blob_name=object_name, chunk_size=chunk_size)
if filename:
blob.download_to_filename(filename, timeout=timeout)
get_hook_lineage_collector().add_input_asset(
context=self, scheme="gs", asset_kwargs={"bucket": bucket.name, "key": blob.name}
)
get_hook_lineage_collector().add_output_asset(
context=self, scheme="file", asset_kwargs={"path": filename}
)
self.log.info("File downloaded to %s", filename)
return filename
else:
get_hook_lineage_collector().add_input_asset(
context=self, scheme="gs", asset_kwargs={"bucket": bucket.name, "key": blob.name}
)
return blob.download_as_bytes()
except GoogleCloudError:
if attempt == num_max_attempts - 1:
self.log.error(
"Download attempt of object: %s from %s has failed. Attempt: %s, max %s.",
object_name,
bucket_name,
attempt,
num_max_attempts,
)
raise
else:
raise NotImplementedError # should not reach this, but makes mypy happy
def download_as_byte_array(
self,
bucket_name: str,
object_name: str,
chunk_size: int | None = None,
timeout: int | None = DEFAULT_TIMEOUT,
num_max_attempts: int | None = 1,
) -> bytes:
"""
Download a file from Google Cloud Storage.
When no filename is supplied, the operator loads the file into memory and returns its
content. When a filename is supplied, it writes the file to the specified location and
returns the location. For file sizes that exceed the available memory it is recommended
to write to a file.
:param bucket_name: The bucket to fetch from.
:param object_name: The object to fetch.
:param chunk_size: Blob chunk size.
:param timeout: Request timeout in seconds.
:param num_max_attempts: Number of attempts to download the file.
"""
# We do not pass filename, so will never receive string as response
return self.download(
bucket_name=bucket_name,
object_name=object_name,
chunk_size=chunk_size,
timeout=timeout,
num_max_attempts=num_max_attempts,
)

Tests
@mock.patch(GCS_STRING.format("GCSHook.get_conn"))
def test_download_as_bytes(self, mock_service):
test_bucket = "test_bucket"
test_object = "test_object"
test_object_bytes = BytesIO(b"input")
download_method = mock_service.return_value.bucket.return_value.blob.return_value.download_as_bytes
download_method.return_value = test_object_bytes
response = self.gcs_hook.download(bucket_name=test_bucket, object_name=test_object, filename=None)
assert response == test_object_bytes
download_method.assert_called_once_with()
@pytest.mark.skipif(not AIRFLOW_V_2_10_PLUS, reason="Hook lineage works in Airflow >= 2.10.0")
@mock.patch("google.cloud.storage.Blob.download_as_bytes")
@mock.patch(GCS_STRING.format("GCSHook.get_conn"))
def test_download_as_bytes_exposes_lineage(self, mock_service, mock_download, hook_lineage_collector):
source_bucket_name = "test-source-bucket"
source_object_name = "test-source-object"
mock_service.return_value.bucket.return_value = storage.Bucket(mock_service, source_bucket_name)
self.gcs_hook.download(bucket_name=source_bucket_name, object_name=source_object_name, filename=None)
assert len(hook_lineage_collector.collected_assets.inputs) == 1
assert len(hook_lineage_collector.collected_assets.outputs) == 0
assert hook_lineage_collector.collected_assets.inputs[0].asset == Asset(
uri=f"gs://{source_bucket_name}/{source_object_name}"
)
@mock.patch(GCS_STRING.format("GCSHook.get_conn"))
def test_download_to_file(self, mock_service):
test_bucket = "test_bucket"
test_object = "test_object"
test_object_bytes = BytesIO(b"input")
test_file = "test_file"
download_filename_method = (
mock_service.return_value.bucket.return_value.blob.return_value.download_to_filename
)
download_filename_method.return_value = None
download_as_a_bytes_method = (
mock_service.return_value.bucket.return_value.blob.return_value.download_as_bytes
)
download_as_a_bytes_method.return_value = test_object_bytes
response = self.gcs_hook.download(
bucket_name=test_bucket, object_name=test_object, filename=test_file
)
assert response == test_file
download_filename_method.assert_called_once_with(test_file, timeout=60)
@pytest.mark.skipif(not AIRFLOW_V_2_10_PLUS, reason="Hook lineage works in Airflow >= 2.10.0")
@mock.patch("google.cloud.storage.Blob.download_to_filename")
@mock.patch(GCS_STRING.format("GCSHook.get_conn"))
def test_download_to_file_exposes_lineage(self, mock_service, mock_download, hook_lineage_collector):
source_bucket_name = "test-source-bucket"
source_object_name = "test-source-object"
file_name = "test.txt"
mock_service.return_value.bucket.return_value = storage.Bucket(mock_service, source_bucket_name)
self.gcs_hook.download(
bucket_name=source_bucket_name, object_name=source_object_name, filename=file_name
)
assert len(hook_lineage_collector.collected_assets.inputs) == 1
assert len(hook_lineage_collector.collected_assets.outputs) == 1
assert hook_lineage_collector.collected_assets.inputs[0].asset == Asset(
uri=f"gs://{source_bucket_name}/{source_object_name}"
)
assert hook_lineage_collector.collected_assets.outputs[0].asset == Asset(uri=f"file://{file_name}")

@HowlyBlood
Copy link

HowlyBlood commented Mar 6, 2025

Hi, this is my first "contribution", even if it is just a suggestion. I'm clearly new to this, but I was wondering if some users has some statement like

CREATE TABLE example (id INT); -- added ; in a comment. If the split is not "comment zone safe", it may raise an error
INSERT INTO example (id) VALUES (1);
SELECT * FROM 

or "BEGIN; ... COMMIT;" which makes no sense if executed in multiple statements .
As split_statement was initially false, it was behaving correctly, but now it may not behave as expected

I'm not sure if it's a good solution, but providing split_statements = hasattr(sql,'__iter__') would prevent any unexpected behaviour: if it's not a iterable, it would not split. No changes from before the update.

I'm still looking for the split_statement's definition, but I think it is splitting str with semicolon and iterables by their elements, so it might be an intuitive way to code : you use an iterable if you want it to be split, otherwise, you use a single string (for a single statement)

["DESCRIBE ... ;","SELECT ...;","BEGIN; ... ; COMMIT;"] --> split_statement = True
"BEGIN; ... ; COMMIT;" --> split_statement = False

I'm not a native English speaker, so tell me if i'm not easy to understand, and i'll try to explain my point of view.

Good job anyway, you pointed out something interesting.

Thanks to this issue, I learnt some new things like @overload, which I had already encountered in Java but never explored. Java wasn't in my career path, but seeing it in Python made me look for explanations. I'm still learning, and this was a great opportunity to expand my beginner knowledge.

See you !

@guan404ming
Copy link
Contributor

Hi, @willshen99 it seems like this PR has been inactive for a while. Could I help with this?

@bugraoz93
Copy link
Contributor

Hi, @willshen99 it seems like this PR has been inactive for a while. Could I help with this?

Feel free to create a new one :)

Copy link

github-actions bot commented Jun 8, 2025

This pull request has been automatically marked as stale because it has not had recent activity. It will be closed in 5 days if no further activity occurs. Thank you for your contributions.

@github-actions github-actions bot added the stale Stale PRs per the .github/workflows/stale.yml policy file label Jun 8, 2025
@eladkal
Copy link
Contributor

eladkal commented Jun 13, 2025

@willshen99 if you can rebase and resolve conflicts I think we can merge this PR after tests passes

@github-actions github-actions bot removed the stale Stale PRs per the .github/workflows/stale.yml policy file label Jun 14, 2025
@willshen99
Copy link
Author

@eladkal Thanks for the review! I have resolved conflicts.

@eladkal
Copy link
Contributor

eladkal commented Jun 26, 2025

Tests are failing

@willshen99
Copy link
Author

Fixed the missing import. Sorry about that.

@eladkal
Copy link
Contributor

eladkal commented Jun 26, 2025

I think we should have a test in Trino and Presto to make sure that when you can change split_statements to false when using SQLExecuteQueryOperator. WDYT?

@willshen99
Copy link
Author

@eladkal Thanks for the suggestion. This makes total sense! I'm new to airflow project, and I need some help on adding test cases. I don't think the test case I added work as expected.

@eladkal eladkal requested a review from bugraoz93 June 30, 2025 19:35
@eladkal eladkal requested a review from jason810496 June 30, 2025 19:35
@eladkal
Copy link
Contributor

eladkal commented Jun 30, 2025

You need to enable pre-commits so it will automaticly fix static checks probelms like wrong import order. Check the contribution guide, it explain how to enable it

@eladkal
Copy link
Contributor

eladkal commented Jun 30, 2025

Some of the tests arw also failing:

FAILED providers/trino/tests/unit/trino/hooks/test_trino.py::TestTrinoHook::test_run - AssertionError: expected call not found.
Expected: run('SELECT 1', False, ('hello', 'world'), <class 'list'>, split_statements=True)
Actual: run('SELECT 1', False, ('hello', 'world'), <class 'list'>)

pytest introspection follows:

Kwargs:
assert equals failed
  {}                           {'split_statements': True}

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Trino split_statements is default to False
6 participants