added support for learning from CW'd posts

This commit is contained in:
Lynne 2019-02-25 11:17:06 +10:00
parent 5dfd167bc0
commit 2cecfd42a5
No known key found for this signature in database
GPG key ID: FB7B970303ACE499
3 changed files with 19 additions and 9 deletions

View file

@ -1,4 +1,5 @@
{ {
"site": "https://botsin.space", "site": "https://botsin.space",
"cw": null "cw": null,
"learn_from_cw": false
} }

View file

@ -7,6 +7,8 @@ import markovify
from bs4 import BeautifulSoup from bs4 import BeautifulSoup
import re, multiprocessing, sqlite3, shutil, os, json import re, multiprocessing, sqlite3, shutil, os, json
cfg = json.load(open('config.json'))
def make_sentence(output): def make_sentence(output):
class nlt_fixed(markovify.NewlineText): #modified version of NewlineText that never rejects sentences class nlt_fixed(markovify.NewlineText): #modified version of NewlineText that never rejects sentences
def test_sentence_input(self, sentence): def test_sentence_input(self, sentence):
@ -16,7 +18,10 @@ def make_sentence(output):
db = sqlite3.connect("toots-copy.db") db = sqlite3.connect("toots-copy.db")
db.text_factory=str db.text_factory=str
c = db.cursor() c = db.cursor()
toots = c.execute("SELECT content FROM `toots` ORDER BY RANDOM() LIMIT 10000").fetchall() if cfg['learn_from_cw']:
toots = c.execute("SELECT content FROM `toots` ORDER BY RANDOM() LIMIT 10000").fetchall()
else:
toots = c.execute("SELECT content FROM `toots` WHERE cw = 0 ORDER BY RANDOM() LIMIT 10000").fetchall()
toots_str = "" toots_str = ""
for toot in toots: for toot in toots:
toots_str += "\n{}".format(toot[0]) toots_str += "\n{}".format(toot[0])
@ -32,7 +37,6 @@ def make_sentence(output):
tries = tries + 1 tries = tries + 1
sentence = re.sub("^(?:@\u202B[^ ]* )*", "", sentence) #remove leading pings (don't say "@bob blah blah" but still say "blah @bob blah") sentence = re.sub("^(?:@\u202B[^ ]* )*", "", sentence) #remove leading pings (don't say "@bob blah blah" but still say "blah @bob blah")
sentence = re.sub("^(?:@\u200B[^ ]* )*", "", sentence)
output.send(sentence) output.send(sentence)

17
main.py
View file

@ -18,7 +18,8 @@ except:
shutil.copy2("config.sample.json", "config.json") shutil.copy2("config.sample.json", "config.json")
cfg = json.load(open('config.json', 'r')) cfg = json.load(open('config.json', 'r'))
#config.json *MUST* contain the instance URL, the instance blacklist (for dead/broken instances), and the CW text. if they're not provided, we'll fall back to defaults. #config.json should contain the instance URL, the instance blacklist (for dead/broken instances), and the CW text. if they're not provided, we'll fall back to defaults.
# TODO: this is pretty messy
if 'site' not in cfg: if 'site' not in cfg:
cfg['website'] = "https://botsin.space" cfg['website'] = "https://botsin.space"
if 'cw' not in cfg: if 'cw' not in cfg:
@ -28,6 +29,8 @@ if 'instance_blacklist' not in cfg:
"bofa.lol", "bofa.lol",
"witches.town" "witches.town"
] ]
if 'learn_from_cw' not in cfg:
cfg['learn_from_cw'] = False
#if the user is using a (very!) old version that still uses the .secret files, migrate to the new method #if the user is using a (very!) old version that still uses the .secret files, migrate to the new method
if os.path.exists("clientcred.secret"): if os.path.exists("clientcred.secret"):
@ -82,7 +85,11 @@ following = client.account_following(me.id)
db = sqlite3.connect("toots.db") db = sqlite3.connect("toots.db")
db.text_factory=str db.text_factory=str
c = db.cursor() c = db.cursor()
c.execute("CREATE TABLE IF NOT EXISTS `toots` (id INT NOT NULL UNIQUE PRIMARY KEY, userid INT NOT NULL, uri VARCHAR NOT NULL, content VARCHAR NOT NULL) WITHOUT ROWID") c.execute("CREATE TABLE IF NOT EXISTS `toots` (id INT NOT NULL UNIQUE PRIMARY KEY, cw INT NOT NULL DEFAULT 0, userid INT NOT NULL, uri VARCHAR NOT NULL, content VARCHAR NOT NULL) WITHOUT ROWID")
try:
c.execute("ALTER TABLE `toots` ADD COLUMN cw INT NOT NULL DEFAULT 0")
except:
pass # column already exists
db.commit() db.commit()
def handleCtrlC(signal, frame): def handleCtrlC(signal, frame):
@ -157,9 +164,6 @@ for f in following:
# its a toost baby # its a toost baby
content = oi['object']['content'] content = oi['object']['content']
if oi['object']['summary'] != None and oi['object']['summary'] != "":
#don't download CW'd toots. if you want your bot to download and learn from CW'd toots, replace "continue" with "pass". (todo: add a config.json option for this)
continue
toot = extract_toot(content) toot = extract_toot(content)
# print(toot) # print(toot)
try: try:
@ -170,8 +174,9 @@ for f in following:
done = True done = True
break break
pid = patterns["pid"].search(oi['object']['id']).group(0) pid = patterns["pid"].search(oi['object']['id']).group(0)
c.execute("REPLACE INTO toots (id, userid, uri, content) VALUES (?, ?, ?, ?)", ( c.execute("REPLACE INTO toots (id, cw, userid, uri, content) VALUES (?, ?, ?, ?)", (
pid, pid,
1 if (oi['object']['summary'] != None and oi['object']['summary'] != "") else 0,
f.id, f.id,
oi['object']['id'], oi['object']['id'],
toot toot