Skip to content

Commit

Permalink
Add timeoutIntervalForRequest and timeoutIntervalForResource (#6)
Browse files Browse the repository at this point in the history
  • Loading branch information
atdrendel committed Mar 17, 2021
1 parent 17747e7 commit 00de3b9
Show file tree
Hide file tree
Showing 10 changed files with 156 additions and 95 deletions.
2 changes: 2 additions & 0 deletions .swift-version
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
5.3.0

6 changes: 6 additions & 0 deletions .swiftformat
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
--funcattributes prev-line
--minversion 0.47.2
--maxwidth 100
--typeattributes prev-line
--wraparguments before-first
--wrapcollections before-first
23 changes: 17 additions & 6 deletions Package.swift
Original file line number Diff line number Diff line change
Expand Up @@ -9,24 +9,35 @@ let package = Package(
products: [
.library(
name: "WebSocket",
targets: ["WebSocket"]),
targets: ["WebSocket"]
),
],
dependencies: [
.package(name: "Synchronized", url: "https://github.com/shareup/synchronized.git", from: "2.1.0"),
.package(name: "WebSocketProtocol", url: "https://github.com/shareup/websocket-protocol.git", from: "2.2.0"),
.package(
name: "Synchronized",
url: "https://github.com/shareup/synchronized.git",
from: "2.1.0"
),
.package(
name: "WebSocketProtocol",
url: "https://github.com/shareup/websocket-protocol.git",
from: "2.2.0"
),
.package(name: "swift-nio", url: "https://github.com/apple/swift-nio.git", from: "2.0.0"),
],
targets: [
.target(
name: "WebSocket",
dependencies: ["Synchronized", "WebSocketProtocol"]),
dependencies: ["Synchronized", "WebSocketProtocol"]
),
.testTarget(
name: "WebSocketTests",
dependencies: [
.product(name: "NIO", package: "swift-nio"),
.product(name: "NIOHTTP1", package: "swift-nio"),
.product(name: "NIOWebSocket", package: "swift-nio"),
"WebSocket"
]),
"WebSocket",
]
),
]
)
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
import Foundation
import WebSocketProtocol

extension URLSessionWebSocketTask.CloseCode {
public init?(_ closeCode: WebSocketCloseCode) {
public extension URLSessionWebSocketTask.CloseCode {
init?(_ closeCode: WebSocketCloseCode) {
self.init(rawValue: closeCode.rawValue)
}
}
73 changes: 49 additions & 24 deletions Sources/WebSocket/WebSocket.swift
Original file line number Diff line number Diff line change
Expand Up @@ -54,25 +54,35 @@ public final class WebSocket: WebSocketProtocol {
return true
} }

private let lock: RecursiveLock = RecursiveLock()
private func sync<T>(_ block: () throws -> T) rethrows -> T { return try lock.locked(block) }
private let lock = RecursiveLock()
private func sync<T>(_ block: () throws -> T) rethrows -> T { try lock.locked(block) }

private let url: URL

private let timeoutIntervalForRequest: TimeInterval
private let timeoutIntervalForResource: TimeInterval

private var state: State = .unopened
private let subject = PassthroughSubject<Output, Failure>()

private let subjectQueue: DispatchQueue

public convenience init(url: URL) {
self.init(url: url, publisherQueue: nil)
self.init(url: url, publisherQueue: DispatchQueue.global())
}

public init(url: URL, publisherQueue: DispatchQueue?) {
public init(
url: URL,
timeoutIntervalForRequest: TimeInterval = 60, // 60 seconds
timeoutIntervalForResource: TimeInterval = 604_800, // 7 days
publisherQueue: DispatchQueue = DispatchQueue.global()
) {
self.url = url
self.subjectQueue = DispatchQueue(
self.timeoutIntervalForRequest = timeoutIntervalForRequest
self.timeoutIntervalForResource = timeoutIntervalForResource
subjectQueue = DispatchQueue(
label: "app.shareup.websocket.subjectqueue",
attributes: [],
qos: .default,
autoreleaseFrequency: .workItem,
target: publisherQueue
)
Expand All @@ -91,23 +101,30 @@ public final class WebSocket: WebSocketProtocol {
state.debugDescription
)

switch (state) {
switch state {
case .closed, .unopened:
let delegate = WebSocketDelegate(
onOpen: onOpen,
onClose: onClose,
onCompletion: onCompletion
)

let config = URLSessionConfiguration.default
config.timeoutIntervalForRequest = timeoutIntervalForRequest
config.timeoutIntervalForResource = timeoutIntervalForResource

let session = URLSession(
configuration: .default,
configuration: config,
delegate: delegate,
delegateQueue: nil
)

let task = session.webSocketTask(with: url)
task.maximumMessageSize = maximumMessageSize
state = .connecting(session, task, delegate)
task.resume()
receiveFromWebSocket()

default:
break
}
Expand All @@ -121,7 +138,7 @@ public final class WebSocket: WebSocketProtocol {
}

private func receiveFromWebSocket() {
let task: URLSessionWebSocketTask? = self.sync {
let task: URLSessionWebSocketTask? = sync {
let webSocketTask = self.state.webSocketSessionAndTask?.1
guard let task = webSocketTask, case .running = task.state else { return nil }
return task
Expand Down Expand Up @@ -169,7 +186,7 @@ public final class WebSocket: WebSocketProtocol {
completionHandler: @escaping (Error?) -> Void
) {
let task: URLSessionWebSocketTask? = sync {
guard case .open(_, let task, _) = state, task.state == .running
guard case let .open(_, task, _) = state, task.state == .running
else {
os_log(
"send message in incorrect task state: message=%s taskstate=%{public}@",
Expand All @@ -187,8 +204,8 @@ public final class WebSocket: WebSocketProtocol {
task?.send(message, completionHandler: completionHandler)
}

public func close(_ closeCode: WebSocketCloseCode) {
let task: URLSessionWebSocketTask? = self.sync {
public func close(_ closeCode: WebSocketCloseCode) {
let task: URLSessionWebSocketTask? = sync {
os_log(
"close: oldstate=%{public}@ code=%lld",
log: .webSocket,
Expand All @@ -209,16 +226,21 @@ public final class WebSocket: WebSocketProtocol {
}

private typealias OnOpenHandler = (URLSession, URLSessionWebSocketTask, String?) -> Void
private typealias OnCloseHandler = (URLSession, URLSessionWebSocketTask, URLSessionWebSocketTask.CloseCode, Data?) -> Void
private typealias OnCloseHandler = (
URLSession,
URLSessionWebSocketTask,
URLSessionWebSocketTask.CloseCode,
Data?
) -> Void
private typealias OnCompletionHandler = (URLSession, URLSessionTask, Error?) -> Void

private let normalCloseCodes: [URLSessionWebSocketTask.CloseCode] = [.goingAway, .normalClosure]

// MARK: onOpen and onClose

private extension WebSocket {
private extension WebSocket {
var onOpen: OnOpenHandler {
return { [weak self] (webSocketSession, webSocketTask, `protocol`) in
{ [weak self] webSocketSession, webSocketTask, _ in
guard let self = self else { return }

self.sync {
Expand Down Expand Up @@ -255,7 +277,7 @@ private extension WebSocket {
}

var onClose: OnCloseHandler {
return { [weak self] (webSocketSession, webSocketTask, closeCode, reason) in
{ [weak self] _, _, closeCode, reason in
guard let self = self else { return }

self.sync {
Expand Down Expand Up @@ -284,10 +306,10 @@ private extension WebSocket {
}

var onCompletion: OnCompletionHandler {
return { [weak self] (webSocketSession, webSocketTask, error) in
{ [weak self] webSocketSession, _, error in
defer { webSocketSession.invalidateAndCancel() }
guard let self = self else { return }

os_log("onCompletion", log: .webSocket, type: .debug)

// "The only errors your delegate receives through the error parameter
Expand Down Expand Up @@ -338,20 +360,23 @@ private class WebSocketDelegate: NSObject, URLSessionWebSocketDelegate {

func urlSession(_ webSocketSession: URLSession,
webSocketTask: URLSessionWebSocketTask,
didOpenWithProtocol protocol: String?) {
self.onOpen(webSocketSession, webSocketTask, `protocol`)
didOpenWithProtocol protocol: String?)
{
onOpen(webSocketSession, webSocketTask, `protocol`)
}

func urlSession(_ session: URLSession,
webSocketTask: URLSessionWebSocketTask,
didCloseWith closeCode: URLSessionWebSocketTask.CloseCode,
reason: Data?) {
self.onClose(session, webSocketTask, closeCode, reason)
reason: Data?)
{
onClose(session, webSocketTask, closeCode, reason)
}

func urlSession(_ session: URLSession,
task: URLSessionTask,
didCompleteWithError error: Error?) {
self.onCompletion(session, task, error)
didCompleteWithError error: Error?)
{
onCompletion(session, task, error)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,9 @@ import WebSocketProtocol
extension WebSocketMessage {
init(_ message: URLSessionWebSocketTask.Message) {
switch message {
case .data(let data):
case let .data(data):
self = .binary(data)
case .string(let string):
case let .string(string):
self = .text(string)
@unknown default:
assertionFailure("Unknown WebSocket Message type")
Expand Down
52 changes: 30 additions & 22 deletions Tests/WebSocketTests/Server/WebSocketServer.swift
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,8 @@ final class WebSocketServer {

init(port: UInt16, replyProvider: ReplyType) {
self.port = port
self.replyType = replyProvider
self.eventLoopGroup = MultiThreadedEventLoopGroup(numberOfThreads: 1)
replyType = replyProvider
eventLoopGroup = MultiThreadedEventLoopGroup(numberOfThreads: 1)
}

func listen() {
Expand All @@ -36,11 +36,9 @@ final class WebSocketServer {
throw NIO.ChannelError.unknownLocalAddress
}
print("WebSocketServer running on \(localAddress)")
}
catch let error as NIO.IOError {
} catch let error as NIO.IOError {
print("Failed to start server: \(error.errnoCode) '\(error.localizedDescription)'")
}
catch {
} catch {
print("Failed to start server: \(String(describing: error))")
}
}
Expand All @@ -50,25 +48,30 @@ final class WebSocketServer {
catch { print("Failed to wait on server: \(error)") }
}

private func shouldUpgrade(channel: Channel, head: HTTPRequestHead) -> EventLoopFuture<HTTPHeaders?> {
private func shouldUpgrade(channel _: Channel,
head: HTTPRequestHead) -> EventLoopFuture<HTTPHeaders?>
{
let headers = head.uri.starts(with: "/socket") ? HTTPHeaders() : nil
return eventLoopGroup.next().makeSucceededFuture(headers)
}

private func upgradePipelineHandler(channel: Channel, head: HTTPRequestHead) -> NIO.EventLoopFuture<Void> {
return head.uri.starts(with: "/socket") ?
channel.pipeline.addHandler(WebSocketHandler(replyProvider: replyProvider)) : channel.closeFuture
private func upgradePipelineHandler(channel: Channel, head: HTTPRequestHead) -> NIO
.EventLoopFuture<Void>
{
head.uri.starts(with: "/socket") ?
channel.pipeline.addHandler(WebSocketHandler(replyProvider: replyProvider)) : channel
.closeFuture
}

private var replyProvider: (String) -> String? {
return { [weak self] (input: String) -> String? in
{ [weak self] (input: String) -> String? in
guard let self = self else { return nil }
switch self.replyType {
case .echo:
return input
case .reply(let iterator):
case let .reply(iterator):
return iterator()
case .matchReply(let matcher):
case let .matchReply(matcher):
return matcher(input)
}
}
Expand Down Expand Up @@ -104,7 +107,7 @@ final class WebSocketServer {
}

private class WebSocketHandler: ChannelInboundHandler {
typealias InboundIn = WebSocketFrame
typealias InboundIn = WebSocketFrame
typealias OutboundOut = WebSocketFrame

private let replyProvider: (String) -> String?
Expand All @@ -119,9 +122,9 @@ private class WebSocketHandler: ChannelInboundHandler {

switch frame.opcode {
case .connectionClose:
self.onClose(context: context, frame: frame)
onClose(context: context, frame: frame)
case .ping:
self.onPing(context: context, frame: frame)
onPing(context: context, frame: frame)
case .text:
var data = frame.unmaskedData
let text = data.readString(length: data.readableBytes) ?? ""
Expand All @@ -142,7 +145,7 @@ private class WebSocketHandler: ChannelInboundHandler {
if let text = String(data: binary, encoding: .utf8) {
onText(context: context, text: text)
} else {
throw NIO.IOError.init(errnoCode: EBADMSG, reason: "Invalid message")
throw NIO.IOError(errnoCode: EBADMSG, reason: "Invalid message")
}
} catch {
onError(context: context)
Expand All @@ -168,7 +171,7 @@ private class WebSocketHandler: ChannelInboundHandler {
}

let pong = WebSocketFrame(fin: true, opcode: .pong, data: frameData)
context.write(self.wrapOutboundOut(pong), promise: nil)
context.write(wrapOutboundOut(pong), promise: nil)
}

private func onClose(context: ChannelHandlerContext, frame: WebSocketFrame) {
Expand All @@ -178,9 +181,14 @@ private class WebSocketHandler: ChannelInboundHandler {
} else {
// The close came from the client.
var data = frame.unmaskedData
let closeDataCode = data.readSlice(length: 2) ?? context.channel.allocator.buffer(capacity: 0)
let closeFrame = WebSocketFrame(fin: true, opcode: .connectionClose, data: closeDataCode)
_ = context.write(self.wrapOutboundOut(closeFrame)).map { () in
let closeDataCode = data.readSlice(length: 2) ?? context.channel.allocator
.buffer(capacity: 0)
let closeFrame = WebSocketFrame(
fin: true,
opcode: .connectionClose,
data: closeDataCode
)
_ = context.write(wrapOutboundOut(closeFrame)).map { () in
context.close(promise: nil)
}
}
Expand All @@ -190,7 +198,7 @@ private class WebSocketHandler: ChannelInboundHandler {
var data = context.channel.allocator.buffer(capacity: 2)
data.write(webSocketErrorCode: .protocolError)
let frame = WebSocketFrame(fin: true, opcode: .connectionClose, data: data)
context.write(self.wrapOutboundOut(frame)).whenComplete { (_: Result<Void, Error>) in
context.write(wrapOutboundOut(frame)).whenComplete { (_: Result<Void, Error>) in
context.close(mode: .output, promise: nil)
}
awaitingClose = true
Expand Down
Loading

0 comments on commit 00de3b9

Please sign in to comment.