Skip to content

feat: add JAX as Computation Backend #1646

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 26 commits into from
Jul 18, 2023

Conversation

agaraman0
Copy link
Contributor

@agaraman0 agaraman0 commented Jun 14, 2023

Integrating JAX as a computation backend in Docarray similar to Pytorch, Tensorflow, and Numpy

  • JaxArray typing implementation
  • JaxArray typing unit tests
  • JaxCompBackend methods implementation
  • JaxCompBackend basic tests
  • DocVec working
  • Integration tests
  • function annotations and commenting
  • docs updations
from docarray import BaseDoc
from docarray.typing import JaxArray
import jax.numpy as jnp


class MyDoc(BaseDoc):
    arr: JaxArray
    image_arr: JaxArray[3, 224, 224]
    square_crop: JaxArray[3, 'x', 'x']
    random_image: JaxArray[3, ...]  # first dimension is fixed, can have arbitrary shape


# create a document with tensors
doc = MyDoc(
    arr=jnp.zeros((128,)),
    image_arr=jnp.zeros((3, 224, 224)),
    square_crop=jnp.zeros((3, 64, 64)),
    random_image=jnp.zeros((3, 128, 256)),
)
assert doc.image_arr.shape == (3, 224, 224)

# automatic shape conversion
doc = MyDoc(
    arr=np.zeros((128,)),
    image_arr=np.zeros((224, 224, 3)),  # will reshape to (3, 224, 224)
    square_crop=np.zeros((3, 128, 128)),
    random_image=np.zeros((3, 64, 128)),
)
assert doc.image_arr.shape == (3, 224, 224)

@agaraman0 agaraman0 force-pushed the feat-comp-jax-backend branch 3 times, most recently from 1aef069 to fdcbe92 Compare June 20, 2023 10:15
@agaraman0 agaraman0 marked this pull request as ready for review June 22, 2023 10:04
@samsja samsja self-requested a review June 22, 2023 10:06
Copy link
Member

@samsja samsja left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks mostly good ! Once we have :

  • DocVec working
  • Integration test

we will be good to go

@agaraman0 agaraman0 requested review from JoanFM and samsja June 23, 2023 12:27
@agaraman0 agaraman0 force-pushed the feat-comp-jax-backend branch from a679638 to 83a2b9b Compare July 4, 2023 04:39
Copy link
Member

@samsja samsja left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

looks really good. I added some tiny comments.

We are just missing the interagration test before we merge !

@samsja samsja changed the title JAX as Computation Backend similar to Pytorch, Tensorflow and Numpy Feat: add JAX as Computation Backend Jul 5, 2023
@samsja
Copy link
Member

samsja commented Jul 5, 2023

btw looking at the CI it seems that you still have place where you load jax without using the helper import and therefore it throw an error

@JoanFM
Copy link
Member

JoanFM commented Jul 5, 2023

Hey @agaraman0 ,

would it be possible for you to add a code snippet as an example of the usage of this feature in the PR description. Like this we can add it in future release notes to highlight this awesome feature.

@agaraman0
Copy link
Contributor Author

Hey @agaraman0 ,

would it be possible for you to add a code snippet as an example of the usage of this feature in the PR description. Like this we can add it in future release notes to highlight this awesome feature.

Sure will do that, I am just looking for some existing example if I can replicate one.

@JoanFM JoanFM changed the title Feat: add JAX as Computation Backend feat: add JAX as Computation Backend Jul 6, 2023
@agaraman0 agaraman0 force-pushed the feat-comp-jax-backend branch 2 times, most recently from 92e32ec to 6b57436 Compare July 11, 2023 11:54
@JoanFM JoanFM requested review from scott-martens and removed request for scott-martens July 14, 2023 13:08
@JoanFM
Copy link
Member

JoanFM commented Jul 14, 2023

@agaraman0 can you add an equivalent test to the one added for Pytorch and TF in this PR #1696

Signed-off-by: agaraman0 <[email protected]>
Copy link
Contributor

@scott-martens scott-martens left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think I found all docstrings. A few changes recommended.

@agaraman0
Copy link
Contributor Author

Hey @samsja @JoanFM PR is ready to be merged, addressing all concerns and review comments.

Copy link
Member

@samsja samsja left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great PR looking forward to merge

@JoanFM JoanFM merged commit b306c80 into docarray:main Jul 18, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
Archived in project
Development

Successfully merging this pull request may close these issues.

4 participants