Fix race when starting NetworkMonitor

NetworkMonitor obtained LinkProperties and NetworkCapabilities via
synchronous calls to ConnectivityManager after receiving an asynchronous
notification, which is prone to races: the network could be gone before
the LinkProperties/NetworkCapabilities can be fetched.

Fix the race by passing LinkProperties/NetworkCapabilities directly to
NetworkMonitor in the asynchronous notifications.

Test: atest FrameworksNetTests NetworkStackTests
Test: booted, WiFi works
Bug: 129375892
Change-Id: I200ac7ca6ff79590b11c9be705f650c92fd3cb63
diff --git a/src/com/android/server/NetworkStackService.java b/src/com/android/server/NetworkStackService.java
index 19e9108..63f057c 100644
--- a/src/com/android/server/NetworkStackService.java
+++ b/src/com/android/server/NetworkStackService.java
@@ -35,7 +35,9 @@
 import android.net.INetworkMonitor;
 import android.net.INetworkMonitorCallbacks;
 import android.net.INetworkStackConnector;
+import android.net.LinkProperties;
 import android.net.Network;
+import android.net.NetworkCapabilities;
 import android.net.PrivateDnsConfigParcel;
 import android.net.dhcp.DhcpServer;
 import android.net.dhcp.DhcpServingParams;
@@ -307,9 +309,9 @@
         }
 
         @Override
-        public void notifyNetworkConnected() {
+        public void notifyNetworkConnected(LinkProperties lp, NetworkCapabilities nc) {
             checkNetworkStackCallingPermission();
-            mNm.notifyNetworkConnected();
+            mNm.notifyNetworkConnected(lp, nc);
         }
 
         @Override
@@ -319,15 +321,15 @@
         }
 
         @Override
-        public void notifyLinkPropertiesChanged() {
+        public void notifyLinkPropertiesChanged(LinkProperties lp) {
             checkNetworkStackCallingPermission();
-            mNm.notifyLinkPropertiesChanged();
+            mNm.notifyLinkPropertiesChanged(lp);
         }
 
         @Override
-        public void notifyNetworkCapabilitiesChanged() {
+        public void notifyNetworkCapabilitiesChanged(NetworkCapabilities nc) {
             checkNetworkStackCallingPermission();
-            mNm.notifyNetworkCapabilitiesChanged();
+            mNm.notifyNetworkCapabilitiesChanged(nc);
         }
     }
 }
diff --git a/src/com/android/server/connectivity/NetworkMonitor.java b/src/com/android/server/connectivity/NetworkMonitor.java
index f3476ed..c000fc6 100644
--- a/src/com/android/server/connectivity/NetworkMonitor.java
+++ b/src/com/android/server/connectivity/NetworkMonitor.java
@@ -78,6 +78,7 @@
 import android.telephony.TelephonyManager;
 import android.text.TextUtils;
 import android.util.Log;
+import android.util.Pair;
 
 import com.android.internal.annotations.VisibleForTesting;
 import com.android.internal.util.RingBufferIndices;
@@ -221,19 +222,31 @@
      * Message to self indicating captive portal detection is completed.
      * obj = CaptivePortalProbeResult for detection result;
      */
-    public static final int CMD_PROBE_COMPLETE = 16;
+    private static final int CMD_PROBE_COMPLETE = 16;
 
     /**
      * ConnectivityService notifies NetworkMonitor of DNS query responses event.
      * arg1 = returncode in OnDnsEvent which indicates the response code for the DNS query.
      */
-    public static final int EVENT_DNS_NOTIFICATION = 17;
+    private static final int EVENT_DNS_NOTIFICATION = 17;
 
     /**
      * ConnectivityService notifies NetworkMonitor that the user accepts partial connectivity and
      * NetworkMonitor should ignore the https probe.
      */
-    public static final int EVENT_ACCEPT_PARTIAL_CONNECTIVITY = 18;
+    private static final int EVENT_ACCEPT_PARTIAL_CONNECTIVITY = 18;
+
+    /**
+     * ConnectivityService notifies NetworkMonitor of changed LinkProperties.
+     * obj = new LinkProperties.
+     */
+    private static final int EVENT_LINK_PROPERTIES_CHANGED = 19;
+
+    /**
+     * ConnectivityService notifies NetworkMonitor of changed NetworkCapabilities.
+     * obj = new NetworkCapabilities.
+     */
+    private static final int EVENT_NETWORK_CAPABILITIES_CHANGED = 20;
 
     // Start mReevaluateDelayMs at this value and double.
     private static final int INITIAL_REEVALUATE_DELAY_MS = 1000;
@@ -379,10 +392,8 @@
         mDataStallValidDnsTimeThreshold = getDataStallValidDnsTimeThreshold();
         mDataStallEvaluationType = getDataStallEvalutionType();
 
-        // mLinkProperties and mNetworkCapbilities must never be null or we will NPE.
-        // Provide empty objects in case we are started and the network disconnects before
-        // we can ever fetch them.
-        // TODO: Delete ASAP.
+        // Provide empty LinkProperties and NetworkCapabilities to make sure they are never null,
+        // even before notifyNetworkConnected.
         mLinkProperties = new LinkProperties();
         mNetworkCapabilities = new NetworkCapabilities(null);
     }
@@ -434,8 +445,16 @@
     /**
      * Send a notification to NetworkMonitor indicating that the network is now connected.
      */
-    public void notifyNetworkConnected() {
-        sendMessage(CMD_NETWORK_CONNECTED);
+    public void notifyNetworkConnected(LinkProperties lp, NetworkCapabilities nc) {
+        sendMessage(CMD_NETWORK_CONNECTED, new Pair<>(
+                new LinkProperties(lp), new NetworkCapabilities(nc)));
+    }
+
+    private void updateConnectedNetworkAttributes(Message connectedMsg) {
+        final Pair<LinkProperties, NetworkCapabilities> attrs =
+                (Pair<LinkProperties, NetworkCapabilities>) connectedMsg.obj;
+        mLinkProperties = attrs.first;
+        mNetworkCapabilities = attrs.second;
     }
 
     /**
@@ -448,37 +467,15 @@
     /**
      * Send a notification to NetworkMonitor indicating that link properties have changed.
      */
-    public void notifyLinkPropertiesChanged() {
-        getHandler().post(() -> {
-            updateLinkProperties();
-        });
-    }
-
-    private void updateLinkProperties() {
-        final LinkProperties lp = mCm.getLinkProperties(mNetwork);
-        // If null, we should soon get a message that the network was disconnected, and will stop.
-        if (lp != null) {
-            // TODO: send LinkProperties parceled in notifyLinkPropertiesChanged() and start().
-            mLinkProperties = lp;
-        }
+    public void notifyLinkPropertiesChanged(final LinkProperties lp) {
+        sendMessage(EVENT_LINK_PROPERTIES_CHANGED, new LinkProperties(lp));
     }
 
     /**
      * Send a notification to NetworkMonitor indicating that network capabilities have changed.
      */
-    public void notifyNetworkCapabilitiesChanged() {
-        getHandler().post(() -> {
-            updateNetworkCapabilities();
-        });
-    }
-
-    private void updateNetworkCapabilities() {
-        final NetworkCapabilities nc = mCm.getNetworkCapabilities(mNetwork);
-        // If null, we should soon get a message that the network was disconnected, and will stop.
-        if (nc != null) {
-            // TODO: send NetworkCapabilities parceled in notifyNetworkCapsChanged() and start().
-            mNetworkCapabilities = nc;
-        }
+    public void notifyNetworkCapabilitiesChanged(final NetworkCapabilities nc) {
+        sendMessage(EVENT_NETWORK_CAPABILITIES_CHANGED, new NetworkCapabilities(nc));
     }
 
     /**
@@ -547,16 +544,10 @@
     // does not entail any real state (hence no enter() or exit() routines).
     private class DefaultState extends State {
         @Override
-        public void enter() {
-            // TODO: have those passed parceled in start() and remove this
-            updateLinkProperties();
-            updateNetworkCapabilities();
-        }
-
-        @Override
         public boolean processMessage(Message message) {
             switch (message.what) {
                 case CMD_NETWORK_CONNECTED:
+                    updateConnectedNetworkAttributes(message);
                     logNetworkEvent(NetworkEvent.NETWORK_CONNECTED);
                     transitionTo(mEvaluatingState);
                     return HANDLED;
@@ -660,6 +651,12 @@
                 case EVENT_ACCEPT_PARTIAL_CONNECTIVITY:
                     mAcceptPartialConnectivity = true;
                     break;
+                case EVENT_LINK_PROPERTIES_CHANGED:
+                    mLinkProperties = (LinkProperties) message.obj;
+                    break;
+                case EVENT_NETWORK_CAPABILITIES_CHANGED:
+                    mNetworkCapabilities = (NetworkCapabilities) message.obj;
+                    break;
                 default:
                     break;
             }
@@ -684,6 +681,7 @@
         public boolean processMessage(Message message) {
             switch (message.what) {
                 case CMD_NETWORK_CONNECTED:
+                    updateConnectedNetworkAttributes(message);
                     transitionTo(mValidatedState);
                     break;
                 case CMD_EVALUATE_PRIVATE_DNS:
diff --git a/tests/src/com/android/server/connectivity/NetworkMonitorTest.java b/tests/src/com/android/server/connectivity/NetworkMonitorTest.java
index d732c4e..6665aae 100644
--- a/tests/src/com/android/server/connectivity/NetworkMonitorTest.java
+++ b/tests/src/com/android/server/connectivity/NetworkMonitorTest.java
@@ -26,6 +26,8 @@
 import static junit.framework.Assert.assertEquals;
 import static junit.framework.Assert.assertFalse;
 
+import static org.junit.Assert.assertNotNull;
+import static org.junit.Assert.assertNull;
 import static org.junit.Assert.assertTrue;
 import static org.junit.Assert.fail;
 import static org.mockito.ArgumentMatchers.anyString;
@@ -60,6 +62,7 @@
 import android.os.Bundle;
 import android.os.ConditionVariable;
 import android.os.Handler;
+import android.os.RemoteException;
 import android.os.SystemClock;
 import android.provider.Settings;
 import android.telephony.CellSignalStrength;
@@ -73,6 +76,7 @@
 import org.junit.Test;
 import org.junit.runner.RunWith;
 import org.mockito.ArgumentCaptor;
+import org.mockito.Captor;
 import org.mockito.Mock;
 import org.mockito.MockitoAnnotations;
 import org.mockito.Spy;
@@ -107,6 +111,7 @@
     private @Spy Network mNetwork = new Network(TEST_NETID);
     private @Mock DataStallStatsUtils mDataStallStatsUtils;
     private @Mock WifiInfo mWifiInfo;
+    private @Captor ArgumentCaptor<String> mNetworkTestedRedirectUrlCaptor;
 
     private static final int TEST_NETID = 4242;
 
@@ -122,7 +127,7 @@
 
     private static final int HANDLER_TIMEOUT_MS = 1000;
 
-    private static final LinkProperties TEST_LINKPROPERTIES = new LinkProperties();
+    private static final LinkProperties TEST_LINK_PROPERTIES = new LinkProperties();
 
     private static final NetworkCapabilities METERED_CAPABILITIES = new NetworkCapabilities()
             .addTransportType(NetworkCapabilities.TRANSPORT_CELLULAR)
@@ -182,10 +187,6 @@
                 InetAddresses.parseNumericAddress("192.168.0.0")
         }).when(mNetwork).getAllByName(any());
 
-        // Default values. Individual tests can override these.
-        when(mCm.getLinkProperties(any())).thenReturn(TEST_LINKPROPERTIES);
-        when(mCm.getNetworkCapabilities(any())).thenReturn(METERED_CAPABILITIES);
-
         setMinDataStallEvaluateInterval(500);
         setDataStallEvaluationType(DATA_STALL_EVALUATION_TYPE_DNS);
         setValidDataStallDnsTimeThreshold(500);
@@ -195,10 +196,9 @@
     private class WrappedNetworkMonitor extends NetworkMonitor {
         private long mProbeTime = 0;
 
-        WrappedNetworkMonitor(Context context, Network network, IpConnectivityLog logger,
-                Dependencies deps, DataStallStatsUtils statsUtils) {
-                super(context, mCallbacks, network, logger,
-                        new SharedLog("test_nm"), deps, statsUtils);
+        WrappedNetworkMonitor() {
+                super(mContext, mCallbacks, mNetwork, mLogger, mValidationLogger, mDependencies,
+                        mDataStallStatsUtils);
         }
 
         @Override
@@ -216,33 +216,30 @@
         }
     }
 
-    private WrappedNetworkMonitor makeMeteredWrappedNetworkMonitor() {
-        final WrappedNetworkMonitor nm = new WrappedNetworkMonitor(
-                mContext, mNetwork, mLogger, mDependencies, mDataStallStatsUtils);
-        when(mCm.getNetworkCapabilities(any())).thenReturn(METERED_CAPABILITIES);
+    private WrappedNetworkMonitor makeMonitor() {
+        final WrappedNetworkMonitor nm = new WrappedNetworkMonitor();
         nm.start();
         waitForIdle(nm.getHandler());
         return nm;
     }
 
-    private WrappedNetworkMonitor makeNotMeteredWrappedNetworkMonitor() {
-        final WrappedNetworkMonitor nm = new WrappedNetworkMonitor(
-                mContext, mNetwork, mLogger, mDependencies, mDataStallStatsUtils);
-        when(mCm.getNetworkCapabilities(any())).thenReturn(NOT_METERED_CAPABILITIES);
-        nm.start();
-        waitForIdle(nm.getHandler());
+    private WrappedNetworkMonitor makeMeteredNetworkMonitor() {
+        final WrappedNetworkMonitor nm = makeMonitor();
+        setNetworkCapabilities(nm, METERED_CAPABILITIES);
         return nm;
     }
 
-    private NetworkMonitor makeMonitor() {
-        final NetworkMonitor nm = new NetworkMonitor(
-                mContext, mCallbacks, mNetwork, mLogger, mValidationLogger,
-                mDependencies, mDataStallStatsUtils);
-        nm.start();
-        waitForIdle(nm.getHandler());
+    private WrappedNetworkMonitor makeNotMeteredNetworkMonitor() {
+        final WrappedNetworkMonitor nm = makeMonitor();
+        setNetworkCapabilities(nm, NOT_METERED_CAPABILITIES);
         return nm;
     }
 
+    private void setNetworkCapabilities(NetworkMonitor nm, NetworkCapabilities nc) {
+        nm.notifyNetworkCapabilitiesChanged(nc);
+        waitForIdle(nm.getHandler());
+    }
+
     private void waitForIdle(Handler handler) {
         final ConditionVariable cv = new ConditionVariable(false);
         handler.post(cv::open);
@@ -256,7 +253,7 @@
         setSslException(mHttpsConnection);
         setPortal302(mHttpConnection);
 
-        assertPortal(makeMonitor().isCaptivePortal());
+        runPortalNetworkTest();
     }
 
     @Test
@@ -264,17 +261,7 @@
         setStatus(mHttpsConnection, 204);
         setStatus(mHttpConnection, 500);
 
-        assertNotPortal(makeMonitor().isCaptivePortal());
-    }
-
-    @Test
-    public void testIsCaptivePortal_HttpsProbeFailedHttpSuccessNotUsed() throws IOException {
-        setSslException(mHttpsConnection);
-        // Even if HTTP returns a 204, do not use the result unless HTTPS succeeded
-        setStatus(mHttpConnection, 204);
-        setStatus(mFallbackConnection, 500);
-
-        assertFailed(makeMonitor().isCaptivePortal());
+        runNotPortalNetworkTest();
     }
 
     @Test
@@ -283,17 +270,17 @@
         setStatus(mHttpConnection, 500);
         setPortal302(mFallbackConnection);
 
-        assertPortal(makeMonitor().isCaptivePortal());
+        runPortalNetworkTest();
     }
 
     @Test
     public void testIsCaptivePortal_FallbackProbeIsNotPortal() throws IOException {
         setSslException(mHttpsConnection);
         setStatus(mHttpConnection, 500);
-        setStatus(mFallbackConnection, 204);
+        setStatus(mFallbackConnection, 500);
 
         // Fallback probe did not see portal, HTTPS failed -> inconclusive
-        assertFailed(makeMonitor().isCaptivePortal());
+        runFailedNetworkTest();
     }
 
     @Test
@@ -310,15 +297,15 @@
         // TEST_OTHER_FALLBACK_URL is third
         when(mRandom.nextInt()).thenReturn(2);
 
-        final NetworkMonitor monitor = makeMonitor();
-
         // First check always uses the first fallback URL: inconclusive
-        assertFailed(monitor.isCaptivePortal());
+        final NetworkMonitor monitor = runNetworkTest(NETWORK_TEST_RESULT_INVALID);
+        assertNull(mNetworkTestedRedirectUrlCaptor.getValue());
         verify(mFallbackConnection, times(1)).getResponseCode();
         verify(mOtherFallbackConnection, never()).getResponseCode();
 
         // Second check uses the URL chosen by Random
-        assertPortal(monitor.isCaptivePortal());
+        final CaptivePortalProbeResult result = monitor.isCaptivePortal();
+        assertTrue(result.isPortal());
         verify(mOtherFallbackConnection, times(1)).getResponseCode();
     }
 
@@ -328,7 +315,7 @@
         setStatus(mHttpConnection, 500);
         setStatus(mFallbackConnection, 404);
 
-        assertFailed(makeMonitor().isCaptivePortal());
+        runFailedNetworkTest();
         verify(mFallbackConnection, times(1)).getResponseCode();
         verify(mOtherFallbackConnection, never()).getResponseCode();
     }
@@ -342,7 +329,7 @@
         setStatus(mHttpConnection, 500);
         setPortal302(mOtherFallbackConnection);
 
-        assertPortal(makeMonitor().isCaptivePortal());
+        runPortalNetworkTest();
         verify(mOtherFallbackConnection, times(1)).getResponseCode();
         verify(mFallbackConnection, never()).getResponseCode();
     }
@@ -360,12 +347,12 @@
     }
 
     @Test
-    public void testIsCaptivePortal_FallbackSpecIsNotPortal() throws IOException {
+    public void testIsCaptivePortal_FallbackSpecIsPartial() throws IOException {
         setupFallbackSpec();
         set302(mOtherFallbackConnection, "https://www.google.com/test?q=3");
 
-        // HTTPS failed, fallback spec did not see a portal -> inconclusive
-        assertFailed(makeMonitor().isCaptivePortal());
+        // HTTPS failed, fallback spec went through -> partial connectivity
+        runPartialConnectivityNetworkTest();
         verify(mOtherFallbackConnection, times(1)).getResponseCode();
         verify(mFallbackConnection, never()).getResponseCode();
     }
@@ -375,7 +362,7 @@
         setupFallbackSpec();
         set302(mOtherFallbackConnection, "http://login.portal.example.com");
 
-        assertPortal(makeMonitor().isCaptivePortal());
+        runPortalNetworkTest();
     }
 
     @Test
@@ -384,20 +371,20 @@
         setSslException(mHttpsConnection);
         setPortal302(mHttpConnection);
 
-        assertNotPortal(makeMonitor().isCaptivePortal());
+        runNotPortalNetworkTest();
     }
 
     @Test
     public void testIsDataStall_EvaluationDisabled() {
         setDataStallEvaluationType(0);
-        WrappedNetworkMonitor wrappedMonitor = makeMeteredWrappedNetworkMonitor();
+        WrappedNetworkMonitor wrappedMonitor = makeMeteredNetworkMonitor();
         wrappedMonitor.setLastProbeTime(SystemClock.elapsedRealtime() - 100);
         assertFalse(wrappedMonitor.isDataStall());
     }
 
     @Test
     public void testIsDataStall_EvaluationDnsOnNotMeteredNetwork() {
-        WrappedNetworkMonitor wrappedMonitor = makeNotMeteredWrappedNetworkMonitor();
+        WrappedNetworkMonitor wrappedMonitor = makeNotMeteredNetworkMonitor();
         wrappedMonitor.setLastProbeTime(SystemClock.elapsedRealtime() - 100);
         makeDnsTimeoutEvent(wrappedMonitor, DEFAULT_DNS_TIMEOUT_THRESHOLD);
         assertTrue(wrappedMonitor.isDataStall());
@@ -405,7 +392,7 @@
 
     @Test
     public void testIsDataStall_EvaluationDnsOnMeteredNetwork() {
-        WrappedNetworkMonitor wrappedMonitor = makeMeteredWrappedNetworkMonitor();
+        WrappedNetworkMonitor wrappedMonitor = makeMeteredNetworkMonitor();
         wrappedMonitor.setLastProbeTime(SystemClock.elapsedRealtime() - 100);
         assertFalse(wrappedMonitor.isDataStall());
 
@@ -416,7 +403,7 @@
 
     @Test
     public void testIsDataStall_EvaluationDnsWithDnsTimeoutCount() {
-        WrappedNetworkMonitor wrappedMonitor = makeMeteredWrappedNetworkMonitor();
+        WrappedNetworkMonitor wrappedMonitor = makeMeteredNetworkMonitor();
         wrappedMonitor.setLastProbeTime(SystemClock.elapsedRealtime() - 1000);
         makeDnsTimeoutEvent(wrappedMonitor, 3);
         assertFalse(wrappedMonitor.isDataStall());
@@ -430,7 +417,7 @@
 
         // Set the value to larger than the default dns log size.
         setConsecutiveDnsTimeoutThreshold(51);
-        wrappedMonitor = makeMeteredWrappedNetworkMonitor();
+        wrappedMonitor = makeMeteredNetworkMonitor();
         wrappedMonitor.setLastProbeTime(SystemClock.elapsedRealtime() - 1000);
         makeDnsTimeoutEvent(wrappedMonitor, 50);
         assertFalse(wrappedMonitor.isDataStall());
@@ -442,7 +429,7 @@
     @Test
     public void testIsDataStall_EvaluationDnsWithDnsTimeThreshold() {
         // Test dns events happened in valid dns time threshold.
-        WrappedNetworkMonitor wrappedMonitor = makeMeteredWrappedNetworkMonitor();
+        WrappedNetworkMonitor wrappedMonitor = makeMeteredNetworkMonitor();
         wrappedMonitor.setLastProbeTime(SystemClock.elapsedRealtime() - 100);
         makeDnsTimeoutEvent(wrappedMonitor, DEFAULT_DNS_TIMEOUT_THRESHOLD);
         assertFalse(wrappedMonitor.isDataStall());
@@ -451,7 +438,7 @@
 
         // Test dns events happened before valid dns time threshold.
         setValidDataStallDnsTimeThreshold(0);
-        wrappedMonitor = makeMeteredWrappedNetworkMonitor();
+        wrappedMonitor = makeMeteredNetworkMonitor();
         wrappedMonitor.setLastProbeTime(SystemClock.elapsedRealtime() - 100);
         makeDnsTimeoutEvent(wrappedMonitor, DEFAULT_DNS_TIMEOUT_THRESHOLD);
         assertFalse(wrappedMonitor.isDataStall());
@@ -464,24 +451,13 @@
         setSslException(mHttpsConnection);
         setStatus(mHttpConnection, 500);
         setStatus(mFallbackConnection, 404);
-        when(mCm.getNetworkCapabilities(any())).thenReturn(METERED_CAPABILITIES);
 
-        final NetworkMonitor nm = makeMonitor();
-        nm.notifyNetworkConnected();
-
-        verify(mCallbacks, timeout(HANDLER_TIMEOUT_MS).times(1))
-                .notifyNetworkTested(NETWORK_TEST_RESULT_INVALID, null);
+        runFailedNetworkTest();
     }
 
     @Test
     public void testNoInternetCapabilityValidated() throws Exception {
-        when(mCm.getNetworkCapabilities(any())).thenReturn(NO_INTERNET_CAPABILITIES);
-
-        final NetworkMonitor nm = makeMonitor();
-        nm.notifyNetworkConnected();
-
-        verify(mCallbacks, timeout(HANDLER_TIMEOUT_MS).times(1))
-                .notifyNetworkTested(NETWORK_TEST_RESULT_VALID, null);
+        runNetworkTest(NO_INTERNET_CAPABILITIES, NETWORK_TEST_RESULT_VALID);
         verify(mNetwork, never()).openConnection(any());
     }
 
@@ -491,7 +467,7 @@
         setPortal302(mHttpConnection);
 
         final NetworkMonitor nm = makeMonitor();
-        nm.notifyNetworkConnected();
+        nm.notifyNetworkConnected(TEST_LINK_PROPERTIES, METERED_CAPABILITIES);
 
         verify(mCallbacks, timeout(HANDLER_TIMEOUT_MS).times(1))
                 .showProvisioningNotification(any(), any());
@@ -522,7 +498,7 @@
 
     @Test
     public void testDataStall_StallSuspectedAndSendMetrics() throws IOException {
-        WrappedNetworkMonitor wrappedMonitor = makeNotMeteredWrappedNetworkMonitor();
+        WrappedNetworkMonitor wrappedMonitor = makeNotMeteredNetworkMonitor();
         wrappedMonitor.setLastProbeTime(SystemClock.elapsedRealtime() - 1000);
         makeDnsTimeoutEvent(wrappedMonitor, 5);
         assertTrue(wrappedMonitor.isDataStall());
@@ -531,7 +507,7 @@
 
     @Test
     public void testDataStall_NoStallSuspectedAndSendMetrics() throws IOException {
-        WrappedNetworkMonitor wrappedMonitor = makeNotMeteredWrappedNetworkMonitor();
+        WrappedNetworkMonitor wrappedMonitor = makeNotMeteredNetworkMonitor();
         wrappedMonitor.setLastProbeTime(SystemClock.elapsedRealtime() - 1000);
         makeDnsTimeoutEvent(wrappedMonitor, 3);
         assertFalse(wrappedMonitor.isDataStall());
@@ -540,7 +516,7 @@
 
     @Test
     public void testCollectDataStallMetrics() {
-        WrappedNetworkMonitor wrappedMonitor = makeNotMeteredWrappedNetworkMonitor();
+        WrappedNetworkMonitor wrappedMonitor = makeNotMeteredNetworkMonitor();
 
         when(mTelephony.getDataNetworkType()).thenReturn(TelephonyManager.NETWORK_TYPE_LTE);
         when(mTelephony.getNetworkOperator()).thenReturn(TEST_MCCMNC);
@@ -578,14 +554,11 @@
         setSslException(mHttpsConnection);
         setStatus(mHttpConnection, 204);
 
-        final NetworkMonitor nm = makeMonitor();
-        nm.notifyNetworkConnected();
-        verify(mCallbacks, timeout(HANDLER_TIMEOUT_MS).times(1))
-                .notifyNetworkTested(NETWORK_TEST_RESULT_PARTIAL_CONNECTIVITY, null);
+        final NetworkMonitor nm = runNetworkTest(NETWORK_TEST_RESULT_PARTIAL_CONNECTIVITY);
 
         nm.setAcceptPartialConnectivity();
         verify(mCallbacks, timeout(HANDLER_TIMEOUT_MS).times(1))
-                .notifyNetworkTested(NETWORK_TEST_RESULT_VALID, null);
+                .notifyNetworkTested(eq(NETWORK_TEST_RESULT_VALID), any());
     }
 
     @Test
@@ -593,12 +566,12 @@
         setStatus(mHttpsConnection, 500);
         setStatus(mHttpConnection, 204);
         setStatus(mFallbackConnection, 500);
-        assertPartialConnectivity(makeMonitor().isCaptivePortal());
+        runPartialConnectivityNetworkTest();
 
         setStatus(mHttpsConnection, 500);
         setStatus(mHttpConnection, 500);
         setStatus(mFallbackConnection, 204);
-        assertPartialConnectivity(makeMonitor().isCaptivePortal());
+        runPartialConnectivityNetworkTest();
     }
 
     private void makeDnsTimeoutEvent(WrappedNetworkMonitor wrappedMonitor, int count) {
@@ -660,26 +633,41 @@
                 eq(Settings.Global.CAPTIVE_PORTAL_MODE), anyInt())).thenReturn(mode);
     }
 
-    private void assertPortal(CaptivePortalProbeResult result) {
-        assertTrue(result.isPortal());
-        assertFalse(result.isFailed());
-        assertFalse(result.isSuccessful());
+    private void runPortalNetworkTest() {
+        runNetworkTest(NETWORK_TEST_RESULT_INVALID);
+        assertNotNull(mNetworkTestedRedirectUrlCaptor.getValue());
     }
 
-    private void assertNotPortal(CaptivePortalProbeResult result) {
-        assertFalse(result.isPortal());
-        assertFalse(result.isFailed());
-        assertTrue(result.isSuccessful());
+    private void runNotPortalNetworkTest() {
+        runNetworkTest(NETWORK_TEST_RESULT_VALID);
+        assertNull(mNetworkTestedRedirectUrlCaptor.getValue());
     }
 
-    private void assertFailed(CaptivePortalProbeResult result) {
-        assertFalse(result.isPortal());
-        assertTrue(result.isFailed());
-        assertFalse(result.isSuccessful());
+    private void runFailedNetworkTest() {
+        runNetworkTest(NETWORK_TEST_RESULT_INVALID);
+        assertNull(mNetworkTestedRedirectUrlCaptor.getValue());
     }
 
-    private void assertPartialConnectivity(CaptivePortalProbeResult result) {
-        assertTrue(result.isPartialConnectivity());
+    private void runPartialConnectivityNetworkTest() {
+        runNetworkTest(NETWORK_TEST_RESULT_PARTIAL_CONNECTIVITY);
+        assertNull(mNetworkTestedRedirectUrlCaptor.getValue());
+    }
+
+    private NetworkMonitor runNetworkTest(int testResult) {
+        return runNetworkTest(METERED_CAPABILITIES, testResult);
+    }
+
+    private NetworkMonitor runNetworkTest(NetworkCapabilities nc, int testResult) {
+        final NetworkMonitor monitor = makeMonitor();
+        monitor.notifyNetworkConnected(TEST_LINK_PROPERTIES, nc);
+        try {
+            verify(mCallbacks, timeout(HANDLER_TIMEOUT_MS).times(1))
+                    .notifyNetworkTested(eq(testResult), mNetworkTestedRedirectUrlCaptor.capture());
+        } catch (RemoteException e) {
+            fail("Unexpected exception: " + e);
+        }
+
+        return monitor;
     }
 
     private void setSslException(HttpURLConnection connection) throws IOException {