Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow Interceptors to add EventListeners. #7447

Draft
wants to merge 11 commits into
base: master
Choose a base branch
from
Expand Up @@ -55,7 +55,9 @@ import org.assertj.core.api.Assertions.assertThat
import org.assertj.core.data.Offset
import org.junit.jupiter.api.Assertions.assertTrue

open class RecordingEventListener : EventListener() {
open class RecordingEventListener(
private val enforceOrder: Boolean = true
yschimke marked this conversation as resolved.
Show resolved Hide resolved
) : EventListener() {
val eventSequence: Deque<CallEvent> = ConcurrentLinkedDeque()

private val forbiddenLocks = mutableListOf<Any>()
Expand Down Expand Up @@ -133,7 +135,7 @@ open class RecordingEventListener : EventListener() {
}

val startEvent = e.closes(-1L)
if (startEvent != null) {
if (enforceOrder && startEvent != null) {
assertTrue(eventSequence.any { it == e.closes(it.timestampNs) })
}

Expand Down
1 change: 1 addition & 0 deletions okhttp/api/okhttp.api
Expand Up @@ -752,6 +752,7 @@ public abstract interface class okhttp3/Interceptor$Chain {
public abstract fun readTimeoutMillis ()I
public abstract fun request ()Lokhttp3/Request;
public abstract fun withConnectTimeout (ILjava/util/concurrent/TimeUnit;)Lokhttp3/Interceptor$Chain;
public abstract fun withEventListener (Lokhttp3/EventListener;)Lokhttp3/Interceptor$Chain;
public abstract fun withReadTimeout (ILjava/util/concurrent/TimeUnit;)Lokhttp3/Interceptor$Chain;
public abstract fun withWriteTimeout (ILjava/util/concurrent/TimeUnit;)Lokhttp3/Interceptor$Chain;
public abstract fun writeTimeoutMillis ()I
Expand Down
7 changes: 7 additions & 0 deletions okhttp/src/jvmMain/kotlin/okhttp3/Interceptor.kt
Expand Up @@ -100,5 +100,12 @@ fun interface Interceptor {
fun writeTimeoutMillis(): Int

fun withWriteTimeout(timeout: Int, unit: TimeUnit): Chain

/**
* Add an EventListener to the Call instance, that will receive all
* subsequent events. The chain is not mutated, since the Call is shared
* when copying a Chain.
yschimke marked this conversation as resolved.
Show resolved Hide resolved
*/
fun withEventListener(eventListener: EventListener): Chain
}
}
@@ -0,0 +1,226 @@
/*
* Copyright (C) 2022 Square, Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package okhttp3.internal.connection

import java.io.IOException
import java.net.InetAddress
import java.net.InetSocketAddress
import java.net.Proxy
import okhttp3.Call
import okhttp3.Connection
import okhttp3.EventListener
import okhttp3.Handshake
import okhttp3.HttpUrl
import okhttp3.Protocol
import okhttp3.Request
import okhttp3.Response

internal class EventListenerList(
var existingListener: EventListener
) : EventListener() {
private var additionalListeners: MutableList<EventListener>? = null

inline fun forEachListener(fn: EventListener.() -> Unit) {
existingListener.fn()
synchronized(this) {
additionalListeners?.forEach(fn)
}
}

override fun callStart(call: Call) {
forEachListener {
callStart(call)
}
}

override fun proxySelectStart(call: Call, url: HttpUrl) {
forEachListener {
proxySelectStart(call, url)
}
}

override fun proxySelectEnd(call: Call, url: HttpUrl, proxies: List<Proxy>) {
forEachListener {
proxySelectEnd(call, url, proxies)
}
}

override fun dnsStart(call: Call, domainName: String) {
forEachListener {
dnsStart(call, domainName)
}
}

override fun dnsEnd(call: Call, domainName: String, inetAddressList: List<InetAddress>) {
forEachListener {
dnsEnd(call, domainName, inetAddressList)
}
}

override fun connectStart(call: Call, inetSocketAddress: InetSocketAddress, proxy: Proxy) {
forEachListener {
connectStart(call, inetSocketAddress, proxy)
}
}

override fun secureConnectStart(call: Call) {
forEachListener {
secureConnectStart(call)
}
}

override fun secureConnectEnd(call: Call, handshake: Handshake?) {
forEachListener {
secureConnectEnd(call, handshake)
}
}

override fun connectEnd(call: Call, inetSocketAddress: InetSocketAddress, proxy: Proxy, protocol: Protocol?) {
forEachListener {
connectEnd(call, inetSocketAddress, proxy, protocol)
}
}

override fun connectFailed(call: Call, inetSocketAddress: InetSocketAddress, proxy: Proxy, protocol: Protocol?, ioe: IOException) {
forEachListener {
connectFailed(call, inetSocketAddress, proxy, protocol, ioe)
}
}

override fun connectionAcquired(call: Call, connection: Connection) {
forEachListener {
connectionAcquired(call, connection)
}
}

override fun connectionReleased(call: Call, connection: Connection) {
forEachListener {
connectionReleased(call, connection)
}
}

override fun requestHeadersStart(call: Call) {
forEachListener {
requestHeadersStart(call)
}
}

override fun requestHeadersEnd(call: Call, request: Request) {
forEachListener {
requestHeadersEnd(call, request)
}
}

override fun requestBodyStart(call: Call) {
forEachListener {
requestBodyStart(call)
}
}

override fun requestBodyEnd(call: Call, byteCount: Long) {
forEachListener {
requestBodyEnd(call, byteCount)
}
}

override fun requestFailed(call: Call, ioe: IOException) {
forEachListener {
requestFailed(call, ioe)
}
}

override fun responseHeadersStart(call: Call) {
forEachListener {
responseHeadersStart(call)
}
}

override fun responseHeadersEnd(call: Call, response: Response) {
forEachListener {
responseHeadersEnd(call, response)
}
}

override fun responseBodyStart(call: Call) {
forEachListener {
responseBodyStart(call)
}
}

override fun responseBodyEnd(call: Call, byteCount: Long) {
forEachListener {
responseBodyEnd(call, byteCount)
}
}

override fun responseFailed(call: Call, ioe: IOException) {
forEachListener {
responseFailed(call, ioe)
}
}

override fun callEnd(call: Call) {
forEachListener {
callEnd(call)
}
}

override fun callFailed(call: Call, ioe: IOException) {
forEachListener {
callFailed(call, ioe)
}
}

override fun canceled(call: Call) {
forEachListener {
canceled(call)
}
}

override fun satisfactionFailure(call: Call, response: Response) {
forEachListener {
satisfactionFailure(call, response)
}
}

override fun cacheHit(call: Call, response: Response) {
forEachListener {
cacheHit(call, response)
}
}

override fun cacheMiss(call: Call) {
forEachListener {
cacheMiss(call)
}
}

override fun cacheConditionalHit(call: Call, cachedResponse: Response) {
forEachListener {
cacheConditionalHit(call, cachedResponse)
}
}

fun addListener(eventListener: EventListener) {
synchronized(this) {
if (additionalListeners == null) {
additionalListeners = mutableListOf(eventListener)
} else {
additionalListeners?.add(eventListener)
}
}
}
}
Expand Up @@ -31,7 +31,6 @@ import okhttp3.Address
import okhttp3.Call
import okhttp3.Callback
import okhttp3.CertificatePinner
import okhttp3.EventListener
import okhttp3.HttpUrl
import okhttp3.Interceptor
import okhttp3.OkHttpClient
Expand Down Expand Up @@ -67,7 +66,7 @@ class RealCall(
) : Call, Cloneable {
private val connectionPool: RealConnectionPool = client.connectionPool.delegate

internal val eventListener: EventListener = client.eventListenerFactory.create(this)
internal val eventListener: EventListenerList = EventListenerList(client.eventListenerFactory.create(this))
yschimke marked this conversation as resolved.
Show resolved Hide resolved

private val timeout = object : AsyncTimeout() {
override fun timedOut() {
Expand Down
Expand Up @@ -19,6 +19,7 @@ import java.io.IOException
import java.util.concurrent.TimeUnit
import okhttp3.Call
import okhttp3.Connection
import okhttp3.EventListener
import okhttp3.Interceptor
import okhttp3.Request
import okhttp3.Response
Expand Down Expand Up @@ -54,7 +55,7 @@ class RealInterceptorChain(
readTimeoutMillis: Int = this.readTimeoutMillis,
writeTimeoutMillis: Int = this.writeTimeoutMillis
) = RealInterceptorChain(call, interceptors, index, exchange, request, connectTimeoutMillis,
readTimeoutMillis, writeTimeoutMillis)
readTimeoutMillis, writeTimeoutMillis)

override fun connection(): Connection? = exchange?.connection

Expand Down Expand Up @@ -82,6 +83,11 @@ class RealInterceptorChain(
return copy(writeTimeoutMillis = checkDuration("writeTimeout", timeout.toLong(), unit))
}

override fun withEventListener(eventListener: EventListener): Interceptor.Chain {
call.eventListener.addListener(eventListener)
yschimke marked this conversation as resolved.
Show resolved Hide resolved
return this
}

override fun call(): Call = call

override fun request(): Request = request
Expand All @@ -107,7 +113,7 @@ class RealInterceptorChain(

@Suppress("USELESS_ELVIS")
val response = interceptor.intercept(next) ?: throw NullPointerException(
"interceptor $interceptor returned null")
"interceptor $interceptor returned null")

if (exchange != null) {
check(index + 1 >= interceptors.size || next.calls == 1) {
Expand Down
69 changes: 69 additions & 0 deletions okhttp/src/jvmTest/java/okhttp3/InterceptorTest.java
Expand Up @@ -759,6 +759,75 @@ private void interceptorThrowsRuntimeExceptionAsynchronous(boolean network) thro
}
}

@Test
public void interceptorCanAddEventListener() throws Exception {
RecordingEventListener clientEventListener = new RecordingEventListener();
RecordingEventListener applicationEventListener = new RecordingEventListener(false);
RecordingEventListener networkEventListener = new RecordingEventListener(false);

client = client.newBuilder()
.eventListener(clientEventListener)
.addInterceptor(chain ->
chain.withEventListener(applicationEventListener).proceed(chain.request())
)
.addNetworkInterceptor(chain ->
chain.withEventListener(networkEventListener).proceed(chain.request())
)
.build();

server.enqueue(new MockResponse()
.setBody("abc")
.throttleBody(1, 1, TimeUnit.SECONDS));

Request request1 = new Request.Builder().url(server.url("/")).build();
Call call = client.newCall(request1);

call.execute().body().string();

assertThat(clientEventListener.recordedEventTypes()).containsExactly(
"CallStart",
"ProxySelectStart",
"ProxySelectEnd",
"DnsStart",
"DnsEnd",
"ConnectStart",
"ConnectEnd",
"ConnectionAcquired",
"RequestHeadersStart",
"RequestHeadersEnd",
"ResponseHeadersStart",
"ResponseHeadersEnd",
"ResponseBodyStart",
"ResponseBodyEnd",
"ConnectionReleased",
"CallEnd");
assertThat(applicationEventListener.recordedEventTypes()).containsExactly(
"ProxySelectStart",
"ProxySelectEnd",
"DnsStart",
"DnsEnd",
"ConnectStart",
"ConnectEnd",
"ConnectionAcquired",
"RequestHeadersStart",
"RequestHeadersEnd",
"ResponseHeadersStart",
"ResponseHeadersEnd",
"ResponseBodyStart",
"ResponseBodyEnd",
"ConnectionReleased",
"CallEnd");
assertThat(networkEventListener.recordedEventTypes()).containsExactly(
"RequestHeadersStart",
"RequestHeadersEnd",
"ResponseHeadersStart",
"ResponseHeadersEnd",
"ResponseBodyStart",
"ResponseBodyEnd",
"ConnectionReleased",
"CallEnd");
}

@Test public void chainWithWriteTimeout() throws Exception {
Interceptor interceptor1 = chainA -> {
assertThat(chainA.writeTimeoutMillis()).isEqualTo(5000);
Expand Down