MaskingInputStream.java
package sprout.server.websocket;
import java.io.FilterInputStream;
import java.io.IOException;
import java.io.InputStream;
public class MaskingInputStream extends FilterInputStream {
private final byte[] maskingKey;
private long bytesRead; // 마스킹 키 인덱스 계산용
public MaskingInputStream(InputStream in, byte[] maskingKey) {
super(in);
if (maskingKey == null || maskingKey.length != 4) {
throw new IllegalArgumentException("Masking key must be 4 bytes long.");
}
this.maskingKey = maskingKey;
this.bytesRead = 0;
}
@Override
public int read() throws IOException {
int r = super.read();
if (r != -1) {
// FIX : maskingKey 가 byte라서 int로 승격될 때 음수로 확장되고 ^ 연산 결과가 0255 범위를 벗어나 음수가 됨
// 255 또는 -1만 반환해야 하니까 결과를 & 0xFF로 정리
int k = maskingKey[(int) (bytesRead & 3)] & 0xFF; // 키도 0~255로
r = (r ^ k) & 0xFF; // 결과도 0~255로
bytesRead++;
}
return r;
}
@Override
public int read(byte[] b, int off, int len) throws IOException {
int n = super.read(b, off, len);
if (n != -1) {
for (int i = 0; i < n; i++) {
// FIX: 일관성 위해 키 마스킹 추가
int k = maskingKey[(int) ((bytesRead + i) & 3)] & 0xFF;
b[off + i] = (byte) ((b[off + i] ^ k) & 0xFF);
}
bytesRead += n;
}
return n;
}
}