Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

OkHttpClient.Builder network pinning on Android #8376

Open
wants to merge 15 commits into
base: master
Choose a base branch
from
Expand Up @@ -73,7 +73,7 @@ class AndroidAsyncDnsTest {

client =
OkHttpClient.Builder()
.dns(AsyncDns.toDns(AndroidAsyncDns.IPv4, AndroidAsyncDns.IPv6))
.dns(AndroidAsyncDns.DEFAULT.asBlocking())
.sslSocketFactory(localhost.sslSocketFactory(), localhost.trustManager)
.build()

Expand Down Expand Up @@ -131,24 +131,32 @@ class AndroidAsyncDnsTest {

// assumes an IPv4 address
AndroidAsyncDns.IPv4.query(
hostname,
object : AsyncDns.Callback {
override fun onResponse(
hostname: String,
addresses: List<InetAddress>,
) {
allAddresses.addAll(addresses)
latch.countDown()
}

override fun onFailure(
hostname: String,
e: IOException,
) {
exception = e
latch.countDown()
}
},
hostname = hostname,
originatingCall = null,
callback =
object : AsyncDns.Callback {
override fun onAddresses(
hasMore: Boolean,
hostname: String,
addresses: List<InetAddress>,
) {
allAddresses.addAll(addresses)
if (!hasMore) {
latch.countDown()
}
}

override fun onFailure(
hasMore: Boolean,
hostname: String,
e: IOException,
) {
exception = e
if (!hasMore) {
latch.countDown()
}
}
},
)

latch.await()
Expand Down Expand Up @@ -187,7 +195,7 @@ class AndroidAsyncDnsTest {

val client =
OkHttpClient.Builder()
.dns(AsyncDns.toDns(AndroidAsyncDns.IPv4, AndroidAsyncDns.IPv6))
.dns(AndroidAsyncDns.DEFAULT.asBlocking())
.socketFactory(network.socketFactory)
.build()

Expand Down
@@ -0,0 +1,88 @@
/*
* Copyright (c) 2024 Block, Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
*/
package okhttp3.android

import android.net.ConnectivityManager
import android.net.Network
import android.os.Build
import androidx.test.platform.app.InstrumentationRegistry
import assertk.assertThat
import assertk.assertions.isEqualTo
import mockwebserver3.MockResponse
import mockwebserver3.junit4.MockWebServerRule
import okhttp3.OkHttpClient
import okhttp3.Request
import okhttp3.android.NetworkSelection.withNetwork
import okhttp3.tls.HandshakeCertificates
import okhttp3.tls.HeldCertificate
import org.junit.Assume.assumeTrue
import org.junit.Before
import org.junit.Rule
import org.junit.Test

class AndroidNetworkSelectionTest {
private var activeNetwork: Network? = null

@JvmField
@Rule
val serverRule = MockWebServerRule()
private lateinit var client: OkHttpClient

private val localhost: HandshakeCertificates by lazy {
// Generate a self-signed cert for the server to serve and the client to trust.
val heldCertificate =
HeldCertificate.Builder()
.addSubjectAlternativeName("localhost")
.build()
return@lazy HandshakeCertificates.Builder()
.addPlatformTrustedCertificates()
.heldCertificate(heldCertificate)
.addTrustedCertificate(heldCertificate.certificate)
.build()
}

@Before
fun init() {
assumeTrue("Supported on API 29+", Build.VERSION.SDK_INT >= Build.VERSION_CODES.Q)

val connectivityManager =
InstrumentationRegistry.getInstrumentation().context.getSystemService(ConnectivityManager::class.java)

activeNetwork = connectivityManager.activeNetwork
assumeTrue(activeNetwork != null)

client =
OkHttpClient.Builder()
.sslSocketFactory(localhost.sslSocketFactory(), localhost.trustManager)
.withNetwork(network = activeNetwork)
.build()

serverRule.server.useHttps(localhost.sslSocketFactory())
}

@Test
fun testRequest() {
serverRule.server.enqueue(MockResponse())

val call = client.newCall(Request(serverRule.server.url("/")))

call.execute().use { response ->
assertThat(response.code).isEqualTo(200)
assertThat(response.request.tag<NetworkPin>()?.network).isEqualTo(activeNetwork)
}
}
}
24 changes: 19 additions & 5 deletions okhttp-android/src/main/kotlin/okhttp3/android/AndroidAsyncDns.kt
Expand Up @@ -25,6 +25,7 @@ import java.net.InetAddress
import java.net.UnknownHostException
import java.util.concurrent.Executors
import okhttp3.AsyncDns
import okhttp3.Call
import okhttp3.ExperimentalOkHttpApi

/**
Expand All @@ -47,6 +48,7 @@ class AndroidAsyncDns(

override fun query(
hostname: String,
originatingCall: Call?,
yschimke marked this conversation as resolved.
Show resolved Hide resolved
callback: AsyncDns.Callback,
) {
try {
Expand All @@ -62,15 +64,17 @@ class AndroidAsyncDns(
addresses: List<InetAddress>,
rCode: Int,
) {
callback.onResponse(hostname, addresses)
callback.onAddresses(hasMore = false, hostname = hostname, addresses = addresses)
}

override fun onError(e: DnsResolver.DnsException) {
callback.onFailure(
hostname,
UnknownHostException(e.message).apply {
initCause(e)
},
hasMore = false,
hostname = hostname,
e =
UnknownHostException(e.message).apply {
initCause(e)
},
)
}
},
Expand All @@ -79,6 +83,7 @@ class AndroidAsyncDns(
// Handle any errors that might leak out
// https://issuetracker.google.com/issues/319957694
callback.onFailure(
hasMore = false,
hostname,
UnknownHostException(e.message).apply {
initCause(e)
Expand All @@ -94,5 +99,14 @@ class AndroidAsyncDns(

@RequiresApi(Build.VERSION_CODES.Q)
val IPv6 = AndroidAsyncDns(dnsClass = AsyncDns.DnsClass.IPV6)

val DEFAULT: AsyncDns = AsyncDns.union(IPv4, IPv6)

fun forNetwork(network: Network): AsyncDns {
return AsyncDns.union(
AndroidAsyncDns(dnsClass = AsyncDns.DnsClass.IPV4, network = network),
AndroidAsyncDns(dnsClass = AsyncDns.DnsClass.IPV6, network = network),
)
}
}
}
@@ -0,0 +1,80 @@
/*
* Copyright (C) 2024 Block, Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package okhttp3.android

import android.net.Network
import android.net.TrafficStats
import java.net.InetAddress
import java.net.Socket
import javax.net.SocketFactory
import okhttp3.ExperimentalOkHttpApi

@ExperimentalOkHttpApi
class AndroidSocketFactory(
val network: Network,
) : SocketFactory() {
private val socketFactory: SocketFactory = network.socketFactory

override fun createSocket(
host: String?,
port: Int,
): Socket {
return socketFactory.createSocket(host, port).also { configure(it) }
}

private fun configure(it: Socket) {
println("Tagging socket on ${Thread.currentThread().name}")
TrafficStats.tagSocket(it)
}

override fun createSocket(
host: String?,
port: Int,
localHost: InetAddress?,
localPort: Int,
): Socket {
return socketFactory.createSocket(host, port, localHost, localPort).also { configure(it) }
}

override fun createSocket(
host: InetAddress?,
port: Int,
): Socket {
return socketFactory.createSocket(host, port).also { configure(it) }
}

override fun createSocket(
address: InetAddress?,
port: Int,
localAddress: InetAddress?,
localPort: Int,
): Socket {
return socketFactory.createSocket(address, port, localAddress, localPort).also { configure(it) }
}

override fun hashCode(): Int {
return network.networkHandle.hashCode()
}

override fun equals(other: Any?): Boolean {
return other is AndroidSocketFactory &&
network == other.network
}

override fun toString(): String {
return "AndroidSocketFactory{$network}"
}
}
22 changes: 22 additions & 0 deletions okhttp-android/src/main/kotlin/okhttp3/android/NetworkPin.kt
@@ -0,0 +1,22 @@
/*
* Copyright (C) 2024 Block, Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package okhttp3.android

import android.net.Network
import okhttp3.ExperimentalOkHttpApi

@ExperimentalOkHttpApi
data class NetworkPin(val network: Network?)
yschimke marked this conversation as resolved.
Show resolved Hide resolved
47 changes: 47 additions & 0 deletions okhttp-android/src/main/kotlin/okhttp3/android/NetworkSelection.kt
@@ -0,0 +1,47 @@
/*
* Copyright (C) 2024 Block, Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package okhttp3.android

import android.net.Network
import android.os.Build
import androidx.annotation.RequiresApi
import javax.net.SocketFactory
import okhttp3.ExperimentalOkHttpApi
import okhttp3.OkHttpClient
import okhttp3.android.internal.NetworkPinInterceptor

@ExperimentalOkHttpApi
object NetworkSelection {
@RequiresApi(Build.VERSION_CODES.Q)
fun OkHttpClient.Builder.withNetwork(network: Network?): OkHttpClient.Builder {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I like this part of the API.

What happens if we attempt this in the main OkHttp module? We’d have a runtime dependency on Network, but you wouldn’t call this API if you don’t have it

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am ok with either.

I guess ultimately we should choose either

a) we build android specific bits in main module, but you just don't use them if not needed.
b) we encourage every android app to include okhttp and okhttp-android and use those bits.

I'm not sure.

interceptors().apply {
removeIf { it is NetworkPinInterceptor }
}

return if (network == null) {
dns(AndroidAsyncDns.DEFAULT.asBlocking())
.socketFactory(SocketFactory.getDefault())
} else {
dns(AndroidAsyncDns.forNetwork(network).asBlocking())
.socketFactory(AndroidSocketFactory(network))
.apply {
interceptors().apply {
add(0, NetworkPinInterceptor(network))
}
}
}
}
}