feat(dm): WebSocket 클라이언트를 추가한다

This commit is contained in:
2026-06-18 17:41:41 +09:00
parent e76562067f
commit c5bcaf7329
2 changed files with 398 additions and 0 deletions

View File

@@ -0,0 +1,122 @@
package kr.co.vividnext.sodalive.v2.main.chat.dm.data
import com.google.gson.Gson
import com.google.gson.JsonObject
import okhttp3.OkHttpClient
import okhttp3.Request
import okhttp3.Response
import okhttp3.WebSocket
import okhttp3.WebSocketListener
class DmChatSocketClient(
private val okHttpClient: OkHttpClient,
private val gson: Gson,
private val baseUrl: String,
private val webSocketFactory: (Request, WebSocketListener) -> WebSocket = okHttpClient::newWebSocket
) {
interface Listener {
fun onEvent(event: DmChatSocketEvent)
fun onFailure(throwable: Throwable)
}
private val parser = DmChatSocketParser(gson)
private var webSocket: WebSocket? = null
@Volatile
private var listener: Listener? = null
@Volatile
private var activeSocket: WebSocket? = null
@Synchronized
fun connect(token: String, listener: Listener) {
close()
this.listener = listener
val socketUrl = socketUrl()
val request = Request.Builder()
.url(socketUrl)
.tag(String::class.java, socketUrl)
.header(HEADER_AUTHORIZATION, bearer(token))
.build()
val socketListener = object : WebSocketListener() {
override fun onMessage(webSocket: WebSocket, text: String) {
if (webSocket != activeSocket) return
parser.parse(text)?.let { event -> this@DmChatSocketClient.listener?.onEvent(event) }
}
override fun onFailure(webSocket: WebSocket, t: Throwable, response: Response?) {
if (webSocket != activeSocket) return
this@DmChatSocketClient.listener?.onFailure(t)
}
}
webSocket = webSocketFactory(request, socketListener).also { activeSocket = it }
}
fun sendJoinRoom(roomId: Long): Boolean = send(
type = DmChatSocketClientType.JOIN_ROOM,
payload = DmChatSocketRoomPayload(roomId = roomId)
)
fun sendLeaveRoom(roomId: Long): Boolean = send(
type = DmChatSocketClientType.LEAVE_ROOM,
payload = DmChatSocketRoomPayload(roomId = roomId)
)
fun sendText(
roomId: Long,
requestId: String,
textMessage: String
): Boolean = send(
type = DmChatSocketClientType.SEND_TEXT,
payload = DmChatSocketSendTextPayload(
roomId = roomId,
requestId = requestId,
textMessage = textMessage
)
)
fun sendPing(): Boolean = send(
type = DmChatSocketClientType.PING,
payload = JsonObject()
)
@Synchronized
fun close() {
val socket = webSocket ?: return
webSocket = null
activeSocket = null
listener = null
socket.close(NORMAL_CLOSE_CODE, null)
}
private fun send(
type: DmChatSocketClientType,
payload: Any
): Boolean {
val socket = webSocket ?: return false
return socket.send(gson.toJson(DmChatSocketOutboundEnvelope(type = type.value, payload = payload)))
}
private fun socketUrl(): String = baseUrl
.trimEnd('/')
.replacePrefix(oldValue = "https://", newValue = "wss://")
.replacePrefix(oldValue = "http://", newValue = "ws://") + SOCKET_PATH
private fun bearer(token: String) = "Bearer $token"
private data class DmChatSocketOutboundEnvelope(
val type: String,
val payload: Any
)
private companion object {
const val HEADER_AUTHORIZATION = "Authorization"
const val NORMAL_CLOSE_CODE = 1000
const val SOCKET_PATH = "/ws/v2/user-creator-chat"
}
}
private fun String.replacePrefix(oldValue: String, newValue: String): String =
if (startsWith(oldValue)) newValue + removePrefix(oldValue) else this

View File

@@ -0,0 +1,276 @@
package kr.co.vividnext.sodalive.v2.main.chat.dm
import com.google.gson.Gson
import com.google.gson.JsonObject
import com.google.gson.JsonParser
import kr.co.vividnext.sodalive.v2.main.chat.dm.data.DmChatSocketClient
import kr.co.vividnext.sodalive.v2.main.chat.dm.data.DmChatSocketEvent
import okhttp3.OkHttpClient
import okhttp3.Request
import okhttp3.WebSocket
import okhttp3.WebSocketListener
import okio.ByteString
import org.junit.Assert.assertEquals
import org.junit.Assert.assertFalse
import org.junit.Assert.assertNull
import org.junit.Assert.assertTrue
import org.junit.Test
import java.io.IOException
class DmChatSocketClientTest {
private val gson = Gson()
@Test
fun `https baseUrl은 wss endpoint로 변환하고 Authorization header를 추가한다`() {
val factory = FakeWebSocketFactory()
val client = client(factory = factory, baseUrl = "https://api.example.com")
client.connect(token = "test-token", listener = TestListener())
assertEquals("wss://api.example.com/ws/v2/user-creator-chat", factory.request?.tag(String::class.java))
assertEquals("Bearer test-token", factory.request?.header("Authorization"))
assertEquals(emptyList<String>(), factory.webSocket.sentTexts)
}
@Test
fun `http baseUrl은 ws endpoint로 변환하고 trailing slash를 제거한다`() {
val factory = FakeWebSocketFactory()
val client = client(factory = factory, baseUrl = "http://10.0.2.2:8080/")
client.connect(token = "test-token", listener = TestListener())
assertEquals("ws://10.0.2.2:8080/ws/v2/user-creator-chat", factory.request?.tag(String::class.java))
}
@Test
fun `sendJoinRoom은 JOIN_ROOM envelope를 전송한다`() {
val factory = FakeWebSocketFactory()
val client = connectedClient(factory)
assertTrue(client.sendJoinRoom(roomId = 10L))
val json = factory.webSocket.singleSentJson()
assertEquals("JOIN_ROOM", json.type())
assertEquals(10L, json.payload().get("roomId").asLong)
}
@Test
fun `sendLeaveRoom은 LEAVE_ROOM envelope를 전송한다`() {
val factory = FakeWebSocketFactory()
val client = connectedClient(factory)
assertTrue(client.sendLeaveRoom(roomId = 10L))
val json = factory.webSocket.singleSentJson()
assertEquals("LEAVE_ROOM", json.type())
assertEquals(10L, json.payload().get("roomId").asLong)
}
@Test
fun `sendText는 SEND_TEXT envelope를 전송한다`() {
val factory = FakeWebSocketFactory()
val client = connectedClient(factory)
assertTrue(client.sendText(roomId = 10L, requestId = "request-1", textMessage = "안녕하세요"))
val json = factory.webSocket.singleSentJson()
assertEquals("SEND_TEXT", json.type())
assertEquals(10L, json.payload().get("roomId").asLong)
assertEquals("request-1", json.payload().get("requestId").asString)
assertEquals("안녕하세요", json.payload().get("textMessage").asString)
}
@Test
fun `sendPing은 PING envelope를 전송한다`() {
val factory = FakeWebSocketFactory()
val client = connectedClient(factory)
assertTrue(client.sendPing())
val json = factory.webSocket.singleSentJson()
assertEquals("PING", json.type())
assertEquals(0, json.payload().size())
}
@Test
fun `onMessage는 parser event를 listener로 전달한다`() {
val factory = FakeWebSocketFactory()
var receivedEvent: DmChatSocketEvent? = null
val client = client(factory = factory)
client.connect(
token = "test-token",
listener = object : TestListener() {
override fun onEvent(event: DmChatSocketEvent) {
receivedEvent = event
}
}
)
factory.listener?.onMessage(factory.webSocket, messageEnvelope())
val messageEvent = receivedEvent as? DmChatSocketEvent.Message
requireNotNull(messageEvent)
assertEquals(10L, messageEvent.message.messageId)
assertEquals("안녕하세요", messageEvent.message.textMessage)
}
@Test
fun `알 수 없는 type과 잘못된 JSON은 listener event로 전달하지 않는다`() {
val factory = FakeWebSocketFactory()
var eventCount = 0
val client = client(factory = factory)
client.connect(
token = "test-token",
listener = object : TestListener() {
override fun onEvent(event: DmChatSocketEvent) {
eventCount += 1
}
}
)
factory.listener?.onMessage(
factory.webSocket,
"""
{
"type": "UNKNOWN",
"payload": {}
}
""".trimIndent()
)
factory.listener?.onMessage(factory.webSocket, "{not-json}")
assertEquals(0, eventCount)
}
@Test
fun `onFailure는 현재 listener로 전달된다`() {
val factory = FakeWebSocketFactory()
var failure: Throwable? = null
val client = client(factory = factory)
client.connect(
token = "test-token",
listener = object : TestListener() {
override fun onFailure(throwable: Throwable) {
failure = throwable
}
}
)
factory.listener?.onFailure(factory.webSocket, IOException("socket failed"), null)
assertEquals("socket failed", failure?.message)
}
@Test
fun `close는 socket을 정상 종료하고 listener를 해제한다`() {
val factory = FakeWebSocketFactory()
var eventCount = 0
val client = client(factory = factory)
client.connect(
token = "test-token",
listener = object : TestListener() {
override fun onEvent(event: DmChatSocketEvent) {
eventCount += 1
}
}
)
val oldListener = factory.listener
client.close()
client.close()
oldListener?.onMessage(factory.webSocket, messageEnvelope())
assertEquals(1, factory.webSocket.closeCount)
assertEquals(1000, factory.webSocket.closeCode)
assertNull(factory.webSocket.closeReason)
assertEquals(0, eventCount)
assertFalse(client.sendPing())
}
private fun connectedClient(factory: FakeWebSocketFactory): DmChatSocketClient =
client(factory = factory).also { it.connect(token = "test-token", listener = TestListener()) }
private fun client(
factory: FakeWebSocketFactory,
baseUrl: String = "https://api.example.com"
): DmChatSocketClient = DmChatSocketClient(
okHttpClient = OkHttpClient(),
gson = gson,
baseUrl = baseUrl,
webSocketFactory = factory::newWebSocket
)
private open class TestListener : DmChatSocketClient.Listener {
override fun onEvent(event: DmChatSocketEvent) = Unit
override fun onFailure(throwable: Throwable) = Unit
}
private class FakeWebSocketFactory {
val webSocket = FakeWebSocket()
var request: Request? = null
var listener: WebSocketListener? = null
fun newWebSocket(request: Request, listener: WebSocketListener): WebSocket {
this.request = request
this.listener = listener
return webSocket
}
}
private class FakeWebSocket : WebSocket {
val sentTexts = mutableListOf<String>()
var closeCount = 0
var closeCode: Int? = null
var closeReason: String? = null
override fun request(): Request = Request.Builder().url("wss://example.com").build()
override fun queueSize(): Long = 0L
override fun send(text: String): Boolean {
sentTexts += text
return true
}
override fun send(bytes: ByteString): Boolean = true
override fun close(code: Int, reason: String?): Boolean {
closeCount += 1
closeCode = code
closeReason = reason
return true
}
override fun cancel() = Unit
fun singleSentJson(): JsonObject = JsonParser.parseString(sentTexts.single()).asJsonObject
}
private fun JsonObject.type(): String = get("type").asString
private fun JsonObject.payload(): JsonObject = getAsJsonObject("payload")
private fun messageEnvelope(): String =
"""
{
"type": "MESSAGE",
"payload": { "message": ${messageJson()} }
}
""".trimIndent()
private fun messageJson(): String =
"""
{
"messageId": 10,
"messageType": "TEXT",
"mine": false,
"createdAt": 1000,
"textMessage": "안녕하세요",
"voiceMessageUrl": null,
"senderId": 20,
"senderNickname": "크리에이터",
"senderProfileImageUrl": "https://example.com/profile.png"
}
""".trimIndent().replace("\n", "")
}