feat: UserActionsをClientと分けた

pull/110/head
yupix 10 months ago
parent 966601ed6a
commit 171f60bea4
No known key found for this signature in database
GPG Key ID: 2FF705F5C56D9C06

@ -1,6 +1,8 @@
from __future__ import annotations
from typing import TYPE_CHECKING, AsyncGenerator, Literal, Optional, overload
from typing_extensions import override
from mipac.abstract.action import AbstractAction
from mipac.errors.base import NotExistRequiredData, ParameterError
from mipac.http import HTTPClient, Route
@ -14,7 +16,7 @@ from mipac.types.user import IMeDetailedSchema, IUser, is_partial_user
from mipac.utils.cache import cache
from mipac.utils.format import remove_dict_empty
from mipac.utils.pagination import Pagination
from mipac.utils.util import check_multi_arg
from mipac.utils.util import check_multi_arg, credentials_required
if TYPE_CHECKING:
from mipac.manager.client import ClientManager
@ -22,20 +24,16 @@ if TYPE_CHECKING:
__all__ = ["UserActions"]
class UserActions:
class ClientUserActions(AbstractAction):
def __init__(
self,
session: HTTPClient,
client: ClientManager,
user: Optional[PartialUser] = None,
self, user: PartialUser | None = None, *, session: HTTPClient, client: ClientManager
):
self.__session: HTTPClient = session
self.__user: Optional[PartialUser] = user
self.__client: ClientManager = client
self._user: PartialUser | None = user
self._session: HTTPClient = session
self._client: ClientManager = client
async def get_notes(
self,
user_id: str | None = None,
with_replies: bool = False,
with_renotes: bool = True,
limit: int = 10,
@ -47,11 +45,13 @@ class UserActions:
with_files: bool = False,
file_type: list[str] | None = None,
exclude_nsfw: bool = False,
*,
user_id: str | None = None,
) -> list[Note]: # TODO: since_dataなどを用いたページネーションを今後できるようにする
if check_multi_arg(user_id, self.__user) is False:
raise ParameterError("missing required argument: user_id", user_id, self.__user)
if check_multi_arg(user_id, self._user) is False:
raise ParameterError("missing required argument: user_id", user_id, self._user)
user_id = user_id or self.__user and self.__user.id
user_id = user_id or self._user and self._user.id
data = {
"userId": user_id,
"withReplies": with_replies,
@ -67,15 +67,14 @@ class UserActions:
"excludeNsfw": exclude_nsfw,
}
raw_note: list[INote] = await self.__session.request(
raw_note: list[INote] = await self._session.request(
Route("POST", "/api/users/notes"), json=data
)
return [Note(raw_note=raw_note, client=self.__client) for raw_note in raw_note]
return [Note(raw_note=raw_note, client=self._client) for raw_note in raw_note]
async def get_all_notes(
self,
user_id: str | None = None,
with_replies: bool = False,
with_renotes: bool = True,
since_id: str | None = None,
@ -86,8 +85,10 @@ class UserActions:
with_files: bool = False,
file_type: list[str] | None = None,
exclude_nsfw: bool = False,
):
user_id = user_id or self.__user and self.__user.id
*,
user_id: str | None = None,
) -> AsyncGenerator[Note, None]:
user_id = user_id or self._user and self._user.id
data = {
"userId": user_id,
"withReplies": with_replies,
@ -102,43 +103,173 @@ class UserActions:
"fileType": file_type,
"excludeNsfw": exclude_nsfw,
}
pagination = Pagination[INote](
self.__session, Route("POST", "/api/users/notes"), json=data
)
pagination = Pagination[INote](self._session, Route("POST", "/api/users/notes"), json=data)
while pagination.is_final is False:
res_notes = await pagination.next()
for note in res_notes:
yield Note(note, client=self.__client)
yield Note(note, client=self._client)
async def get_clips(
self,
limit: int = 10,
since_id: str | None = None,
until_id: str | None = None,
*,
user_id: str | None = None,
) -> list[Clip]:
data = {"userId": user_id, "limit": limit, "sinceId": since_id, "untilId": until_id}
raw_clip: list[IClip] = await self._session.request(
Route("POST", "/api/users/clips"), json=data, auth=True
)
return [Clip(raw_clip=raw_clip, client=self._client) for raw_clip in raw_clip]
async def get_all_clips(
self,
since_id: str | None = None,
until_id: str | None = None,
*,
user_id: str | None = None,
) -> AsyncGenerator[Clip, None]:
user_id = user_id or self._user and self._user.id
if user_id is None:
raise ParameterError("user_id is required")
data = {"userId": user_id, "limit": 100, "sinceId": since_id, "untilId": until_id}
pagination = Pagination[IClip](
self._session, Route("POST", "/api/users/clips"), json=data, auth=True
)
while pagination.is_final is False:
clips: list[IClip] = await pagination.next()
for clip in clips:
yield Clip(raw_clip=clip, client=self._client)
async def get_achievements(self, *, user_id: str | None = None) -> list[Achievement]:
"""Get achievements of user."""
user_id = user_id or self._user and self._user.id
if not user_id:
raise ParameterError("user_id is required")
data = {
"userId": user_id,
}
res = await self._session.request(
Route("POST", "/api/users/achievements"),
json=data,
auth=True,
lower=True,
)
return [Achievement(i) for i in res]
class UserActions(ClientUserActions):
def __init__(
self,
session: HTTPClient,
client: ClientManager,
):
super().__init__(session=session, client=client)
@override
async def get_notes(
self,
user_id: str,
with_replies: bool = False,
with_renotes: bool = True,
limit: int = 10,
since_id: str | None = None,
until_id: str | None = None,
since_data: int | None = None,
until_data: int | None = None,
include_my_renotes: bool = True,
with_files: bool = False,
file_type: list[str] | None = None,
exclude_nsfw: bool = False,
) -> list[Note]:
return await super().get_notes(
with_replies=with_replies,
with_renotes=with_renotes,
limit=limit,
since_id=since_id,
until_id=until_id,
since_data=since_data,
until_data=until_data,
include_my_renotes=include_my_renotes,
with_files=with_files,
file_type=file_type,
exclude_nsfw=exclude_nsfw,
user_id=user_id,
)
@override
async def get_all_notes(
self,
user_id: str,
with_replies: bool = False,
with_renotes: bool = True,
since_id: str | None = None,
until_id: str | None = None,
since_data: int | None = None,
until_data: int | None = None,
include_my_renotes: bool = True,
with_files: bool = False,
file_type: list[str] | None = None,
exclude_nsfw: bool = False,
) -> AsyncGenerator[Note, None]:
async for i in super().get_all_notes(
with_replies=with_replies,
with_renotes=with_renotes,
since_id=since_id,
until_id=until_id,
since_data=since_data,
until_data=until_data,
include_my_renotes=include_my_renotes,
with_files=with_files,
file_type=file_type,
exclude_nsfw=exclude_nsfw,
user_id=user_id,
):
yield i
@override
async def get_clips(
self,
user_id: str,
limit: int = 10,
since_id: str | None = None,
until_id: str | None = None,
) -> list[Clip]:
return await super().get_clips(
user_id=user_id, limit=limit, since_id=since_id, until_id=until_id
)
@override
async def get_all_clips(
self, user_id: str, since_id: str | None = None, until_id: str | None = None
) -> AsyncGenerator[Clip, None]:
async for i in super().get_all_clips(
user_id=user_id, since_id=since_id, until_id=until_id
):
yield i
@credentials_required
async def get_me(self) -> MeDetailed: # TODO: トークンが無い場合は例外返すようにする
"""
ログインしているユーザーの情報を取得します
"""
res: IMeDetailedSchema = await self.__session.request(
res: IMeDetailedSchema = await self._session.request(
Route("POST", "/api/i"),
auth=True,
lower=True,
)
return MeDetailed(res, client=self.__client)
def get_profile_link(
self,
external: bool = True,
protocol: Literal["http", "https"] = "https",
): # TODO: これモデルに移すべきな気がする
if not self.__user:
return None
host = (
f"{protocol}://{self.__user.host}"
if external and self.__user.host
else self.__session._url
)
path = (
f"/@{self.__user.username}" if external else f"/{self.__user.api.action.get_mention()}"
)
return host + path
return MeDetailed(res, client=self._client)
@cache(group="get_user")
async def get(
@ -169,10 +300,10 @@ class UserActions:
field = remove_dict_empty(
{"userId": user_id, "username": username, "host": host, "userIds": user_ids}
)
data: IUser = await self.__session.request(
data: IUser = await self._session.request(
Route("POST", "/api/users/show"), json=field, auth=True, lower=True
)
return packed_user(data, client=self.__client)
return packed_user(data, client=self._client)
async def fetch(
self,
@ -217,12 +348,58 @@ class UserActions:
メンション
"""
user = user or self.__user
user = user or self._user
if user is None:
raise NotExistRequiredData("Required parameters: user")
return f"@{user.username}@{user.host}" if user.instance else f"@{user.username}"
async def search_by_username_and_host(
self,
username: str,
host: str,
limit: int = 100,
detail: bool = True,
) -> list[UserDetailedNotMe | MeDetailed | PartialUser]: # TODO: 続き
"""
Search users by username and host.
Parameters
----------
username : str
Username of user.
host : str
Host of user.
limit : int, default=100
The maximum number of users to return.
detail : bool, default=True
Weather to get detailed user information.
Returns
-------
list[UserDetailedNotMe | MeDetailed | PartialUser]
A list of users.
"""
if limit > 100:
raise ParameterError("limit は100以下である必要があります")
body = remove_dict_empty(
{"username": username, "host": host, "limit": limit, "detail": detail}
)
res = await self._session.request(
Route("POST", "/api/users/search-by-username-and-host"),
lower=True,
auth=True,
json=body,
)
return [
packed_user(user, client=self._client)
if detail
else PartialUser(user, client=self._client)
for user in res
]
@overload
async def search(
self,
@ -290,7 +467,7 @@ class UserActions:
)
pagination = Pagination[IUser](
self.__session,
self._session,
Route("POST", "/api/users/search"),
json=body,
pagination_type="count",
@ -300,105 +477,13 @@ class UserActions:
users: list[IUser] = await pagination.next()
for user in users:
yield (
packed_user(user, client=self.__client)
packed_user(user, client=self._client)
if is_partial_user(user) is False
else PartialUser(user, client=self.__client)
else PartialUser(user, client=self._client)
)
if get_all is False or pagination.is_final:
break
async def search_by_username_and_host(
self,
username: str,
host: str,
limit: int = 100,
detail: bool = True,
) -> list[UserDetailedNotMe | MeDetailed | PartialUser]: # TODO: 続き
"""
Search users by username and host.
Parameters
----------
username : str
Username of user.
host : str
Host of user.
limit : int, default=100
The maximum number of users to return.
detail : bool, default=True
Weather to get detailed user information.
Returns
-------
list[UserDetailedNotMe | MeDetailed | PartialUser]
A list of users.
"""
if limit > 100:
raise ParameterError("limit は100以下である必要があります")
body = remove_dict_empty(
{"username": username, "host": host, "limit": limit, "detail": detail}
)
res = await self.__session.request(
Route("POST", "/api/users/search-by-username-and-host"),
lower=True,
auth=True,
json=body,
)
return [
packed_user(user, client=self.__client)
if detail
else PartialUser(user, client=self.__client)
for user in res
]
async def get_achievements(self, user_id: str | None = None) -> list[Achievement]:
"""Get achievements of user."""
user_id = user_id or self.__user and self.__user.id
if not user_id:
raise ParameterError("user_id is required")
data = {
"userId": user_id,
}
res = await self.__session.request(
Route("POST", "/api/users/achievements"),
json=data,
auth=True,
lower=True,
)
return [Achievement(i) for i in res]
async def get_clips(
self,
user_id: str | None = None,
limit: int = 10,
since_id: str | None = None,
until_id: str | None = None,
get_all: bool = False,
):
user_id = user_id or self.__user and self.__user.id
if not user_id:
raise ParameterError("user_id is required")
if limit > 100:
raise ParameterError("limit must be less than 100")
if get_all:
limit = 100
body = {"userId": user_id, "limit": limit, "sinceId": since_id, "untilId": until_id}
pagination = Pagination[IClip](
self.__session, Route("POST", "/api/users/clips"), json=body, auth=True
)
while True:
clips: list[IClip] = await pagination.next()
for clip in clips:
yield Clip(clip, client=self.__client)
if get_all is False or pagination.is_final:
break
@override
async def get_achievements(self, user_id: str) -> list[Achievement]:
return await super().get_achievements(user_id=user_id)

Loading…
Cancel
Save