23
23
Type ,
24
24
TypeVar ,
25
25
Union ,
26
+ cast ,
26
27
)
27
28
28
29
import orjson
40
41
if TYPE_CHECKING :
41
42
import pandas as pd
42
43
44
+ from docarray .array .doc_vec .doc_vec import DocVec
45
+ from docarray .array .doc_vec .io import IOMixinDocVec
43
46
from docarray .proto import DocListProto
47
+ from docarray .typing .tensor .abstract_tensor import AbstractTensor
44
48
45
- T = TypeVar ('T' , bound = 'IOMixinArray ' )
49
+ T = TypeVar ('T' , bound = 'IOMixinDocList ' )
46
50
T_doc = TypeVar ('T_doc' , bound = BaseDoc )
47
51
48
52
ARRAY_PROTOCOLS = {'protobuf-array' , 'pickle-array' , 'json-array' }
@@ -96,7 +100,7 @@ def __getitem__(self, item: slice):
96
100
return self .content [item ]
97
101
98
102
99
- class IOMixinArray (Iterable [T_doc ]):
103
+ class IOMixinDocList (Iterable [T_doc ]):
100
104
doc_type : Type [T_doc ]
101
105
102
106
@abstractmethod
@@ -515,8 +519,6 @@ class Person(BaseDoc):
515
519
doc_dict = _access_path_dict_to_nested_dict (access_path2val )
516
520
docs .append (doc_type .parse_obj (doc_dict ))
517
521
518
- if not isinstance (docs , cls ):
519
- return cls (docs )
520
522
return docs
521
523
522
524
def to_dataframe (self ) -> 'pd.DataFrame' :
@@ -577,11 +579,13 @@ def _load_binary_all(
577
579
protocol : Optional [str ],
578
580
compress : Optional [str ],
579
581
show_progress : bool ,
582
+ tensor_type : Optional [Type ['AbstractTensor' ]] = None ,
580
583
):
581
584
"""Read a `DocList` object from a binary file
582
585
:param protocol: protocol to use. It can be 'pickle-array', 'protobuf-array', 'pickle' or 'protobuf'
583
586
:param compress: compress algorithm to use between `lz4`, `bz2`, `lzma`, `zlib`, `gzip`
584
587
:param show_progress: show progress bar, only works when protocol is `pickle` or `protobuf`
588
+ :param tensor_type: only relevant for DocVec; tensor_type of the DocVec
585
589
:return: a `DocList`
586
590
"""
587
591
with file_ctx as fp :
@@ -603,12 +607,20 @@ def _load_binary_all(
603
607
proto = cls ._get_proto_class ()()
604
608
proto .ParseFromString (d )
605
609
606
- return cls .from_protobuf (proto )
610
+ if tensor_type is not None :
611
+ cls_ = cast ('IOMixinDocVec' , cls )
612
+ return cls_ .from_protobuf (proto , tensor_type = tensor_type )
613
+ else :
614
+ return cls .from_protobuf (proto )
607
615
elif protocol is not None and protocol == 'pickle-array' :
608
616
return pickle .loads (d )
609
617
610
618
elif protocol is not None and protocol == 'json-array' :
611
- return cls .from_json (d )
619
+ if tensor_type is not None :
620
+ cls_ = cast ('IOMixinDocVec' , cls )
621
+ return cls_ .from_json (d , tensor_type = tensor_type )
622
+ else :
623
+ return cls .from_json (d )
612
624
613
625
# Binary format for streaming case
614
626
else :
@@ -658,6 +670,10 @@ def _load_binary_all(
658
670
pbar .update (
659
671
t , advance = 1 , total_size = str (filesize .decimal (_total_size ))
660
672
)
673
+ if tensor_type is not None :
674
+ cls__ = cast (Type ['DocVec' ], cls )
675
+ # mypy doesn't realize that cls_ is callable
676
+ return cls__ (docs , tensor_type = tensor_type ) # type: ignore
661
677
return cls (docs )
662
678
663
679
@classmethod
@@ -724,6 +740,27 @@ def _load_binary_stream(
724
740
t , advance = 1 , total_size = str (filesize .decimal (_total_size ))
725
741
)
726
742
743
+ @staticmethod
744
+ def _get_file_context (
745
+ file : Union [str , bytes , pathlib .Path , io .BufferedReader , _LazyRequestReader ],
746
+ protocol : str ,
747
+ compress : Optional [str ] = None ,
748
+ ) -> Tuple [Union [nullcontext , io .BufferedReader ], Optional [str ], Optional [str ]]:
749
+ load_protocol : Optional [str ] = protocol
750
+ load_compress : Optional [str ] = compress
751
+ file_ctx : Union [nullcontext , io .BufferedReader ]
752
+ if isinstance (file , (io .BufferedReader , _LazyRequestReader , bytes )):
753
+ file_ctx = nullcontext (file )
754
+ # by checking path existence we allow file to be of type Path, LocalPath, PurePath and str
755
+ elif isinstance (file , (str , pathlib .Path )) and os .path .exists (file ):
756
+ load_protocol , load_compress = _protocol_and_compress_from_file_path (
757
+ file , protocol , compress
758
+ )
759
+ file_ctx = open (file , 'rb' )
760
+ else :
761
+ raise FileNotFoundError (f'cannot find file { file } ' )
762
+ return file_ctx , load_protocol , load_compress
763
+
727
764
@classmethod
728
765
def load_binary (
729
766
cls : Type [T ],
@@ -753,19 +790,9 @@ def load_binary(
753
790
:return: a `DocList` object
754
791
755
792
"""
756
- load_protocol : Optional [str ] = protocol
757
- load_compress : Optional [str ] = compress
758
- file_ctx : Union [nullcontext , io .BufferedReader ]
759
- if isinstance (file , (io .BufferedReader , _LazyRequestReader , bytes )):
760
- file_ctx = nullcontext (file )
761
- # by checking path existence we allow file to be of type Path, LocalPath, PurePath and str
762
- elif isinstance (file , (str , pathlib .Path )) and os .path .exists (file ):
763
- load_protocol , load_compress = _protocol_and_compress_from_file_path (
764
- file , protocol , compress
765
- )
766
- file_ctx = open (file , 'rb' )
767
- else :
768
- raise FileNotFoundError (f'cannot find file { file } ' )
793
+ file_ctx , load_protocol , load_compress = cls ._get_file_context (
794
+ file , protocol , compress
795
+ )
769
796
if streaming :
770
797
if load_protocol not in SINGLE_PROTOCOLS :
771
798
raise ValueError (
0 commit comments