diff --git a/app/src/main/java/com/nextcloud/talk/webrtc/PeerConnectionWrapper.java b/app/src/main/java/com/nextcloud/talk/webrtc/PeerConnectionWrapper.java index 6413f0344..27ad9e439 100644 --- a/app/src/main/java/com/nextcloud/talk/webrtc/PeerConnectionWrapper.java +++ b/app/src/main/java/com/nextcloud/talk/webrtc/PeerConnectionWrapper.java @@ -34,6 +34,7 @@ import java.io.IOException; import java.nio.ByteBuffer; import java.util.ArrayList; import java.util.Collections; +import java.util.HashMap; import java.util.List; import java.util.Map; import java.util.Objects; @@ -57,7 +58,7 @@ public class PeerConnectionWrapper { private PeerConnection peerConnection; private String sessionId; private final MediaConstraints mediaConstraints; - private DataChannel statusDataChannel; + private final Map dataChannels = new HashMap<>(); private final SdpObserver sdpObserver; private final boolean hasInitiated; @@ -144,8 +145,11 @@ public class PeerConnectionWrapper { if (hasMCU || hasInitiated) { DataChannel.Init init = new DataChannel.Init(); init.negotiated = false; - statusDataChannel = peerConnection.createDataChannel("status", init); + + DataChannel statusDataChannel = peerConnection.createDataChannel("status", init); statusDataChannel.registerObserver(new DataChannelObserver(statusDataChannel)); + dataChannels.put("status", statusDataChannel); + if (isMCUPublisher) { peerConnection.createOffer(sdpObserver, mediaConstraints); } else if (hasMCU && "video".equals(this.videoStreamType)) { @@ -233,13 +237,12 @@ public class PeerConnectionWrapper { public void removePeerConnection() { signalingMessageReceiver.removeListener(webRtcMessageListener); - if (statusDataChannel != null) { - statusDataChannel.dispose(); - statusDataChannel = null; - Log.d(TAG, "Disposed DataChannel"); - } else { - Log.d(TAG, "DataChannel is null."); + for (DataChannel dataChannel: dataChannels.values()) { + Log.d(TAG, "Disposed DataChannel " + dataChannel.label()); + + dataChannel.dispose(); } + dataChannels.clear(); if (peerConnection != null) { peerConnection.close(); @@ -283,6 +286,7 @@ public class PeerConnectionWrapper { */ public void send(DataChannelMessage dataChannelMessage) { ByteBuffer buffer; + DataChannel statusDataChannel = dataChannels.get("status"); if (statusDataChannel != null && dataChannelMessage != null) { try { buffer = ByteBuffer.wrap(LoganSquare.serialize(dataChannelMessage).getBytes()); @@ -525,7 +529,23 @@ public class PeerConnectionWrapper { @Override public void onDataChannel(DataChannel dataChannel) { + // Another data channel with the same label, no matter if the same instance or a different one, should not + // be added, but just in case. + DataChannel oldDataChannel = dataChannels.get(dataChannel.label()); + if (oldDataChannel == dataChannel) { + Log.w(TAG, "Data channel with label " + dataChannel.label() + " added again"); + + return; + } + + if (oldDataChannel != null) { + Log.w(TAG, "Data channel with label " + dataChannel.label() + " exists"); + + oldDataChannel.dispose(); + } + dataChannel.registerObserver(new DataChannelObserver(dataChannel)); + dataChannels.put(dataChannel.label(), dataChannel); } @Override diff --git a/app/src/test/java/com/nextcloud/talk/webrtc/PeerConnectionWrapperTest.kt b/app/src/test/java/com/nextcloud/talk/webrtc/PeerConnectionWrapperTest.kt index be628c2af..450cc4fa8 100644 --- a/app/src/test/java/com/nextcloud/talk/webrtc/PeerConnectionWrapperTest.kt +++ b/app/src/test/java/com/nextcloud/talk/webrtc/PeerConnectionWrapperTest.kt @@ -288,4 +288,47 @@ class PeerConnectionWrapperTest { Mockito.verify(mockedDataChannelMessageListener).onAudioOff() Mockito.verifyNoMoreInteractions(mockedDataChannelMessageListener) } + + @Test + fun testRemovePeerConnectionWithOpenRemoteDataChannel() { + val peerConnectionObserverArgumentCaptor: ArgumentCaptor = + ArgumentCaptor.forClass(PeerConnection.Observer::class.java) + + Mockito.`when`( + mockedPeerConnectionFactory!!.createPeerConnection( + any(PeerConnection.RTCConfiguration::class.java), + peerConnectionObserverArgumentCaptor.capture() + ) + ).thenReturn(mockedPeerConnection) + + val mockedStatusDataChannel = Mockito.mock(DataChannel::class.java) + Mockito.`when`(mockedStatusDataChannel.label()).thenReturn("status") + Mockito.`when`(mockedStatusDataChannel.state()).thenReturn(DataChannel.State.OPEN) + Mockito.`when`(mockedPeerConnection!!.createDataChannel(eq("status"), any())) + .thenReturn(mockedStatusDataChannel) + + peerConnectionWrapper = PeerConnectionWrapper( + mockedPeerConnectionFactory, + ArrayList(), + MediaConstraints(), + "the-session-id", + "the-local-session-id", + null, + true, + true, + "video", + mockedSignalingMessageReceiver, + mockedSignalingMessageSender + ) + + val mockedRandomIdDataChannel = Mockito.mock(DataChannel::class.java) + Mockito.`when`(mockedRandomIdDataChannel.label()).thenReturn("random-id") + Mockito.`when`(mockedRandomIdDataChannel.state()).thenReturn(DataChannel.State.OPEN) + peerConnectionObserverArgumentCaptor.value.onDataChannel(mockedRandomIdDataChannel) + + peerConnectionWrapper!!.removePeerConnection() + + Mockito.verify(mockedStatusDataChannel).dispose() + Mockito.verify(mockedRandomIdDataChannel).dispose() + } }