Project/AutoSchedule

8.5일차 - Spring WebSocket 보안 3단계 추가

sowon02 2025. 11. 11. 16:51

어제는 실시간 일정 편집, 자동 스케줄 배포, 협업 알림을 위해 WebSocket을 구현했다.
오늘은 이 WebSocket 통신의 보안성과 안정성을 높이기 위해, 세 가지 단계의 보안 인터셉터를 추가했다.
핸드셰이크 → 인증 → 트래픽 제어까지, STOMP 전 구간을 검증하는 구조다.


1. StrictHandshakeInterceptor – Origin / TLS 검사

WebSocket 연결 직전에 실행되는 핸드셰이크 인터셉터로, 다음 두 가지를 검사한다.

  • 허용 Origin 확인: 외부 도메인의 비정상 요청 차단
  • TLS 강제 여부 검사: 배포 환경에서는 wss:// 만 허용

이 값들은 모두 application.properties 의 app.websocket.allowed-origins, app.websocket.require-tls 로 관리된다.

/**
 * WebSocket Handshake 단계에서 Origin과 TLS를 검증하는 인터셉터.
 * 허용되지 않은 도메인이나 비보안 연결을 조기에 차단한다.
 */
@Component
public class StrictHandshakeInterceptor implements HandshakeInterceptor {

    private final WebSocketSecurityProperties properties;

    public StrictHandshakeInterceptor(WebSocketSecurityProperties properties) {
        this.properties = properties;
    }

    @Override
    public boolean beforeHandshake(@NonNull ServerHttpRequest request,
                                   @NonNull ServerHttpResponse response,
                                   @NonNull WebSocketHandler wsHandler,
                                   @NonNull java.util.Map<String, Object> attributes) {
        // 1) Origin 헤더가 허용된 도메인인지 확인
        if (!isOriginAllowed(request.getHeaders())) {
            return false;
        }
        // 2) 배포 환경에서 TLS가 강제된다면 https/wss 연결인지 검사
        if (properties.isRequireSecureHandshake() && !isSecure(request)) {
            return false;
        }
        return true;
    }

    @Override
    public void afterHandshake(@NonNull ServerHttpRequest request,
                               @NonNull ServerHttpResponse response,
                               @NonNull WebSocketHandler wsHandler,
                               Exception exception) {
        // no-op
    }

    /**
     * Origin 헤더가 비어 있거나 허용 목록에 포함되어 있는지 점검한다.
     */
    private boolean isOriginAllowed(HttpHeaders headers) {
        List<String> origins = headers.get(HttpHeaders.ORIGIN);
        if (origins == null || origins.isEmpty()) {
            // Origin 헤더가 없는 경우(같은 오리진)는 허용
            return true;
        }
        return origins.stream()
                .anyMatch(origin -> properties.getAllowedOrigins().contains(origin));
    }

    /**
     * 요청이 https/wss인지 확인한다. 역프록시 환경을 고려해 X-Forwarded-Proto도 검사한다.
     */
    private boolean isSecure(ServerHttpRequest request) {
        if (request instanceof ServletServerHttpRequest servletRequest) {
            HttpServletRequest httpServletRequest = servletRequest.getServletRequest();
            if (httpServletRequest.isSecure()) {
                return true;
            }
            String forwardedProto = httpServletRequest.getHeader("X-Forwarded-Proto");
            if (forwardedProto != null) {
                return "https".equalsIgnoreCase(forwardedProto) || "wss".equalsIgnoreCase(forwardedProto);
            }
        }

        URI uri = request.getURI();
        return Objects.equals(uri.getScheme(), "https") || Objects.equals(uri.getScheme(), "wss");
    }
}

2. StompJwtChannelInterceptor – STOMP 프레임별 JWT 인증

핸드셰이크에서 연결이 허용되더라도, 매 메시지마다 JWT 유효성을 검증해야 한다.

이 인터셉터는 STOMP 프레임의 헤더(Authorization)에서 토큰을 파싱하고, 유효하지 않으면 메시지 자체를 drop시킨다.
이로써 “한 번만 인증하고 끝”이 아니라, 모든 송수신 프레임 단위로 인증이 적용된다.

/**
 * STOMP CONNECT/SEND/SUBSCRIBE 프레임마다 JWT를 검증하는 인터셉터.
 * WebSocket 연결 이후에도 각 메시지 단위로 인증 정보를 확인한다.
 */
@Component
public class StompJwtChannelInterceptor implements ChannelInterceptor {

    private static final String AUTHORIZATION_HEADER = "Authorization";
    private static final String BEARER_PREFIX = "Bearer ";

    private final JwtUtil jwtUtil;

    /**
     * SockJS CONNECT 프레임에서 토큰을 전파하지 못하는 경우를 대비하여
     * handshake 단계에서 저장한 토큰을 sessionId 기준으로 보관한다.
     */
    private final Map<String, String> sessionTokenCache = new ConcurrentHashMap<>();

    public StompJwtChannelInterceptor(JwtUtil jwtUtil) {
        this.jwtUtil = jwtUtil;
    }

    @Override
    public Message<?> preSend(Message<?> message, MessageChannel channel) {
        StompHeaderAccessor accessor = StompHeaderAccessor.wrap(message);
        StompCommand command = accessor.getCommand();
        if (command == null) {
            // STOMP 프레임이 아니면 그대로 통과
            return message;
        }

        switch (command) {
            case CONNECT -> handleConnect(accessor);               // 최초 연결 시 토큰 검증
            case SEND, SUBSCRIBE -> ensureAuthenticated(accessor); // 프레임 전송/구독 시 인증 여부 확인
            case DISCONNECT -> cleanupSession(accessor);           // 세션 정리
            default -> {
            }
        }

        return message;
    }

    private void handleConnect(StompHeaderAccessor accessor) {
        String sessionId = accessor.getSessionId();
        String rawToken = resolveToken(accessor);

        if (!StringUtils.hasText(rawToken)) {
            // CONNECT 단계에서 토큰이 없으면 연결 자체를 거부
            throw new MessageDeliveryException("Missing Authorization header for STOMP CONNECT");
        }

        String token = normalizeBearerToken(rawToken);
        if (!jwtUtil.isTokenValid(token)) {
            throw new MessageDeliveryException("Invalid JWT token in STOMP CONNECT");
        }

        Long userId = jwtUtil.getUserIdFromToken(token);
        Principal principal = new UsernamePasswordAuthenticationToken(
                userId,
                null,
                Collections.emptyList()
        );

        accessor.setUser(principal);
        accessor.getSessionAttributes().put("userId", userId);
        accessor.getSessionAttributes().put("token", token);

        if (StringUtils.hasText(sessionId)) {
            sessionTokenCache.put(sessionId, token);
        }
    }

    private void ensureAuthenticated(StompHeaderAccessor accessor) {
        // CONNECT에서 인증된 Principal이 없으면 세션 캐시에서 토큰을 복구한다.
        Principal principal = accessor.getUser();
        if (principal == null) {
            String sessionId = accessor.getSessionId();
            String token = sessionId != null ? sessionTokenCache.get(sessionId) : null;
            if (!StringUtils.hasText(token) || !jwtUtil.isTokenValid(token)) {
                throw new MessageDeliveryException("Unauthenticated STOMP frame rejected");
            }
            Long userId = jwtUtil.getUserIdFromToken(token);
            principal = new UsernamePasswordAuthenticationToken(
                    userId,
                    null,
                    Collections.emptyList()
            );
            accessor.setUser(principal);
        }
    }

    private void cleanupSession(StompHeaderAccessor accessor) {
        String sessionId = accessor.getSessionId();
        if (sessionId != null) {
            sessionTokenCache.remove(sessionId);
        }
    }

    @Nullable
    private String resolveToken(StompHeaderAccessor accessor) {
        // STOMP 헤더 우선, 없으면 handshake 단계에서 저장한 세션 속성을 확인
        List<String> nativeHeaders = accessor.getNativeHeader(AUTHORIZATION_HEADER);
        if (!CollectionUtils.isEmpty(nativeHeaders)) {
            return nativeHeaders.get(0);
        }

        Object tokenAttr = accessor.getSessionAttributes() != null
                ? accessor.getSessionAttributes().get("token")
                : null;
        if (tokenAttr instanceof String tokenString) {
            return tokenString;
        }

        return null;
    }

    private String normalizeBearerToken(String value) {
        if (!StringUtils.hasText(value)) {
            return value;
        }
        if (value.startsWith(BEARER_PREFIX)) {
            return value.substring(BEARER_PREFIX.length()).trim();
        }
        return value;
    }
}

3. StompRateLimitingChannelInterceptor – 메시지 폭주 방지

일부 클라이언트가 짧은 시간에 다량의 메시지를 전송하는 걸 막기 위해, RateLimiter(예: Bucket4j) 기반의 인터셉터를 추가했다.
초당 전송 가능한 메시지 수를 제한하여 서버 부하 및 악의적 요청을 방어한다.

/**
 * STOMP SEND 프레임에 대한 전송 빈도 및 페이로드 크기를 제어하는 인터셉터.
 * 사용자 또는 세션이 과도한 트래픽을 발생시키지 못하도록 보호한다.
 */
@Component
public class StompRateLimitingChannelInterceptor implements ChannelInterceptor {

    private final WebSocketSecurityProperties properties;
    private final Map<String, RateWindow> windows = new ConcurrentHashMap<>();

    public StompRateLimitingChannelInterceptor(WebSocketSecurityProperties properties) {
        this.properties = properties;
    }

    @Override
    public Message<?> preSend(Message<?> message, MessageChannel channel) {
        StompHeaderAccessor accessor = StompHeaderAccessor.wrap(message);
        StompCommand command = accessor.getCommand();
        if (command == null) {
            // STOMP 메시지가 아니면 그대로 통과
            return message;
        }

        switch (command) {
            case SEND -> {
                enforceRateLimit(accessor);  // 초과 전송 시 예외
                enforcePayloadLimit(message); // 메시지 크기 제한
            }
            case DISCONNECT -> cleanup(accessor);
            default -> {
            }
        }

        return message;
    }

    private void enforceRateLimit(StompHeaderAccessor accessor) {
        String key = resolveLimiterKey(accessor);
        if (key == null) {
            throw new MessageDeliveryException("Unauthenticated sessions cannot send STOMP messages");
        }

        long now = Instant.now().toEpochMilli();
        long windowMillis = properties.getRateLimit().getWindowSeconds() * 1000L;

        RateWindow window = windows.compute(key, (k, existing) -> {
            if (existing == null || now - existing.windowStartMs >= windowMillis) {
                // 새 윈도우 시작
                return new RateWindow(now, 1);
            }
            int nextCount = existing.count + 1;
            if (nextCount > properties.getRateLimit().getMaxMessages()) {
                existing.count = nextCount;
                return existing;
            }
            existing.count = nextCount;
            return existing;
        });

        if (window != null
                && now - window.windowStartMs < windowMillis
                && window.count > properties.getRateLimit().getMaxMessages()) {
            // 설정된 제한을 초과하면 메시지를 거부
            throw new MessageDeliveryException("Rate limit exceeded for STOMP session");
        }
    }

    private void enforcePayloadLimit(Message<?> message) {
        Object payload = message.getPayload();
        int limit = properties.getMessageSizeLimitBytes();

        if (payload instanceof byte[] bytes && bytes.length > limit) {
            throw new MessageDeliveryException("STOMP payload exceeds allowed size");
        }
        if (payload instanceof String text && text.getBytes(StandardCharsets.UTF_8).length > limit) {
            throw new MessageDeliveryException("STOMP payload exceeds allowed size");
        }
    }

    private void cleanup(StompHeaderAccessor accessor) {
        // 세션 종료 시 레이트 윈도우 캐시 정리
        String key = resolveLimiterKey(accessor);
        if (key != null) {
            windows.remove(key);
        }
    }

    private String resolveLimiterKey(StompHeaderAccessor accessor) {
        Principal user = accessor.getUser();
        if (user != null) {
            return user.getName();
        }
        return accessor.getSessionId();
    }

    private static final class RateWindow {
        private final long windowStartMs;
        private int count;

        private RateWindow(long windowStartMs, int count) {
            this.windowStartMs = windowStartMs;
            this.count = count;
        }
    }
}

4. WebSocketConfig – 인터셉터 통합

WebSocketConfig에서 세 인터셉터를 등록하고, STOMP 브로커 설정과 함께 관리했다.
또한 허용 Origin, 메시지 크기 제한 등은 모두 외부 설정(application.properties)으로 분리했다.

/**
 * STOMP 엔드포인트와 브로커 설정을 담당하고,
 * Handshake/JWT/RateLimit 인터셉터를 전체 WebSocket 파이프라인에 적용한다.
 */
@Configuration
@EnableWebSocketMessageBroker
public class WebSocketConfig implements WebSocketMessageBrokerConfigurer {

    private final StompJwtChannelInterceptor jwtChannelInterceptor;
    private final StompRateLimitingChannelInterceptor rateLimitingChannelInterceptor;
    private final StrictHandshakeInterceptor strictHandshakeInterceptor;
    private final WebSocketSecurityProperties webSocketSecurityProperties;

    public WebSocketConfig(StompJwtChannelInterceptor jwtChannelInterceptor,
                           StompRateLimitingChannelInterceptor rateLimitingChannelInterceptor,
                           StrictHandshakeInterceptor strictHandshakeInterceptor,
                           WebSocketSecurityProperties webSocketSecurityProperties) {
        this.jwtChannelInterceptor = jwtChannelInterceptor;
        this.rateLimitingChannelInterceptor = rateLimitingChannelInterceptor;
        this.strictHandshakeInterceptor = strictHandshakeInterceptor;
        this.webSocketSecurityProperties = webSocketSecurityProperties;
    }

    @Override
    public void configureMessageBroker(@NonNull MessageBrokerRegistry registry) {
        registry.enableSimpleBroker("/topic", "/queue");
        registry.setApplicationDestinationPrefixes("/app");
    }

    @Override
    public void registerStompEndpoints(@NonNull StompEndpointRegistry registry) {
        String[] allowedOrigins = webSocketSecurityProperties.getAllowedOrigins()
                .toArray(new String[0]);

        registry.addEndpoint("/ws")
                // Handshake 인터셉터로 Origin/TLS 검증
                .addInterceptors(strictHandshakeInterceptor)
                // application.properties에서 관리하는 Origin 목록 적용
                .setAllowedOrigins(allowedOrigins)
                .withSockJS();
    }

    @Override
    public void configureClientInboundChannel(@NonNull ChannelRegistration registration) {
        // 클라이언트에서 들어오는 STOMP 프레임에 JWT와 레이트 리밋 순으로 적용
        registration.interceptors(rateLimitingChannelInterceptor, jwtChannelInterceptor);
    }

    @Override
    public void configureWebSocketTransport(@NonNull WebSocketTransportRegistration registration) {
        // 메시지 크기 제한도 외부 설정으로 제어한다.
        registration.setMessageSizeLimit(webSocketSecurityProperties.getMessageSizeLimitBytes());
    }
}

5. application.properties

app.websocket.allowed-origins=http://localhost:5173,http://localhost:3000
app.websocket.require-secure-handshake=false
app.websocket.message-size-limit-bytes=65536
app.websocket.rate-limit.window-seconds=10
app.websocket.rate-limit.max-messages=60

6. 프론트엔드 ws.ts 수정

프론트엔드에서도 보안을 강화했다.
브라우저 프로토콜을 감지해 자동으로 wss:// 접속하도록 변경:

const protocol = window.location.protocol === "https:" ? "wss" : "ws";
const socket = new WebSocket(`${protocol}://${location.host}/ws`);