Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

OkHttpClient.Builder network pinning on Android #8376

Open
wants to merge 15 commits into
base: master
Choose a base branch
from
32 changes: 31 additions & 1 deletion okhttp-android/api/okhttp-android.api
Expand Up @@ -2,10 +2,12 @@ public final class okhttp3/android/AndroidAsyncDns : okhttp3/AsyncDns {
public static final field Companion Lokhttp3/android/AndroidAsyncDns$Companion;
public fun <init> (Lokhttp3/AsyncDns$DnsClass;Landroid/net/Network;)V
public synthetic fun <init> (Lokhttp3/AsyncDns$DnsClass;Landroid/net/Network;ILkotlin/jvm/internal/DefaultConstructorMarker;)V
public fun query (Ljava/lang/String;Lokhttp3/AsyncDns$Callback;)V
public fun query (Ljava/lang/String;Lokhttp3/Call;Lokhttp3/AsyncDns$Callback;)V
}

public final class okhttp3/android/AndroidAsyncDns$Companion {
public final fun forNetwork (Landroid/net/Network;)Lokhttp3/AsyncDns;
public final fun getDEFAULT ()Lokhttp3/AsyncDns;
public final fun getIPv4 ()Lokhttp3/android/AndroidAsyncDns;
public final fun getIPv6 ()Lokhttp3/android/AndroidAsyncDns;
}
Expand All @@ -17,3 +19,31 @@ public final class okhttp3/android/AndroidLoggingKt {
public static synthetic fun androidLogging$default (Lokhttp3/logging/LoggingEventListener$Companion;ILjava/lang/String;ILjava/lang/Object;)Lokhttp3/logging/LoggingEventListener$Factory;
}

public final class okhttp3/android/AndroidSocketFactory : javax/net/SocketFactory {
public fun <init> (Landroid/net/Network;)V
public fun createSocket (Ljava/lang/String;I)Ljava/net/Socket;
public fun createSocket (Ljava/lang/String;ILjava/net/InetAddress;I)Ljava/net/Socket;
public fun createSocket (Ljava/net/InetAddress;I)Ljava/net/Socket;
public fun createSocket (Ljava/net/InetAddress;ILjava/net/InetAddress;I)Ljava/net/Socket;
public fun equals (Ljava/lang/Object;)Z
public final fun getNetwork ()Landroid/net/Network;
public fun hashCode ()I
public fun toString ()Ljava/lang/String;
}

public final class okhttp3/android/NetworkPin {
public fun <init> (Landroid/net/Network;)V
public final fun component1 ()Landroid/net/Network;
public final fun copy (Landroid/net/Network;)Lokhttp3/android/NetworkPin;
public static synthetic fun copy$default (Lokhttp3/android/NetworkPin;Landroid/net/Network;ILjava/lang/Object;)Lokhttp3/android/NetworkPin;
public fun equals (Ljava/lang/Object;)Z
public final fun getNetwork ()Landroid/net/Network;
public fun hashCode ()I
public fun toString ()Ljava/lang/String;
}

public final class okhttp3/android/NetworkSelection {
public static final field INSTANCE Lokhttp3/android/NetworkSelection;
public final fun withNetwork (Lokhttp3/OkHttpClient$Builder;Landroid/net/Network;)Lokhttp3/OkHttpClient$Builder;
}

Expand Up @@ -73,7 +73,7 @@ class AndroidAsyncDnsTest {

client =
OkHttpClient.Builder()
.dns(AsyncDns.toDns(AndroidAsyncDns.IPv4, AndroidAsyncDns.IPv6))
.dns(AndroidAsyncDns.DEFAULT.asBlocking())
.sslSocketFactory(localhost.sslSocketFactory(), localhost.trustManager)
.build()

Expand Down Expand Up @@ -131,24 +131,32 @@ class AndroidAsyncDnsTest {

// assumes an IPv4 address
AndroidAsyncDns.IPv4.query(
hostname,
object : AsyncDns.Callback {
override fun onResponse(
hostname: String,
addresses: List<InetAddress>,
) {
allAddresses.addAll(addresses)
latch.countDown()
}

override fun onFailure(
hostname: String,
e: IOException,
) {
exception = e
latch.countDown()
}
},
hostname = hostname,
originatingCall = null,
callback =
object : AsyncDns.Callback {
override fun onAddresses(
hasMore: Boolean,
hostname: String,
addresses: List<InetAddress>,
) {
allAddresses.addAll(addresses)
if (!hasMore) {
latch.countDown()
}
}

override fun onFailure(
hasMore: Boolean,
hostname: String,
e: IOException,
) {
exception = e
if (!hasMore) {
latch.countDown()
}
}
},
)

latch.await()
Expand Down Expand Up @@ -187,7 +195,7 @@ class AndroidAsyncDnsTest {

val client =
OkHttpClient.Builder()
.dns(AsyncDns.toDns(AndroidAsyncDns.IPv4, AndroidAsyncDns.IPv6))
.dns(AndroidAsyncDns.DEFAULT.asBlocking())
.socketFactory(network.socketFactory)
.build()

Expand Down
@@ -0,0 +1,88 @@
/*
* Copyright (c) 2024 Block, Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
*/
package okhttp3.android

import android.net.ConnectivityManager
import android.net.Network
import android.os.Build
import androidx.test.platform.app.InstrumentationRegistry
import assertk.assertThat
import assertk.assertions.isEqualTo
import mockwebserver3.MockResponse
import mockwebserver3.junit4.MockWebServerRule
import okhttp3.OkHttpClient
import okhttp3.Request
import okhttp3.android.NetworkSelection.withNetwork
import okhttp3.tls.HandshakeCertificates
import okhttp3.tls.HeldCertificate
import org.junit.Assume.assumeTrue
import org.junit.Before
import org.junit.Rule
import org.junit.Test

class AndroidNetworkSelectionTest {
private var activeNetwork: Network? = null

@JvmField
@Rule
val serverRule = MockWebServerRule()
private lateinit var client: OkHttpClient

private val localhost: HandshakeCertificates by lazy {
// Generate a self-signed cert for the server to serve and the client to trust.
val heldCertificate =
HeldCertificate.Builder()
.addSubjectAlternativeName("localhost")
.build()
return@lazy HandshakeCertificates.Builder()
.addPlatformTrustedCertificates()
.heldCertificate(heldCertificate)
.addTrustedCertificate(heldCertificate.certificate)
.build()
}

@Before
fun init() {
assumeTrue("Supported on API 29+", Build.VERSION.SDK_INT >= Build.VERSION_CODES.Q)

val connectivityManager =
InstrumentationRegistry.getInstrumentation().context.getSystemService(ConnectivityManager::class.java)

activeNetwork = connectivityManager.activeNetwork
assumeTrue(activeNetwork != null)

client =
OkHttpClient.Builder()
.sslSocketFactory(localhost.sslSocketFactory(), localhost.trustManager)
.withNetwork(network = activeNetwork)
.build()

serverRule.server.useHttps(localhost.sslSocketFactory())
}

@Test
fun testRequest() {
serverRule.server.enqueue(MockResponse())

val call = client.newCall(Request(serverRule.server.url("/")))

call.execute().use { response ->
assertThat(response.code).isEqualTo(200)
assertThat(response.request.tag<NetworkPin>()?.network).isEqualTo(activeNetwork)
}
}
}
24 changes: 19 additions & 5 deletions okhttp-android/src/main/kotlin/okhttp3/android/AndroidAsyncDns.kt
Expand Up @@ -25,6 +25,7 @@ import java.net.InetAddress
import java.net.UnknownHostException
import java.util.concurrent.Executors
import okhttp3.AsyncDns
import okhttp3.Call
import okhttp3.ExperimentalOkHttpApi

/**
Expand All @@ -47,6 +48,7 @@ class AndroidAsyncDns(

override fun query(
hostname: String,
originatingCall: Call?,
yschimke marked this conversation as resolved.
Show resolved Hide resolved
callback: AsyncDns.Callback,
) {
try {
Expand All @@ -62,15 +64,17 @@ class AndroidAsyncDns(
addresses: List<InetAddress>,
rCode: Int,
) {
callback.onResponse(hostname, addresses)
callback.onAddresses(hasMore = false, hostname = hostname, addresses = addresses)
}

override fun onError(e: DnsResolver.DnsException) {
callback.onFailure(
hostname,
UnknownHostException(e.message).apply {
initCause(e)
},
hasMore = false,
hostname = hostname,
e =
UnknownHostException(e.message).apply {
initCause(e)
},
)
}
},
Expand All @@ -79,6 +83,7 @@ class AndroidAsyncDns(
// Handle any errors that might leak out
// https://issuetracker.google.com/issues/319957694
callback.onFailure(
hasMore = false,
hostname,
UnknownHostException(e.message).apply {
initCause(e)
Expand All @@ -94,5 +99,14 @@ class AndroidAsyncDns(

@RequiresApi(Build.VERSION_CODES.Q)
val IPv6 = AndroidAsyncDns(dnsClass = AsyncDns.DnsClass.IPV6)

val DEFAULT: AsyncDns = AsyncDns.union(IPv4, IPv6)

fun forNetwork(network: Network): AsyncDns {
return AsyncDns.union(
AndroidAsyncDns(dnsClass = AsyncDns.DnsClass.IPV4, network = network),
AndroidAsyncDns(dnsClass = AsyncDns.DnsClass.IPV6, network = network),
)
}
}
}
@@ -0,0 +1,80 @@
/*
* Copyright (C) 2024 Block, Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package okhttp3.android

import android.net.Network
import android.net.TrafficStats
import java.net.InetAddress
import java.net.Socket
import javax.net.SocketFactory
import okhttp3.ExperimentalOkHttpApi

@ExperimentalOkHttpApi
class AndroidSocketFactory(
val network: Network,
) : SocketFactory() {
private val socketFactory: SocketFactory = network.socketFactory

override fun createSocket(
host: String?,
port: Int,
): Socket {
return socketFactory.createSocket(host, port).also { configure(it) }
}

private fun configure(it: Socket) {
println("Tagging socket on ${Thread.currentThread().name}")
TrafficStats.tagSocket(it)
}

override fun createSocket(
host: String?,
port: Int,
localHost: InetAddress?,
localPort: Int,
): Socket {
return socketFactory.createSocket(host, port, localHost, localPort).also { configure(it) }
}

override fun createSocket(
host: InetAddress?,
port: Int,
): Socket {
return socketFactory.createSocket(host, port).also { configure(it) }
}

override fun createSocket(
address: InetAddress?,
port: Int,
localAddress: InetAddress?,
localPort: Int,
): Socket {
return socketFactory.createSocket(address, port, localAddress, localPort).also { configure(it) }
}

override fun hashCode(): Int {
return network.networkHandle.hashCode()
}

override fun equals(other: Any?): Boolean {
return other is AndroidSocketFactory &&
network == other.network
}

override fun toString(): String {
return "AndroidSocketFactory{$network}"
}
}
22 changes: 22 additions & 0 deletions okhttp-android/src/main/kotlin/okhttp3/android/NetworkPin.kt
@@ -0,0 +1,22 @@
/*
* Copyright (C) 2024 Block, Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package okhttp3.android

import android.net.Network
import okhttp3.ExperimentalOkHttpApi

@ExperimentalOkHttpApi
data class NetworkPin(val network: Network?)
yschimke marked this conversation as resolved.
Show resolved Hide resolved