WorldTime/userdb.py
2018-11-01 21:09:22 -07:00

129 lines
No EOL
4.2 KiB
Python

# User database abstractions
import psycopg2
class UserDatabase:
def __init__(self, connstr):
'''
Sets up the PostgreSQL connection to be used by this instance.
'''
self.db = psycopg2.connect(connstr)
cur = self.db.cursor()
cur.execute("""
CREATE TABLE IF NOT EXISTS userdata (
guild_id BIGINT,
user_id BIGINT,
zone TEXT NOT NULL,
last_active TIMESTAMPTZ NOT NULL DEFAULT now(),
PRIMARY KEY (guild_id, user_id)
)""")
self.db.commit()
cur.close()
def update_activity(self, serverid : str, authorid : str):
'''
If a user exists in the database, updates their last activity timestamp.
'''
c = self.db.cursor()
c.execute("""
UPDATE userdata SET last_active = now()
WHERE guild_id = %s AND user_id = %s
""", (serverid, authorid))
self.db.commit()
c.close()
def delete_user(self, serverid : str, authorid : str):
'''
Deletes existing user from the database.
'''
c = self.db.cursor()
c.execute("""
DELETE FROM userdata
WHERE guild_id = %s AND user_id = %s
""", (serverid, authorid))
self.db.commit()
c.close()
def update_user(self, serverid : str, authorid : str, zone : str):
'''
Insert or update user in the database.
Does not do any sanitizing of incoming values, as only a small set of
values are allowed anyway. This is enforced by the caller.
'''
self.delete_user(serverid, authorid)
c = self.db.cursor()
c.execute("""
INSERT INTO userdata (guild_id, user_id, zone) VALUES
(%s, %s, %s)
ON CONFLICT (guild_id, user_id)
DO UPDATE SET zone = EXCLUDED.zone
""", (serverid, authorid, zone))
self.db.commit()
c.close()
def get_list(self, serverid, userid=None):
'''
Retrieves a list of recent time zones based on
recent activity per user. For use in the list command.
'''
c = self.db.cursor()
if userid is None:
c.execute("""
SELECT zone, count(*) as ct FROM userdata
WHERE guild_id = %s
AND last_active >= now() - INTERVAL '72 HOURS' -- only users active in the last 72 hrs
GROUP BY zone -- separate by popularity
ORDER BY ct DESC LIMIT 10 -- top 10 zones are given
""", (serverid))
else:
c.execute("""
SELECT zone, '0' as ct FROM userdata
WHERE guild_id = %s AND user_id = %s
""", (serverid, userid))
results = c.fetchall()
c.close()
return [i[0] for i in results]
def get_list2(self, serverid):
'''
Retrieves data for the tz.list command.
Returns a dictionary. Keys are zone name, values are arrays with user IDs.
'''
c = self.db.cursor()
c.execute("""
SELECT zone, user_id
FROM userdata
WHERE
last_active >= now() - INTERVAL '72 HOURS' -- only users active in the last 72 hrs
AND guild_id = %(guild)s
AND zone in (SELECT zone from (
SELECT zone, count(*) as ct
FROM userdata
WHERE
guild_id = %(guild)s
AND last_active >= now() - INTERVAL '72 HOURS'
GROUP BY zone
LIMIT 10
) as pop_zones)
ORDER BY RANDOM() -- Randomize display order (expected by consumer)
""", {'guild': serverid})
result = {}
for row in c:
inlist = result.get(row[0])
if inlist is None:
result[row[0]] = []
inlist = result[row[0]]
inlist.append(row[1])
c.close()
return result
def get_unique_tz_count(self):
'''
Gets the number of unique time zones in the database.
'''
c = self.db.cursor()
c.execute('SELECT COUNT(DISTINCT zone) FROM userdata')
result = c.fetchall()
c.close()
return result[0][0]