Rudimentary word-based reply generation

This commit is contained in:
Mint 2023-03-06 02:04:38 +03:00
parent 9f4f1d85dd
commit b5befe7393
5 changed files with 37 additions and 10 deletions

View file

@ -14,6 +14,7 @@
"overlap_ratio_enabled": false, "overlap_ratio_enabled": false,
"overlap_ratio": 0.7, "overlap_ratio": 0.7,
"generation_mode": "markov", "generation_mode": "markov",
"keyword_from_reply": true,
"access_token": "", "access_token": "",
"db_path": "posts.db" "db_path": "posts.db"
} }

View file

@ -2,9 +2,9 @@
import sqlite3 import sqlite3
import markovify import markovify
from random import randint from random import randint, choice
def make_sentence(cfg): def make_sentence(cfg, keywords):
class nlt_fixed(markovify.NewlineText): # modified version of NewlineText that never rejects sentences class nlt_fixed(markovify.NewlineText): # modified version of NewlineText that never rejects sentences
def test_sentence_input(self, sentence): def test_sentence_input(self, sentence):
return True # all sentences are valid <3 return True # all sentences are valid <3
@ -49,10 +49,25 @@ def make_sentence(cfg):
if cfg['limit_length']: if cfg['limit_length']:
sentence_len = randint(cfg['length_lower_limit'], cfg['length_upper_limit']) sentence_len = randint(cfg['length_lower_limit'], cfg['length_upper_limit'])
def make_short_sentence_with_keyword(max_chars, min_chars=0, keywords=None, **kwargs):
tries = kwargs.get("tries")
for _ in range(tries):
if keywords:
try:
keyword = choice(model.word_split(keywords))
sentence = model.make_sentence_with_start(keyword, strict=False, **kwargs)
except:
sentence = model.make_sentence(**kwargs)
else:
sentence = model.make_sentence(**kwargs)
if sentence and min_chars <= len(sentence) <= max_chars:
return sentence
sentence = None sentence = None
tries = 0 tries = 0
for tries in range(10): for tries in range(10):
if (sentence := model.make_short_sentence( if (sentence := make_short_sentence_with_keyword(
keywords=keywords if cfg['keywords_from_reply'] else None,
max_chars=500, max_chars=500,
tries=10000, tries=10000,
max_overlap_ratio=cfg['overlap_ratio'] if cfg['overlap_ratio_enabled'] else 0.7, max_overlap_ratio=cfg['overlap_ratio'] if cfg['overlap_ratio_enabled'] else 0.7,

View file

@ -66,7 +66,7 @@ class Pleroma:
id = self._unpack_id(id) id = self._unpack_id(id)
return await self.request('GET', f'/api/v1/statuses/{id}/context') return await self.request('GET', f'/api/v1/statuses/{id}/context')
async def post(self, content, *, in_reply_to_id=None, cw=None, visibility=None): async def post(self, content, *, in_reply_to_id=None, cw=None, visibility=None, keywords=None):
if visibility not in {None, 'private', 'public', 'unlisted', 'direct'}: if visibility not in {None, 'private', 'public', 'unlisted', 'direct'}:
raise ValueError('invalid visibility', visibility) raise ValueError('invalid visibility', visibility)

View file

@ -77,9 +77,14 @@ class ReplyBot:
if command in ('pin', 'unpin'): if command in ('pin', 'unpin'):
await (self.pleroma.pin if command == 'pin' else self.pleroma.unpin)(target_post_id) await (self.pleroma.pin if command == 'pin' else self.pleroma.unpin)(target_post_id)
elif command == 'reply': elif command == 'reply':
toot = await utils.make_post(self.cfg)
toot = self.cleanup_toot(toot, self.cfg)
status = await self.pleroma.get_status(argument) status = await self.pleroma.get_status(argument)
if status['content'] != "":
keywords = cleanup_toot(utils.extract_post_content(status['content']), self.cfg)
else:
keywords = None
toot = await utils.make_post(self.cfg, keywords)
toot = self.cleanup_toot(toot, self.cfg)
await self.pleroma.reply(status, toot, cw=self.cfg['cw']) await self.pleroma.reply(status, toot, cw=self.cfg['cw'])
except pleroma.BadRequest as exc: except pleroma.BadRequest as exc:
async with anyio.create_task_group() as tg: async with anyio.create_task_group() as tg:
@ -89,9 +94,15 @@ class ReplyBot:
await self.pleroma.react(post_id, '') await self.pleroma.react(post_id, '')
async def reply(self, notification): async def reply(self, notification):
toot = await utils.make_post(self.cfg) status = notification['status']
if status['content'] != "":
keywords = self.cleanup_toot(utils.extract_post_content(status['content']), self.cfg)
else:
keywords = None
toot = await utils.make_post(self.cfg, keywords)
toot = self.cleanup_toot(toot, self.cfg) toot = self.cleanup_toot(toot, self.cfg)
await self.pleroma.reply(notification['status'], toot, cw=self.cfg['cw']) await self.pleroma.reply(status, toot, cw=self.cfg['cw'])
@staticmethod @staticmethod
def cleanup_toot(text, cfg): def cleanup_toot(text, cfg):

View file

@ -63,13 +63,13 @@ def remove_mentions(cfg, sentence):
return sentence return sentence
async def make_post(cfg, *, mode=TextGenerationMode.markov): async def make_post(cfg, keywords=None, *, mode=TextGenerationMode.markov):
if mode is TextGenerationMode.markov: if mode is TextGenerationMode.markov:
from generators.markov import make_sentence from generators.markov import make_sentence
elif mode is TextGenerationMode.gpt_2: elif mode is TextGenerationMode.gpt_2:
from generators.gpt_2 import make_sentence from generators.gpt_2 import make_sentence
return await anyio.to_process.run_sync(make_sentence, cfg) return await anyio.to_process.run_sync(make_sentence, cfg, keywords)
def extract_post_content(text): def extract_post_content(text):
soup = BeautifulSoup(text, "html.parser") soup = BeautifulSoup(text, "html.parser")