Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

OkHttpClient.Builder network pinning on Android #8376

Open
wants to merge 15 commits into
base: master
Choose a base branch
from
31 changes: 21 additions & 10 deletions okhttp-android/api/okhttp-android.api
@@ -1,13 +1,6 @@
public final class okhttp3/android/AndroidAsyncDns : okhttp3/AsyncDns {
public static final field Companion Lokhttp3/android/AndroidAsyncDns$Companion;
public fun <init> (Lokhttp3/AsyncDns$DnsClass;Landroid/net/Network;)V
public synthetic fun <init> (Lokhttp3/AsyncDns$DnsClass;Landroid/net/Network;ILkotlin/jvm/internal/DefaultConstructorMarker;)V
public fun query (Ljava/lang/String;Lokhttp3/AsyncDns$Callback;)V
}

public final class okhttp3/android/AndroidAsyncDns$Companion {
public final fun getIPv4 ()Lokhttp3/android/AndroidAsyncDns;
public final fun getIPv6 ()Lokhttp3/android/AndroidAsyncDns;
public final class okhttp3/android/AndroidDnsKt {
public static final fun forNetwork (Lokhttp3/Dns$Companion;Landroid/net/Network;)Lokhttp3/Dns;
public static final fun getANDROID (Lokhttp3/Dns$Companion;)Lokhttp3/Dns;
}

public final class okhttp3/android/AndroidLoggingKt {
Expand All @@ -17,3 +10,21 @@ public final class okhttp3/android/AndroidLoggingKt {
public static synthetic fun androidLogging$default (Lokhttp3/logging/LoggingEventListener$Companion;ILjava/lang/String;ILjava/lang/Object;)Lokhttp3/logging/LoggingEventListener$Factory;
}

public final class okhttp3/android/AndroidSocketFactory : javax/net/SocketFactory {
public fun <init> (Landroid/net/Network;)V
public fun createSocket ()Ljava/net/Socket;
public fun createSocket (Ljava/lang/String;I)Ljava/net/Socket;
public fun createSocket (Ljava/lang/String;ILjava/net/InetAddress;I)Ljava/net/Socket;
public fun createSocket (Ljava/net/InetAddress;I)Ljava/net/Socket;
public fun createSocket (Ljava/net/InetAddress;ILjava/net/InetAddress;I)Ljava/net/Socket;
public fun equals (Ljava/lang/Object;)Z
public final fun getNetwork ()Landroid/net/Network;
public fun hashCode ()I
public fun toString ()Ljava/lang/String;
}

public final class okhttp3/android/NetworkSelection {
public static final field INSTANCE Lokhttp3/android/NetworkSelection;
public final fun withNetwork (Lokhttp3/OkHttpClient$Builder;Landroid/net/Network;)Lokhttp3/OkHttpClient$Builder;
}

Expand Up @@ -32,10 +32,11 @@ import java.net.UnknownHostException
import java.util.concurrent.CountDownLatch
import mockwebserver3.MockResponse
import mockwebserver3.junit4.MockWebServerRule
import okhttp3.AsyncDns
import okhttp3.Dns
import okhttp3.HttpUrl.Companion.toHttpUrl
import okhttp3.OkHttpClient
import okhttp3.Request
import okhttp3.android.internal.AsyncDns
import okhttp3.tls.HandshakeCertificates
import okhttp3.tls.HeldCertificate
import okio.IOException
Expand All @@ -49,7 +50,7 @@ import org.junit.Test
/**
* Run with "./gradlew :android-test:connectedCheck -PandroidBuild=true" and make sure ANDROID_SDK_ROOT is set.
*/
class AndroidAsyncDnsTest {
class AndroidDnsTest {
@JvmField @Rule
val serverRule = MockWebServerRule()
private lateinit var client: OkHttpClient
Expand All @@ -73,7 +74,7 @@ class AndroidAsyncDnsTest {

client =
OkHttpClient.Builder()
.dns(AsyncDns.toDns(AndroidAsyncDns.IPv4, AndroidAsyncDns.IPv6))
.dns(Dns.ANDROID)
.sslSocketFactory(localhost.sslSocketFactory(), localhost.trustManager)
.build()

Expand Down Expand Up @@ -130,25 +131,33 @@ class AndroidAsyncDnsTest {
val latch = CountDownLatch(1)

// assumes an IPv4 address
AndroidAsyncDns.IPv4.query(
hostname,
object : AsyncDns.Callback {
override fun onResponse(
hostname: String,
addresses: List<InetAddress>,
) {
allAddresses.addAll(addresses)
latch.countDown()
}

override fun onFailure(
hostname: String,
e: IOException,
) {
exception = e
latch.countDown()
}
},
AndroidDns(AndroidDns.DnsClass.IPV4).query(
hostname = hostname,
originatingCall = null,
callback =
object : AsyncDns.Callback {
override fun onAddresses(
hasMore: Boolean,
hostname: String,
addresses: List<InetAddress>,
) {
allAddresses.addAll(addresses)
if (!hasMore) {
latch.countDown()
}
}

override fun onFailure(
hasMore: Boolean,
hostname: String,
e: IOException,
) {
exception = e
if (!hasMore) {
latch.countDown()
}
}
},
)

latch.await()
Expand Down Expand Up @@ -187,7 +196,7 @@ class AndroidAsyncDnsTest {

val client =
OkHttpClient.Builder()
.dns(AsyncDns.toDns(AndroidAsyncDns.IPv4, AndroidAsyncDns.IPv6))
.dns(Dns.ANDROID)
.socketFactory(network.socketFactory)
.build()

Expand All @@ -199,11 +208,13 @@ class AndroidAsyncDnsTest {
}
}

private fun assumeNetwork() {
try {
InetAddress.getByName("www.google.com")
} catch (uhe: UnknownHostException) {
throw AssumptionViolatedException(uhe.message, uhe)
companion object {
fun assumeNetwork() {
try {
InetAddress.getByName("www.google.com")
} catch (uhe: UnknownHostException) {
throw AssumptionViolatedException(uhe.message, uhe)
}
}
}
}
@@ -0,0 +1,64 @@
/*
* Copyright (c) 2024 Block, Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
*/
package okhttp3.android

import android.net.ConnectivityManager
import android.net.Network
import android.os.Build
import androidx.test.platform.app.InstrumentationRegistry
import assertk.assertThat
import assertk.assertions.isEqualTo
import okhttp3.HttpUrl.Companion.toHttpUrl
import okhttp3.OkHttpClient
import okhttp3.Request
import okhttp3.android.AndroidDnsTest.Companion.assumeNetwork
import okhttp3.android.NetworkSelection.withNetwork
import org.junit.Assume.assumeTrue
import org.junit.Before
import org.junit.Test

class AndroidNetworkSelectionTest {
private var workingNetwork: Network? = null

private lateinit var client: OkHttpClient

@Before
fun init() {
assumeTrue("Supported on API 29+", Build.VERSION.SDK_INT >= Build.VERSION_CODES.Q)
assumeNetwork()

val connectivityManager =
InstrumentationRegistry.getInstrumentation().context.getSystemService(ConnectivityManager::class.java)

workingNetwork = connectivityManager.activeNetwork

assumeTrue(workingNetwork != null)

client =
OkHttpClient.Builder()
.withNetwork(network = workingNetwork).build()
}

@Test
fun testRequest() {
val call = client.newCall(Request("https://www.google.com/robots.txt".toHttpUrl()))

call.execute().use { response ->
assertThat(response.code).isEqualTo(200)
}
}
}
Expand Up @@ -24,8 +24,13 @@ import androidx.annotation.RequiresApi
import java.net.InetAddress
import java.net.UnknownHostException
import java.util.concurrent.Executors
import okhttp3.AsyncDns
import okhttp3.Call
import okhttp3.Dns
import okhttp3.ExperimentalOkHttpApi
import okhttp3.android.AndroidDns.DnsClass
import okhttp3.android.internal.AsyncDns
import okhttp3.android.internal.BlockingAsyncDns.Companion.asBlocking
import okhttp3.android.internal.CombinedAsyncDns.Companion.union

/**
* DNS implementation based on android.net.DnsResolver, which submits a request for
Expand All @@ -37,8 +42,8 @@ import okhttp3.ExperimentalOkHttpApi
*/
@RequiresApi(Build.VERSION_CODES.Q)
@ExperimentalOkHttpApi
class AndroidAsyncDns(
private val dnsClass: AsyncDns.DnsClass,
internal class AndroidDns internal constructor(
private val dnsClass: DnsClass,
private val network: Network? = null,
) : AsyncDns {
@RequiresApi(Build.VERSION_CODES.Q)
Expand All @@ -47,6 +52,7 @@ class AndroidAsyncDns(

override fun query(
hostname: String,
originatingCall: Call?,
callback: AsyncDns.Callback,
) {
try {
Expand All @@ -62,15 +68,17 @@ class AndroidAsyncDns(
addresses: List<InetAddress>,
rCode: Int,
) {
callback.onResponse(hostname, addresses)
callback.onAddresses(hasMore = false, hostname = hostname, addresses = addresses)
}

override fun onError(e: DnsResolver.DnsException) {
callback.onFailure(
hostname,
UnknownHostException(e.message).apply {
initCause(e)
},
hasMore = false,
hostname = hostname,
e =
UnknownHostException(e.message).apply {
initCause(e)
},
)
}
},
Expand All @@ -79,6 +87,7 @@ class AndroidAsyncDns(
// Handle any errors that might leak out
// https://issuetracker.google.com/issues/319957694
callback.onFailure(
hasMore = false,
hostname,
UnknownHostException(e.message).apply {
initCause(e)
Expand All @@ -89,10 +98,31 @@ class AndroidAsyncDns(

@ExperimentalOkHttpApi
companion object {
@RequiresApi(Build.VERSION_CODES.Q)
val IPv4 = AndroidAsyncDns(dnsClass = AsyncDns.DnsClass.IPV4)
internal fun forNetwork(network: Network): AsyncDns {
return union(
AndroidDns(dnsClass = DnsClass.IPV4, network = network),
AndroidDns(dnsClass = DnsClass.IPV6, network = network),
)
}

@RequiresApi(Build.VERSION_CODES.Q)
val IPv6 = AndroidAsyncDns(dnsClass = AsyncDns.DnsClass.IPV6)
internal const val TYPE_A = 1
internal const val TYPE_AAAA = 28
}

/**
* Class of DNS addresses, such that clients that treat these differently, such
* as attempting IPv6 first, can make such decisions.
*/
@ExperimentalOkHttpApi
internal enum class DnsClass(val type: Int) {
IPV4(TYPE_A),
IPV6(TYPE_AAAA),
}
}

val Dns.Companion.ANDROID: Dns
@RequiresApi(Build.VERSION_CODES.Q)
get() = union(AndroidDns(dnsClass = DnsClass.IPV4), AndroidDns(dnsClass = DnsClass.IPV6)).asBlocking()

@RequiresApi(Build.VERSION_CODES.Q)
fun Dns.Companion.forNetwork(network: Network): Dns = AndroidDns.forNetwork(network).asBlocking()