Use async dns query to resolve all addresses

Currently, it looks like private DNS server resolution uses
OneAddressPerFamilyNetwork and only returns one server address.
It should return all addresses. Use async dns api for this.

Bug: 123435238
Test: atest NetworkStacktests
Change-Id: I9f50da3c8c2e3b12b29bc8844291e4bf1559cd1f
This commit is contained in:
Chiachang Wang
2019-05-09 21:28:47 +08:00
parent abfef61707
commit e37f8729d1
3 changed files with 143 additions and 41 deletions

View File

@@ -0,0 +1,119 @@
/*
* Copyright (C) 2019 The Android Open Source Project
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.android.networkstack.util;
import static android.net.DnsResolver.FLAG_NO_CACHE_LOOKUP;
import static android.net.DnsResolver.TYPE_A;
import static android.net.DnsResolver.TYPE_AAAA;
import android.annotation.NonNull;
import android.net.DnsResolver;
import android.net.Network;
import android.net.TrafficStats;
import android.util.Log;
import com.android.internal.util.TrafficStatsConstants;
import java.net.InetAddress;
import java.net.UnknownHostException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicReference;
/**
* Collection of utilities for dns query.
*/
public class DnsUtils {
// Decide what queries to make depending on what IP addresses are on the system.
public static final int TYPE_ADDRCONFIG = -1;
private static final String TAG = DnsUtils.class.getSimpleName();
/**
* Return both A and AAAA query results regardless the ip address type of the giving network.
* Used for probing in NetworkMonitor.
*/
@NonNull
public static InetAddress[] getAllByName(@NonNull final DnsResolver dnsResolver,
@NonNull final Network network, @NonNull String host, int timeout)
throws UnknownHostException {
final List<InetAddress> result = new ArrayList<InetAddress>();
result.addAll(Arrays.asList(
getAllByName(dnsResolver, network, host, TYPE_AAAA, FLAG_NO_CACHE_LOOKUP,
timeout)));
result.addAll(Arrays.asList(
getAllByName(dnsResolver, network, host, TYPE_A, FLAG_NO_CACHE_LOOKUP,
timeout)));
return result.toArray(new InetAddress[0]);
}
/**
* Return dns query result based on the given QueryType(TYPE_A, TYPE_AAAA) or TYPE_ADDRCONFIG.
* Used for probing in NetworkMonitor.
*/
@NonNull
public static InetAddress[] getAllByName(@NonNull final DnsResolver dnsResolver,
@NonNull final Network network, @NonNull final String host, int type, int flag,
int timeoutMs) throws UnknownHostException {
final CountDownLatch latch = new CountDownLatch(1);
final AtomicReference<List<InetAddress>> resultRef = new AtomicReference<>();
final DnsResolver.Callback<List<InetAddress>> callback =
new DnsResolver.Callback<List<InetAddress>>() {
@Override
public void onAnswer(List<InetAddress> answer, int rcode) {
if (rcode == 0) {
resultRef.set(answer);
}
latch.countDown();
}
@Override
public void onError(@NonNull DnsResolver.DnsException e) {
Log.d(TAG, "DNS error resolving " + host + ": " + e.getMessage());
latch.countDown();
}
};
final int oldTag = TrafficStats.getAndSetThreadStatsTag(
TrafficStatsConstants.TAG_SYSTEM_PROBE);
if (type == TYPE_ADDRCONFIG) {
dnsResolver.query(network, host, flag, r -> r.run(), null /* cancellationSignal */,
callback);
} else {
dnsResolver.query(network, host, type, flag, r -> r.run(),
null /* cancellationSignal */, callback);
}
TrafficStats.setThreadStatsTag(oldTag);
try {
latch.await(timeoutMs, TimeUnit.MILLISECONDS);
} catch (InterruptedException e) {
}
final List<InetAddress> result = resultRef.get();
if (result == null || result.size() == 0) {
throw new UnknownHostException(host);
}
return result.toArray(new InetAddress[0]);
}
}

View File

@@ -23,6 +23,7 @@ import static android.net.ConnectivityManager.EXTRA_CAPTIVE_PORTAL_PROBE_SPEC;
import static android.net.ConnectivityManager.EXTRA_CAPTIVE_PORTAL_URL;
import static android.net.ConnectivityManager.TYPE_MOBILE;
import static android.net.ConnectivityManager.TYPE_WIFI;
import static android.net.DnsResolver.FLAG_EMPTY;
import static android.net.INetworkMonitor.NETWORK_TEST_RESULT_INVALID;
import static android.net.INetworkMonitor.NETWORK_TEST_RESULT_PARTIAL_CONNECTIVITY;
import static android.net.NetworkCapabilities.NET_CAPABILITY_NOT_METERED;
@@ -56,6 +57,8 @@ import static android.net.util.NetworkStackUtils.CAPTIVE_PORTAL_USE_HTTPS;
import static android.net.util.NetworkStackUtils.NAMESPACE_CONNECTIVITY;
import static android.net.util.NetworkStackUtils.isEmpty;
import static com.android.networkstack.util.DnsUtils.TYPE_ADDRCONFIG;
import android.annotation.NonNull;
import android.annotation.Nullable;
import android.app.PendingIntent;
@@ -113,6 +116,7 @@ import com.android.internal.util.TrafficStatsConstants;
import com.android.networkstack.R;
import com.android.networkstack.metrics.DataStallDetectionStats;
import com.android.networkstack.metrics.DataStallStatsUtils;
import com.android.networkstack.util.DnsUtils;
import java.io.IOException;
import java.net.HttpURLConnection;
@@ -129,7 +133,6 @@ import java.util.Random;
import java.util.UUID;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicReference;
import java.util.function.Function;
/**
@@ -994,8 +997,8 @@ public class NetworkMonitor extends StateMachine {
private void resolveStrictModeHostname() {
try {
// Do a blocking DNS resolution using the network-assigned nameservers.
final InetAddress[] ips = mCleartextDnsNetwork.getAllByName(
mPrivateDnsProviderHostname);
final InetAddress[] ips = DnsUtils.getAllByName(mDependencies.getDnsResolver(),
mCleartextDnsNetwork, mPrivateDnsProviderHostname, getDnsProbeTimeout());
mPrivateDnsConfig = new PrivateDnsConfig(mPrivateDnsProviderHostname, ips);
validationLog("Strict mode hostname resolved: " + mPrivateDnsConfig);
} catch (UnknownHostException uhe) {
@@ -1489,39 +1492,8 @@ public class NetworkMonitor extends StateMachine {
@VisibleForTesting
protected InetAddress[] sendDnsProbeWithTimeout(String host, int timeoutMs)
throws UnknownHostException {
final CountDownLatch latch = new CountDownLatch(1);
final AtomicReference<List<InetAddress>> resultRef = new AtomicReference<>();
final DnsResolver.Callback<List<InetAddress>> callback =
new DnsResolver.Callback<List<InetAddress>>() {
public void onAnswer(List<InetAddress> answer, int rcode) {
if (rcode == 0) {
resultRef.set(answer);
}
latch.countDown();
}
public void onError(@NonNull DnsResolver.DnsException e) {
validationLog("DNS error resolving " + host + ": " + e.getMessage());
latch.countDown();
}
};
final int oldTag = TrafficStats.getAndSetThreadStatsTag(
TrafficStatsConstants.TAG_SYSTEM_PROBE);
mDependencies.getDnsResolver().query(mCleartextDnsNetwork, host, DnsResolver.FLAG_EMPTY,
r -> r.run() /* executor */, null /* cancellationSignal */, callback);
TrafficStats.setThreadStatsTag(oldTag);
try {
latch.await(timeoutMs, TimeUnit.MILLISECONDS);
} catch (InterruptedException e) {
}
List<InetAddress> result = resultRef.get();
if (result == null || result.size() == 0) {
throw new UnknownHostException(host);
}
return result.toArray(new InetAddress[0]);
return DnsUtils.getAllByName(mDependencies.getDnsResolver(), mCleartextDnsNetwork, host,
TYPE_ADDRCONFIG, FLAG_EMPTY, timeoutMs);
}
/** Do a DNS resolution of the given server. */

View File

@@ -226,11 +226,6 @@ public class NetworkMonitorTest {
/** Starts mocking DNS queries. */
private void startMocking() throws UnknownHostException {
// Queries on mCleartextDnsNetwork using getAllByName.
doAnswer(invocation -> {
return getAllByName(invocation.getMock(), invocation.getArgument(0));
}).when(mCleartextDnsNetwork).getAllByName(any());
// Queries on mNetwork using getAllByName.
doAnswer(invocation -> {
return getAllByName(invocation.getMock(), invocation.getArgument(0));
@@ -251,6 +246,22 @@ public class NetworkMonitorTest {
// If no answers, do nothing. sendDnsProbeWithTimeout will time out and throw UHE.
return null;
}).when(mDnsResolver).query(any(), any(), anyInt(), any(), any(), any());
// Queries on mCleartextDnsNetwork using using DnsResolver#query with QueryType.
doAnswer(invocation -> {
String hostname = (String) invocation.getArgument(1);
Executor executor = (Executor) invocation.getArgument(4);
DnsResolver.Callback<List<InetAddress>> callback = invocation.getArgument(6);
List<InetAddress> answer = getAnswer(invocation.getMock(), hostname);
if (answer != null && answer.size() > 0) {
new Handler(Looper.getMainLooper()).post(() -> {
executor.execute(() -> callback.onAnswer(answer, 0));
});
}
// If no answers, do nothing. sendDnsProbeWithTimeout will time out and throw UHE.
return null;
}).when(mDnsResolver).query(any(), any(), anyInt(), anyInt(), any(), any(), any());
}
}