WebSocketProtocolHandler.java
package sprout.server.builtins;
import sprout.beans.annotation.Component;
import sprout.mvc.http.HttpRequest;
import sprout.mvc.http.parser.HttpRequestParser;
import sprout.server.AcceptableProtocolHandler;
import sprout.server.ProtocolHandler;
import sprout.server.ReadableProtocolHandler;
import sprout.server.argument.WebSocketArgumentResolver;
import sprout.server.websocket.*;
import sprout.server.websocket.endpoint.WebSocketEndpointInfo;
import sprout.server.websocket.endpoint.WebSocketEndpointRegistry;
import sprout.server.websocket.framehandler.FrameHandler;
import sprout.server.websocket.handler.WebSocketHandshakeHandler;
import sprout.server.websocket.message.WebSocketMessageDispatcher;
import sprout.server.websocket.message.WebSocketMessageParser;
import java.io.*;
import java.lang.reflect.Method;
import java.net.Socket;
import java.nio.ByteBuffer;
import java.nio.channels.SelectionKey;
import java.nio.channels.Selector;
import java.nio.channels.SocketChannel;
import java.nio.charset.StandardCharsets;
import java.util.List;
import java.util.Map;
import java.util.UUID;
@Component
public class WebSocketProtocolHandler implements AcceptableProtocolHandler {
private final WebSocketHandshakeHandler handshakeHandler;
private final WebSocketContainer webSocketContainer;
private final WebSocketEndpointRegistry endpointRegistry;
private final HttpRequestParser httpRequestParser;
private final WebSocketFrameParser frameParser;
private final WebSocketFrameEncoder frameEncoder;
private final List<WebSocketArgumentResolver> webSocketArgumentResolvers;
private final List<WebSocketMessageDispatcher> messageDispatchers;
private final CloseListener closeListener;
private final List<FrameHandler> frameHandlers;
public WebSocketProtocolHandler(
WebSocketHandshakeHandler handshakeHandler,
WebSocketContainer webSocketContainer,
WebSocketEndpointRegistry endpointRegistry,
HttpRequestParser httpRequestParser,
WebSocketFrameParser frameParser,
WebSocketFrameEncoder frameEncoder,
List<WebSocketArgumentResolver> webSocketArgumentResolvers,
List<WebSocketMessageDispatcher> messageDispatchers,
CloseListener closeListener,
List<FrameHandler> frameHandlers
) {
this.handshakeHandler = handshakeHandler;
this.webSocketContainer = webSocketContainer;
this.endpointRegistry = endpointRegistry;
this.httpRequestParser = httpRequestParser;
this.frameParser = frameParser;
this.frameEncoder = frameEncoder;
this.webSocketArgumentResolvers = webSocketArgumentResolvers;
this.messageDispatchers = messageDispatchers;
this.closeListener = closeListener;
this.frameHandlers = frameHandlers;
}
@Override
public boolean supports(String protocol) {
return "WEBSOCKET".equals(protocol);
}
@Override
public void accept(SocketChannel channel, Selector selector, ByteBuffer byteBuffer) throws Exception {
Socket socket = channel.socket();
BufferedReader httpReader = new BufferedReader(new InputStreamReader(socket.getInputStream())); // 초기 HTTP 파싱용
BufferedWriter httpWriter = new BufferedWriter(new OutputStreamWriter(socket.getOutputStream()));
// 1. 초기 HTTP 요청 파싱 (웹소켓 핸드셰이크 요청)
String rawHttpRequest = readRawHttpRequestContent(httpReader);
HttpRequest<?> request = httpRequestParser.parse(rawHttpRequest);
if (!request.isValid()) {
System.out.println("Empty or invalid HTTP request for websocket handshake. Closing socket.");
socket.close();
return;
}
// 2. 웹소켓 엔드포인트 찾기
String requestPath = request.getPath();
WebSocketEndpointInfo endpointInfo = endpointRegistry.getEndpointInfo(requestPath);
if (endpointInfo == null) {
sendHttpResponse(httpWriter, 404, "Not Found", "No WebSocket endpoint found for " + requestPath);
socket.close();
return;
}
// 3. 핸드셰이크 수행
boolean handshakeSuccess = handshakeHandler.performHandshake(request, httpWriter); // httpWriter 사용
if (!handshakeSuccess) {
System.out.println("WebSocket handshake failed. Closing socket.");
channel.close();
return;
}
// 4. 핸드셰이크 성공 후 WebSocketSession 초기화 및 등록
String sessionId = UUID.randomUUID().toString();
Map<String, String> pathVars = endpointInfo.getPathPattern().extractPathVariables(request.getPath());
// DefaultWebSocketSession 생성 시 argumentResolvers와 messageParser 전달
WebSocketSession wsSession = new DefaultWebSocketSession(sessionId, channel, selector, request, endpointInfo, frameParser, frameEncoder, pathVars, webSocketArgumentResolvers, messageDispatchers, closeListener, frameHandlers);
webSocketContainer.addSession(endpointInfo.getPathPattern().getOriginalPattern(), wsSession);
SelectionKey key = channel.register(selector, SelectionKey.OP_READ);
key.attach(wsSession);
wsSession.callOnOpenMethod();
}
private String readRawHttpRequestContent(BufferedReader in) throws IOException {
StringBuilder sb = new StringBuilder();
String line;
int contentLength = 0;
// 헤더 끝을 나타내는 플래그
boolean headersDone = false;
// HTTP 요청 라인 + 헤더 읽기
// readLine()이 null을 반환하면 클라이언트가 연결을 끊은 것
// line.isEmpty()는 헤더 끝의 빈 줄을 의미
while ((line = in.readLine()) != null) {
if (line.isEmpty()) { // 빈 줄은 헤더의 끝을 의미 (CRLFCRLF 또는 LF LF)
headersDone = true;
break;
}
sb.append(line).append("\r\n"); // HTTP 규격에 맞게 CRLF 추가
if (line.toLowerCase().startsWith("content-length:")) {
try {
contentLength = Integer.parseInt(line.substring(line.indexOf(':') + 1).trim());
} catch (NumberFormatException e) {
System.err.println("Warning: Invalid Content-Length header: " + line);
contentLength = 0; // 파싱 실패 시 0으로 설정
}
}
}
sb.append("\r\n"); // 헤더와 바디 구분자 (readLine()이 빈 줄을 이미 제거했을 수도 있지만, 안전을 위해 추가)
// HTTP 바디 읽기 (Content-Length가 있고, 헤더가 끝났을 경우에만)
if (contentLength > 0 && headersDone) {
char[] body = new char[contentLength];
int totalRead = 0;
int read;
// Content-Length만큼 정확히 읽으려고 시도
// read()는 모든 바이트를 한 번에 읽지 않을 수 있으므로 루프 필요
while (totalRead < contentLength && (read = in.read(body, totalRead, contentLength - totalRead)) != -1) {
totalRead += read;
}
sb.append(body, 0, totalRead); // 읽은 만큼만 추가
}
return sb.toString();
}
// HTTP 응답을 보내는 헬퍼 메서드 (핸드셰이크 실패 또는 엔드포인트 없을 때)
private void sendHttpResponse(BufferedWriter out, int statusCode, String statusText, String message) throws IOException {
out.write("HTTP/1.1 " + statusCode + " " + statusText + "\r\n");
out.write("Content-Type: text/plain;charset=UTF-8\r\n");
out.write("Content-Length: " + message.getBytes(StandardCharsets.UTF_8).length + "\r\n");
out.write("\r\n");
out.write(message);
out.flush();
}
}