Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Experiment with locks #8390

Draft
wants to merge 10 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
Expand Up @@ -45,6 +45,7 @@ import javax.net.ssl.SSLSocket
import javax.net.ssl.SSLSocketFactory
import javax.net.ssl.TrustManager
import javax.net.ssl.X509TrustManager
import kotlin.time.Duration.Companion.milliseconds
import mockwebserver3.SocketPolicy.DisconnectAfterRequest
import mockwebserver3.SocketPolicy.DisconnectAtEnd
import mockwebserver3.SocketPolicy.DisconnectAtStart
Expand Down Expand Up @@ -374,6 +375,12 @@ class MockWebServer : Closeable {
@Throws(Exception::class)
private fun acceptConnections() {
while (true) {
val socketPolicy = dispatcher.peek().socketPolicy

if (socketPolicy is SocketPolicy.DelayAccept) {
Thread.sleep(100.milliseconds.inWholeMilliseconds)
}

val socket: Socket
try {
socket = serverSocket!!.accept()
Expand All @@ -382,7 +389,6 @@ class MockWebServer : Closeable {
return
}

val socketPolicy = dispatcher.peek().socketPolicy
if (socketPolicy === DisconnectAtStart) {
dispatchBookkeepingRequest(0, socket)
socket.close()
Expand Down
6 changes: 6 additions & 0 deletions mockwebserver/src/main/kotlin/mockwebserver3/SocketPolicy.kt
Expand Up @@ -16,6 +16,7 @@

package mockwebserver3

import kotlin.time.Duration
import okhttp3.ExperimentalOkHttpApi

/**
Expand Down Expand Up @@ -59,6 +60,11 @@ sealed interface SocketPolicy {
*/
object DisconnectAtStart : SocketPolicy

/**
* Delay before accepting on the ServerSocket.
*/
class DelayAccept(val delay: Duration) : SocketPolicy

/**
* Close connection after reading the request but before writing the response. Use this to
* simulate late connection pool failures.
Expand Down
Expand Up @@ -25,9 +25,9 @@ import java.util.logging.Level
import java.util.logging.LogManager
import java.util.logging.LogRecord
import java.util.logging.Logger
import kotlin.concurrent.withLock
import okhttp3.internal.buildConnectionPool
import okhttp3.internal.concurrent.TaskRunner
import okhttp3.internal.connection.Locks.withLock
import okhttp3.internal.connection.RealConnectionPool
import okhttp3.internal.http2.Http2
import okhttp3.internal.taskRunnerInternal
Expand Down Expand Up @@ -234,7 +234,7 @@ class OkHttpClientTestRule : BeforeEachCallback, AfterEachCallback {
// a test timeout failure.
val waitTime = (entryTime + 1_000_000_000L - System.nanoTime())
if (!queue.idleLatch().await(waitTime, TimeUnit.NANOSECONDS)) {
TaskRunner.INSTANCE.lock.withLock {
TaskRunner.INSTANCE.withLock {
TaskRunner.INSTANCE.cancelAll()
}
fail<Unit>("Queue still active after 1000 ms")
Expand Down
Expand Up @@ -13,6 +13,8 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
@file:Suppress("INVISIBLE_MEMBER", "INVISIBLE_REFERENCE")

package okhttp3.internal.concurrent

import assertk.assertThat
Expand All @@ -23,9 +25,9 @@ import java.util.concurrent.BlockingQueue
import java.util.concurrent.Executors
import java.util.concurrent.TimeUnit
import java.util.logging.Logger
import kotlin.concurrent.withLock
import okhttp3.OkHttpClient
import okhttp3.TestUtil.threadFactory
import okhttp3.internal.connection.Locks.withLock

/**
* Runs a [TaskRunner] in a controlled environment so that everything is sequential and
Expand Down Expand Up @@ -166,7 +168,7 @@ class TaskFaker : Closeable {
fun advanceUntil(newTime: Long) {
taskRunner.assertThreadDoesntHoldLock()

taskRunner.lock.withLock {
taskRunner.withLock {
check(currentTask == TestThreadSerialTask)
nanoTime = newTime
yieldUntil(ResumePriority.AfterOtherTasks)
Expand All @@ -177,7 +179,7 @@ class TaskFaker : Closeable {
fun assertNoMoreTasks() {
taskRunner.assertThreadDoesntHoldLock()

taskRunner.lock.withLock {
taskRunner.withLock {
assertThat(activeThreads).isEqualTo(0)
}
}
Expand Down Expand Up @@ -207,7 +209,7 @@ class TaskFaker : Closeable {
fun runNextTask() {
taskRunner.assertThreadDoesntHoldLock()

taskRunner.lock.withLock {
taskRunner.withLock {
val contextSwitchCountBefore = contextSwitchCount
yieldUntil(ResumePriority.BeforeOtherTasks) {
contextSwitchCount > contextSwitchCountBefore
Expand All @@ -217,7 +219,7 @@ class TaskFaker : Closeable {

/** Sleep until [durationNanos] elapses. For use by the task threads. */
fun sleep(durationNanos: Long) {
taskRunner.lock.withLock {
taskRunner.withLock {
val sleepUntil = nanoTime + durationNanos
yieldUntil { nanoTime >= sleepUntil }
}
Expand All @@ -229,7 +231,7 @@ class TaskFaker : Closeable {
*/
fun yield() {
taskRunner.assertThreadDoesntHoldLock()
taskRunner.lock.withLock {
taskRunner.withLock {
yieldUntil()
}
}
Expand Down Expand Up @@ -328,7 +330,7 @@ class TaskFaker : Closeable {
runnable.run()
require(currentTask == this) { "unexpected current task: $currentTask" }
} finally {
taskRunner.lock.withLock {
taskRunner.withLock {
activeThreads--
startNextTask()
}
Expand All @@ -354,7 +356,7 @@ class TaskFaker : Closeable {
timeout: Long,
unit: TimeUnit,
): T? {
taskRunner.lock.withLock {
taskRunner.withLock {
val waitUntil = nanoTime + unit.toNanos(timeout)
while (true) {
val result = poll()
Expand All @@ -367,7 +369,7 @@ class TaskFaker : Closeable {
}

override fun put(element: T) {
taskRunner.lock.withLock {
taskRunner.withLock {
delegate.put(element)
editCount++
}
Expand Down
14 changes: 7 additions & 7 deletions okhttp/src/main/kotlin/okhttp3/internal/concurrent/TaskQueue.kt
Expand Up @@ -18,8 +18,8 @@ package okhttp3.internal.concurrent
import java.util.concurrent.CountDownLatch
import java.util.concurrent.RejectedExecutionException
import java.util.concurrent.locks.ReentrantLock
import kotlin.concurrent.withLock
import okhttp3.internal.assertNotHeld
import okhttp3.internal.connection.Locks.withLock
import okhttp3.internal.okHttpName

/**
Expand All @@ -32,7 +32,7 @@ class TaskQueue internal constructor(
internal val taskRunner: TaskRunner,
internal val name: String,
) {
val lock: ReentrantLock = ReentrantLock()
internal val lock: ReentrantLock = ReentrantLock()

internal var shutdown = false

Expand All @@ -50,7 +50,7 @@ class TaskQueue internal constructor(
* currently-executing task unless it is also scheduled for future execution.
*/
val scheduledTasks: List<Task>
get() = taskRunner.lock.withLock { futureTasks.toList() }
get() = taskRunner.withLock { futureTasks.toList() }

/**
* Schedules [task] for execution in [delayNanos]. A task may only have one future execution
Expand All @@ -66,7 +66,7 @@ class TaskQueue internal constructor(
task: Task,
delayNanos: Long = 0L,
) {
taskRunner.lock.withLock {
taskRunner.withLock {
if (shutdown) {
if (task.cancelable) {
taskRunner.logger.taskLog(task, this) { "schedule canceled (queue is shutdown)" }
Expand Down Expand Up @@ -126,7 +126,7 @@ class TaskQueue internal constructor(

/** Returns a latch that reaches 0 when the queue is next idle. */
fun idleLatch(): CountDownLatch {
taskRunner.lock.withLock {
taskRunner.withLock {
// If the queue is already idle, that's easy.
if (activeTask == null && futureTasks.isEmpty()) {
return CountDownLatch(0)
Expand Down Expand Up @@ -208,7 +208,7 @@ class TaskQueue internal constructor(
fun cancelAll() {
lock.assertNotHeld()

taskRunner.lock.withLock {
taskRunner.withLock {
if (cancelAllAndDecide()) {
taskRunner.kickCoordinator(this)
}
Expand All @@ -218,7 +218,7 @@ class TaskQueue internal constructor(
fun shutdown() {
lock.assertNotHeld()

taskRunner.lock.withLock {
taskRunner.withLock {
shutdown = true
if (cancelAllAndDecide()) {
taskRunner.kickCoordinator(this)
Expand Down
35 changes: 25 additions & 10 deletions okhttp/src/main/kotlin/okhttp3/internal/concurrent/TaskRunner.kt
Expand Up @@ -23,10 +23,15 @@ import java.util.concurrent.TimeUnit
import java.util.concurrent.locks.Condition
import java.util.concurrent.locks.ReentrantLock
import java.util.logging.Logger
import kotlin.concurrent.withLock
import kotlin.time.Duration.Companion.microseconds
import kotlin.time.Duration.Companion.milliseconds
import kotlin.time.measureTime
import okhttp3.internal.addIfAbsent
import okhttp3.internal.assertHeld
import okhttp3.internal.concurrent.TaskRunner.Companion.INSTANCE
import okhttp3.internal.connection.Locks
import okhttp3.internal.connection.Locks.newLockCondition
import okhttp3.internal.connection.Locks.withLock
import okhttp3.internal.okHttpName
import okhttp3.internal.threadFactory

Expand All @@ -45,8 +50,8 @@ class TaskRunner(
val backend: Backend,
internal val logger: Logger = TaskRunner.logger,
) {
val lock: ReentrantLock = ReentrantLock()
val condition: Condition = lock.newCondition()
internal val lock: ReentrantLock = ReentrantLock()
val condition: Condition = lock.newLockCondition()

private var nextQueueName = 10000
private var coordinatorWaiting = false
Expand All @@ -63,7 +68,7 @@ class TaskRunner(
override fun run() {
while (true) {
val task =
this@TaskRunner.lock.withLock {
[email protected] {
awaitTaskToRun()
} ?: return

Expand All @@ -75,7 +80,7 @@ class TaskRunner(
} finally {
// If the task is crashing start another thread to service the queues.
if (!completedNormally) {
lock.withLock {
this@TaskRunner.withLock {
backend.execute(this@TaskRunner, this)
}
}
Expand Down Expand Up @@ -123,7 +128,7 @@ class TaskRunner(
try {
delayNanos = task.runOnce()
} finally {
lock.withLock {
this.withLock {
afterRun(task, delayNanos)
}
currentThread.name = oldName
Expand Down Expand Up @@ -239,7 +244,7 @@ class TaskRunner(
}

fun newQueue(): TaskQueue {
val name = lock.withLock { nextQueueName++ }
val name = this.withLock { nextQueueName++ }
return TaskQueue(this, "Q$name")
}

Expand All @@ -248,7 +253,7 @@ class TaskRunner(
* necessarily track queues that have no tasks scheduled.
*/
fun activeQueues(): List<TaskQueue> {
lock.withLock {
this.withLock {
return busyQueues + readyQueues
}
}
Expand Down Expand Up @@ -295,7 +300,7 @@ class TaskRunner(
// keepAliveTime:
60L,
TimeUnit.SECONDS,
SynchronousQueue(),
SynchronousQueue(false),
threadFactory,
)

Expand Down Expand Up @@ -327,7 +332,13 @@ class TaskRunner(
taskRunner: TaskRunner,
runnable: Runnable,
) {
executor.execute(runnable)
val time = measureTime {
executor.execute(runnable)
}

if (time > 500.microseconds) {
println("executor.execute " + time)
}
}

fun shutdown() {
Expand All @@ -340,5 +351,9 @@ class TaskRunner(

@JvmField
val INSTANCE = TaskRunner(RealBackend(threadFactory("$okHttpName TaskRunner", daemon = true)))

init {
Locks.lockToWatch = INSTANCE.lock
}
}
}
Expand Up @@ -26,7 +26,6 @@ import java.security.cert.X509Certificate
import java.util.concurrent.TimeUnit
import javax.net.ssl.SSLPeerUnverifiedException
import javax.net.ssl.SSLSocket
import kotlin.concurrent.withLock
import okhttp3.CertificatePinner
import okhttp3.ConnectionSpec
import okhttp3.Handshake
Expand Down