From 014b18de8923e9ab398521d7ae1b0d0c5b4f74c3 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Daniel=20Calvi=C3=B1o=20S=C3=A1nchez?= <danxuliu@gmail.com>
Date: Fri, 6 Dec 2024 03:08:15 +0100
Subject: [PATCH] Fix remote data channels not disposed when removing peer
 connection
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit

Signed-off-by: Daniel Calviño Sánchez <danxuliu@gmail.com>
---
 .../talk/webrtc/PeerConnectionWrapper.java    | 36 ++++++++++++----
 .../talk/webrtc/PeerConnectionWrapperTest.kt  | 43 +++++++++++++++++++
 2 files changed, 71 insertions(+), 8 deletions(-)

diff --git a/app/src/main/java/com/nextcloud/talk/webrtc/PeerConnectionWrapper.java b/app/src/main/java/com/nextcloud/talk/webrtc/PeerConnectionWrapper.java
index 6413f0344..27ad9e439 100644
--- a/app/src/main/java/com/nextcloud/talk/webrtc/PeerConnectionWrapper.java
+++ b/app/src/main/java/com/nextcloud/talk/webrtc/PeerConnectionWrapper.java
@@ -34,6 +34,7 @@ import java.io.IOException;
 import java.nio.ByteBuffer;
 import java.util.ArrayList;
 import java.util.Collections;
+import java.util.HashMap;
 import java.util.List;
 import java.util.Map;
 import java.util.Objects;
@@ -57,7 +58,7 @@ public class PeerConnectionWrapper {
     private PeerConnection peerConnection;
     private String sessionId;
     private final MediaConstraints mediaConstraints;
-    private DataChannel statusDataChannel;
+    private final Map<String, DataChannel> dataChannels = new HashMap<>();
     private final SdpObserver sdpObserver;
 
     private final boolean hasInitiated;
@@ -144,8 +145,11 @@ public class PeerConnectionWrapper {
             if (hasMCU || hasInitiated) {
                 DataChannel.Init init = new DataChannel.Init();
                 init.negotiated = false;
-                statusDataChannel = peerConnection.createDataChannel("status", init);
+
+                DataChannel statusDataChannel = peerConnection.createDataChannel("status", init);
                 statusDataChannel.registerObserver(new DataChannelObserver(statusDataChannel));
+                dataChannels.put("status", statusDataChannel);
+
                 if (isMCUPublisher) {
                     peerConnection.createOffer(sdpObserver, mediaConstraints);
                 } else if (hasMCU && "video".equals(this.videoStreamType)) {
@@ -233,13 +237,12 @@ public class PeerConnectionWrapper {
     public void removePeerConnection() {
         signalingMessageReceiver.removeListener(webRtcMessageListener);
 
-        if (statusDataChannel != null) {
-            statusDataChannel.dispose();
-            statusDataChannel = null;
-            Log.d(TAG, "Disposed DataChannel");
-        } else {
-            Log.d(TAG, "DataChannel is null.");
+        for (DataChannel dataChannel: dataChannels.values()) {
+            Log.d(TAG, "Disposed DataChannel " + dataChannel.label());
+
+            dataChannel.dispose();
         }
+        dataChannels.clear();
 
         if (peerConnection != null) {
             peerConnection.close();
@@ -283,6 +286,7 @@ public class PeerConnectionWrapper {
      */
     public void send(DataChannelMessage dataChannelMessage) {
         ByteBuffer buffer;
+        DataChannel statusDataChannel = dataChannels.get("status");
         if (statusDataChannel != null && dataChannelMessage != null) {
             try {
                 buffer = ByteBuffer.wrap(LoganSquare.serialize(dataChannelMessage).getBytes());
@@ -525,7 +529,23 @@ public class PeerConnectionWrapper {
 
         @Override
         public void onDataChannel(DataChannel dataChannel) {
+            // Another data channel with the same label, no matter if the same instance or a different one, should not
+            // be added, but just in case.
+            DataChannel oldDataChannel = dataChannels.get(dataChannel.label());
+            if (oldDataChannel == dataChannel) {
+                Log.w(TAG, "Data channel with label " + dataChannel.label() + " added again");
+
+                return;
+            }
+
+            if (oldDataChannel != null) {
+                Log.w(TAG, "Data channel with label " + dataChannel.label() + " exists");
+
+                oldDataChannel.dispose();
+            }
+
             dataChannel.registerObserver(new DataChannelObserver(dataChannel));
+            dataChannels.put(dataChannel.label(), dataChannel);
         }
 
         @Override
diff --git a/app/src/test/java/com/nextcloud/talk/webrtc/PeerConnectionWrapperTest.kt b/app/src/test/java/com/nextcloud/talk/webrtc/PeerConnectionWrapperTest.kt
index be628c2af..450cc4fa8 100644
--- a/app/src/test/java/com/nextcloud/talk/webrtc/PeerConnectionWrapperTest.kt
+++ b/app/src/test/java/com/nextcloud/talk/webrtc/PeerConnectionWrapperTest.kt
@@ -288,4 +288,47 @@ class PeerConnectionWrapperTest {
         Mockito.verify(mockedDataChannelMessageListener).onAudioOff()
         Mockito.verifyNoMoreInteractions(mockedDataChannelMessageListener)
     }
+
+    @Test
+    fun testRemovePeerConnectionWithOpenRemoteDataChannel() {
+        val peerConnectionObserverArgumentCaptor: ArgumentCaptor<PeerConnection.Observer> =
+            ArgumentCaptor.forClass(PeerConnection.Observer::class.java)
+
+        Mockito.`when`(
+            mockedPeerConnectionFactory!!.createPeerConnection(
+                any(PeerConnection.RTCConfiguration::class.java),
+                peerConnectionObserverArgumentCaptor.capture()
+            )
+        ).thenReturn(mockedPeerConnection)
+
+        val mockedStatusDataChannel = Mockito.mock(DataChannel::class.java)
+        Mockito.`when`(mockedStatusDataChannel.label()).thenReturn("status")
+        Mockito.`when`(mockedStatusDataChannel.state()).thenReturn(DataChannel.State.OPEN)
+        Mockito.`when`(mockedPeerConnection!!.createDataChannel(eq("status"), any()))
+            .thenReturn(mockedStatusDataChannel)
+
+        peerConnectionWrapper = PeerConnectionWrapper(
+            mockedPeerConnectionFactory,
+            ArrayList<PeerConnection.IceServer>(),
+            MediaConstraints(),
+            "the-session-id",
+            "the-local-session-id",
+            null,
+            true,
+            true,
+            "video",
+            mockedSignalingMessageReceiver,
+            mockedSignalingMessageSender
+        )
+
+        val mockedRandomIdDataChannel = Mockito.mock(DataChannel::class.java)
+        Mockito.`when`(mockedRandomIdDataChannel.label()).thenReturn("random-id")
+        Mockito.`when`(mockedRandomIdDataChannel.state()).thenReturn(DataChannel.State.OPEN)
+        peerConnectionObserverArgumentCaptor.value.onDataChannel(mockedRandomIdDataChannel)
+
+        peerConnectionWrapper!!.removePeerConnection()
+
+        Mockito.verify(mockedStatusDataChannel).dispose()
+        Mockito.verify(mockedRandomIdDataChannel).dispose()
+    }
 }