From b6d6986b62a1fd124dce8f29f4f1ba12b1ccb1f4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Daniel=20Calvi=C3=B1o=20S=C3=A1nchez?= Date: Sun, 8 Dec 2024 05:23:11 +0100 Subject: [PATCH] Fix "removePeerConnection" not being thread-safe MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Adding and disposing remote data channels is done from different threads; they are added from the WebRTC signaling thread when "onDataChannel" is called, while they can be disposed potentially from any thread when "removePeerConnection" is called. To prevent race conditions between them now both operations are synchronized. However, as "onDataChannel" belongs to an inner class it needs to use a synchronized statement with the outer class lock. This could still cause a race condition if the same data channel was added again; this should not happen, but it is handled just in case. Moreover, once a data channel is disposed it can be no longer used, and trying to call any of its methods throws an "IllegalStateException". Due to this, as sending can be also done potentially from any thread, it needs to be synchronized too with removing the peer connection. State changes on data channels as well as receiving messages are also done in the WebRTC signaling thread. State changes needs synchronization as well, although receiving messages should not, as it does not directly use the data channel (and it is assumed that using the buffers of a disposed data channel is safe). Nevertheless a little check (which in this case requires synchronization) was added to ignore the received messages if the peer connection was removed already. Finally, the synchronization added to "send" and "onStateChange" had the nice side effect of making the pending data channel messages thread-safe too, as before it could happen that a message was enqueued when the pending messages were being sent, which caused a "ConcurrentModificationException". Signed-off-by: Daniel Calviño Sánchez --- .../talk/webrtc/PeerConnectionWrapper.java | 88 ++-- .../talk/webrtc/PeerConnectionWrapperTest.kt | 376 ++++++++++++++++++ 2 files changed, 440 insertions(+), 24 deletions(-) diff --git a/app/src/main/java/com/nextcloud/talk/webrtc/PeerConnectionWrapper.java b/app/src/main/java/com/nextcloud/talk/webrtc/PeerConnectionWrapper.java index 08da1f726..13bb5f2f7 100644 --- a/app/src/main/java/com/nextcloud/talk/webrtc/PeerConnectionWrapper.java +++ b/app/src/main/java/com/nextcloud/talk/webrtc/PeerConnectionWrapper.java @@ -235,7 +235,7 @@ public class PeerConnectionWrapper { return stream; } - public void removePeerConnection() { + public synchronized void removePeerConnection() { signalingMessageReceiver.removeListener(webRtcMessageListener); for (DataChannel dataChannel: dataChannels.values()) { @@ -292,7 +292,7 @@ public class PeerConnectionWrapper { * * @param dataChannelMessage the message to send */ - public void send(DataChannelMessage dataChannelMessage) { + public synchronized void send(DataChannelMessage dataChannelMessage) { if (dataChannelMessage == null) { return; } @@ -414,20 +414,38 @@ public class PeerConnectionWrapper { @Override public void onStateChange() { - if (dataChannel.state() == DataChannel.State.OPEN && "status".equals(dataChannelLabel)) { - for (DataChannelMessage dataChannelMessage: pendingDataChannelMessages) { - send(dataChannelMessage); + synchronized (PeerConnectionWrapper.this) { + // The PeerConnection could have been removed in parallel even with the synchronization (as just after + // "onStateChange" was called "removePeerConnection" could have acquired the lock). + if (peerConnection == null) { + return; } - pendingDataChannelMessages.clear(); - } - if (dataChannel.state() == DataChannel.State.OPEN) { - sendInitialMediaStatus(); + if (dataChannel.state() == DataChannel.State.OPEN && "status".equals(dataChannelLabel)) { + for (DataChannelMessage dataChannelMessage : pendingDataChannelMessages) { + send(dataChannelMessage); + } + pendingDataChannelMessages.clear(); + } + + if (dataChannel.state() == DataChannel.State.OPEN) { + sendInitialMediaStatus(); + } } } @Override public void onMessage(DataChannel.Buffer buffer) { + synchronized (PeerConnectionWrapper.this) { + // It is assumed that, even if its data channel was disposed, its buffers can be used while there is + // a reference to them, so it would not be necessary to check this from a thread-safety point of view. + // Nevertheless, if the remote peer connection was removed it would not make sense to notify the + // listeners anyway. + if (peerConnection == null) { + return; + } + } + if (buffer.binary) { Log.d(TAG, "Received binary data channel message over " + dataChannelLabel + " " + sessionId); return; @@ -557,23 +575,45 @@ public class PeerConnectionWrapper { @Override public void onDataChannel(DataChannel dataChannel) { - // Another data channel with the same label, no matter if the same instance or a different one, should not - // be added, but just in case. - DataChannel oldDataChannel = dataChannels.get(dataChannel.label()); - if (oldDataChannel == dataChannel) { - Log.w(TAG, "Data channel with label " + dataChannel.label() + " added again"); + synchronized (PeerConnectionWrapper.this) { + // Another data channel with the same label, no matter if the same instance or a different one, should + // not be added, but this is handled just in case. + // Moreover, if it were possible that an already added data channel was added again there would be a + // potential race condition with "removePeerConnection", even with the synchronization, as it would + // be possible that "onDataChannel" was called, then "removePeerConnection" disposed the data + // channel, and then "onDataChannel" continued in the synchronized statements and tried to get the + // label, which would throw an exception due to the data channel having been disposed already. + String dataChannelLabel; + try { + dataChannelLabel = dataChannel.label(); + } catch (IllegalStateException e) { + // The data channel was disposed already, nothing to do. + return; + } - return; + DataChannel oldDataChannel = dataChannels.get(dataChannelLabel); + if (oldDataChannel == dataChannel) { + Log.w(TAG, "Data channel with label " + dataChannel.label() + " added again"); + + return; + } + + if (oldDataChannel != null) { + Log.w(TAG, "Data channel with label " + dataChannel.label() + " exists"); + + oldDataChannel.dispose(); + } + + // If the peer connection was removed in parallel dispose the data channel instead of adding it. + if (peerConnection == null) { + dataChannel.dispose(); + + return; + } + + dataChannel.registerObserver(new DataChannelObserver(dataChannel)); + dataChannels.put(dataChannel.label(), dataChannel); } - - if (oldDataChannel != null) { - Log.w(TAG, "Data channel with label " + dataChannel.label() + " exists"); - - oldDataChannel.dispose(); - } - - dataChannel.registerObserver(new DataChannelObserver(dataChannel)); - dataChannels.put(dataChannel.label(), dataChannel); } @Override diff --git a/app/src/test/java/com/nextcloud/talk/webrtc/PeerConnectionWrapperTest.kt b/app/src/test/java/com/nextcloud/talk/webrtc/PeerConnectionWrapperTest.kt index 61c5f5daa..f4f7e6443 100644 --- a/app/src/test/java/com/nextcloud/talk/webrtc/PeerConnectionWrapperTest.kt +++ b/app/src/test/java/com/nextcloud/talk/webrtc/PeerConnectionWrapperTest.kt @@ -19,15 +19,22 @@ import org.mockito.ArgumentMatchers.any import org.mockito.ArgumentMatchers.argThat import org.mockito.ArgumentMatchers.eq import org.mockito.Mockito +import org.mockito.Mockito.atLeast +import org.mockito.Mockito.atMostOnce +import org.mockito.Mockito.doAnswer import org.mockito.Mockito.doNothing import org.mockito.Mockito.never +import org.mockito.invocation.InvocationOnMock +import org.mockito.stubbing.Answer import org.webrtc.DataChannel import org.webrtc.MediaConstraints import org.webrtc.PeerConnection import org.webrtc.PeerConnectionFactory import java.nio.ByteBuffer import java.util.HashMap +import kotlin.concurrent.thread +@Suppress("LongMethod", "TooGenericExceptionCaught") class PeerConnectionWrapperTest { private var peerConnectionWrapper: PeerConnectionWrapper? = null @@ -36,6 +43,23 @@ class PeerConnectionWrapperTest { private var mockedSignalingMessageReceiver: SignalingMessageReceiver? = null private var mockedSignalingMessageSender: SignalingMessageSender? = null + /** + * Helper answer for DataChannel methods. + */ + private class ReturnValueOrThrowIfDisposed(val value: T) : + Answer { + override fun answer(currentInvocation: InvocationOnMock): T { + if (Mockito.mockingDetails(currentInvocation.mock).invocations.find { + it!!.method.name === "dispose" + } !== null + ) { + throw IllegalStateException("DataChannel has been disposed") + } + + return value + } + } + /** * Helper matcher for DataChannelMessages. */ @@ -195,6 +219,83 @@ class PeerConnectionWrapperTest { ) } + @Test + fun testSendDataChannelMessageBeforeOpeningDataChannelWithDifferentThreads() { + // A brute force approach is used to test race conditions between different threads just repeating the test + // several times. Due to this the test passing could be a false positive, as it could have been a matter of + // luck, but even if the test may wrongly pass sometimes it is better than nothing (although, in general, with + // that number of reruns, it fails when it should). + for (i in 1..1000) { + Mockito.`when`( + mockedPeerConnectionFactory!!.createPeerConnection( + any(PeerConnection.RTCConfiguration::class.java), + any(PeerConnection.Observer::class.java) + ) + ).thenReturn(mockedPeerConnection) + + val mockedStatusDataChannel = Mockito.mock(DataChannel::class.java) + Mockito.`when`(mockedStatusDataChannel.label()).thenReturn("status") + Mockito.`when`(mockedStatusDataChannel.state()).thenReturn(DataChannel.State.CONNECTING) + Mockito.`when`(mockedPeerConnection!!.createDataChannel(eq("status"), any())) + .thenReturn(mockedStatusDataChannel) + + val statusDataChannelObserverArgumentCaptor: ArgumentCaptor = + ArgumentCaptor.forClass(DataChannel.Observer::class.java) + + doNothing().`when`(mockedStatusDataChannel) + .registerObserver(statusDataChannelObserverArgumentCaptor.capture()) + + peerConnectionWrapper = PeerConnectionWrapper( + mockedPeerConnectionFactory, + ArrayList(), + MediaConstraints(), + "the-session-id", + "the-local-session-id", + null, + true, + true, + "video", + mockedSignalingMessageReceiver, + mockedSignalingMessageSender + ) + + val dataChannelMessageCount = 5 + + val sendThread = thread { + for (j in 1..dataChannelMessageCount) { + peerConnectionWrapper!!.send(DataChannelMessage("the-message-type-$j")) + } + } + + // Exceptions thrown in threads are not propagated to the main thread, so it needs to be explicitly done + // (for example, for ConcurrentModificationExceptions when iterating over the data channel messages). + var exceptionOnStateChange: Exception? = null + + val openDataChannelThread = thread { + Mockito.`when`(mockedStatusDataChannel.state()).thenReturn(DataChannel.State.OPEN) + + try { + statusDataChannelObserverArgumentCaptor.value.onStateChange() + } catch (e: Exception) { + exceptionOnStateChange = e + } + } + + sendThread.join() + openDataChannelThread.join() + + if (exceptionOnStateChange !== null) { + throw exceptionOnStateChange!! + } + + for (j in 1..dataChannelMessageCount) { + Mockito.verify(mockedStatusDataChannel).send( + argThat(MatchesDataChannelMessage(DataChannelMessage("the-message-type-$j"))) + ) + } + } + } + @Test fun testReceiveDataChannelMessage() { Mockito.`when`( @@ -381,4 +482,279 @@ class PeerConnectionWrapperTest { Mockito.verify(mockedStatusDataChannel).dispose() Mockito.verify(mockedRandomIdDataChannel).dispose() } + + @Test + fun testRemovePeerConnectionWhileAddingRemoteDataChannelsWithDifferentThreads() { + // A brute force approach is used to test race conditions between different threads just repeating the test + // several times. Due to this the test passing could be a false positive, as it could have been a matter of + // luck, but even if the test may wrongly pass sometimes it is better than nothing (although, in general, with + // that number of reruns, it fails when it should). + for (i in 1..1000) { + val peerConnectionObserverArgumentCaptor: ArgumentCaptor = + ArgumentCaptor.forClass(PeerConnection.Observer::class.java) + + Mockito.`when`( + mockedPeerConnectionFactory!!.createPeerConnection( + any(PeerConnection.RTCConfiguration::class.java), + peerConnectionObserverArgumentCaptor.capture() + ) + ).thenReturn(mockedPeerConnection) + + val mockedStatusDataChannel = Mockito.mock(DataChannel::class.java) + Mockito.`when`(mockedStatusDataChannel.label()).thenAnswer(ReturnValueOrThrowIfDisposed("status")) + Mockito.`when`(mockedStatusDataChannel.state()).thenAnswer( + ReturnValueOrThrowIfDisposed(DataChannel.State.OPEN) + ) + Mockito.`when`(mockedPeerConnection!!.createDataChannel(eq("status"), any())) + .thenReturn(mockedStatusDataChannel) + + peerConnectionWrapper = PeerConnectionWrapper( + mockedPeerConnectionFactory, + ArrayList(), + MediaConstraints(), + "the-session-id", + "the-local-session-id", + null, + true, + true, + "video", + mockedSignalingMessageReceiver, + mockedSignalingMessageSender + ) + + val dataChannelCount = 5 + + val mockedRandomIdDataChannels: MutableList = ArrayList() + val dataChannelObservers: MutableList = ArrayList() + for (j in 0.. + if (Mockito.mockingDetails(invocation.mock).invocations.find { + it!!.method.name === "dispose" + } !== null + ) { + throw IllegalStateException("DataChannel has been disposed") + } + + dataChannelObservers[j] = invocation.getArgument(0, DataChannel.Observer::class.java) + + null + }.`when`(mockedRandomIdDataChannels[j]).registerObserver(any()) + } + + val onDataChannelThread = thread { + // Add again "status" data channel to test that it is correctly disposed also in that case (which + // should not happen anyway even if it was added by the remote peer, but just in case) + peerConnectionObserverArgumentCaptor.value.onDataChannel(mockedStatusDataChannel) + + for (j in 0.. = + ArgumentCaptor.forClass(PeerConnection.Observer::class.java) + + Mockito.`when`( + mockedPeerConnectionFactory!!.createPeerConnection( + any(PeerConnection.RTCConfiguration::class.java), + peerConnectionObserverArgumentCaptor.capture() + ) + ).thenReturn(mockedPeerConnection) + + val mockedStatusDataChannel = Mockito.mock(DataChannel::class.java) + + Mockito.`when`(mockedStatusDataChannel.label()).thenAnswer(ReturnValueOrThrowIfDisposed("status")) + Mockito.`when`(mockedStatusDataChannel.state()) + .thenAnswer(ReturnValueOrThrowIfDisposed(DataChannel.State.OPEN)) + Mockito.`when`(mockedStatusDataChannel.send(any())).thenAnswer(ReturnValueOrThrowIfDisposed(true)) + Mockito.`when`(mockedPeerConnection!!.createDataChannel(eq("status"), any())) + .thenReturn(mockedStatusDataChannel) + + peerConnectionWrapper = PeerConnectionWrapper( + mockedPeerConnectionFactory, + ArrayList(), + MediaConstraints(), + "the-session-id", + "the-local-session-id", + null, + true, + true, + "video", + mockedSignalingMessageReceiver, + mockedSignalingMessageSender + ) + + val dataChannelMessageCount = 5 + + // Exceptions thrown in threads are not propagated to the main thread, so it needs to be explicitly done + // (for example, for IllegalStateExceptions when using a disposed data channel). + var exceptionSend: Exception? = null + + val sendThread = thread { + try { + for (j in 0.. = + ArgumentCaptor.forClass(PeerConnection.Observer::class.java) + + Mockito.`when`( + mockedPeerConnectionFactory!!.createPeerConnection( + any(PeerConnection.RTCConfiguration::class.java), + peerConnectionObserverArgumentCaptor.capture() + ) + ).thenReturn(mockedPeerConnection) + + val mockedStatusDataChannel = Mockito.mock(DataChannel::class.java) + Mockito.`when`(mockedStatusDataChannel.label()).thenAnswer(ReturnValueOrThrowIfDisposed("status")) + Mockito.`when`(mockedStatusDataChannel.state()).thenAnswer( + ReturnValueOrThrowIfDisposed(DataChannel.State.OPEN) + ) + Mockito.`when`(mockedPeerConnection!!.createDataChannel(eq("status"), any())) + .thenReturn(mockedStatusDataChannel) + + val statusDataChannelObserverArgumentCaptor: ArgumentCaptor = + ArgumentCaptor.forClass(DataChannel.Observer::class.java) + + doNothing().`when`(mockedStatusDataChannel) + .registerObserver(statusDataChannelObserverArgumentCaptor.capture()) + + peerConnectionWrapper = PeerConnectionWrapper( + mockedPeerConnectionFactory, + ArrayList(), + MediaConstraints(), + "the-session-id", + "the-local-session-id", + null, + true, + true, + "video", + mockedSignalingMessageReceiver, + mockedSignalingMessageSender + ) + + val mockedDataChannelMessageListener = Mockito.mock(DataChannelMessageListener::class.java) + peerConnectionWrapper!!.addListener(mockedDataChannelMessageListener) + + // Exceptions thrown in threads are not propagated to the main thread, so it needs to be explicitly done + // (for example, for IllegalStateExceptions when using a disposed data channel). + var exceptionOnMessage: Exception? = null + + val onMessageThread = thread { + try { + // It is assumed that, even if its data channel was disposed, its buffers can be used while there + // is a reference to them, so no special mock behaviour is added to throw an exception in that case. + statusDataChannelObserverArgumentCaptor.value.onMessage( + dataChannelMessageToBuffer(DataChannelMessage("audioOn")) + ) + + statusDataChannelObserverArgumentCaptor.value.onMessage( + dataChannelMessageToBuffer(DataChannelMessage("audioOff")) + ) + } catch (e: Exception) { + exceptionOnMessage = e + } + } + + val removePeerConnectionThread = thread { + peerConnectionWrapper!!.removePeerConnection() + } + + onMessageThread.join() + removePeerConnectionThread.join() + + if (exceptionOnMessage !== null) { + throw exceptionOnMessage!! + } + + Mockito.verify(mockedStatusDataChannel).registerObserver(any()) + Mockito.verify(mockedStatusDataChannel).dispose() + Mockito.verify(mockedStatusDataChannel, atLeast(0)).label() + Mockito.verify(mockedStatusDataChannel, atLeast(0)).state() + Mockito.verifyNoMoreInteractions(mockedStatusDataChannel) + Mockito.verify(mockedDataChannelMessageListener, atMostOnce()).onAudioOn() + Mockito.verify(mockedDataChannelMessageListener, atMostOnce()).onAudioOff() + Mockito.verifyNoMoreInteractions(mockedDataChannelMessageListener) + } + } }