From b5befe73934d750f4bedfde9693d85558386a4eb Mon Sep 17 00:00:00 2001 From: Mint <> Date: Mon, 6 Mar 2023 02:04:38 +0300 Subject: [PATCH] Rudimentary word-based reply generation --- config.defaults.json | 1 + generators/markov.py | 21 ++++++++++++++++++--- pleroma.py | 2 +- reply.py | 19 +++++++++++++++---- third_party/utils.py | 4 ++-- 5 files changed, 37 insertions(+), 10 deletions(-) diff --git a/config.defaults.json b/config.defaults.json index 8216794..c069dba 100644 --- a/config.defaults.json +++ b/config.defaults.json @@ -14,6 +14,7 @@ "overlap_ratio_enabled": false, "overlap_ratio": 0.7, "generation_mode": "markov", + "keyword_from_reply": true, "access_token": "", "db_path": "posts.db" } diff --git a/generators/markov.py b/generators/markov.py index 7e1d7a4..d47ac21 100644 --- a/generators/markov.py +++ b/generators/markov.py @@ -2,9 +2,9 @@ import sqlite3 import markovify -from random import randint +from random import randint, choice -def make_sentence(cfg): +def make_sentence(cfg, keywords): class nlt_fixed(markovify.NewlineText): # modified version of NewlineText that never rejects sentences def test_sentence_input(self, sentence): return True # all sentences are valid <3 @@ -49,10 +49,25 @@ def make_sentence(cfg): if cfg['limit_length']: sentence_len = randint(cfg['length_lower_limit'], cfg['length_upper_limit']) + def make_short_sentence_with_keyword(max_chars, min_chars=0, keywords=None, **kwargs): + tries = kwargs.get("tries") + for _ in range(tries): + if keywords: + try: + keyword = choice(model.word_split(keywords)) + sentence = model.make_sentence_with_start(keyword, strict=False, **kwargs) + except: + sentence = model.make_sentence(**kwargs) + else: + sentence = model.make_sentence(**kwargs) + if sentence and min_chars <= len(sentence) <= max_chars: + return sentence + sentence = None tries = 0 for tries in range(10): - if (sentence := model.make_short_sentence( + if (sentence := make_short_sentence_with_keyword( + keywords=keywords if cfg['keywords_from_reply'] else None, max_chars=500, tries=10000, max_overlap_ratio=cfg['overlap_ratio'] if cfg['overlap_ratio_enabled'] else 0.7, diff --git a/pleroma.py b/pleroma.py index a553b15..1430486 100644 --- a/pleroma.py +++ b/pleroma.py @@ -66,7 +66,7 @@ class Pleroma: id = self._unpack_id(id) return await self.request('GET', f'/api/v1/statuses/{id}/context') - async def post(self, content, *, in_reply_to_id=None, cw=None, visibility=None): + async def post(self, content, *, in_reply_to_id=None, cw=None, visibility=None, keywords=None): if visibility not in {None, 'private', 'public', 'unlisted', 'direct'}: raise ValueError('invalid visibility', visibility) diff --git a/reply.py b/reply.py index 13b72b8..0d815c4 100755 --- a/reply.py +++ b/reply.py @@ -77,9 +77,14 @@ class ReplyBot: if command in ('pin', 'unpin'): await (self.pleroma.pin if command == 'pin' else self.pleroma.unpin)(target_post_id) elif command == 'reply': - toot = await utils.make_post(self.cfg) - toot = self.cleanup_toot(toot, self.cfg) status = await self.pleroma.get_status(argument) + if status['content'] != "": + keywords = cleanup_toot(utils.extract_post_content(status['content']), self.cfg) + else: + keywords = None + + toot = await utils.make_post(self.cfg, keywords) + toot = self.cleanup_toot(toot, self.cfg) await self.pleroma.reply(status, toot, cw=self.cfg['cw']) except pleroma.BadRequest as exc: async with anyio.create_task_group() as tg: @@ -89,9 +94,15 @@ class ReplyBot: await self.pleroma.react(post_id, '✅') async def reply(self, notification): - toot = await utils.make_post(self.cfg) + status = notification['status'] + if status['content'] != "": + keywords = self.cleanup_toot(utils.extract_post_content(status['content']), self.cfg) + else: + keywords = None + + toot = await utils.make_post(self.cfg, keywords) toot = self.cleanup_toot(toot, self.cfg) - await self.pleroma.reply(notification['status'], toot, cw=self.cfg['cw']) + await self.pleroma.reply(status, toot, cw=self.cfg['cw']) @staticmethod def cleanup_toot(text, cfg): diff --git a/third_party/utils.py b/third_party/utils.py index fa15435..247f5ad 100644 --- a/third_party/utils.py +++ b/third_party/utils.py @@ -63,13 +63,13 @@ def remove_mentions(cfg, sentence): return sentence -async def make_post(cfg, *, mode=TextGenerationMode.markov): +async def make_post(cfg, keywords=None, *, mode=TextGenerationMode.markov): if mode is TextGenerationMode.markov: from generators.markov import make_sentence elif mode is TextGenerationMode.gpt_2: from generators.gpt_2 import make_sentence - return await anyio.to_process.run_sync(make_sentence, cfg) + return await anyio.to_process.run_sync(make_sentence, cfg, keywords) def extract_post_content(text): soup = BeautifulSoup(text, "html.parser")