diff --git a/mipac/actions/user.py b/mipac/actions/user.py index 21535a6..84a7e90 100644 --- a/mipac/actions/user.py +++ b/mipac/actions/user.py @@ -13,6 +13,8 @@ from mipac.models.user import ( MeDetailed, MeDetailedModerator, UserDetailed, + UserDetailedModels, + UserModels, create_user_model, ) from mipac.types.clip import IClip @@ -27,7 +29,7 @@ from mipac.types.user import ( ) from mipac.utils.cache import cache from mipac.utils.format import remove_dict_empty -from mipac.utils.pagination import Pagination, pagination_iterator +from mipac.utils.pagination import Pagination from mipac.utils.util import check_multi_arg if TYPE_CHECKING: @@ -239,7 +241,7 @@ class UserActions: detail: Literal[True] = True, *, get_all: bool = False, - ) -> AsyncGenerator[UserDetailed, None]: + ) -> AsyncGenerator[UserDetailedModels, None]: ... async def search( @@ -251,7 +253,7 @@ class UserActions: detail: Literal[True, False] = True, *, get_all: bool = False, - ) -> AsyncGenerator[UserDetailed | PartialUser, None]: + ) -> AsyncGenerator[UserModels, None]: """ Search users by keyword. @@ -286,29 +288,19 @@ class UserActions: {"query": query, "limit": limit, "offset": offset, "origin": origin, "detail": detail} ) - if detail: - pagination = Pagination[IUserDetailed]( - self.__session, - Route("POST", "/api/users/search"), - json=body, - pagination_type="count", - ) - iterator = pagination_iterator( - pagination, get_all, model=UserDetailed, client=self.__client - ) - else: - pagination = Pagination[IPartialUser]( - self.__session, - Route("POST", "/api/users/search"), - json=body, - pagination_type="count", - ) - - iterator = pagination_iterator( - pagination, get_all=get_all, model=PartialUser, client=self.__client - ) - async for user in iterator: - yield user + pagination = Pagination[IUser]( + self.__session, + Route("POST", "/api/users/search"), + json=body, + pagination_type="count", + ) + + while True: + users: list[IUser] = await pagination.next() + for user in users: + yield create_user_model(user, client=self.__client) + if get_all is False or pagination.is_final: + break async def search_by_username_and_host( self,