diff --git a/src/main/java/wtf/beatrice/hidekobot/Configuration.java b/src/main/java/wtf/beatrice/hidekobot/Configuration.java index 13c8295..6087b9c 100644 --- a/src/main/java/wtf/beatrice/hidekobot/Configuration.java +++ b/src/main/java/wtf/beatrice/hidekobot/Configuration.java @@ -6,9 +6,7 @@ import wtf.beatrice.hidekobot.database.DatabaseManager; import wtf.beatrice.hidekobot.listeners.MessageLogger; import java.awt.*; -import java.time.Duration; import java.time.LocalDateTime; -import java.time.temporal.ChronoUnit; import java.util.ArrayList; import java.util.Comparator; import java.util.List; @@ -27,7 +25,7 @@ public class Configuration private final static String expiryTimestampFormat = "yy/MM/dd HH:mm:ss"; // note: discord sets interactions' expiry time to 15 minutes by default, so we can't go higher than that. - private final static long expiryTimeSeconds = 60L; + private final static long expiryTimeSeconds = 5L; // used to count eg. uptime private static LocalDateTime startupTime; diff --git a/src/main/java/wtf/beatrice/hidekobot/commands/slash/ClearChatCommand.java b/src/main/java/wtf/beatrice/hidekobot/commands/slash/ClearChatCommand.java index ff92bf5..7751ce8 100644 --- a/src/main/java/wtf/beatrice/hidekobot/commands/slash/ClearChatCommand.java +++ b/src/main/java/wtf/beatrice/hidekobot/commands/slash/ClearChatCommand.java @@ -165,13 +165,8 @@ public class ClearChatCommand .setActionRow(dismissButton) .complete(); - String replyMessageId = message.getId(); - String replyChannelId = message.getChannel().getId(); - String replyGuildId = message.getGuild().getId(); - String userId = event.getUser().getId(); - - Configuration.getDatabaseManager().queueDisabling(replyGuildId, replyChannelId, replyMessageId); - Configuration.getDatabaseManager().trackRanCommandReply(replyGuildId, replyChannelId, replyMessageId, userId); + Configuration.getDatabaseManager().queueDisabling(message); + Configuration.getDatabaseManager().trackRanCommandReply(message, event.getUser()); } }).start(); diff --git a/src/main/java/wtf/beatrice/hidekobot/commands/slash/CoinFlipCommand.java b/src/main/java/wtf/beatrice/hidekobot/commands/slash/CoinFlipCommand.java index 03b0ba5..8776c42 100644 --- a/src/main/java/wtf/beatrice/hidekobot/commands/slash/CoinFlipCommand.java +++ b/src/main/java/wtf/beatrice/hidekobot/commands/slash/CoinFlipCommand.java @@ -62,12 +62,9 @@ public class CoinFlipCommand private void trackAndRestrict(Message replyMessage, User user) { String replyMessageId = replyMessage.getId(); - String replyChannelId = replyMessage.getChannel().getId(); - String replyGuildId = replyMessage.getGuild().getId(); - String userId = user.getId(); - Configuration.getDatabaseManager().queueDisabling(replyGuildId, replyChannelId, replyMessageId); - Configuration.getDatabaseManager().trackRanCommandReply(replyGuildId, replyChannelId, replyMessageId, userId); + Configuration.getDatabaseManager().queueDisabling(replyMessage); + Configuration.getDatabaseManager().trackRanCommandReply(replyMessage, user); } private String genRandom() diff --git a/src/main/java/wtf/beatrice/hidekobot/database/DatabaseManager.java b/src/main/java/wtf/beatrice/hidekobot/database/DatabaseManager.java index 25313af..8da750b 100644 --- a/src/main/java/wtf/beatrice/hidekobot/database/DatabaseManager.java +++ b/src/main/java/wtf/beatrice/hidekobot/database/DatabaseManager.java @@ -1,5 +1,8 @@ package wtf.beatrice.hidekobot.database; +import net.dv8tion.jda.api.entities.Message; +import net.dv8tion.jda.api.entities.User; +import net.dv8tion.jda.api.entities.channel.ChannelType; import wtf.beatrice.hidekobot.Configuration; import wtf.beatrice.hidekobot.utils.Logger; @@ -61,7 +64,24 @@ public class DatabaseManager return true; } - + /* + * DB STRUCTURE + * TABLE 1: pending_disabled_messages + * ---------------------------------------------------------------------------------- + * | guild_id | channel_id | message_id | expiry_timestamp | + * ---------------------------------------------------------------------------------- + * |39402849302 | 39402849302 | 39402849302 | 2022-11-20 22:45:53:300 | + * --------------------------------------------------------------------------------- + * + * + * TABLE 2: command_runners + * -------------------------------------------------------------------------------------------- + * | guild_id | channel_id | message_id | user_id | channel_type | + * -------------------------------------------------------------------------------------------- + * | 39402849302 | 39402849302 | 39402849302 | 39402849302 | PRIVATE | + * -------------------------------------------------------------------------------------------- + * + */ public boolean initDb() { @@ -71,14 +91,15 @@ public class DatabaseManager "guild_id TEXT NOT NULL, " + "channel_id TEXT NOT NULL," + "message_id TEXT NOT NULL," + - "expiry_timestamp TEXT NOT NULL" + + "expiry_timestamp TEXT NOT NULL " + ");"); newTables.add("CREATE TABLE IF NOT EXISTS command_runners (" + "guild_id TEXT NOT NULL, " + "channel_id TEXT NOT NULL," + // channel the command was run in "message_id TEXT NOT NULL," + // message id of the bot's response - "user_id TEXT NOT NULL" + // user who ran the command + "user_id TEXT NOT NULL, " + // user who ran the command + "channel_type TEXT NOT NULL" + // channel type (PRIVATE, FORUM, ...) ");"); for(String sql : newTables) @@ -95,11 +116,26 @@ public class DatabaseManager return true; } - public boolean trackRanCommandReply(String guildId, String channelId, String messageId, String userId) + public boolean trackRanCommandReply(Message message, User user) { + String userId = user.getId(); + String guildId; + + ChannelType channelType = message.getChannelType(); + if(channelType == ChannelType.PRIVATE) + { + guildId = userId; + } else { + guildId = message.getGuild().getId(); + } + + String channelId = message.getChannel().getId(); + String messageId = message.getId(); + + String query = "INSERT INTO command_runners " + - "(guild_id, channel_id, message_id, user_id) VALUES " + - " (?, ?, ?, ?);"; + "(guild_id, channel_id, message_id, user_id, channel_type) VALUES " + + " (?, ?, ?, ?, ?);"; try(PreparedStatement preparedStatement = dbConnection.prepareStatement(query)) { @@ -107,6 +143,7 @@ public class DatabaseManager preparedStatement.setString(2, channelId); preparedStatement.setString(3, messageId); preparedStatement.setString(4, userId); + preparedStatement.setString(5, channelType.name()); preparedStatement.executeUpdate(); @@ -126,6 +163,31 @@ public class DatabaseManager return userId.equals(trackedUserId); } + public ChannelType getTrackedMessageChannelType(String messageId) + { + String query = "SELECT channel_type " + + "FROM command_runners " + + "WHERE message_id = ?;"; + + try(PreparedStatement preparedStatement = dbConnection.prepareStatement(query)) + { + preparedStatement.setString(1, messageId); + ResultSet resultSet = preparedStatement.executeQuery(); + if(resultSet.isClosed()) return null; + while(resultSet.next()) + { + String channelTypeName = resultSet.getString("channel_type"); + return ChannelType.valueOf(channelTypeName); + } + + } catch (SQLException e) { + e.printStackTrace(); + } + + return null; + + } + public String getTrackedReplyUserId(String messageId) { String query = "SELECT user_id " + @@ -149,8 +211,20 @@ public class DatabaseManager return null; } - public boolean queueDisabling(String guildId, String channelId, String messageId) + public boolean queueDisabling(Message message) { + String messageId = message.getId(); + String channelId = message.getChannel().getId(); + String guildId; + + ChannelType channelType = message.getChannelType(); + if(channelType == ChannelType.PRIVATE) + { + guildId = "PRIVATE"; + } else { + guildId = message.getGuild().getId(); + } + LocalDateTime expiryTime = LocalDateTime.now().plusSeconds(Configuration.getExpiryTimeSeconds()); DateTimeFormatter dateTimeFormatter = DateTimeFormatter.ofPattern(Configuration.getExpiryTimestampFormat()); @@ -297,24 +371,5 @@ public class DatabaseManager return null; } - /** - * DB STRUCTURE - * TABLE 1: pending_disabled_messages - * ---------------------------------------------------------------------------------- - * | guild_id | channel_id | message_id | expiry_timestamp | - * ---------------------------------------------------------------------------------- - * |39402849302 | 39402849302 | 39402849302 | 2022-11-20 22:45:53:300 | - * ---------------------------------------------------------------------------------- - * - * - * TABLE 2: command_runners - * -------------------------------------------------------------------------- - * | guild_id | channel_id | message_id | user_id | - * -------------------------------------------------------------------------- - * | 39402849302 | 39402849302 | 39402849302 | 39402849302 | - * -------------------------------------------------------------------------- - * - */ - } diff --git a/src/main/java/wtf/beatrice/hidekobot/runnables/ExpiredMessageTask.java b/src/main/java/wtf/beatrice/hidekobot/runnables/ExpiredMessageTask.java index 9d0cb79..f7c2fd7 100644 --- a/src/main/java/wtf/beatrice/hidekobot/runnables/ExpiredMessageTask.java +++ b/src/main/java/wtf/beatrice/hidekobot/runnables/ExpiredMessageTask.java @@ -2,7 +2,9 @@ package wtf.beatrice.hidekobot.runnables; import net.dv8tion.jda.api.entities.Guild; import net.dv8tion.jda.api.entities.Message; -import net.dv8tion.jda.api.entities.channel.concrete.TextChannel; +import net.dv8tion.jda.api.entities.User; +import net.dv8tion.jda.api.entities.channel.ChannelType; +import net.dv8tion.jda.api.entities.channel.middleman.MessageChannel; import net.dv8tion.jda.api.interactions.components.LayoutComponent; import net.dv8tion.jda.api.requests.RestAction; import wtf.beatrice.hidekobot.Configuration; @@ -71,20 +73,49 @@ public class ExpiredMessageTask implements Runnable { private void disableExpired(String messageId) { - String guildId = databaseManager.getQueuedExpiringMessageGuild(messageId); String channelId = databaseManager.getQueuedExpiringMessageChannel(messageId); + ChannelType msgChannelType = databaseManager.getTrackedMessageChannelType(messageId); + + MessageChannel textChannel = null; - Guild guild = HidekoBot.getAPI().getGuildById(guildId); - if(guild == null) + // this should never happen, but only message channels are supported. + if(!msgChannelType.isMessage()) { - // if guild is not found, consider it expired - // (server was deleted or bot was kicked) databaseManager.untrackExpiredMessage(messageId); return; } - TextChannel textChannel = guild.getTextChannelById(channelId); + + // if this is a DM + if(msgChannelType == ChannelType.PRIVATE) + { + String userId = databaseManager.getTrackedReplyUserId(messageId); + User user = HidekoBot.getAPI().retrieveUserById(userId).complete(); + if(user == null) + { + // if user is not found, consider it expired + // (deleted profile, or blocked the bot) + databaseManager.untrackExpiredMessage(messageId); + return; + } + + textChannel = user.openPrivateChannel().complete(); + } + else + { + String guildId = databaseManager.getQueuedExpiringMessageGuild(messageId); + Guild guild = HidekoBot.getAPI().getGuildById(guildId); + if(guild == null) + { + // if guild is not found, consider it expired + // (server was deleted or bot was kicked) + databaseManager.untrackExpiredMessage(messageId); + return; + } + textChannel = guild.getTextChannelById(channelId); + } + if(textChannel == null) { // if channel is not found, count it as expired