From f0c84e809fcca6216dc2858a9b8fad0996cc89fd Mon Sep 17 00:00:00 2001 From: Teriuihi Date: Sat, 10 Aug 2024 00:50:53 +0200 Subject: [PATCH] Add rate limiting functionality Introduces a new 'rate_limit' table to track request counts by IP and email. Adds `RateLimitQuery` class for querying and inserting rate limits, and `RateLimitEntryDTO` for passing rate limit data. --- .../com/alttd/forms/database/Database.java | 31 +++++++++- .../mail/rate_limitter/RateLimitEntryDTO.java | 6 ++ .../mail/rate_limitter/RateLimitQuery.java | 59 +++++++++++++++++++ 3 files changed, 94 insertions(+), 2 deletions(-) create mode 100644 src/main/java/com/alttd/forms/mail/rate_limitter/RateLimitEntryDTO.java create mode 100644 src/main/java/com/alttd/forms/mail/rate_limitter/RateLimitQuery.java diff --git a/src/main/java/com/alttd/forms/database/Database.java b/src/main/java/com/alttd/forms/database/Database.java index 5516d82..06b610b 100644 --- a/src/main/java/com/alttd/forms/database/Database.java +++ b/src/main/java/com/alttd/forms/database/Database.java @@ -12,8 +12,35 @@ public class Database { public static void createTables() { String[] createTables = { - "CREATE TABLE IF NOT EXISTS verify_form (e_mail VARCHAR(256), verification_code INT, formId INT, PRIMARY KEY(e_mail, verification_code))", - "CREATE TABLE IF NOT EXISTS form (formId INT AUTO_INCREMENT, creation_date BIGINT, form_json TEXT, form_class VARCHAR(64), PRIMARY KEY(formId))" + // language=SQL + """ + CREATE TABLE IF NOT EXISTS verify_form( + e_mail VARCHAR(256), + verification_code INT, + formId INT, + PRIMARY KEY(e_mail, verification_code) + ) + """, + // language=SQL + """ + CREATE TABLE IF NOT EXISTS form( + formId INT AUTO_INCREMENT, + creation_date BIGINT, + form_json TEXT, + form_class VARCHAR(64), + PRIMARY KEY(formId) + ) + """, + // language=SQL + """ + CREATE TABLE IF NOT EXISTS rate_limit( + id INT AUTO_INCREMENT, + time TIMESTAMP, + ip VARCHAR(45), + mail VARCHAR(256), + PRIMARY KEY(id) + ) + """ }; Connection connection = DatabaseConnection.getConnection(); for (String query : createTables) { diff --git a/src/main/java/com/alttd/forms/mail/rate_limitter/RateLimitEntryDTO.java b/src/main/java/com/alttd/forms/mail/rate_limitter/RateLimitEntryDTO.java new file mode 100644 index 0000000..b694e30 --- /dev/null +++ b/src/main/java/com/alttd/forms/mail/rate_limitter/RateLimitEntryDTO.java @@ -0,0 +1,6 @@ +package com.alttd.forms.mail.rate_limitter; + +import java.time.Instant; + +public record RateLimitEntryDTO(Instant time, String ip, String mail) { +} diff --git a/src/main/java/com/alttd/forms/mail/rate_limitter/RateLimitQuery.java b/src/main/java/com/alttd/forms/mail/rate_limitter/RateLimitQuery.java new file mode 100644 index 0000000..af3baf5 --- /dev/null +++ b/src/main/java/com/alttd/forms/mail/rate_limitter/RateLimitQuery.java @@ -0,0 +1,59 @@ +package com.alttd.forms.mail.rate_limitter; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.sql.*; +import java.time.Instant; + +public class RateLimitQuery { + + private static final Logger logger = LoggerFactory.getLogger(RateLimitQuery.class); + + public int getIpHits(Connection connection, String ip, Instant after) throws SQLException { + String sql = "SELECT COUNT(*) AS hits FROM rate_limit WHERE ip = ? AND time > ?"; + try (PreparedStatement stmt = connection.prepareStatement(sql)) { + stmt.setString(1, ip); + stmt.setTimestamp(2, Timestamp.from(after)); + ResultSet resultSet = stmt.executeQuery(); + if (!resultSet.next()) { + return 0; + } + return resultSet.getInt("hits"); + } catch (SQLException e) { + logger.error("Failed get ip hits query for ip: {}", ip, e); + throw e; + } + } + + public int getMailHits(Connection connection, String mail, Instant after) throws SQLException { + String sql = "SELECT COUNT(*) AS hits FROM rate_limit WHERE mail = ? AND time > ?"; + try (PreparedStatement stmt = connection.prepareStatement(sql)) { + stmt.setString(1, mail); + stmt.setTimestamp(2, Timestamp.from(after)); + ResultSet resultSet = stmt.executeQuery(); + if (!resultSet.next()) { + return 0; + } + return resultSet.getInt("hits"); + } catch (SQLException e) { + logger.error("Failed get mail hits query for ip: {}", mail, e); + throw e; + } + } + + public boolean insertRateLimitEntry(Connection connection, RateLimitEntryDTO entry) throws SQLException { + String sql = "INSERT INTO rate_limit (time, ip, mail) VALUES (?, ?, ?)"; + try { + PreparedStatement stmt = connection.prepareStatement(sql); + stmt.setTimestamp(1, Timestamp.from(entry.time())); + stmt.setString(2, entry.ip()); + stmt.setString(3, entry.mail()); + return stmt.executeUpdate() > 0; + } catch (SQLException e) { + logger.error("Failed to store rate limit for ip: {}, mail: {}, time: {}", entry.ip(), entry.mail(), entry.time(), e); + throw e; + } + } + +}