Skip to content

fix: make DocList compatible with BaseDocWithoutId #1805

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Sep 26, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 7 additions & 7 deletions docarray/array/any_array.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@

import numpy as np

from docarray.base_doc import BaseDoc
from docarray.base_doc.doc import BaseDocWithoutId
from docarray.display.document_array_summary import DocArraySummary
from docarray.exceptions.exceptions import UnusableObjectError
from docarray.typing.abstract_type import AbstractType
Expand All @@ -30,7 +30,7 @@
from docarray.typing.tensor.abstract_tensor import AbstractTensor

T = TypeVar('T', bound='AnyDocArray')
T_doc = TypeVar('T_doc', bound=BaseDoc)
T_doc = TypeVar('T_doc', bound=BaseDocWithoutId)
IndexIterType = Union[slice, Iterable[int], Iterable[bool], None]

UNUSABLE_ERROR_MSG = (
Expand All @@ -42,18 +42,18 @@


class AnyDocArray(Sequence[T_doc], Generic[T_doc], AbstractType):
doc_type: Type[BaseDoc]
__typed_da__: Dict[Type['AnyDocArray'], Dict[Type[BaseDoc], Type]] = {}
doc_type: Type[BaseDocWithoutId]
__typed_da__: Dict[Type['AnyDocArray'], Dict[Type[BaseDocWithoutId], Type]] = {}

def __repr__(self):
return f'<{self.__class__.__name__} (length={len(self)})>'

@classmethod
def __class_getitem__(cls, item: Union[Type[BaseDoc], TypeVar, str]):
def __class_getitem__(cls, item: Union[Type[BaseDocWithoutId], TypeVar, str]):
if not isinstance(item, type):
return Generic.__class_getitem__.__func__(cls, item) # type: ignore
# this do nothing that checking that item is valid type var or str
if not safe_issubclass(item, BaseDoc):
if not safe_issubclass(item, BaseDocWithoutId):
raise ValueError(
f'{cls.__name__}[item] item should be a Document not a {item} '
)
Expand All @@ -66,7 +66,7 @@ def __class_getitem__(cls, item: Union[Type[BaseDoc], TypeVar, str]):
global _DocArrayTyped

class _DocArrayTyped(cls): # type: ignore
doc_type: Type[BaseDoc] = cast(Type[BaseDoc], item)
doc_type: Type[BaseDocWithoutId] = cast(Type[BaseDocWithoutId], item)

for field in _DocArrayTyped.doc_type._docarray_fields().keys():

Expand Down
17 changes: 9 additions & 8 deletions docarray/array/doc_list/doc_list.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,14 @@

from pydantic import parse_obj_as
from typing_extensions import SupportsIndex
from typing_inspect import is_union_type, is_typevar
from typing_inspect import is_typevar, is_union_type

from docarray.array.any_array import AnyDocArray
from docarray.array.doc_list.io import IOMixinDocList
from docarray.array.doc_list.pushpull import PushPullMixin
from docarray.array.list_advance_indexing import IndexIterType, ListAdvancedIndexing
from docarray.base_doc import AnyDoc, BaseDoc
from docarray.base_doc import AnyDoc
from docarray.base_doc.doc import BaseDocWithoutId
from docarray.typing import NdArray
from docarray.utils._internal.pydantic import is_pydantic_v2

Expand All @@ -40,7 +41,7 @@
from docarray.typing.tensor.abstract_tensor import AbstractTensor

T = TypeVar('T', bound='DocList')
T_doc = TypeVar('T_doc', bound=BaseDoc)
T_doc = TypeVar('T_doc', bound=BaseDocWithoutId)


class DocList(
Expand Down Expand Up @@ -120,7 +121,7 @@ class Image(BaseDoc):

"""

doc_type: Type[BaseDoc] = AnyDoc
doc_type: Type[BaseDocWithoutId] = AnyDoc

def __init__(
self,
Expand Down Expand Up @@ -229,7 +230,7 @@ def __class_getitem__(cls, item: Union[Type[BaseDoc], TypeVar, str]):
not is_union_type(field_type)
and is_field_required
and isinstance(field_type, type)
and safe_issubclass(field_type, BaseDoc)
and safe_issubclass(field_type, BaseDocWithoutId)
):
# calling __class_getitem__ ourselves is a hack otherwise mypy complain
# most likely a bug in mypy though
Expand Down Expand Up @@ -273,7 +274,7 @@ def to_doc_vec(
@classmethod
def _docarray_validate(
cls: Type[T],
value: Union[T, Iterable[BaseDoc]],
value: Union[T, Iterable[BaseDocWithoutId]],
):
from docarray.array.doc_vec.doc_vec import DocVec

Expand Down Expand Up @@ -333,9 +334,9 @@ def __getitem__(self, item):
return super().__getitem__(item)

@classmethod
def __class_getitem__(cls, item: Union[Type[BaseDoc], TypeVar, str]):
def __class_getitem__(cls, item: Union[Type[BaseDocWithoutId], TypeVar, str]):

if isinstance(item, type) and safe_issubclass(item, BaseDoc):
if isinstance(item, type) and safe_issubclass(item, BaseDocWithoutId):
return AnyDocArray.__class_getitem__.__func__(cls, item) # type: ignore
if (
isinstance(item, object)
Expand Down
11 changes: 11 additions & 0 deletions tests/units/document/test_doc_wo_id.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
from docarray import DocList
from docarray.base_doc.doc import BaseDocWithoutId


def test_doc_list():
class A(BaseDocWithoutId):
text: str

cls_doc_list = DocList[A]

assert isinstance(cls_doc_list, type)