Implement new database methods in existing code

This commit is contained in:
Noi 2022-06-28 21:46:58 -07:00
parent 191ac13355
commit 478b39f268
4 changed files with 64 additions and 45 deletions

View file

@ -3,9 +3,9 @@ using NodaTime;
using System.Collections.ObjectModel; using System.Collections.ObjectModel;
using System.Globalization; using System.Globalization;
using System.Text; using System.Text;
using WorldTime.Data;
namespace WorldTime; namespace WorldTime;
public class ApplicationCommands : InteractionModuleBase<ShardedInteractionContext> { public class ApplicationCommands : InteractionModuleBase<ShardedInteractionContext> {
const string ErrNotAllowed = ":x: Only server moderators may use this command."; const string ErrNotAllowed = ":x: Only server moderators may use this command.";
@ -30,7 +30,7 @@ public class ApplicationCommands : InteractionModuleBase<ShardedInteractionConte
private static readonly ReadOnlyDictionary<string, string> _tzNameMap; private static readonly ReadOnlyDictionary<string, string> _tzNameMap;
public DiscordShardedClient ShardedClient { get; set; } = null!; public DiscordShardedClient ShardedClient { get; set; } = null!;
public Database Database { get; set; } = null!; public BotDatabaseContext DbContext { get; set; } = null!;
static ApplicationCommands() { static ApplicationCommands() {
Dictionary<string, string> tzNameMap = new(StringComparer.OrdinalIgnoreCase); Dictionary<string, string> tzNameMap = new(StringComparer.OrdinalIgnoreCase);
@ -42,7 +42,8 @@ public class ApplicationCommands : InteractionModuleBase<ShardedInteractionConte
public async Task CmdHelp() { public async Task CmdHelp() {
var version = System.Reflection.Assembly.GetExecutingAssembly().GetName().Version!.ToString(3); var version = System.Reflection.Assembly.GetExecutingAssembly().GetName().Version!.ToString(3);
var guildct = ShardedClient.Guilds.Count; var guildct = ShardedClient.Guilds.Count;
var uniquetz = await Database.GetDistinctZoneCountAsync(); using var db = DbContext;
var uniquetz = db.GetDistinctZoneCount();
await RespondAsync(embed: new EmbedBuilder() { await RespondAsync(embed: new EmbedBuilder() {
Title = "Help & About", Title = "Help & About",
Description = $"World Time v{version} - Serving {guildct} communities across {uniquetz} time zones.\n\n" Description = $"World Time v{version} - Serving {guildct} communities across {uniquetz} time zones.\n\n"
@ -69,12 +70,19 @@ public class ApplicationCommands : InteractionModuleBase<ShardedInteractionConte
return; return;
} }
if (user != null) { if (user == null) {
await CmdListUserAsync(user); // No parameter - full listing
return; await CmdListWithoutParamAsync();
} else {
// Has parameter - do single user listing
await CmdListWithUserParamAsync(user);
}
} }
var userlist = await Database.GetGuildZonesAsync(Context.Guild.Id); private async Task CmdListWithoutParamAsync() {
// Called by CmdList
using var db = DbContext;
var userlist = db.GetGuildZones(Context.Guild.Id);
if (userlist.Count == 0) { if (userlist.Count == 0) {
await RespondAsync(":x: Nothing to show. Register your time zones with the bot using the `/set` command."); await RespondAsync(":x: Nothing to show. Register your time zones with the bot using the `/set` command.");
return; return;
@ -138,9 +146,10 @@ public class ApplicationCommands : InteractionModuleBase<ShardedInteractionConte
} }
} }
private async Task CmdListUserAsync(SocketGuildUser parameter) { private async Task CmdListWithUserParamAsync(SocketGuildUser parameter) {
// Not meant as a command handler - called by CmdList // Called by CmdList
var result = await Database.GetUserZoneAsync(parameter); using var db = DbContext;
var result = db.GetUserZone(parameter);
if (result == null) { if (result == null) {
bool isself = Context.User.Id == parameter.Id; bool isself = Context.User.Id == parameter.Id;
if (isself) await RespondAsync(":x: You do not have a time zone. Set it with `tz.set`.", ephemeral: true); if (isself) await RespondAsync(":x: You do not have a time zone. Set it with `tz.set`.", ephemeral: true);
@ -159,7 +168,8 @@ public class ApplicationCommands : InteractionModuleBase<ShardedInteractionConte
await RespondAsync(ErrInvalidZone, ephemeral: true); await RespondAsync(ErrInvalidZone, ephemeral: true);
return; return;
} }
await Database.UpdateUserAsync((SocketGuildUser)Context.User, parsedzone); using var db = DbContext;
db.UpdateUser((SocketGuildUser)Context.User, parsedzone);
await RespondAsync($":white_check_mark: Your time zone has been set to **{parsedzone}**."); await RespondAsync($":white_check_mark: Your time zone has been set to **{parsedzone}**.");
} }
@ -179,14 +189,16 @@ public class ApplicationCommands : InteractionModuleBase<ShardedInteractionConte
return; return;
} }
await Database.UpdateUserAsync(user, newtz).ConfigureAwait(false); using var db = DbContext;
db.UpdateUser(user, newtz);
await RespondAsync($":white_check_mark: Time zone for **{user}** set to **{newtz}**."); await RespondAsync($":white_check_mark: Time zone for **{user}** set to **{newtz}**.");
} }
[RequireGuildContext] [RequireGuildContext]
[SlashCommand("remove", HelpRemove)] [SlashCommand("remove", HelpRemove)]
public async Task CmdRemove() { public async Task CmdRemove() {
var success = await Database.DeleteUserAsync((SocketGuildUser)Context.User); using var db = DbContext;
var success = db.DeleteUser((SocketGuildUser)Context.User);
if (success) await RespondAsync(":white_check_mark: Your zone has been removed."); if (success) await RespondAsync(":white_check_mark: Your zone has been removed.");
else await RespondAsync(":x: You don't have a time zone set."); else await RespondAsync(":x: You don't have a time zone set.");
} }
@ -199,7 +211,8 @@ public class ApplicationCommands : InteractionModuleBase<ShardedInteractionConte
return; return;
} }
if (await Database.DeleteUserAsync(user)) using var db = DbContext;
if (db.DeleteUser(user))
await RespondAsync($":white_check_mark: Removed zone information for {user}."); await RespondAsync($":white_check_mark: Removed zone information for {user}.");
else else
await RespondAsync($":white_check_mark: No time zone is set for {user}."); await RespondAsync($":white_check_mark: No time zone is set for {user}.");

View file

@ -1,11 +1,13 @@
using NodaTime; using Microsoft.Extensions.DependencyInjection;
using NodaTime;
using System.Collections.ObjectModel; using System.Collections.ObjectModel;
using System.Globalization; using System.Globalization;
using System.Text; using System.Text;
using System.Text.RegularExpressions; using System.Text.RegularExpressions;
using WorldTime.Data;
namespace WorldTime; namespace WorldTime;
[Obsolete("Text commands are deprecated and will be removed soon.")]
internal class CommandsText { internal class CommandsText {
#if DEBUG #if DEBUG
public const string CommandPrefix = "tt."; public const string CommandPrefix = "tt.";
@ -15,9 +17,11 @@ internal class CommandsText {
delegate Task Command(SocketTextChannel channel, SocketGuildUser sender, SocketMessage message); delegate Task Command(SocketTextChannel channel, SocketGuildUser sender, SocketMessage message);
private readonly Dictionary<string, Command> _commands; private readonly Dictionary<string, Command> _commands;
private readonly Database _database; private readonly IServiceProvider _services;
private readonly WorldTime _instance; private readonly WorldTime _instance;
private BotDatabaseContext Database => _services.GetRequiredService<BotDatabaseContext>();
private static readonly Regex _userExplicit; private static readonly Regex _userExplicit;
private static readonly Regex _userMention; private static readonly Regex _userMention;
private static readonly ReadOnlyDictionary<string, string> _tzNameMap; private static readonly ReadOnlyDictionary<string, string> _tzNameMap;
@ -38,9 +42,9 @@ internal class CommandsText {
_tzNameMap = new(tzNameMap); _tzNameMap = new(tzNameMap);
} }
public CommandsText(WorldTime inst, Database db) { public CommandsText(WorldTime inst, IServiceProvider services) {
_instance = inst; _instance = inst;
_database = db; _services = services;
_commands = new(StringComparer.OrdinalIgnoreCase) { _commands = new(StringComparer.OrdinalIgnoreCase) {
{ "help", CmdHelp }, { "help", CmdHelp },
{ "list", CmdList }, { "list", CmdList },
@ -60,7 +64,7 @@ internal class CommandsText {
var msgsplit = message.Content.Split(' ', 2, StringSplitOptions.RemoveEmptyEntries); var msgsplit = message.Content.Split(' ', 2, StringSplitOptions.RemoveEmptyEntries);
if (msgsplit.Length == 0 || msgsplit[0].Length < 4) return; if (msgsplit.Length == 0 || msgsplit[0].Length < 4) return;
if (msgsplit[0].StartsWith(CommandPrefix, StringComparison.OrdinalIgnoreCase)) { // TODO add support for multiple prefixes? if (msgsplit[0].StartsWith(CommandPrefix, StringComparison.OrdinalIgnoreCase)) {
var cmdBase = msgsplit[0][3..]; var cmdBase = msgsplit[0][3..];
if (_commands.ContainsKey(cmdBase)) { if (_commands.ContainsKey(cmdBase)) {
Program.Log("Command invoked", $"{channel.Guild.Name}/{message.Author} {message.Content}"); Program.Log("Command invoked", $"{channel.Guild.Name}/{message.Author} {message.Content}");
@ -74,9 +78,10 @@ internal class CommandsText {
} }
private async Task CmdHelp(SocketTextChannel channel, SocketGuildUser sender, SocketMessage message) { private async Task CmdHelp(SocketTextChannel channel, SocketGuildUser sender, SocketMessage message) {
using var db = Database;
var version = System.Reflection.Assembly.GetExecutingAssembly().GetName().Version!.ToString(3); var version = System.Reflection.Assembly.GetExecutingAssembly().GetName().Version!.ToString(3);
var guildct = _instance.DiscordClient.Guilds.Count; var guildct = _instance.DiscordClient.Guilds.Count;
var uniquetz = await _database.GetDistinctZoneCountAsync(); var uniquetz = db.GetDistinctZoneCount();
await channel.SendMessageAsync(embed: new EmbedBuilder() { await channel.SendMessageAsync(embed: new EmbedBuilder() {
Color = new Color(0xe0f2f7), Color = new Color(0xe0f2f7),
Title = "Help & About", Title = "Help & About",
@ -118,7 +123,8 @@ internal class CommandsText {
return; return;
} }
var result = await _database.GetUserZoneAsync(usersearch).ConfigureAwait(false); using var db = Database;
var result = db.GetUserZone(usersearch);
if (result == null) { if (result == null) {
bool isself = sender.Id == usersearch.Id; bool isself = sender.Id == usersearch.Id;
if (isself) await channel.SendMessageAsync(":x: You do not have a time zone. Set it with `tz.set`.").ConfigureAwait(false); if (isself) await channel.SendMessageAsync(":x: You do not have a time zone. Set it with `tz.set`.").ConfigureAwait(false);
@ -130,7 +136,8 @@ internal class CommandsText {
await channel.SendMessageAsync(embed: new EmbedBuilder().WithDescription(resulttext).Build()).ConfigureAwait(false); await channel.SendMessageAsync(embed: new EmbedBuilder().WithDescription(resulttext).Build()).ConfigureAwait(false);
} else { } else {
// Does not have parameter - build full list // Does not have parameter - build full list
var userlist = await _database.GetGuildZonesAsync(channel.Guild.Id).ConfigureAwait(false); using var db = Database;
var userlist = db.GetGuildZones(channel.Guild.Id);
if (userlist.Count == 0) { if (userlist.Count == 0) {
await channel.SendMessageAsync(":x: Nothing to show. " + await channel.SendMessageAsync(":x: Nothing to show. " +
$"To register time zones with the bot, use the `{CommandPrefix}set` command.").ConfigureAwait(false); $"To register time zones with the bot, use the `{CommandPrefix}set` command.").ConfigureAwait(false);
@ -196,7 +203,8 @@ internal class CommandsText {
await channel.SendMessageAsync(ErrInvalidZone).ConfigureAwait(false); await channel.SendMessageAsync(ErrInvalidZone).ConfigureAwait(false);
return; return;
} }
await _database.UpdateUserAsync(sender, input).ConfigureAwait(false); using var db = Database;
db.UpdateUser(sender, input);
await channel.SendMessageAsync($":white_check_mark: Your time zone has been set to **{input}**.").ConfigureAwait(false); await channel.SendMessageAsync($":white_check_mark: Your time zone has been set to **{input}**.").ConfigureAwait(false);
} }
@ -229,12 +237,14 @@ internal class CommandsText {
return; return;
} }
await _database.UpdateUserAsync(targetuser, newtz).ConfigureAwait(false); using var db = Database;
db.UpdateUser(targetuser, newtz);
await channel.SendMessageAsync($":white_check_mark: Time zone for **{targetuser}** set to **{newtz}**.").ConfigureAwait(false); await channel.SendMessageAsync($":white_check_mark: Time zone for **{targetuser}** set to **{newtz}**.").ConfigureAwait(false);
} }
private async Task CmdRemove(SocketTextChannel channel, SocketGuildUser sender, SocketMessage message) { private async Task CmdRemove(SocketTextChannel channel, SocketGuildUser sender, SocketMessage message) {
var success = await _database.DeleteUserAsync(sender).ConfigureAwait(false); using var db = Database;
var success = db.DeleteUser(sender);
if (success) await channel.SendMessageAsync(":white_check_mark: Your zone has been removed.").ConfigureAwait(false); if (success) await channel.SendMessageAsync(":white_check_mark: Your zone has been removed.").ConfigureAwait(false);
else await channel.SendMessageAsync(":x: You don't have a time zone set."); else await channel.SendMessageAsync(":x: You don't have a time zone set.");
} }
@ -259,7 +269,8 @@ internal class CommandsText {
return; return;
} }
await _database.DeleteUserAsync(targetuser).ConfigureAwait(false); using var db = Database;
db.DeleteUser(targetuser);
await channel.SendMessageAsync($":white_check_mark: Removed zone information for {targetuser}."); await channel.SendMessageAsync($":white_check_mark: Removed zone information for {targetuser}.");
} }
@ -346,7 +357,6 @@ internal class CommandsText {
/// </summary> /// </summary>
private static bool IsUserAdmin(SocketGuildUser user) private static bool IsUserAdmin(SocketGuildUser user)
=> user.GuildPermissions.Administrator || user.GuildPermissions.ManageGuild; => user.GuildPermissions.Administrator || user.GuildPermissions.ManageGuild;
// TODO port modrole feature from BB, implement in here
/// <summary> /// <summary>
/// Checks if the member cache for the specified guild needs to be filled, and sends a request if needed. /// Checks if the member cache for the specified guild needs to be filled, and sends a request if needed.

View file

@ -18,16 +18,10 @@ class Program {
Environment.Exit((int)ExitCodes.ConfigError); Environment.Exit((int)ExitCodes.ConfigError);
} }
Database? d = null; Data.BotDatabaseContext.NpgsqlConnectionString = cfg.DbConnectionString;
try {
d = new(cfg.DbConnectionString);
} catch (Npgsql.NpgsqlException e) {
Console.WriteLine("Error when attempting to connect to database: " + e.Message);
Environment.Exit((int)ExitCodes.DatabaseError);
}
Console.CancelKeyPress += OnCancelKeyPressed; Console.CancelKeyPress += OnCancelKeyPressed;
_bot = new WorldTime(cfg, d); _bot = new WorldTime(cfg);
await _bot.StartAsync().ConfigureAwait(false); await _bot.StartAsync().ConfigureAwait(false);
await Task.Delay(-1).ConfigureAwait(false); await Task.Delay(-1).ConfigureAwait(false);

View file

@ -4,6 +4,7 @@ using Discord.Interactions;
using Microsoft.Extensions.DependencyInjection; using Microsoft.Extensions.DependencyInjection;
using System.Reflection; using System.Reflection;
using System.Text; using System.Text;
using WorldTime.Data;
namespace WorldTime; namespace WorldTime;
@ -33,10 +34,9 @@ internal class WorldTime : IDisposable {
internal Configuration Config { get; } internal Configuration Config { get; }
internal DiscordShardedClient DiscordClient => _services.GetRequiredService<DiscordShardedClient>(); internal DiscordShardedClient DiscordClient => _services.GetRequiredService<DiscordShardedClient>();
internal Database Database => _services.GetRequiredService<Database>();
public WorldTime(Configuration cfg, Database d) { public WorldTime(Configuration cfg) {
var ver = System.Reflection.Assembly.GetExecutingAssembly().GetName().Version; var ver = Assembly.GetExecutingAssembly().GetName().Version;
Program.Log(nameof(WorldTime), $"Version {ver!.ToString(3)} is starting..."); Program.Log(nameof(WorldTime), $"Version {ver!.ToString(3)} is starting...");
Config = cfg; Config = cfg;
@ -51,7 +51,7 @@ internal class WorldTime : IDisposable {
_services = new ServiceCollection() _services = new ServiceCollection()
.AddSingleton(new DiscordShardedClient(clientConf)) .AddSingleton(new DiscordShardedClient(clientConf))
.AddSingleton(s => new InteractionService(s.GetRequiredService<DiscordShardedClient>())) .AddSingleton(s => new InteractionService(s.GetRequiredService<DiscordShardedClient>()))
.AddSingleton(d) .AddTransient(typeof(BotDatabaseContext))
.BuildServiceProvider(); .BuildServiceProvider();
DiscordClient.Log += DiscordClient_Log; DiscordClient.Log += DiscordClient_Log;
DiscordClient.ShardReady += DiscordClient_ShardReady; DiscordClient.ShardReady += DiscordClient_ShardReady;
@ -60,7 +60,7 @@ internal class WorldTime : IDisposable {
DiscordClient.InteractionCreated += DiscordClient_InteractionCreated; DiscordClient.InteractionCreated += DiscordClient_InteractionCreated;
iasrv.SlashCommandExecuted += InteractionService_SlashCommandExecuted; iasrv.SlashCommandExecuted += InteractionService_SlashCommandExecuted;
_commandsTxt = new CommandsText(this, Database); _commandsTxt = new CommandsText(this, _services);
// Start status reporting thread // Start status reporting thread
_mainCancel = new CancellationTokenSource(); _mainCancel = new CancellationTokenSource();
@ -189,10 +189,12 @@ internal class WorldTime : IDisposable {
// Proactively fill guild user cache if the bot has any data for the respective guild // Proactively fill guild user cache if the bot has any data for the respective guild
// Can skip an extra query if the last_seen update is known to have been successful, otherwise query for any users // Can skip an extra query if the last_seen update is known to have been successful, otherwise query for any users
var guild = channel.Guild; if (!channel.Guild.HasAllMembers) {
if (!guild.HasAllMembers && await Database.HasAnyAsync(guild)) { using var db = _services.GetRequiredService<BotDatabaseContext>();
if (db.HasAnyUsers(channel.Guild)) {
// Event handler hangs if awaited normally or used with Task.Run // Event handler hangs if awaited normally or used with Task.Run
await Task.Factory.StartNew(guild.DownloadUsersAsync); await Task.Factory.StartNew(channel.Guild.DownloadUsersAsync);
}
} }
} }