Fix "removePeerConnection" not being thread-safe

Adding and disposing remote data channels is done from different
threads; they are added from the WebRTC signaling thread when
"onDataChannel" is called, while they can be disposed potentially from
any thread when "removePeerConnection" is called. To prevent race
conditions between them now both operations are synchronized.

However, as "onDataChannel" belongs to an inner class it needs to use a
synchronized statement with the outer class lock. This could still cause
a race condition if the same data channel was added again; this should
not happen, but it is handled just in case.

Moreover, once a data channel is disposed it can be no longer used, and
trying to call any of its methods throws an "IllegalStateException". Due
to this, as sending can be also done potentially from any thread, it
needs to be synchronized too with removing the peer connection.

State changes on data channels as well as receiving messages are also
done in the WebRTC signaling thread. State changes needs synchronization
as well, although receiving messages should not, as it does not directly
use the data channel (and it is assumed that using the buffers of a
disposed data channel is safe). Nevertheless a little check (which in
this case requires synchronization) was added to ignore the received
messages if the peer connection was removed already.

Finally, the synchronization added to "send" and "onStateChange" had the
nice side effect of making the pending data channel messages thread-safe
too, as before it could happen that a message was enqueued when the
pending messages were being sent, which caused a
"ConcurrentModificationException".

Signed-off-by: Daniel Calviño Sánchez <danxuliu@gmail.com>
This commit is contained in:
Daniel Calviño Sánchez 2024-12-08 05:23:11 +01:00 committed by Marcel Hibbe
parent a301bdeb76
commit b6d6986b62
No known key found for this signature in database
GPG Key ID: C793F8B59F43CE7B
2 changed files with 440 additions and 24 deletions

View File

@ -235,7 +235,7 @@ public class PeerConnectionWrapper {
return stream;
}
public void removePeerConnection() {
public synchronized void removePeerConnection() {
signalingMessageReceiver.removeListener(webRtcMessageListener);
for (DataChannel dataChannel: dataChannels.values()) {
@ -292,7 +292,7 @@ public class PeerConnectionWrapper {
*
* @param dataChannelMessage the message to send
*/
public void send(DataChannelMessage dataChannelMessage) {
public synchronized void send(DataChannelMessage dataChannelMessage) {
if (dataChannelMessage == null) {
return;
}
@ -414,20 +414,38 @@ public class PeerConnectionWrapper {
@Override
public void onStateChange() {
if (dataChannel.state() == DataChannel.State.OPEN && "status".equals(dataChannelLabel)) {
for (DataChannelMessage dataChannelMessage: pendingDataChannelMessages) {
send(dataChannelMessage);
synchronized (PeerConnectionWrapper.this) {
// The PeerConnection could have been removed in parallel even with the synchronization (as just after
// "onStateChange" was called "removePeerConnection" could have acquired the lock).
if (peerConnection == null) {
return;
}
pendingDataChannelMessages.clear();
}
if (dataChannel.state() == DataChannel.State.OPEN) {
sendInitialMediaStatus();
if (dataChannel.state() == DataChannel.State.OPEN && "status".equals(dataChannelLabel)) {
for (DataChannelMessage dataChannelMessage : pendingDataChannelMessages) {
send(dataChannelMessage);
}
pendingDataChannelMessages.clear();
}
if (dataChannel.state() == DataChannel.State.OPEN) {
sendInitialMediaStatus();
}
}
}
@Override
public void onMessage(DataChannel.Buffer buffer) {
synchronized (PeerConnectionWrapper.this) {
// It is assumed that, even if its data channel was disposed, its buffers can be used while there is
// a reference to them, so it would not be necessary to check this from a thread-safety point of view.
// Nevertheless, if the remote peer connection was removed it would not make sense to notify the
// listeners anyway.
if (peerConnection == null) {
return;
}
}
if (buffer.binary) {
Log.d(TAG, "Received binary data channel message over " + dataChannelLabel + " " + sessionId);
return;
@ -557,23 +575,45 @@ public class PeerConnectionWrapper {
@Override
public void onDataChannel(DataChannel dataChannel) {
// Another data channel with the same label, no matter if the same instance or a different one, should not
// be added, but just in case.
DataChannel oldDataChannel = dataChannels.get(dataChannel.label());
if (oldDataChannel == dataChannel) {
Log.w(TAG, "Data channel with label " + dataChannel.label() + " added again");
synchronized (PeerConnectionWrapper.this) {
// Another data channel with the same label, no matter if the same instance or a different one, should
// not be added, but this is handled just in case.
// Moreover, if it were possible that an already added data channel was added again there would be a
// potential race condition with "removePeerConnection", even with the synchronization, as it would
// be possible that "onDataChannel" was called, then "removePeerConnection" disposed the data
// channel, and then "onDataChannel" continued in the synchronized statements and tried to get the
// label, which would throw an exception due to the data channel having been disposed already.
String dataChannelLabel;
try {
dataChannelLabel = dataChannel.label();
} catch (IllegalStateException e) {
// The data channel was disposed already, nothing to do.
return;
}
return;
DataChannel oldDataChannel = dataChannels.get(dataChannelLabel);
if (oldDataChannel == dataChannel) {
Log.w(TAG, "Data channel with label " + dataChannel.label() + " added again");
return;
}
if (oldDataChannel != null) {
Log.w(TAG, "Data channel with label " + dataChannel.label() + " exists");
oldDataChannel.dispose();
}
// If the peer connection was removed in parallel dispose the data channel instead of adding it.
if (peerConnection == null) {
dataChannel.dispose();
return;
}
dataChannel.registerObserver(new DataChannelObserver(dataChannel));
dataChannels.put(dataChannel.label(), dataChannel);
}
if (oldDataChannel != null) {
Log.w(TAG, "Data channel with label " + dataChannel.label() + " exists");
oldDataChannel.dispose();
}
dataChannel.registerObserver(new DataChannelObserver(dataChannel));
dataChannels.put(dataChannel.label(), dataChannel);
}
@Override

View File

@ -19,15 +19,22 @@ import org.mockito.ArgumentMatchers.any
import org.mockito.ArgumentMatchers.argThat
import org.mockito.ArgumentMatchers.eq
import org.mockito.Mockito
import org.mockito.Mockito.atLeast
import org.mockito.Mockito.atMostOnce
import org.mockito.Mockito.doAnswer
import org.mockito.Mockito.doNothing
import org.mockito.Mockito.never
import org.mockito.invocation.InvocationOnMock
import org.mockito.stubbing.Answer
import org.webrtc.DataChannel
import org.webrtc.MediaConstraints
import org.webrtc.PeerConnection
import org.webrtc.PeerConnectionFactory
import java.nio.ByteBuffer
import java.util.HashMap
import kotlin.concurrent.thread
@Suppress("LongMethod", "TooGenericExceptionCaught")
class PeerConnectionWrapperTest {
private var peerConnectionWrapper: PeerConnectionWrapper? = null
@ -36,6 +43,23 @@ class PeerConnectionWrapperTest {
private var mockedSignalingMessageReceiver: SignalingMessageReceiver? = null
private var mockedSignalingMessageSender: SignalingMessageSender? = null
/**
* Helper answer for DataChannel methods.
*/
private class ReturnValueOrThrowIfDisposed<T>(val value: T) :
Answer<T> {
override fun answer(currentInvocation: InvocationOnMock): T {
if (Mockito.mockingDetails(currentInvocation.mock).invocations.find {
it!!.method.name === "dispose"
} !== null
) {
throw IllegalStateException("DataChannel has been disposed")
}
return value
}
}
/**
* Helper matcher for DataChannelMessages.
*/
@ -195,6 +219,83 @@ class PeerConnectionWrapperTest {
)
}
@Test
fun testSendDataChannelMessageBeforeOpeningDataChannelWithDifferentThreads() {
// A brute force approach is used to test race conditions between different threads just repeating the test
// several times. Due to this the test passing could be a false positive, as it could have been a matter of
// luck, but even if the test may wrongly pass sometimes it is better than nothing (although, in general, with
// that number of reruns, it fails when it should).
for (i in 1..1000) {
Mockito.`when`(
mockedPeerConnectionFactory!!.createPeerConnection(
any(PeerConnection.RTCConfiguration::class.java),
any(PeerConnection.Observer::class.java)
)
).thenReturn(mockedPeerConnection)
val mockedStatusDataChannel = Mockito.mock(DataChannel::class.java)
Mockito.`when`(mockedStatusDataChannel.label()).thenReturn("status")
Mockito.`when`(mockedStatusDataChannel.state()).thenReturn(DataChannel.State.CONNECTING)
Mockito.`when`(mockedPeerConnection!!.createDataChannel(eq("status"), any()))
.thenReturn(mockedStatusDataChannel)
val statusDataChannelObserverArgumentCaptor: ArgumentCaptor<DataChannel.Observer> =
ArgumentCaptor.forClass(DataChannel.Observer::class.java)
doNothing().`when`(mockedStatusDataChannel)
.registerObserver(statusDataChannelObserverArgumentCaptor.capture())
peerConnectionWrapper = PeerConnectionWrapper(
mockedPeerConnectionFactory,
ArrayList<PeerConnection.IceServer>(),
MediaConstraints(),
"the-session-id",
"the-local-session-id",
null,
true,
true,
"video",
mockedSignalingMessageReceiver,
mockedSignalingMessageSender
)
val dataChannelMessageCount = 5
val sendThread = thread {
for (j in 1..dataChannelMessageCount) {
peerConnectionWrapper!!.send(DataChannelMessage("the-message-type-$j"))
}
}
// Exceptions thrown in threads are not propagated to the main thread, so it needs to be explicitly done
// (for example, for ConcurrentModificationExceptions when iterating over the data channel messages).
var exceptionOnStateChange: Exception? = null
val openDataChannelThread = thread {
Mockito.`when`(mockedStatusDataChannel.state()).thenReturn(DataChannel.State.OPEN)
try {
statusDataChannelObserverArgumentCaptor.value.onStateChange()
} catch (e: Exception) {
exceptionOnStateChange = e
}
}
sendThread.join()
openDataChannelThread.join()
if (exceptionOnStateChange !== null) {
throw exceptionOnStateChange!!
}
for (j in 1..dataChannelMessageCount) {
Mockito.verify(mockedStatusDataChannel).send(
argThat(MatchesDataChannelMessage(DataChannelMessage("the-message-type-$j")))
)
}
}
}
@Test
fun testReceiveDataChannelMessage() {
Mockito.`when`(
@ -381,4 +482,279 @@ class PeerConnectionWrapperTest {
Mockito.verify(mockedStatusDataChannel).dispose()
Mockito.verify(mockedRandomIdDataChannel).dispose()
}
@Test
fun testRemovePeerConnectionWhileAddingRemoteDataChannelsWithDifferentThreads() {
// A brute force approach is used to test race conditions between different threads just repeating the test
// several times. Due to this the test passing could be a false positive, as it could have been a matter of
// luck, but even if the test may wrongly pass sometimes it is better than nothing (although, in general, with
// that number of reruns, it fails when it should).
for (i in 1..1000) {
val peerConnectionObserverArgumentCaptor: ArgumentCaptor<PeerConnection.Observer> =
ArgumentCaptor.forClass(PeerConnection.Observer::class.java)
Mockito.`when`(
mockedPeerConnectionFactory!!.createPeerConnection(
any(PeerConnection.RTCConfiguration::class.java),
peerConnectionObserverArgumentCaptor.capture()
)
).thenReturn(mockedPeerConnection)
val mockedStatusDataChannel = Mockito.mock(DataChannel::class.java)
Mockito.`when`(mockedStatusDataChannel.label()).thenAnswer(ReturnValueOrThrowIfDisposed("status"))
Mockito.`when`(mockedStatusDataChannel.state()).thenAnswer(
ReturnValueOrThrowIfDisposed(DataChannel.State.OPEN)
)
Mockito.`when`(mockedPeerConnection!!.createDataChannel(eq("status"), any()))
.thenReturn(mockedStatusDataChannel)
peerConnectionWrapper = PeerConnectionWrapper(
mockedPeerConnectionFactory,
ArrayList<PeerConnection.IceServer>(),
MediaConstraints(),
"the-session-id",
"the-local-session-id",
null,
true,
true,
"video",
mockedSignalingMessageReceiver,
mockedSignalingMessageSender
)
val dataChannelCount = 5
val mockedRandomIdDataChannels: MutableList<DataChannel> = ArrayList()
val dataChannelObservers: MutableList<DataChannel.Observer?> = ArrayList()
for (j in 0..<dataChannelCount) {
mockedRandomIdDataChannels.add(Mockito.mock(DataChannel::class.java))
// Add data channels with duplicated labels (from the second data channel and onwards) to test that
// they are correctly disposed also in that case (which should not happen anyway, but just in case).
Mockito.`when`(mockedRandomIdDataChannels[j].label())
.thenAnswer(ReturnValueOrThrowIfDisposed("random-id-" + ((j + 1) / 2)))
Mockito.`when`(mockedRandomIdDataChannels[j].state())
.thenAnswer(ReturnValueOrThrowIfDisposed(DataChannel.State.OPEN))
// Store a reference to the registered observer, if any, to be called after the registration. The call
// is done outside the mock to better simulate the normal behaviour, as it would not be called during
// the registration itself.
dataChannelObservers.add(null)
doAnswer { invocation ->
if (Mockito.mockingDetails(invocation.mock).invocations.find {
it!!.method.name === "dispose"
} !== null
) {
throw IllegalStateException("DataChannel has been disposed")
}
dataChannelObservers[j] = invocation.getArgument(0, DataChannel.Observer::class.java)
null
}.`when`(mockedRandomIdDataChannels[j]).registerObserver(any())
}
val onDataChannelThread = thread {
// Add again "status" data channel to test that it is correctly disposed also in that case (which
// should not happen anyway even if it was added by the remote peer, but just in case)
peerConnectionObserverArgumentCaptor.value.onDataChannel(mockedStatusDataChannel)
for (j in 0..<dataChannelCount) {
peerConnectionObserverArgumentCaptor.value.onDataChannel(mockedRandomIdDataChannels[j])
// Call "onStateChange" on the registered observer to simulate that the data channel was opened.
dataChannelObservers[j]?.onStateChange()
}
}
// Exceptions thrown in threads are not propagated to the main thread, so it needs to be explicitly done
// (for example, for ConcurrentModificationExceptions when iterating over the data channels).
var exceptionRemovePeerConnection: Exception? = null
val removePeerConnectionThread = thread {
try {
peerConnectionWrapper!!.removePeerConnection()
} catch (e: Exception) {
exceptionRemovePeerConnection = e
}
}
onDataChannelThread.join()
removePeerConnectionThread.join()
if (exceptionRemovePeerConnection !== null) {
throw exceptionRemovePeerConnection!!
}
Mockito.verify(mockedStatusDataChannel).dispose()
for (j in 0..<dataChannelCount) {
Mockito.verify(mockedRandomIdDataChannels[j]).dispose()
}
}
}
@Test
fun testRemovePeerConnectionWhileSendingWithDifferentThreads() {
// A brute force approach is used to test race conditions between different threads just repeating the test
// several times. Due to this the test passing could be a false positive, as it could have been a matter of
// luck, but even if the test may wrongly pass sometimes it is better than nothing (although, in general, with
// that number of reruns, it fails when it should).
for (i in 1..1000) {
val peerConnectionObserverArgumentCaptor: ArgumentCaptor<PeerConnection.Observer> =
ArgumentCaptor.forClass(PeerConnection.Observer::class.java)
Mockito.`when`(
mockedPeerConnectionFactory!!.createPeerConnection(
any(PeerConnection.RTCConfiguration::class.java),
peerConnectionObserverArgumentCaptor.capture()
)
).thenReturn(mockedPeerConnection)
val mockedStatusDataChannel = Mockito.mock(DataChannel::class.java)
Mockito.`when`(mockedStatusDataChannel.label()).thenAnswer(ReturnValueOrThrowIfDisposed("status"))
Mockito.`when`(mockedStatusDataChannel.state())
.thenAnswer(ReturnValueOrThrowIfDisposed(DataChannel.State.OPEN))
Mockito.`when`(mockedStatusDataChannel.send(any())).thenAnswer(ReturnValueOrThrowIfDisposed(true))
Mockito.`when`(mockedPeerConnection!!.createDataChannel(eq("status"), any()))
.thenReturn(mockedStatusDataChannel)
peerConnectionWrapper = PeerConnectionWrapper(
mockedPeerConnectionFactory,
ArrayList<PeerConnection.IceServer>(),
MediaConstraints(),
"the-session-id",
"the-local-session-id",
null,
true,
true,
"video",
mockedSignalingMessageReceiver,
mockedSignalingMessageSender
)
val dataChannelMessageCount = 5
// Exceptions thrown in threads are not propagated to the main thread, so it needs to be explicitly done
// (for example, for IllegalStateExceptions when using a disposed data channel).
var exceptionSend: Exception? = null
val sendThread = thread {
try {
for (j in 0..<dataChannelMessageCount) {
peerConnectionWrapper!!.send(DataChannelMessage("the-message-type-$j"))
}
} catch (e: Exception) {
exceptionSend = e
}
}
val removePeerConnectionThread = thread {
peerConnectionWrapper!!.removePeerConnection()
}
sendThread.join()
removePeerConnectionThread.join()
if (exceptionSend !== null) {
throw exceptionSend!!
}
Mockito.verify(mockedStatusDataChannel).registerObserver(any())
Mockito.verify(mockedStatusDataChannel).dispose()
Mockito.verify(mockedStatusDataChannel, atLeast(0)).label()
Mockito.verify(mockedStatusDataChannel, atLeast(0)).state()
Mockito.verify(mockedStatusDataChannel, atLeast(0)).send(any())
Mockito.verifyNoMoreInteractions(mockedStatusDataChannel)
}
}
@Test
fun testRemovePeerConnectionWhileReceivingWithDifferentThreads() {
// A brute force approach is used to test race conditions between different threads just repeating the test
// several times. Due to this the test passing could be a false positive, as it could have been a matter of
// luck, but even if the test may wrongly pass sometimes it is better than nothing (although, in general, with
// that number of reruns, it fails when it should).
for (i in 1..1000) {
val peerConnectionObserverArgumentCaptor: ArgumentCaptor<PeerConnection.Observer> =
ArgumentCaptor.forClass(PeerConnection.Observer::class.java)
Mockito.`when`(
mockedPeerConnectionFactory!!.createPeerConnection(
any(PeerConnection.RTCConfiguration::class.java),
peerConnectionObserverArgumentCaptor.capture()
)
).thenReturn(mockedPeerConnection)
val mockedStatusDataChannel = Mockito.mock(DataChannel::class.java)
Mockito.`when`(mockedStatusDataChannel.label()).thenAnswer(ReturnValueOrThrowIfDisposed("status"))
Mockito.`when`(mockedStatusDataChannel.state()).thenAnswer(
ReturnValueOrThrowIfDisposed(DataChannel.State.OPEN)
)
Mockito.`when`(mockedPeerConnection!!.createDataChannel(eq("status"), any()))
.thenReturn(mockedStatusDataChannel)
val statusDataChannelObserverArgumentCaptor: ArgumentCaptor<DataChannel.Observer> =
ArgumentCaptor.forClass(DataChannel.Observer::class.java)
doNothing().`when`(mockedStatusDataChannel)
.registerObserver(statusDataChannelObserverArgumentCaptor.capture())
peerConnectionWrapper = PeerConnectionWrapper(
mockedPeerConnectionFactory,
ArrayList<PeerConnection.IceServer>(),
MediaConstraints(),
"the-session-id",
"the-local-session-id",
null,
true,
true,
"video",
mockedSignalingMessageReceiver,
mockedSignalingMessageSender
)
val mockedDataChannelMessageListener = Mockito.mock(DataChannelMessageListener::class.java)
peerConnectionWrapper!!.addListener(mockedDataChannelMessageListener)
// Exceptions thrown in threads are not propagated to the main thread, so it needs to be explicitly done
// (for example, for IllegalStateExceptions when using a disposed data channel).
var exceptionOnMessage: Exception? = null
val onMessageThread = thread {
try {
// It is assumed that, even if its data channel was disposed, its buffers can be used while there
// is a reference to them, so no special mock behaviour is added to throw an exception in that case.
statusDataChannelObserverArgumentCaptor.value.onMessage(
dataChannelMessageToBuffer(DataChannelMessage("audioOn"))
)
statusDataChannelObserverArgumentCaptor.value.onMessage(
dataChannelMessageToBuffer(DataChannelMessage("audioOff"))
)
} catch (e: Exception) {
exceptionOnMessage = e
}
}
val removePeerConnectionThread = thread {
peerConnectionWrapper!!.removePeerConnection()
}
onMessageThread.join()
removePeerConnectionThread.join()
if (exceptionOnMessage !== null) {
throw exceptionOnMessage!!
}
Mockito.verify(mockedStatusDataChannel).registerObserver(any())
Mockito.verify(mockedStatusDataChannel).dispose()
Mockito.verify(mockedStatusDataChannel, atLeast(0)).label()
Mockito.verify(mockedStatusDataChannel, atLeast(0)).state()
Mockito.verifyNoMoreInteractions(mockedStatusDataChannel)
Mockito.verify(mockedDataChannelMessageListener, atMostOnce()).onAudioOn()
Mockito.verify(mockedDataChannelMessageListener, atMostOnce()).onAudioOff()
Mockito.verifyNoMoreInteractions(mockedDataChannelMessageListener)
}
}
}