diff --git a/okhttp/api/okhttp.api b/okhttp/api/okhttp.api index f5d00d98b77d..91910affffa0 100644 --- a/okhttp/api/okhttp.api +++ b/okhttp/api/okhttp.api @@ -766,6 +766,7 @@ public abstract interface class okhttp3/Interceptor$Chain { public abstract fun readTimeoutMillis ()I public abstract fun request ()Lokhttp3/Request; public abstract fun withConnectTimeout (ILjava/util/concurrent/TimeUnit;)Lokhttp3/Interceptor$Chain; + public abstract fun withEventListener (Lokhttp3/EventListener;)Lokhttp3/Interceptor$Chain; public abstract fun withReadTimeout (ILjava/util/concurrent/TimeUnit;)Lokhttp3/Interceptor$Chain; public abstract fun withWriteTimeout (ILjava/util/concurrent/TimeUnit;)Lokhttp3/Interceptor$Chain; public abstract fun writeTimeoutMillis ()I diff --git a/okhttp/src/jvmMain/kotlin/okhttp3/Interceptor.kt b/okhttp/src/jvmMain/kotlin/okhttp3/Interceptor.kt index 9c8814d61460..7bc11105a9e2 100644 --- a/okhttp/src/jvmMain/kotlin/okhttp3/Interceptor.kt +++ b/okhttp/src/jvmMain/kotlin/okhttp3/Interceptor.kt @@ -100,5 +100,12 @@ fun interface Interceptor { fun writeTimeoutMillis(): Int fun withWriteTimeout(timeout: Int, unit: TimeUnit): Chain + + /** + * Add an [EventListener] to the [Call] instance, that will receive all + * subsequent events. The [Chain] is not mutated, since the call is shared + * when copying a chain. + */ + fun withEventListener(eventListener: EventListener): Chain } } diff --git a/okhttp/src/jvmMain/kotlin/okhttp3/internal/connection/EventListenerList.kt b/okhttp/src/jvmMain/kotlin/okhttp3/internal/connection/EventListenerList.kt new file mode 100644 index 000000000000..fe7e32b704df --- /dev/null +++ b/okhttp/src/jvmMain/kotlin/okhttp3/internal/connection/EventListenerList.kt @@ -0,0 +1,226 @@ +/* + * Copyright (C) 2022 Square, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package okhttp3.internal.connection + +import java.io.IOException +import java.net.InetAddress +import java.net.InetSocketAddress +import java.net.Proxy +import okhttp3.Call +import okhttp3.Connection +import okhttp3.EventListener +import okhttp3.Handshake +import okhttp3.HttpUrl +import okhttp3.Protocol +import okhttp3.Request +import okhttp3.Response + +internal class EventListenerList( + var existingListener: EventListener +) : EventListener() { + private var additionalListeners: MutableList? = null + + inline fun forEachListener(fn: EventListener.() -> Unit) { + existingListener.fn() + synchronized(this) { + additionalListeners?.forEach(fn) + } + } + + override fun callStart(call: Call) { + forEachListener { + callStart(call) + } + } + + override fun proxySelectStart(call: Call, url: HttpUrl) { + forEachListener { + proxySelectStart(call, url) + } + } + + override fun proxySelectEnd(call: Call, url: HttpUrl, proxies: List) { + forEachListener { + proxySelectEnd(call, url, proxies) + } + } + + override fun dnsStart(call: Call, domainName: String) { + forEachListener { + dnsStart(call, domainName) + } + } + + override fun dnsEnd(call: Call, domainName: String, inetAddressList: List) { + forEachListener { + dnsEnd(call, domainName, inetAddressList) + } + } + + override fun connectStart(call: Call, inetSocketAddress: InetSocketAddress, proxy: Proxy) { + forEachListener { + connectStart(call, inetSocketAddress, proxy) + } + } + + override fun secureConnectStart(call: Call) { + forEachListener { + secureConnectStart(call) + } + } + + override fun secureConnectEnd(call: Call, handshake: Handshake?) { + forEachListener { + secureConnectEnd(call, handshake) + } + } + + override fun connectEnd(call: Call, inetSocketAddress: InetSocketAddress, proxy: Proxy, protocol: Protocol?) { + forEachListener { + connectEnd(call, inetSocketAddress, proxy, protocol) + } + } + + override fun connectFailed(call: Call, inetSocketAddress: InetSocketAddress, proxy: Proxy, protocol: Protocol?, ioe: IOException) { + forEachListener { + connectFailed(call, inetSocketAddress, proxy, protocol, ioe) + } + } + + override fun connectionAcquired(call: Call, connection: Connection) { + forEachListener { + connectionAcquired(call, connection) + } + } + + override fun connectionReleased(call: Call, connection: Connection) { + forEachListener { + connectionReleased(call, connection) + } + } + + override fun requestHeadersStart(call: Call) { + forEachListener { + requestHeadersStart(call) + } + } + + override fun requestHeadersEnd(call: Call, request: Request) { + forEachListener { + requestHeadersEnd(call, request) + } + } + + override fun requestBodyStart(call: Call) { + forEachListener { + requestBodyStart(call) + } + } + + override fun requestBodyEnd(call: Call, byteCount: Long) { + forEachListener { + requestBodyEnd(call, byteCount) + } + } + + override fun requestFailed(call: Call, ioe: IOException) { + forEachListener { + requestFailed(call, ioe) + } + } + + override fun responseHeadersStart(call: Call) { + forEachListener { + responseHeadersStart(call) + } + } + + override fun responseHeadersEnd(call: Call, response: Response) { + forEachListener { + responseHeadersEnd(call, response) + } + } + + override fun responseBodyStart(call: Call) { + forEachListener { + responseBodyStart(call) + } + } + + override fun responseBodyEnd(call: Call, byteCount: Long) { + forEachListener { + responseBodyEnd(call, byteCount) + } + } + + override fun responseFailed(call: Call, ioe: IOException) { + forEachListener { + responseFailed(call, ioe) + } + } + + override fun callEnd(call: Call) { + forEachListener { + callEnd(call) + } + } + + override fun callFailed(call: Call, ioe: IOException) { + forEachListener { + callFailed(call, ioe) + } + } + + override fun canceled(call: Call) { + forEachListener { + canceled(call) + } + } + + override fun satisfactionFailure(call: Call, response: Response) { + forEachListener { + satisfactionFailure(call, response) + } + } + + override fun cacheHit(call: Call, response: Response) { + forEachListener { + cacheHit(call, response) + } + } + + override fun cacheMiss(call: Call) { + forEachListener { + cacheMiss(call) + } + } + + override fun cacheConditionalHit(call: Call, cachedResponse: Response) { + forEachListener { + cacheConditionalHit(call, cachedResponse) + } + } + + fun addListener(eventListener: EventListener) { + synchronized(this) { + if (additionalListeners == null) { + additionalListeners = mutableListOf(eventListener) + } else { + additionalListeners?.add(eventListener) + } + } + } +} diff --git a/okhttp/src/jvmMain/kotlin/okhttp3/internal/connection/RealCall.kt b/okhttp/src/jvmMain/kotlin/okhttp3/internal/connection/RealCall.kt index 033ee5122348..0e14776bcf48 100644 --- a/okhttp/src/jvmMain/kotlin/okhttp3/internal/connection/RealCall.kt +++ b/okhttp/src/jvmMain/kotlin/okhttp3/internal/connection/RealCall.kt @@ -31,7 +31,6 @@ import okhttp3.Address import okhttp3.Call import okhttp3.Callback import okhttp3.CertificatePinner -import okhttp3.EventListener import okhttp3.HttpUrl import okhttp3.Interceptor import okhttp3.OkHttpClient @@ -67,7 +66,7 @@ class RealCall( ) : Call, Cloneable { private val connectionPool: RealConnectionPool = client.connectionPool.delegate - internal val eventListener: EventListener = client.eventListenerFactory.create(this) + internal val eventListener: EventListenerList = EventListenerList(client.eventListenerFactory.create(this)) private val timeout = object : AsyncTimeout() { override fun timedOut() { diff --git a/okhttp/src/jvmMain/kotlin/okhttp3/internal/http/RealInterceptorChain.kt b/okhttp/src/jvmMain/kotlin/okhttp3/internal/http/RealInterceptorChain.kt index d7fe4e744cd8..f06e13dad179 100644 --- a/okhttp/src/jvmMain/kotlin/okhttp3/internal/http/RealInterceptorChain.kt +++ b/okhttp/src/jvmMain/kotlin/okhttp3/internal/http/RealInterceptorChain.kt @@ -19,6 +19,7 @@ import java.io.IOException import java.util.concurrent.TimeUnit import okhttp3.Call import okhttp3.Connection +import okhttp3.EventListener import okhttp3.Interceptor import okhttp3.Request import okhttp3.Response @@ -54,7 +55,7 @@ class RealInterceptorChain( readTimeoutMillis: Int = this.readTimeoutMillis, writeTimeoutMillis: Int = this.writeTimeoutMillis ) = RealInterceptorChain(call, interceptors, index, exchange, request, connectTimeoutMillis, - readTimeoutMillis, writeTimeoutMillis) + readTimeoutMillis, writeTimeoutMillis) override fun connection(): Connection? = exchange?.connection @@ -82,6 +83,11 @@ class RealInterceptorChain( return copy(writeTimeoutMillis = checkDuration("writeTimeout", timeout.toLong(), unit)) } + override fun withEventListener(eventListener: EventListener): Interceptor.Chain { + call.eventListener.addListener(eventListener) + return this + } + override fun call(): Call = call override fun request(): Request = request @@ -107,7 +113,7 @@ class RealInterceptorChain( @Suppress("USELESS_ELVIS") val response = interceptor.intercept(next) ?: throw NullPointerException( - "interceptor $interceptor returned null") + "interceptor $interceptor returned null") if (exchange != null) { check(index + 1 >= interceptors.size || next.calls == 1) { diff --git a/okhttp/src/jvmTest/java/okhttp3/InterceptorTest.java b/okhttp/src/jvmTest/java/okhttp3/InterceptorTest.java index 1b0a2fee1ce2..ebd3ef9c02ec 100644 --- a/okhttp/src/jvmTest/java/okhttp3/InterceptorTest.java +++ b/okhttp/src/jvmTest/java/okhttp3/InterceptorTest.java @@ -760,6 +760,77 @@ private void interceptorThrowsRuntimeExceptionAsynchronous(boolean network) thro } } + @Test + public void interceptorCanAddEventListener() throws Exception { + RecordingEventListener clientEventListener = new RecordingEventListener(); + RecordingEventListener applicationEventListener = new RecordingEventListener(false); + RecordingEventListener networkEventListener = new RecordingEventListener(false); + + client = client.newBuilder() + .eventListener(clientEventListener) + .addInterceptor(chain -> + chain.withEventListener(applicationEventListener).proceed(chain.request()) + ) + .addNetworkInterceptor(chain -> + chain.withEventListener(networkEventListener).proceed(chain.request()) + ) + .build(); + + server.enqueue(new MockResponse.Builder() + .body("abc") + .throttleBody(1, 1, TimeUnit.SECONDS) + .build() + ); + + Request request1 = new Request.Builder().url(server.url("/")).build(); + Call call = client.newCall(request1); + + call.execute().body().string(); + + assertThat(clientEventListener.recordedEventTypes()).containsExactly( + "CallStart", + "ProxySelectStart", + "ProxySelectEnd", + "DnsStart", + "DnsEnd", + "ConnectStart", + "ConnectEnd", + "ConnectionAcquired", + "RequestHeadersStart", + "RequestHeadersEnd", + "ResponseHeadersStart", + "ResponseHeadersEnd", + "ResponseBodyStart", + "ResponseBodyEnd", + "ConnectionReleased", + "CallEnd"); + assertThat(applicationEventListener.recordedEventTypes()).containsExactly( + "ProxySelectStart", + "ProxySelectEnd", + "DnsStart", + "DnsEnd", + "ConnectStart", + "ConnectEnd", + "ConnectionAcquired", + "RequestHeadersStart", + "RequestHeadersEnd", + "ResponseHeadersStart", + "ResponseHeadersEnd", + "ResponseBodyStart", + "ResponseBodyEnd", + "ConnectionReleased", + "CallEnd"); + assertThat(networkEventListener.recordedEventTypes()).containsExactly( + "RequestHeadersStart", + "RequestHeadersEnd", + "ResponseHeadersStart", + "ResponseHeadersEnd", + "ResponseBodyStart", + "ResponseBodyEnd", + "ConnectionReleased", + "CallEnd"); + } + @Test public void chainWithWriteTimeout() throws Exception { Interceptor interceptor1 = chainA -> { assertThat(chainA.writeTimeoutMillis()).isEqualTo(5000); diff --git a/okhttp/src/jvmTest/java/okhttp3/KotlinSourceModernTest.kt b/okhttp/src/jvmTest/java/okhttp3/KotlinSourceModernTest.kt index 6e020d10a847..6726557b09e7 100644 --- a/okhttp/src/jvmTest/java/okhttp3/KotlinSourceModernTest.kt +++ b/okhttp/src/jvmTest/java/okhttp3/KotlinSourceModernTest.kt @@ -1181,6 +1181,7 @@ class KotlinSourceModernTest { override fun withReadTimeout(timeout: Int, unit: TimeUnit): Interceptor.Chain = TODO() override fun writeTimeoutMillis(): Int = TODO() override fun withWriteTimeout(timeout: Int, unit: TimeUnit): Interceptor.Chain = TODO() + override fun withEventListener(eventListener: EventListener): Interceptor.Chain = TODO() } } }