diff --git a/backend/src/main/java/com/alttd/altitudeweb/controllers/data_from_auth/AuthenticatedUuid.java b/backend/src/main/java/com/alttd/altitudeweb/controllers/data_from_auth/AuthenticatedUuid.java index c5ea419..c8ac6c6 100644 --- a/backend/src/main/java/com/alttd/altitudeweb/controllers/data_from_auth/AuthenticatedUuid.java +++ b/backend/src/main/java/com/alttd/altitudeweb/controllers/data_from_auth/AuthenticatedUuid.java @@ -10,6 +10,7 @@ import org.springframework.security.oauth2.jwt.Jwt; import org.springframework.stereotype.Service; import org.springframework.web.server.ResponseStatusException; +import java.util.Optional; import java.util.UUID; @Slf4j @@ -45,4 +46,28 @@ public class AuthenticatedUuid { throw new ResponseStatusException(HttpStatus.BAD_REQUEST, "Invalid UUID format"); } } + + /** + * Extracts the authenticated user's UUID from the JWT token. + * + * @return The UUID of the authenticated user + */ + public Optional tryGetAuthenticatedUserUuid() { + Authentication authentication = SecurityContextHolder.getContext().getAuthentication(); + + if (authentication == null || !(authentication.getPrincipal() instanceof Jwt jwt)) { + if (unsecured) { + return Optional.of(UUID.fromString("55e46bc3-2a29-4c53-850f-dbd944dc5c5f")); + } + return Optional.empty(); + } + + String stringUuid = jwt.getSubject(); + + try { + return Optional.of(UUID.fromString(stringUuid)); + } catch (IllegalArgumentException e) { + return Optional.empty(); + } + } } diff --git a/backend/src/main/java/com/alttd/altitudeweb/services/limits/RateLimitAspect.java b/backend/src/main/java/com/alttd/altitudeweb/services/limits/RateLimitAspect.java index ba0d939..9bb6681 100644 --- a/backend/src/main/java/com/alttd/altitudeweb/services/limits/RateLimitAspect.java +++ b/backend/src/main/java/com/alttd/altitudeweb/services/limits/RateLimitAspect.java @@ -1,5 +1,6 @@ package com.alttd.altitudeweb.services.limits; +import com.alttd.altitudeweb.controllers.data_from_auth.AuthenticatedUuid; import jakarta.servlet.http.HttpServletRequest; import jakarta.servlet.http.HttpServletResponse; import lombok.RequiredArgsConstructor; @@ -16,6 +17,8 @@ import org.springframework.web.context.request.ServletRequestAttributes; import java.lang.reflect.Method; import java.time.Duration; +import java.util.Optional; +import java.util.UUID; @Aspect @Component @@ -24,6 +27,7 @@ import java.time.Duration; public class RateLimitAspect { private final InMemoryRateLimiterService rateLimiterService; + private final AuthenticatedUuid authenticatedUuid; @Around(""" @annotation(com.alttd.altitudeweb.services.limits.RateLimit) @@ -37,7 +41,6 @@ public class RateLimitAspect { HttpServletRequest request = requestAttributes.getRequest(); HttpServletResponse response = requestAttributes.getResponse(); - String clientIp = request.getRemoteAddr(); MethodSignature signature = (MethodSignature) joinPoint.getSignature(); Method method = signature.getMethod(); @@ -54,7 +57,12 @@ public class RateLimitAspect { Duration duration = Duration.ofSeconds(rateLimit.timeUnit().toSeconds(rateLimit.timeValue())); String customKey = rateLimit.key(); - String key = clientIp + "-" + (customKey.isEmpty() ? method.getName() : customKey); + Optional optionalUUID = authenticatedUuid.tryGetAuthenticatedUserUuid(); + if (optionalUUID.isEmpty()) { + return joinPoint.proceed(); + } + UUID uuid = optionalUUID.get(); + String key = uuid + "-" + (customKey.isEmpty() ? method.getName() : customKey); boolean allowed = rateLimiterService.tryAcquire(key, limit, duration); @@ -67,7 +75,7 @@ public class RateLimitAspect { return joinPoint.proceed(); } else { - log.warn("Rate limit exceeded for IP: {}, endpoint: {}", clientIp, request.getRequestURI()); + log.warn("Rate limit exceeded for uuid: {}, endpoint: {}", uuid, request.getRequestURI()); Duration nextResetTime = rateLimiterService.getNextResetTime(key, duration);