DefaultWebSocketSession.java

package sprout.server.websocket;

import sprout.mvc.http.HttpRequest;
import sprout.server.WritableHandler;
import sprout.server.argument.WebSocketArgumentResolver;
import sprout.server.websocket.exception.NotEnoughDataException;
import sprout.server.websocket.exception.WebSocketException;
import sprout.server.websocket.endpoint.WebSocketEndpointInfo;
import sprout.server.websocket.framehandler.FrameHandler;
import sprout.server.websocket.framehandler.FrameProcessingContext;
import sprout.server.websocket.message.*;

import java.io.*;
import java.lang.reflect.Method;
import java.lang.reflect.Parameter;
import java.nio.ByteBuffer;
import java.nio.channels.SelectionKey;
import java.nio.channels.Selector;
import java.nio.channels.SocketChannel;
import java.nio.charset.StandardCharsets;
import java.util.*;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ConcurrentLinkedQueue;

import static java.nio.channels.SelectionKey.OP_WRITE;

public class DefaultWebSocketSession implements WebSocketSession, WritableHandler {
    private final String id;
    private final SocketChannel channel;
    private final Selector selector;
    private final HttpRequest<?> handshakeRequest;
    private final Map<String, String> pathParameters;
    private final WebSocketEndpointInfo endpointInfo;
    private final WebSocketFrameParser frameParser;
    private final WebSocketFrameEncoder frameEncoder;
    private final List<WebSocketArgumentResolver> argumentResolvers;
    private final List<WebSocketMessageDispatcher> messageDispatchers;
    private final CloseListener closeListener;

    private final FrameProcessingContext processingContext;
    private volatile boolean open = true;
    private volatile boolean isClosePending = false;
    private final Map<String, Object> userProperties = new ConcurrentHashMap<>();
    private final ByteBuffer readBuffer = ByteBuffer.allocate(65536);
    private final Queue<ByteBuffer> pendingWrites = new ConcurrentLinkedQueue<>();

    private final WebSocketFrameDispatcher frameDispatcher;

    public DefaultWebSocketSession(String id, SocketChannel channel, Selector selector, HttpRequest<?> handshakeRequest, WebSocketEndpointInfo endpointInfo, WebSocketFrameParser frameParser, WebSocketFrameEncoder frameEncoder, Map<String, String> pathParameters, List<WebSocketArgumentResolver> webSocketArgumentResolvers, List<WebSocketMessageDispatcher> messageDispatchers, CloseListener closeListener, List<FrameHandler> frameHandlers) throws IOException {
        this.id = id;
        this.channel = channel;
        this.selector = selector;
        this.handshakeRequest = handshakeRequest;
        this.endpointInfo = endpointInfo;
        this.frameParser = frameParser;
        this.frameEncoder = frameEncoder;
        this.pathParameters = pathParameters;
        this.argumentResolvers = webSocketArgumentResolvers;
        this.messageDispatchers = messageDispatchers;
        this.closeListener = closeListener;
        this.frameDispatcher = new WebSocketFrameDispatcher(frameHandlers, messageDispatchers);
        this.processingContext = new FrameProcessingContext();
    }

    @Override
    public void close() throws IOException {
        if (open && !isClosePending) {
            System.out.println("Scheduling close for WebSocket session: " + id);
            isClosePending = true; // 종료 요청 표시
            // 종료 프레임 생성 (opcode 0x8, 정상 종료 코드 1000)
            String closeReason = "Closing WebSocket session: " + id + ", Close code is: " + CloseCodes.NORMAL_CLOSURE.getCode() + ".";
            byte[] closePayload = closeReason.getBytes(StandardCharsets.UTF_8);
            byte[] encoded = frameEncoder.encodeControlFrame(0x8, closePayload);
            scheduleWrite(ByteBuffer.wrap(encoded));
        }
    }

    @Override
    public HttpRequest<?> getHandshakeRequest() {
        return handshakeRequest;
    }

    @Override
    public void sendText(String message) throws IOException {
        scheduleWrite(ByteBuffer.wrap(frameEncoder.encodeText(message)));
    }

    @Override
    public void write(SelectionKey key) throws Exception {
        ByteBuffer buf;
        while ((buf = pendingWrites.peek()) != null) {
            channel.write(buf);
            if (buf.hasRemaining()) return;
            pendingWrites.poll();
        }

        if (pendingWrites.isEmpty()) {
            key.interestOps(key.interestOps() & ~OP_WRITE);
            // 큐가 비었고 종료 요청이 있었다면 채널 닫기
            if (isClosePending && open) {
                System.out.println("All pending writes completed, closing channel for session: " + id);
                open = false;
                channel.close();
                if (closeListener != null) {
                    closeListener.onSessionClosed(this);
                }
            }
        }
    }

    @Override
    public void sendBinary(byte[] data) throws IOException {
        scheduleWrite(ByteBuffer.wrap(frameEncoder.encodeBinary(data)));
    }

    @Override
    public void sendPing(byte[] data) throws IOException {
        scheduleWrite(ByteBuffer.wrap(frameEncoder.encodeControlFrame(0x9, data)));
    }

    @Override
    public void sendPong(byte[] data) throws IOException {
        scheduleWrite(ByteBuffer.wrap(frameEncoder.encodeControlFrame(0xA, data)));
    }

    @Override
    public String getId() {
        return id;
    }

    @Override
    public boolean isOpen() {
        return open && channel.isOpen();
    }

    @Override
    public Map<String, Object> getUserProperties() {
        return userProperties;
    }

    @Override
    public void read(SelectionKey key) throws Exception {
        int bytesRead = channel.read(readBuffer);
        if (bytesRead == -1) {
            callOnCloseMethod(CloseCodes.NO_STATUS_CODE);
            close();
            return;
        }

        readBuffer.flip();

        while (readBuffer.remaining() > 0) {
            // 파싱 전에 현재 위치를 마크 (파싱 실패 시 복구 위함)
            readBuffer.mark();

            // ByteBuffer를 직접 읽는 InputStream 어댑터 사용
            InputStream frameInputStream = new ByteBufferInputStream(readBuffer);

            try {
                WebSocketFrame frame = frameParser.parse(frameInputStream);
                // 성공적으로 파싱되면, 실제 처리 로직 실행
                processFrame(frame);
            } catch (NotEnoughDataException e) {
                // 버퍼에 아직 완전한 프레임이 없음 -> 다음 read 이벤트를 기다림
                readBuffer.reset(); // 마크한 위치로 복구
                break; // while 루프 종료
            }
        }
        readBuffer.compact();
    }

    private void scheduleWrite(ByteBuffer buf) {
        pendingWrites.add(buf);
        SelectionKey key = channel.keyFor(selector);
        if (key != null && key.isValid() && (key.interestOps() & OP_WRITE) == 0) {
            // | 연산자로 OP_WRITE 플래그를 추가
            key.interestOps(key.interestOps() | OP_WRITE);
            selector.wakeup(); // Selector가 select()에서 대기 중일 수 있으므로 깨워주기
        }
    }

    private void processFrame(WebSocketFrame frame) throws Exception {
        if (WebSocketFrameDecoder.isCloseFrame(frame)) {
            callOnCloseMethod(WebSocketFrameDecoder.getCloseCode(frame.getPayloadBytes()));
            return;
        } else if (WebSocketFrameDecoder.isPingFrame(frame)) {
            System.out.println("Received Ping frame from client " + id);
            sendPong(frame.getPayloadBytes());
        } else if (WebSocketFrameDecoder.isPongFrame(frame)) {
            System.out.println("Received Pong frame from client " + id);
        } else if (WebSocketFrameDecoder.isDataFrame(frame)) {
            dispatchMessage(frame);
        } else {
            System.err.println("Unknown or unsupported WebSocket opcode: " + frame.getOpcode());
            callOnErrorMethod(new WebSocketException("Unknown WebSocket opcode: " + frame.getOpcode()));
        }
    }


    public void dispatchMessage(WebSocketFrame frame) throws Exception {
        this.processingContext.setCurrentFrame(frame);
        try {
            frameDispatcher.dispatch(this.processingContext, this, pathParameters);
        } catch (Exception e) {
            System.err.println("Error dispatching frame: " + e.getMessage());
            callOnErrorMethod(e); // 에러 핸들러 호출
            close(); // 치명적 오류 시 연결 종료
        }
    }

    @Override
    public WebSocketEndpointInfo getEndpointInfo() {
        return endpointInfo;
    }

    @Override
    public void callOnOpenMethod() throws Exception{
        Method onOpenMethod = endpointInfo.getOnOpenMethod();
        if (onOpenMethod == null) return;

        // InvocationContext 생성
        InvocationContext context = new DefaultInvocationContext(handshakeRequest, this, pathParameters);

        Object[] args = resolveArgs(onOpenMethod, context);
        onOpenMethod.invoke(endpointInfo.getHandlerBean(), args);
    }

    @Override
    public void callOnErrorMethod(Throwable error) throws Exception {
        Method onErrorMethod = endpointInfo.getOnErrorMethod();
        if (onErrorMethod == null) return;

        // InvocationContext 생성
        InvocationContext context = new DefaultInvocationContext(this, pathParameters, error);

        Object[] args = resolveArgs(onErrorMethod, context);
        onErrorMethod.invoke(endpointInfo.getHandlerBean(), args);
    }

    @Override
    public void callOnCloseMethod(CloseCode closeCode) throws Exception {
        Method onCloseMethod = endpointInfo.getOnCloseMethod();
        if (onCloseMethod == null) return;

        // InvocationContext 생성
        InvocationContext context = new DefaultInvocationContext(this, pathParameters, closeCode);

        Object[] args = resolveArgs(onCloseMethod, context);
        onCloseMethod.invoke(endpointInfo.getHandlerBean(), args);
    }

    private Object[] resolveArgs(Method method, InvocationContext context) throws Exception {
        Parameter[] parameters = method.getParameters();
        Object[] args = new Object[parameters.length];

        for (int i = 0; i < parameters.length; i++) {
            boolean resolved = false;
            for (WebSocketArgumentResolver resolver : argumentResolvers) {
                if (resolver.supports(parameters[i], context)) { // <- InvocationContext 전달
                    args[i] = resolver.resolve(parameters[i], context); // <- InvocationContext 전달
                    resolved = true;
                    break;
                }
            }
            if (!resolved) {
                throw new IllegalArgumentException("No WebSocketArgumentResolver found for parameter: " + parameters[i].getName() + " in method " + method.getName() + " for phase " + context.phase());
            }
        }
        return args;
    }

    @Override
    public String getRequestPath() {
        return handshakeRequest.getPath();
    }

    @Override
    public Map<String, List<String>> getRequestParameterMap() {
        return handshakeRequest.getQueryParams().entrySet().stream()
                .collect(java.util.stream.Collectors.toMap(
                        Map.Entry::getKey,
                        e -> Collections.singletonList(e.getValue()) // String을 List<String>으로 변환
                ));
    }

    @Override
    public String getQueryString() {
        return handshakeRequest.getQueryParams().entrySet().stream()
                .map(entry -> entry.getKey() + "=" + entry.getValue())
                .collect(java.util.stream.Collectors.joining("&"));
    }

    @Override
    public Map<String, String> getPathParameters() {
        return this.pathParameters;
    }

    public boolean isClosePending() {
        return isClosePending;
    }
}