Refactor RateLimitAspect to use authenticated UUID instead of client IP for rate limiting. Enhance AuthenticatedUuid with optional UUID retrieval method.

This commit is contained in:
akastijn 2025-10-24 22:27:04 +02:00
parent e766fd1125
commit 8b265514a6
2 changed files with 36 additions and 3 deletions

View File

@ -10,6 +10,7 @@ import org.springframework.security.oauth2.jwt.Jwt;
import org.springframework.stereotype.Service; import org.springframework.stereotype.Service;
import org.springframework.web.server.ResponseStatusException; import org.springframework.web.server.ResponseStatusException;
import java.util.Optional;
import java.util.UUID; import java.util.UUID;
@Slf4j @Slf4j
@ -45,4 +46,28 @@ public class AuthenticatedUuid {
throw new ResponseStatusException(HttpStatus.BAD_REQUEST, "Invalid UUID format"); throw new ResponseStatusException(HttpStatus.BAD_REQUEST, "Invalid UUID format");
} }
} }
/**
* Extracts the authenticated user's UUID from the JWT token.
*
* @return The UUID of the authenticated user
*/
public Optional<UUID> tryGetAuthenticatedUserUuid() {
Authentication authentication = SecurityContextHolder.getContext().getAuthentication();
if (authentication == null || !(authentication.getPrincipal() instanceof Jwt jwt)) {
if (unsecured) {
return Optional.of(UUID.fromString("55e46bc3-2a29-4c53-850f-dbd944dc5c5f"));
}
return Optional.empty();
}
String stringUuid = jwt.getSubject();
try {
return Optional.of(UUID.fromString(stringUuid));
} catch (IllegalArgumentException e) {
return Optional.empty();
}
}
} }

View File

@ -1,5 +1,6 @@
package com.alttd.altitudeweb.services.limits; package com.alttd.altitudeweb.services.limits;
import com.alttd.altitudeweb.controllers.data_from_auth.AuthenticatedUuid;
import jakarta.servlet.http.HttpServletRequest; import jakarta.servlet.http.HttpServletRequest;
import jakarta.servlet.http.HttpServletResponse; import jakarta.servlet.http.HttpServletResponse;
import lombok.RequiredArgsConstructor; import lombok.RequiredArgsConstructor;
@ -16,6 +17,8 @@ import org.springframework.web.context.request.ServletRequestAttributes;
import java.lang.reflect.Method; import java.lang.reflect.Method;
import java.time.Duration; import java.time.Duration;
import java.util.Optional;
import java.util.UUID;
@Aspect @Aspect
@Component @Component
@ -24,6 +27,7 @@ import java.time.Duration;
public class RateLimitAspect { public class RateLimitAspect {
private final InMemoryRateLimiterService rateLimiterService; private final InMemoryRateLimiterService rateLimiterService;
private final AuthenticatedUuid authenticatedUuid;
@Around(""" @Around("""
@annotation(com.alttd.altitudeweb.services.limits.RateLimit) @annotation(com.alttd.altitudeweb.services.limits.RateLimit)
@ -37,7 +41,6 @@ public class RateLimitAspect {
HttpServletRequest request = requestAttributes.getRequest(); HttpServletRequest request = requestAttributes.getRequest();
HttpServletResponse response = requestAttributes.getResponse(); HttpServletResponse response = requestAttributes.getResponse();
String clientIp = request.getRemoteAddr();
MethodSignature signature = (MethodSignature) joinPoint.getSignature(); MethodSignature signature = (MethodSignature) joinPoint.getSignature();
Method method = signature.getMethod(); Method method = signature.getMethod();
@ -54,7 +57,12 @@ public class RateLimitAspect {
Duration duration = Duration.ofSeconds(rateLimit.timeUnit().toSeconds(rateLimit.timeValue())); Duration duration = Duration.ofSeconds(rateLimit.timeUnit().toSeconds(rateLimit.timeValue()));
String customKey = rateLimit.key(); String customKey = rateLimit.key();
String key = clientIp + "-" + (customKey.isEmpty() ? method.getName() : customKey); Optional<UUID> optionalUUID = authenticatedUuid.tryGetAuthenticatedUserUuid();
if (optionalUUID.isEmpty()) {
return joinPoint.proceed();
}
UUID uuid = optionalUUID.get();
String key = uuid + "-" + (customKey.isEmpty() ? method.getName() : customKey);
boolean allowed = rateLimiterService.tryAcquire(key, limit, duration); boolean allowed = rateLimiterService.tryAcquire(key, limit, duration);
@ -67,7 +75,7 @@ public class RateLimitAspect {
return joinPoint.proceed(); return joinPoint.proceed();
} else { } else {
log.warn("Rate limit exceeded for IP: {}, endpoint: {}", clientIp, request.getRequestURI()); log.warn("Rate limit exceeded for uuid: {}, endpoint: {}", uuid, request.getRequestURI());
Duration nextResetTime = rateLimiterService.getNextResetTime(key, duration); Duration nextResetTime = rateLimiterService.getNextResetTime(key, duration);