WebSocketProtocolHandler.java

package sprout.server.builtins;

import sprout.beans.annotation.Component;
import sprout.mvc.http.HttpRequest;
import sprout.mvc.http.parser.HttpRequestParser;
import sprout.server.AcceptableProtocolHandler;
import sprout.server.ProtocolHandler;
import sprout.server.ReadableProtocolHandler;
import sprout.server.argument.WebSocketArgumentResolver;
import sprout.server.websocket.*;
import sprout.server.websocket.endpoint.WebSocketEndpointInfo;
import sprout.server.websocket.endpoint.WebSocketEndpointRegistry;
import sprout.server.websocket.framehandler.FrameHandler;
import sprout.server.websocket.handler.WebSocketHandshakeHandler;
import sprout.server.websocket.message.WebSocketMessageDispatcher;
import sprout.server.websocket.message.WebSocketMessageParser;

import java.io.*;
import java.nio.ByteBuffer;
import java.nio.channels.SelectionKey;
import java.nio.channels.Selector;
import java.nio.channels.SocketChannel;
import java.nio.charset.StandardCharsets;
import java.util.List;
import java.util.Map;
import java.util.UUID;

@Component
public class WebSocketProtocolHandler implements AcceptableProtocolHandler {

    private final WebSocketHandshakeHandler handshakeHandler;
    private final WebSocketContainer webSocketContainer;
    private final WebSocketEndpointRegistry endpointRegistry;
    private final HttpRequestParser httpRequestParser;
    private final WebSocketFrameParser frameParser;
    private final WebSocketFrameEncoder frameEncoder;
    private final List<WebSocketArgumentResolver> webSocketArgumentResolvers;
    private final List<WebSocketMessageDispatcher> messageDispatchers;
    private final CloseListener closeListener;
    private final List<FrameHandler> frameHandlers;

    public WebSocketProtocolHandler(
            WebSocketHandshakeHandler handshakeHandler,
            WebSocketContainer webSocketContainer,
            WebSocketEndpointRegistry endpointRegistry,
            HttpRequestParser httpRequestParser,
            WebSocketFrameParser frameParser,
            WebSocketFrameEncoder frameEncoder,
            List<WebSocketArgumentResolver> webSocketArgumentResolvers,
            List<WebSocketMessageDispatcher> messageDispatchers,
            CloseListener closeListener,
            List<FrameHandler> frameHandlers
    ) {
        this.handshakeHandler = handshakeHandler;
        this.webSocketContainer = webSocketContainer;
        this.endpointRegistry = endpointRegistry;
        this.httpRequestParser = httpRequestParser;
        this.frameParser = frameParser;
        this.frameEncoder = frameEncoder;
        this.webSocketArgumentResolvers = webSocketArgumentResolvers;
        this.messageDispatchers = messageDispatchers;
        this.closeListener = closeListener;
        this.frameHandlers = frameHandlers;
    }

    @Override
    public boolean supports(String protocol) {
        return "WEBSOCKET".equals(protocol);
    }

    @Override
    public void accept(SocketChannel channel, Selector selector, ByteBuffer byteBuffer) throws Exception {
        // 1. 초기 HTTP 요청 파싱 (웹소켓 핸드셰이크 요청)
        // ByteBuffer를 사용하여 NIO non-blocking 방식으로 읽기
        String rawHttpRequest = readRawHttpRequestContent(channel, byteBuffer);
        HttpRequest<?> request = httpRequestParser.parse(rawHttpRequest);
        if (!request.isValid()) {
            System.out.println("Empty or invalid HTTP request for websocket handshake. Closing socket.");
            channel.close();
            return;
        }

        // 2. 웹소켓 엔드포인트 찾기
        String requestPath = request.getPath();
        System.out.println("WebSocket handshake request received for path: " + requestPath + ". Trying to find matching endpoint.");
        WebSocketEndpointInfo endpointInfo = endpointRegistry.getEndpointInfo(requestPath);

        if (endpointInfo == null) {
            sendHttpResponse(channel, 404, "Not Found", "No WebSocket endpoint found for " + requestPath);
            channel.close();
            return;
        }

        // 3. 핸드셰이크 수행
        boolean handshakeSuccess = handshakeHandler.performHandshake(request, channel);
        if (!handshakeSuccess) {
            System.out.println("WebSocket handshake failed. Closing socket.");
            channel.close();
            return;
        }

        // 4. 핸드셰이크 성공 후 WebSocketSession 초기화 및 등록
        String sessionId = UUID.randomUUID().toString();
        Map<String, String> pathVars = endpointInfo.getPathPattern().extractPathVariables(request.getPath());

        // DefaultWebSocketSession 생성 시 argumentResolvers와 messageParser 전달
        WebSocketSession wsSession = new DefaultWebSocketSession(sessionId, channel, selector, request, endpointInfo, frameParser, frameEncoder, pathVars, webSocketArgumentResolvers, messageDispatchers, closeListener, frameHandlers);
        webSocketContainer.addSession(endpointInfo.getPathPattern().getOriginalPattern(), wsSession);

        SelectionKey key = channel.register(selector, SelectionKey.OP_READ);
        key.attach(wsSession);
        wsSession.callOnOpenMethod();
    }

    private String readRawHttpRequestContent(SocketChannel channel, ByteBuffer buffer) throws IOException {
        StringBuilder sb = new StringBuilder();

        // 1) 이미 읽은 buffer의 데이터를 먼저 추가
        if (buffer != null && buffer.hasRemaining()) {
            byte[] arr = new byte[buffer.remaining()];
            buffer.get(arr);
            sb.append(new String(arr, StandardCharsets.UTF_8));
        }

        // 2) 이미 완전한 HTTP 요청인지 확인
        String current = sb.toString();
        if (current.contains("\r\n\r\n")) {
            return current;
        }

        // 3) 불완전한 경우, 추가로 읽기 (blocking 모드로 전환)
        boolean wasBlocking = channel.isBlocking();
        try {
            channel.configureBlocking(true);

            ByteBuffer readBuffer = ByteBuffer.allocate(8192);

            // HTTP 헤더 끝(\r\n\r\n)까지 읽기
            while (!sb.toString().contains("\r\n\r\n")) {
                readBuffer.clear();
                int bytesRead = channel.read(readBuffer);

                if (bytesRead == -1) {
                    return ""; // 연결 종료
                }

                if (bytesRead == 0) {
                    break;
                }

                readBuffer.flip();
                byte[] bytes = new byte[readBuffer.remaining()];
                readBuffer.get(bytes);
                sb.append(new String(bytes, StandardCharsets.UTF_8));

                // 너무 큰 요청은 거부 (10KB 제한)
                if (sb.length() > 10240) {
                    throw new IOException("HTTP request too large");
                }
            }

            return sb.toString();

        } finally {
            // 원래 blocking 모드로 복원
            if (!wasBlocking) {
                channel.configureBlocking(false);
            }
        }
    }

    /**
     * NIO 방식으로 HTTP 응답 전송
     */
    private void sendHttpResponse(SocketChannel channel, int statusCode, String statusText, String message) throws IOException {
        String response = "HTTP/1.1 " + statusCode + " " + statusText + "\r\n" +
                         "Content-Type: text/plain;charset=UTF-8\r\n" +
                         "Content-Length: " + message.getBytes(StandardCharsets.UTF_8).length + "\r\n" +
                         "\r\n" +
                         message;

        ByteBuffer buffer = ByteBuffer.wrap(response.getBytes(StandardCharsets.UTF_8));
        while (buffer.hasRemaining()) {
            channel.write(buffer);
        }
    }
}