fix(dm): WebSocket heartbeat와 token 재연결을 보정한다

This commit is contained in:
2026-06-18 23:31:15 +09:00
parent a6485292e4
commit 8f69c1ab82
2 changed files with 190 additions and 5 deletions

View File

@@ -48,6 +48,8 @@ class DmChatRoomViewModel(
private var currentRealtimeToken: String = "" private var currentRealtimeToken: String = ""
private var currentRealtimeRoomId: Long = 0L private var currentRealtimeRoomId: Long = 0L
private var reconnectDisposable: Disposable? = null private var reconnectDisposable: Disposable? = null
private var heartbeatPingDisposable: Disposable? = null
private var heartbeatTimeoutDisposable: Disposable? = null
private var localMessageSequence: Long = 0L private var localMessageSequence: Long = 0L
private var requestSequence: Long = 0L private var requestSequence: Long = 0L
private val pendingRequestLocalIds = mutableMapOf<String, String>() private val pendingRequestLocalIds = mutableMapOf<String, String>()
@@ -188,7 +190,16 @@ class DmChatRoomViewModel(
private fun connectRealtime(token: String) { private fun connectRealtime(token: String) {
val roomId = currentRoomId val roomId = currentRoomId
if (roomId <= 0L || isRealtimeConnected && currentRealtimeRoomId == roomId) return if (roomId <= 0L) return
if (currentRealtimeToken.isNotEmpty() && currentRealtimeToken != token && currentRealtimeRoomId == roomId) {
stopHeartbeat()
repository.closeSocket()
isRealtimeJoining = false
isRealtimeConnected = false
currentRealtimeToken = ""
currentRealtimeRoomId = 0L
}
if (isRealtimeConnected && currentRealtimeRoomId == roomId) return
if (isRealtimeJoining && currentRealtimeRoomId == roomId) return if (isRealtimeJoining && currentRealtimeRoomId == roomId) return
if (!shouldReconnectRealtime && currentRealtimeToken.isNotEmpty() && currentRealtimeRoomId == roomId) return if (!shouldReconnectRealtime && currentRealtimeToken.isNotEmpty() && currentRealtimeRoomId == roomId) return
@@ -222,8 +233,12 @@ class DmChatRoomViewModel(
fun leaveRealtime() { fun leaveRealtime() {
val roomId = currentRoomId val roomId = currentRoomId
if (roomId <= 0L) return if (roomId <= 0L) return
val hasActiveSocket = currentRealtimeRoomId == roomId &&
(isRealtimeJoining || isRealtimeConnected || currentRealtimeToken.isNotEmpty())
if (!hasActiveSocket) return
shouldReconnectRealtime = false shouldReconnectRealtime = false
stopHeartbeat()
currentRealtimeToken = "" currentRealtimeToken = ""
currentRealtimeRoomId = 0L currentRealtimeRoomId = 0L
isRealtimeJoining = false isRealtimeJoining = false
@@ -243,7 +258,7 @@ class DmChatRoomViewModel(
reconnectDisposable = reconnectScheduler.scheduleDirect( reconnectDisposable = reconnectScheduler.scheduleDirect(
{ {
scheduleRealtimeCallback { scheduleRealtimeCallback {
if (shouldReconnectRealtime) connectRealtime(token = token) if (shouldReconnectRealtime) connectRealtime(token = authToken().ifBlank { token })
} }
}, },
RECONNECT_DELAY_MILLIS, RECONNECT_DELAY_MILLIS,
@@ -261,6 +276,7 @@ class DmChatRoomViewModel(
override fun onCleared() { override fun onCleared() {
mainHandler.removeCallbacksAndMessages(null) mainHandler.removeCallbacksAndMessages(null)
stopHeartbeat()
reconnectDisposable?.dispose() reconnectDisposable?.dispose()
reconnectDisposable = null reconnectDisposable = null
currentRealtimeRoomId = 0L currentRealtimeRoomId = 0L
@@ -362,15 +378,65 @@ class DmChatRoomViewModel(
DmChatSocketEvent.Joined -> { DmChatSocketEvent.Joined -> {
isRealtimeJoining = false isRealtimeJoining = false
isRealtimeConnected = true isRealtimeConnected = true
startHeartbeat()
syncLatestMessagesAfterReconnect(token = token) syncLatestMessagesAfterReconnect(token = token)
} }
is DmChatSocketEvent.Message -> handleRealtimeMessage(event.requestId, event.message) is DmChatSocketEvent.Message -> handleRealtimeMessage(event.requestId, event.message)
is DmChatSocketEvent.SendAck -> handleSendAck(event.requestId, event.message) is DmChatSocketEvent.SendAck -> handleSendAck(event.requestId, event.message)
is DmChatSocketEvent.Error -> event.requestId?.let { markPendingMessageFailed(it) } is DmChatSocketEvent.Error -> event.requestId?.let { markPendingMessageFailed(it) }
DmChatSocketEvent.Pong -> Unit DmChatSocketEvent.Pong -> clearHeartbeatTimeout()
} }
} }
private fun startHeartbeat() {
stopHeartbeat()
heartbeatPingDisposable = reconnectScheduler.schedulePeriodicallyDirect(
{
scheduleRealtimeCallback {
if (!isRealtimeConnected || !shouldReconnectRealtime) return@scheduleRealtimeCallback
val latestToken = authToken()
if (latestToken.isNotBlank() && latestToken != currentRealtimeToken) {
connectRealtime(token = latestToken)
return@scheduleRealtimeCallback
}
if (repository.sendPing()) scheduleHeartbeatTimeout()
}
},
HEARTBEAT_INTERVAL_MILLIS,
HEARTBEAT_INTERVAL_MILLIS,
TimeUnit.MILLISECONDS
).also { compositeDisposable.add(it) }
}
private fun scheduleHeartbeatTimeout() {
heartbeatTimeoutDisposable?.dispose()
heartbeatTimeoutDisposable = reconnectScheduler.scheduleDirect(
{
scheduleRealtimeCallback {
if (!isRealtimeConnected || !shouldReconnectRealtime) return@scheduleRealtimeCallback
isRealtimeJoining = false
isRealtimeConnected = false
stopHeartbeat()
repository.closeSocket()
scheduleRealtimeReconnect()
}
},
HEARTBEAT_TIMEOUT_MILLIS,
TimeUnit.MILLISECONDS
).also { compositeDisposable.add(it) }
}
private fun stopHeartbeat() {
heartbeatPingDisposable?.dispose()
heartbeatPingDisposable = null
clearHeartbeatTimeout()
}
private fun clearHeartbeatTimeout() {
heartbeatTimeoutDisposable?.dispose()
heartbeatTimeoutDisposable = null
}
private fun handleSendAck(requestId: String, message: DmChatMessageResponse) { private fun handleSendAck(requestId: String, message: DmChatMessageResponse) {
val localId = pendingRequestLocalIds.remove(requestId) val localId = pendingRequestLocalIds.remove(requestId)
?: recentFailedRequestLocalIds.remove(requestId) ?: recentFailedRequestLocalIds.remove(requestId)
@@ -475,6 +541,8 @@ class DmChatRoomViewModel(
private companion object { private companion object {
const val RECONNECT_DELAY_MILLIS = 3_000L const val RECONNECT_DELAY_MILLIS = 3_000L
const val SEND_ACK_TIMEOUT_MILLIS = 10_000L const val SEND_ACK_TIMEOUT_MILLIS = 10_000L
const val HEARTBEAT_INTERVAL_MILLIS = 30_000L
const val HEARTBEAT_TIMEOUT_MILLIS = 10_000L
} }
} }

View File

@@ -55,6 +55,7 @@ class DmChatRoomViewModelTest {
private lateinit var socketClient: DmChatSocketClient private lateinit var socketClient: DmChatSocketClient
private lateinit var reconnectScheduler: TestScheduler private lateinit var reconnectScheduler: TestScheduler
private lateinit var viewModel: DmChatRoomViewModel private lateinit var viewModel: DmChatRoomViewModel
private var token: String = "test-token"
@Before @Before
fun setUp() { fun setUp() {
@@ -62,6 +63,7 @@ class DmChatRoomViewModelTest {
SharedPreferenceManager.resetForTest() SharedPreferenceManager.resetForTest()
SharedPreferenceManager.init(context) SharedPreferenceManager.init(context)
SharedPreferenceManager.token = "test-token" SharedPreferenceManager.token = "test-token"
token = "test-token"
api = FakeDmChatApi() api = FakeDmChatApi()
socketFactory = FakeWebSocketFactory() socketFactory = FakeWebSocketFactory()
socketClient = DmChatSocketClient( socketClient = DmChatSocketClient(
@@ -74,7 +76,7 @@ class DmChatRoomViewModelTest {
viewModel = DmChatRoomViewModel( viewModel = DmChatRoomViewModel(
repository = DmChatRepository(api, socketClient), repository = DmChatRepository(api, socketClient),
reconnectScheduler = reconnectScheduler, reconnectScheduler = reconnectScheduler,
tokenProvider = { "test-token" } tokenProvider = { token }
) )
} }
@@ -651,7 +653,11 @@ class DmChatRoomViewModelTest {
).readText() ).readText()
val compactSource = source.filterNot { it.isWhitespace() } val compactSource = source.filterNot { it.isWhitespace() }
assertTrue(compactSource.contains("scheduleRealtimeCallback{if(shouldReconnectRealtime)connectRealtime(token=token)}")) assertTrue(
compactSource.contains(
"scheduleRealtimeCallback{if(shouldReconnectRealtime)connectRealtime(token=authToken().ifBlank{token})}"
)
)
assertTrue(!compactSource.contains("scheduleDirect({connectRealtime(token=token)}")) assertTrue(!compactSource.contains("scheduleDirect({connectRealtime(token=token)}"))
} }
@@ -714,6 +720,109 @@ class DmChatRoomViewModelTest {
assertEquals(2, socketFactory.connectCalls.size) assertEquals(2, socketFactory.connectCalls.size)
} }
@Test
fun `leave는 LEAVE_ROOM 전송 후 socket을 close하고 중복 호출은 무시한다`() {
api.enqueueOpenSuccess(openResponse(roomId = 10L))
viewModel.enter(roomId = 10L, creatorId = 0L)
viewModel.connectRealtime()
viewModel.leaveRealtime()
viewModel.leaveRealtime()
assertEquals(listOf("JOIN_ROOM", "LEAVE_ROOM"), socketFactory.webSocket.sentTexts.map { it.type() })
assertEquals(1, socketFactory.closeCount)
}
@Test
fun `JOINED 이후 heartbeat는 PING을 보내고 PONG 수신 시 연결을 유지한다`() {
api.enqueueOpenSuccess(openResponse(roomId = 10L))
api.enqueueMessagesSuccess(messagesPage(messages = emptyList()))
viewModel.enter(roomId = 10L, creatorId = 0L)
viewModel.connectRealtime()
socketFactory.emitJoined()
reconnectScheduler.advanceTimeBy(30L, TimeUnit.SECONDS)
reconnectScheduler.advanceTimeBy(5L, TimeUnit.SECONDS)
socketFactory.emitPong()
reconnectScheduler.advanceTimeBy(24L, TimeUnit.SECONDS)
assertEquals(listOf("JOIN_ROOM", "PING"), socketFactory.webSocket.sentTexts.map { it.type() })
assertEquals(true, viewModel.isRealtimeConnectedForTest())
assertEquals(0, socketFactory.closeCount)
}
@Test
fun `heartbeat PONG timeout은 socket close 후 foreground 조건에서 reconnect를 예약한다`() {
api.enqueueOpenSuccess(openResponse(roomId = 10L))
api.enqueueMessagesSuccess(messagesPage(messages = emptyList()))
viewModel.enter(roomId = 10L, creatorId = 0L)
viewModel.connectRealtime()
socketFactory.emitJoined()
reconnectScheduler.advanceTimeBy(30L, TimeUnit.SECONDS)
reconnectScheduler.advanceTimeBy(10L, TimeUnit.SECONDS)
reconnectScheduler.advanceTimeBy(2999L, TimeUnit.MILLISECONDS)
assertEquals(false, viewModel.isRealtimeConnectedForTest())
assertEquals(1, socketFactory.closeCount)
assertEquals(1, socketFactory.connectCalls.size)
reconnectScheduler.advanceTimeBy(1L, TimeUnit.MILLISECONDS)
assertEquals(2, socketFactory.connectCalls.size)
assertEquals("JOIN_ROOM", socketFactory.webSocket.sentTexts.lastJson().get("type").asString)
}
@Test
fun `leave는 heartbeat timeout과 reconnect 예약을 취소한다`() {
api.enqueueOpenSuccess(openResponse(roomId = 10L))
api.enqueueMessagesSuccess(messagesPage(messages = emptyList()))
viewModel.enter(roomId = 10L, creatorId = 0L)
viewModel.connectRealtime()
socketFactory.emitJoined()
reconnectScheduler.advanceTimeBy(30L, TimeUnit.SECONDS)
viewModel.leaveRealtime()
reconnectScheduler.advanceTimeBy(13L, TimeUnit.SECONDS)
assertEquals(listOf(RealtimeConnectCall("test-token", 10L)), socketFactory.connectCalls)
assertEquals(listOf("JOIN_ROOM", "PING", "LEAVE_ROOM"), socketFactory.webSocket.sentTexts.map { it.type() })
assertEquals(1, socketFactory.closeCount)
}
@Test
fun `token이 변경되면 기존 socket을 close하고 새 token으로 다시 JOIN_ROOM을 보낸다`() {
api.enqueueOpenSuccess(openResponse(roomId = 10L))
api.enqueueMessagesSuccess(messagesPage(messages = emptyList()))
viewModel.enter(roomId = 10L, creatorId = 0L)
viewModel.connectRealtime()
socketFactory.emitJoined()
token = "new-token"
viewModel.connectRealtime()
assertEquals(
listOf(RealtimeConnectCall("test-token", 10L), RealtimeConnectCall("new-token", 10L)),
socketFactory.connectCalls
)
assertEquals(1, socketFactory.closeCount)
assertEquals("JOIN_ROOM", socketFactory.webSocket.sentTexts.lastJson().get("type").asString)
}
@Test
fun `leave 이후 token이 변경되어도 socket reconnect를 진행하지 않는다`() {
api.enqueueOpenSuccess(openResponse(roomId = 10L))
viewModel.enter(roomId = 10L, creatorId = 0L)
viewModel.connectRealtime()
viewModel.leaveRealtime()
token = "new-token"
reconnectScheduler.advanceTimeBy(30L, TimeUnit.SECONDS)
assertEquals(listOf(RealtimeConnectCall("test-token", 10L)), socketFactory.connectCalls)
assertEquals(1, socketFactory.closeCount)
}
@Test @Test
fun `realtime leave는 채팅 상태를 Error로 바꾸지 않는다`() { fun `realtime leave는 채팅 상태를 Error로 바꾸지 않는다`() {
api.enqueueOpenSuccess(openResponse(roomId = 10L, messages = listOf(message(messageId = 1L, textMessage = "기존")))) api.enqueueOpenSuccess(openResponse(roomId = 10L, messages = listOf(message(messageId = 1L, textMessage = "기존"))))
@@ -972,6 +1081,10 @@ class FakeWebSocketFactory {
) )
} }
fun emitPong() {
webSocketListener?.onMessage(webSocket, "{\"type\":\"PONG\",\"payload\":{}}")
}
fun emitFailure(throwable: Throwable) { fun emitFailure(throwable: Throwable) {
webSocketListener?.onFailure(webSocket, throwable, null) webSocketListener?.onFailure(webSocket, throwable, null)
} }
@@ -997,3 +1110,7 @@ class FakeWebSocket : WebSocket {
fun sentJsonAt(index: Int) = JsonParser.parseString(sentTexts[index]).asJsonObject fun sentJsonAt(index: Int) = JsonParser.parseString(sentTexts[index]).asJsonObject
} }
private fun String.type(): String = JsonParser.parseString(this).asJsonObject.get("type").asString
private fun List<String>.lastJson() = JsonParser.parseString(last()).asJsonObject