Skip to content

Commit efeab90

Browse files
feat: tensor_type for all DocVec serializations (#1679)
Signed-off-by: Johannes Messner <[email protected]>
1 parent 00e980d commit efeab90

File tree

12 files changed

+751
-363
lines changed

12 files changed

+751
-363
lines changed

docarray/array/doc_list/doc_list.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
from typing_inspect import is_union_type
2020

2121
from docarray.array.any_array import AnyDocArray
22-
from docarray.array.doc_list.io import IOMixinArray
22+
from docarray.array.doc_list.io import IOMixinDocList
2323
from docarray.array.doc_list.pushpull import PushPullMixin
2424
from docarray.array.list_advance_indexing import IndexIterType, ListAdvancedIndexing
2525
from docarray.base_doc import AnyDoc, BaseDoc
@@ -42,7 +42,7 @@
4242
class DocList(
4343
ListAdvancedIndexing[T_doc],
4444
PushPullMixin,
45-
IOMixinArray,
45+
IOMixinDocList,
4646
AnyDocArray[T_doc],
4747
):
4848
"""

docarray/array/doc_list/io.py

Lines changed: 46 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
Type,
2424
TypeVar,
2525
Union,
26+
cast,
2627
)
2728

2829
import orjson
@@ -40,9 +41,12 @@
4041
if TYPE_CHECKING:
4142
import pandas as pd
4243

44+
from docarray.array.doc_vec.doc_vec import DocVec
45+
from docarray.array.doc_vec.io import IOMixinDocVec
4346
from docarray.proto import DocListProto
47+
from docarray.typing.tensor.abstract_tensor import AbstractTensor
4448

45-
T = TypeVar('T', bound='IOMixinArray')
49+
T = TypeVar('T', bound='IOMixinDocList')
4650
T_doc = TypeVar('T_doc', bound=BaseDoc)
4751

4852
ARRAY_PROTOCOLS = {'protobuf-array', 'pickle-array', 'json-array'}
@@ -96,7 +100,7 @@ def __getitem__(self, item: slice):
96100
return self.content[item]
97101

98102

99-
class IOMixinArray(Iterable[T_doc]):
103+
class IOMixinDocList(Iterable[T_doc]):
100104
doc_type: Type[T_doc]
101105

102106
@abstractmethod
@@ -515,8 +519,6 @@ class Person(BaseDoc):
515519
doc_dict = _access_path_dict_to_nested_dict(access_path2val)
516520
docs.append(doc_type.parse_obj(doc_dict))
517521

518-
if not isinstance(docs, cls):
519-
return cls(docs)
520522
return docs
521523

522524
def to_dataframe(self) -> 'pd.DataFrame':
@@ -577,11 +579,13 @@ def _load_binary_all(
577579
protocol: Optional[str],
578580
compress: Optional[str],
579581
show_progress: bool,
582+
tensor_type: Optional[Type['AbstractTensor']] = None,
580583
):
581584
"""Read a `DocList` object from a binary file
582585
:param protocol: protocol to use. It can be 'pickle-array', 'protobuf-array', 'pickle' or 'protobuf'
583586
:param compress: compress algorithm to use between `lz4`, `bz2`, `lzma`, `zlib`, `gzip`
584587
:param show_progress: show progress bar, only works when protocol is `pickle` or `protobuf`
588+
:param tensor_type: only relevant for DocVec; tensor_type of the DocVec
585589
:return: a `DocList`
586590
"""
587591
with file_ctx as fp:
@@ -603,12 +607,20 @@ def _load_binary_all(
603607
proto = cls._get_proto_class()()
604608
proto.ParseFromString(d)
605609

606-
return cls.from_protobuf(proto)
610+
if tensor_type is not None:
611+
cls_ = cast('IOMixinDocVec', cls)
612+
return cls_.from_protobuf(proto, tensor_type=tensor_type)
613+
else:
614+
return cls.from_protobuf(proto)
607615
elif protocol is not None and protocol == 'pickle-array':
608616
return pickle.loads(d)
609617

610618
elif protocol is not None and protocol == 'json-array':
611-
return cls.from_json(d)
619+
if tensor_type is not None:
620+
cls_ = cast('IOMixinDocVec', cls)
621+
return cls_.from_json(d, tensor_type=tensor_type)
622+
else:
623+
return cls.from_json(d)
612624

613625
# Binary format for streaming case
614626
else:
@@ -658,6 +670,10 @@ def _load_binary_all(
658670
pbar.update(
659671
t, advance=1, total_size=str(filesize.decimal(_total_size))
660672
)
673+
if tensor_type is not None:
674+
cls__ = cast(Type['DocVec'], cls)
675+
# mypy doesn't realize that cls_ is callable
676+
return cls__(docs, tensor_type=tensor_type) # type: ignore
661677
return cls(docs)
662678

663679
@classmethod
@@ -724,6 +740,27 @@ def _load_binary_stream(
724740
t, advance=1, total_size=str(filesize.decimal(_total_size))
725741
)
726742

743+
@staticmethod
744+
def _get_file_context(
745+
file: Union[str, bytes, pathlib.Path, io.BufferedReader, _LazyRequestReader],
746+
protocol: str,
747+
compress: Optional[str] = None,
748+
) -> Tuple[Union[nullcontext, io.BufferedReader], Optional[str], Optional[str]]:
749+
load_protocol: Optional[str] = protocol
750+
load_compress: Optional[str] = compress
751+
file_ctx: Union[nullcontext, io.BufferedReader]
752+
if isinstance(file, (io.BufferedReader, _LazyRequestReader, bytes)):
753+
file_ctx = nullcontext(file)
754+
# by checking path existence we allow file to be of type Path, LocalPath, PurePath and str
755+
elif isinstance(file, (str, pathlib.Path)) and os.path.exists(file):
756+
load_protocol, load_compress = _protocol_and_compress_from_file_path(
757+
file, protocol, compress
758+
)
759+
file_ctx = open(file, 'rb')
760+
else:
761+
raise FileNotFoundError(f'cannot find file {file}')
762+
return file_ctx, load_protocol, load_compress
763+
727764
@classmethod
728765
def load_binary(
729766
cls: Type[T],
@@ -753,19 +790,9 @@ def load_binary(
753790
:return: a `DocList` object
754791
755792
"""
756-
load_protocol: Optional[str] = protocol
757-
load_compress: Optional[str] = compress
758-
file_ctx: Union[nullcontext, io.BufferedReader]
759-
if isinstance(file, (io.BufferedReader, _LazyRequestReader, bytes)):
760-
file_ctx = nullcontext(file)
761-
# by checking path existence we allow file to be of type Path, LocalPath, PurePath and str
762-
elif isinstance(file, (str, pathlib.Path)) and os.path.exists(file):
763-
load_protocol, load_compress = _protocol_and_compress_from_file_path(
764-
file, protocol, compress
765-
)
766-
file_ctx = open(file, 'rb')
767-
else:
768-
raise FileNotFoundError(f'cannot find file {file}')
793+
file_ctx, load_protocol, load_compress = cls._get_file_context(
794+
file, protocol, compress
795+
)
769796
if streaming:
770797
if load_protocol not in SINGLE_PROTOCOLS:
771798
raise ValueError(

0 commit comments

Comments
 (0)