diff --git a/app/src/main/java/com/nextcloud/talk/webrtc/PeerConnectionWrapper.java b/app/src/main/java/com/nextcloud/talk/webrtc/PeerConnectionWrapper.java index 08da1f726..13bb5f2f7 100644 --- a/app/src/main/java/com/nextcloud/talk/webrtc/PeerConnectionWrapper.java +++ b/app/src/main/java/com/nextcloud/talk/webrtc/PeerConnectionWrapper.java @@ -235,7 +235,7 @@ public class PeerConnectionWrapper { return stream; } - public void removePeerConnection() { + public synchronized void removePeerConnection() { signalingMessageReceiver.removeListener(webRtcMessageListener); for (DataChannel dataChannel: dataChannels.values()) { @@ -292,7 +292,7 @@ public class PeerConnectionWrapper { * * @param dataChannelMessage the message to send */ - public void send(DataChannelMessage dataChannelMessage) { + public synchronized void send(DataChannelMessage dataChannelMessage) { if (dataChannelMessage == null) { return; } @@ -414,20 +414,38 @@ public class PeerConnectionWrapper { @Override public void onStateChange() { - if (dataChannel.state() == DataChannel.State.OPEN && "status".equals(dataChannelLabel)) { - for (DataChannelMessage dataChannelMessage: pendingDataChannelMessages) { - send(dataChannelMessage); + synchronized (PeerConnectionWrapper.this) { + // The PeerConnection could have been removed in parallel even with the synchronization (as just after + // "onStateChange" was called "removePeerConnection" could have acquired the lock). + if (peerConnection == null) { + return; } - pendingDataChannelMessages.clear(); - } - if (dataChannel.state() == DataChannel.State.OPEN) { - sendInitialMediaStatus(); + if (dataChannel.state() == DataChannel.State.OPEN && "status".equals(dataChannelLabel)) { + for (DataChannelMessage dataChannelMessage : pendingDataChannelMessages) { + send(dataChannelMessage); + } + pendingDataChannelMessages.clear(); + } + + if (dataChannel.state() == DataChannel.State.OPEN) { + sendInitialMediaStatus(); + } } } @Override public void onMessage(DataChannel.Buffer buffer) { + synchronized (PeerConnectionWrapper.this) { + // It is assumed that, even if its data channel was disposed, its buffers can be used while there is + // a reference to them, so it would not be necessary to check this from a thread-safety point of view. + // Nevertheless, if the remote peer connection was removed it would not make sense to notify the + // listeners anyway. + if (peerConnection == null) { + return; + } + } + if (buffer.binary) { Log.d(TAG, "Received binary data channel message over " + dataChannelLabel + " " + sessionId); return; @@ -557,23 +575,45 @@ public class PeerConnectionWrapper { @Override public void onDataChannel(DataChannel dataChannel) { - // Another data channel with the same label, no matter if the same instance or a different one, should not - // be added, but just in case. - DataChannel oldDataChannel = dataChannels.get(dataChannel.label()); - if (oldDataChannel == dataChannel) { - Log.w(TAG, "Data channel with label " + dataChannel.label() + " added again"); + synchronized (PeerConnectionWrapper.this) { + // Another data channel with the same label, no matter if the same instance or a different one, should + // not be added, but this is handled just in case. + // Moreover, if it were possible that an already added data channel was added again there would be a + // potential race condition with "removePeerConnection", even with the synchronization, as it would + // be possible that "onDataChannel" was called, then "removePeerConnection" disposed the data + // channel, and then "onDataChannel" continued in the synchronized statements and tried to get the + // label, which would throw an exception due to the data channel having been disposed already. + String dataChannelLabel; + try { + dataChannelLabel = dataChannel.label(); + } catch (IllegalStateException e) { + // The data channel was disposed already, nothing to do. + return; + } - return; + DataChannel oldDataChannel = dataChannels.get(dataChannelLabel); + if (oldDataChannel == dataChannel) { + Log.w(TAG, "Data channel with label " + dataChannel.label() + " added again"); + + return; + } + + if (oldDataChannel != null) { + Log.w(TAG, "Data channel with label " + dataChannel.label() + " exists"); + + oldDataChannel.dispose(); + } + + // If the peer connection was removed in parallel dispose the data channel instead of adding it. + if (peerConnection == null) { + dataChannel.dispose(); + + return; + } + + dataChannel.registerObserver(new DataChannelObserver(dataChannel)); + dataChannels.put(dataChannel.label(), dataChannel); } - - if (oldDataChannel != null) { - Log.w(TAG, "Data channel with label " + dataChannel.label() + " exists"); - - oldDataChannel.dispose(); - } - - dataChannel.registerObserver(new DataChannelObserver(dataChannel)); - dataChannels.put(dataChannel.label(), dataChannel); } @Override diff --git a/app/src/test/java/com/nextcloud/talk/webrtc/PeerConnectionWrapperTest.kt b/app/src/test/java/com/nextcloud/talk/webrtc/PeerConnectionWrapperTest.kt index 61c5f5daa..f4f7e6443 100644 --- a/app/src/test/java/com/nextcloud/talk/webrtc/PeerConnectionWrapperTest.kt +++ b/app/src/test/java/com/nextcloud/talk/webrtc/PeerConnectionWrapperTest.kt @@ -19,15 +19,22 @@ import org.mockito.ArgumentMatchers.any import org.mockito.ArgumentMatchers.argThat import org.mockito.ArgumentMatchers.eq import org.mockito.Mockito +import org.mockito.Mockito.atLeast +import org.mockito.Mockito.atMostOnce +import org.mockito.Mockito.doAnswer import org.mockito.Mockito.doNothing import org.mockito.Mockito.never +import org.mockito.invocation.InvocationOnMock +import org.mockito.stubbing.Answer import org.webrtc.DataChannel import org.webrtc.MediaConstraints import org.webrtc.PeerConnection import org.webrtc.PeerConnectionFactory import java.nio.ByteBuffer import java.util.HashMap +import kotlin.concurrent.thread +@Suppress("LongMethod", "TooGenericExceptionCaught") class PeerConnectionWrapperTest { private var peerConnectionWrapper: PeerConnectionWrapper? = null @@ -36,6 +43,23 @@ class PeerConnectionWrapperTest { private var mockedSignalingMessageReceiver: SignalingMessageReceiver? = null private var mockedSignalingMessageSender: SignalingMessageSender? = null + /** + * Helper answer for DataChannel methods. + */ + private class ReturnValueOrThrowIfDisposed(val value: T) : + Answer { + override fun answer(currentInvocation: InvocationOnMock): T { + if (Mockito.mockingDetails(currentInvocation.mock).invocations.find { + it!!.method.name === "dispose" + } !== null + ) { + throw IllegalStateException("DataChannel has been disposed") + } + + return value + } + } + /** * Helper matcher for DataChannelMessages. */ @@ -195,6 +219,83 @@ class PeerConnectionWrapperTest { ) } + @Test + fun testSendDataChannelMessageBeforeOpeningDataChannelWithDifferentThreads() { + // A brute force approach is used to test race conditions between different threads just repeating the test + // several times. Due to this the test passing could be a false positive, as it could have been a matter of + // luck, but even if the test may wrongly pass sometimes it is better than nothing (although, in general, with + // that number of reruns, it fails when it should). + for (i in 1..1000) { + Mockito.`when`( + mockedPeerConnectionFactory!!.createPeerConnection( + any(PeerConnection.RTCConfiguration::class.java), + any(PeerConnection.Observer::class.java) + ) + ).thenReturn(mockedPeerConnection) + + val mockedStatusDataChannel = Mockito.mock(DataChannel::class.java) + Mockito.`when`(mockedStatusDataChannel.label()).thenReturn("status") + Mockito.`when`(mockedStatusDataChannel.state()).thenReturn(DataChannel.State.CONNECTING) + Mockito.`when`(mockedPeerConnection!!.createDataChannel(eq("status"), any())) + .thenReturn(mockedStatusDataChannel) + + val statusDataChannelObserverArgumentCaptor: ArgumentCaptor = + ArgumentCaptor.forClass(DataChannel.Observer::class.java) + + doNothing().`when`(mockedStatusDataChannel) + .registerObserver(statusDataChannelObserverArgumentCaptor.capture()) + + peerConnectionWrapper = PeerConnectionWrapper( + mockedPeerConnectionFactory, + ArrayList(), + MediaConstraints(), + "the-session-id", + "the-local-session-id", + null, + true, + true, + "video", + mockedSignalingMessageReceiver, + mockedSignalingMessageSender + ) + + val dataChannelMessageCount = 5 + + val sendThread = thread { + for (j in 1..dataChannelMessageCount) { + peerConnectionWrapper!!.send(DataChannelMessage("the-message-type-$j")) + } + } + + // Exceptions thrown in threads are not propagated to the main thread, so it needs to be explicitly done + // (for example, for ConcurrentModificationExceptions when iterating over the data channel messages). + var exceptionOnStateChange: Exception? = null + + val openDataChannelThread = thread { + Mockito.`when`(mockedStatusDataChannel.state()).thenReturn(DataChannel.State.OPEN) + + try { + statusDataChannelObserverArgumentCaptor.value.onStateChange() + } catch (e: Exception) { + exceptionOnStateChange = e + } + } + + sendThread.join() + openDataChannelThread.join() + + if (exceptionOnStateChange !== null) { + throw exceptionOnStateChange!! + } + + for (j in 1..dataChannelMessageCount) { + Mockito.verify(mockedStatusDataChannel).send( + argThat(MatchesDataChannelMessage(DataChannelMessage("the-message-type-$j"))) + ) + } + } + } + @Test fun testReceiveDataChannelMessage() { Mockito.`when`( @@ -381,4 +482,279 @@ class PeerConnectionWrapperTest { Mockito.verify(mockedStatusDataChannel).dispose() Mockito.verify(mockedRandomIdDataChannel).dispose() } + + @Test + fun testRemovePeerConnectionWhileAddingRemoteDataChannelsWithDifferentThreads() { + // A brute force approach is used to test race conditions between different threads just repeating the test + // several times. Due to this the test passing could be a false positive, as it could have been a matter of + // luck, but even if the test may wrongly pass sometimes it is better than nothing (although, in general, with + // that number of reruns, it fails when it should). + for (i in 1..1000) { + val peerConnectionObserverArgumentCaptor: ArgumentCaptor = + ArgumentCaptor.forClass(PeerConnection.Observer::class.java) + + Mockito.`when`( + mockedPeerConnectionFactory!!.createPeerConnection( + any(PeerConnection.RTCConfiguration::class.java), + peerConnectionObserverArgumentCaptor.capture() + ) + ).thenReturn(mockedPeerConnection) + + val mockedStatusDataChannel = Mockito.mock(DataChannel::class.java) + Mockito.`when`(mockedStatusDataChannel.label()).thenAnswer(ReturnValueOrThrowIfDisposed("status")) + Mockito.`when`(mockedStatusDataChannel.state()).thenAnswer( + ReturnValueOrThrowIfDisposed(DataChannel.State.OPEN) + ) + Mockito.`when`(mockedPeerConnection!!.createDataChannel(eq("status"), any())) + .thenReturn(mockedStatusDataChannel) + + peerConnectionWrapper = PeerConnectionWrapper( + mockedPeerConnectionFactory, + ArrayList(), + MediaConstraints(), + "the-session-id", + "the-local-session-id", + null, + true, + true, + "video", + mockedSignalingMessageReceiver, + mockedSignalingMessageSender + ) + + val dataChannelCount = 5 + + val mockedRandomIdDataChannels: MutableList = ArrayList() + val dataChannelObservers: MutableList = ArrayList() + for (j in 0.. + if (Mockito.mockingDetails(invocation.mock).invocations.find { + it!!.method.name === "dispose" + } !== null + ) { + throw IllegalStateException("DataChannel has been disposed") + } + + dataChannelObservers[j] = invocation.getArgument(0, DataChannel.Observer::class.java) + + null + }.`when`(mockedRandomIdDataChannels[j]).registerObserver(any()) + } + + val onDataChannelThread = thread { + // Add again "status" data channel to test that it is correctly disposed also in that case (which + // should not happen anyway even if it was added by the remote peer, but just in case) + peerConnectionObserverArgumentCaptor.value.onDataChannel(mockedStatusDataChannel) + + for (j in 0.. = + ArgumentCaptor.forClass(PeerConnection.Observer::class.java) + + Mockito.`when`( + mockedPeerConnectionFactory!!.createPeerConnection( + any(PeerConnection.RTCConfiguration::class.java), + peerConnectionObserverArgumentCaptor.capture() + ) + ).thenReturn(mockedPeerConnection) + + val mockedStatusDataChannel = Mockito.mock(DataChannel::class.java) + + Mockito.`when`(mockedStatusDataChannel.label()).thenAnswer(ReturnValueOrThrowIfDisposed("status")) + Mockito.`when`(mockedStatusDataChannel.state()) + .thenAnswer(ReturnValueOrThrowIfDisposed(DataChannel.State.OPEN)) + Mockito.`when`(mockedStatusDataChannel.send(any())).thenAnswer(ReturnValueOrThrowIfDisposed(true)) + Mockito.`when`(mockedPeerConnection!!.createDataChannel(eq("status"), any())) + .thenReturn(mockedStatusDataChannel) + + peerConnectionWrapper = PeerConnectionWrapper( + mockedPeerConnectionFactory, + ArrayList(), + MediaConstraints(), + "the-session-id", + "the-local-session-id", + null, + true, + true, + "video", + mockedSignalingMessageReceiver, + mockedSignalingMessageSender + ) + + val dataChannelMessageCount = 5 + + // Exceptions thrown in threads are not propagated to the main thread, so it needs to be explicitly done + // (for example, for IllegalStateExceptions when using a disposed data channel). + var exceptionSend: Exception? = null + + val sendThread = thread { + try { + for (j in 0.. = + ArgumentCaptor.forClass(PeerConnection.Observer::class.java) + + Mockito.`when`( + mockedPeerConnectionFactory!!.createPeerConnection( + any(PeerConnection.RTCConfiguration::class.java), + peerConnectionObserverArgumentCaptor.capture() + ) + ).thenReturn(mockedPeerConnection) + + val mockedStatusDataChannel = Mockito.mock(DataChannel::class.java) + Mockito.`when`(mockedStatusDataChannel.label()).thenAnswer(ReturnValueOrThrowIfDisposed("status")) + Mockito.`when`(mockedStatusDataChannel.state()).thenAnswer( + ReturnValueOrThrowIfDisposed(DataChannel.State.OPEN) + ) + Mockito.`when`(mockedPeerConnection!!.createDataChannel(eq("status"), any())) + .thenReturn(mockedStatusDataChannel) + + val statusDataChannelObserverArgumentCaptor: ArgumentCaptor = + ArgumentCaptor.forClass(DataChannel.Observer::class.java) + + doNothing().`when`(mockedStatusDataChannel) + .registerObserver(statusDataChannelObserverArgumentCaptor.capture()) + + peerConnectionWrapper = PeerConnectionWrapper( + mockedPeerConnectionFactory, + ArrayList(), + MediaConstraints(), + "the-session-id", + "the-local-session-id", + null, + true, + true, + "video", + mockedSignalingMessageReceiver, + mockedSignalingMessageSender + ) + + val mockedDataChannelMessageListener = Mockito.mock(DataChannelMessageListener::class.java) + peerConnectionWrapper!!.addListener(mockedDataChannelMessageListener) + + // Exceptions thrown in threads are not propagated to the main thread, so it needs to be explicitly done + // (for example, for IllegalStateExceptions when using a disposed data channel). + var exceptionOnMessage: Exception? = null + + val onMessageThread = thread { + try { + // It is assumed that, even if its data channel was disposed, its buffers can be used while there + // is a reference to them, so no special mock behaviour is added to throw an exception in that case. + statusDataChannelObserverArgumentCaptor.value.onMessage( + dataChannelMessageToBuffer(DataChannelMessage("audioOn")) + ) + + statusDataChannelObserverArgumentCaptor.value.onMessage( + dataChannelMessageToBuffer(DataChannelMessage("audioOff")) + ) + } catch (e: Exception) { + exceptionOnMessage = e + } + } + + val removePeerConnectionThread = thread { + peerConnectionWrapper!!.removePeerConnection() + } + + onMessageThread.join() + removePeerConnectionThread.join() + + if (exceptionOnMessage !== null) { + throw exceptionOnMessage!! + } + + Mockito.verify(mockedStatusDataChannel).registerObserver(any()) + Mockito.verify(mockedStatusDataChannel).dispose() + Mockito.verify(mockedStatusDataChannel, atLeast(0)).label() + Mockito.verify(mockedStatusDataChannel, atLeast(0)).state() + Mockito.verifyNoMoreInteractions(mockedStatusDataChannel) + Mockito.verify(mockedDataChannelMessageListener, atMostOnce()).onAudioOn() + Mockito.verify(mockedDataChannelMessageListener, atMostOnce()).onAudioOff() + Mockito.verifyNoMoreInteractions(mockedDataChannelMessageListener) + } + } }