DefaultWebSocketHandshakeHandler.java
package sprout.server.websocket.handler;
import sprout.beans.annotation.Component;
import sprout.mvc.http.HttpRequest;
import java.io.IOException;
import java.nio.ByteBuffer;
import java.nio.channels.SocketChannel;
import java.nio.charset.StandardCharsets;
import java.security.MessageDigest;
import java.security.NoSuchAlgorithmException;
import java.util.Base64;
import java.util.Map;
@Component
public class DefaultWebSocketHandshakeHandler implements WebSocketHandshakeHandler{
private static final String WEBSOCKET_GUID = "258EAFA5-E914-47DA-95CA-C5AB0DC85B11";
@Override
public boolean performHandshake(HttpRequest<?> request, SocketChannel channel) throws IOException {
// 1. 필수 헤더 검증
Map<String, String> headers = request.getHeaders();
String upgradeHeader = headers.get("Upgrade");
String connectionHeader = headers.get("Connection");
String secWebSocketKey = headers.get("Sec-WebSocket-Key");
String secWebSocketVersion = headers.get("Sec-WebSocket-Version");
System.out.println(upgradeHeader + ", " + connectionHeader + ", " + secWebSocketKey + ", " + secWebSocketVersion + " : " + request.getPath());
// Connection 헤더는 "Upgrade"를 포함해야 함 (쉼표로 구분된 여러 값 가능)
boolean hasUpgradeConnection = connectionHeader != null &&
connectionHeader.toLowerCase().contains("upgrade");
if (!"websocket".equalsIgnoreCase(upgradeHeader) ||
!hasUpgradeConnection ||
secWebSocketKey == null || secWebSocketKey.isBlank() ||
!"13".equals(secWebSocketVersion)) { // WebSocket Version 13 (RFC 6455)
sendHandshakeErrorResponse(channel, 400, "Bad Request", "Invalid WebSocket handshake request headers.");
return false;
}
// 2. Sec-WebSocket-Accept 값 계산
String secWebSocketAccept;
try {
secWebSocketAccept = generateSecWebSocketAccept(secWebSocketKey);
} catch (NoSuchAlgorithmException e) {
System.err.println("SHA-1 algorithm not found for WebSocket handshake: " + e.getMessage());
sendHandshakeErrorResponse(channel, 500, "Internal Server Error", "Server error during handshake.");
return false;
}
// 3. 핸드셰이크 성공 응답 전송
String response = "HTTP/1.1 101 Switching Protocols\r\n" +
"Upgrade: websocket\r\n" +
"Connection: Upgrade\r\n" +
"Sec-WebSocket-Accept: " + secWebSocketAccept + "\r\n" +
"\r\n";
ByteBuffer buffer = ByteBuffer.wrap(response.getBytes(StandardCharsets.UTF_8));
while (buffer.hasRemaining()) {
channel.write(buffer);
}
System.out.println("WebSocket handshake successful for path: " + request.getPath());
return true;
}
// Sec-WebSocket-Accept 값을 계산하는 헬퍼 메서드
private String generateSecWebSocketAccept(String secWebSocketKey) throws NoSuchAlgorithmException {
String combined = secWebSocketKey + WEBSOCKET_GUID;
MessageDigest sha1 = MessageDigest.getInstance("SHA-1");
byte[] sha1Hash = sha1.digest(combined.getBytes(StandardCharsets.US_ASCII)); // ASCII로 인코딩
return Base64.getEncoder().encodeToString(sha1Hash);
}
// 핸드셰이크 실패 시 HTTP 에러 응답 전송
private void sendHandshakeErrorResponse(SocketChannel channel, int statusCode, String statusText, String message) throws IOException {
String response = "HTTP/1.1 " + statusCode + " " + statusText + "\r\n" +
"Content-Type: text/plain;charset=UTF-8\r\n" +
"Content-Length: " + message.getBytes(StandardCharsets.UTF_8).length + "\r\n" +
"\r\n" +
message;
ByteBuffer buffer = ByteBuffer.wrap(response.getBytes(StandardCharsets.UTF_8));
while (buffer.hasRemaining()) {
channel.write(buffer);
}
}
}