From fde88c10c505eaad36241b731beddc837b597c44 Mon Sep 17 00:00:00 2001 From: Erik Little Date: Wed, 27 Jan 2021 13:22:14 -0500 Subject: [PATCH] Support both v2 and v3 --- Source/SocketIO/Client/SocketIOClient.swift | 12 +++ .../Client/SocketIOClientOption.swift | 13 ++++ Source/SocketIO/Engine/SocketEngine.swift | 78 ++++++++++++++++--- .../SocketIO/Engine/SocketEngineClient.swift | 10 ++- .../Engine/SocketEnginePollable.swift | 35 ++++++++- Source/SocketIO/Engine/SocketEngineSpec.swift | 27 ++++++- Source/SocketIO/Manager/SocketManager.swift | 46 +++++++++-- .../SocketIO/Manager/SocketManagerSpec.swift | 3 + Tests/TestSocketIO/SocketEngineTest.swift | 21 ++++- Tests/TestSocketIO/SocketSideEffectTest.swift | 1 + 10 files changed, 219 insertions(+), 27 deletions(-) diff --git a/Source/SocketIO/Client/SocketIOClient.swift b/Source/SocketIO/Client/SocketIOClient.swift index 3047260..160d2e9 100644 --- a/Source/SocketIO/Client/SocketIOClient.swift +++ b/Source/SocketIO/Client/SocketIOClient.swift @@ -134,6 +134,18 @@ open class SocketIOClient: NSObject, SocketIOClientSpec { joinNamespace(withPayload: payload) + switch manager.version { + case .three: + break + case .two where manager.status == .connected && nsp == "/": + // We might not get a connect event for the default nsp, fire immediately + didConnect(toNamespace: nsp, payload: nil) + + return + case _: + break + } + guard timeoutAfter != 0 else { return } manager.handleQueue.asyncAfter(deadline: DispatchTime.now() + timeoutAfter) {[weak self] in diff --git a/Source/SocketIO/Client/SocketIOClientOption.swift b/Source/SocketIO/Client/SocketIOClientOption.swift index c4effb8..1977a92 100644 --- a/Source/SocketIO/Client/SocketIOClientOption.swift +++ b/Source/SocketIO/Client/SocketIOClientOption.swift @@ -25,6 +25,12 @@ import Foundation import Starscream +/// The socket.io version being used. +public enum SocketIOVersion: Int { + case two = 2 + case three = 3 +} + protocol ClientOption : CustomStringConvertible, Equatable { func getSocketIOOptionValue() -> Any } @@ -99,6 +105,9 @@ public enum SocketIOClientOption : ClientOption { /// Sets an NSURLSessionDelegate for the underlying engine. Useful if you need to handle self-signed certs. case sessionDelegate(URLSessionDelegate) + /// The version of socket.io being used. This should match the server version. Default is 3. + case version(SocketIOVersion) + // MARK: Properties /// The description of this option. @@ -148,6 +157,8 @@ public enum SocketIOClientOption : ClientOption { description = "sessionDelegate" case .enableSOCKSProxy: description = "enableSOCKSProxy" + case .version: + description = "version" } return description @@ -199,6 +210,8 @@ public enum SocketIOClientOption : ClientOption { value = delegate case let .enableSOCKSProxy(enable): value = enable + case let.version(versionNum): + value = versionNum } return value diff --git a/Source/SocketIO/Engine/SocketEngine.swift b/Source/SocketIO/Engine/SocketEngine.swift index 7c30d07..6e64ba6 100644 --- a/Source/SocketIO/Engine/SocketEngine.swift +++ b/Source/SocketIO/Engine/SocketEngine.swift @@ -111,6 +111,9 @@ open class SocketEngine: /// The url for WebSockets. public private(set) var urlWebSocket = URL(string: "http://localhost/")! + /// The version of engine.io being used. Default is three. + public private(set) var version: SocketIOVersion = .three + /// If `true`, then the engine is currently in WebSockets mode. @available(*, deprecated, message: "No longer needed, if we're not polling, then we must be doing websockets") public private(set) var websocket = false @@ -133,8 +136,14 @@ open class SocketEngine: private var lastCommunication: Date? private var pingInterval: Int? - private var pingTimeout = 0 + private var pingTimeout = 0 { + didSet { + pongsMissedMax = Int(pingTimeout / (pingInterval ?? 25000)) + } + } + private var pongsMissed = 0 + private var pongsMissedMax = 0 private var probeWait = ProbeWaitQueue() private var secure = false private var certPinner: CertificatePinning? @@ -196,8 +205,9 @@ open class SocketEngine: } private func handleBase64(message: String) { + let offset = version.rawValue >= 3 ? 1 : 2 // binary in base64 string - let noPrefix = String(message[message.index(message.startIndex, offsetBy: 1)..= 3 { + checkPings() + } else { + sendPing() + } + if !forceWebsockets { doPoll() } - checkPings() client?.engineDidOpen(reason: "Connect") } private func handlePong(with message: String) { + pongsMissed = 0 + // We should upgrade if message == "3probe" { DefaultSocketLogger.Logger.log("Received probe response, should upgrade to WebSockets", @@ -445,10 +473,14 @@ open class SocketEngine: upgradeTransport() } + + client?.engineDidReceivePong() } private func handlePing(with message: String) { - write("", withType: .pong, withData: []) + if version.rawValue >= 3 { + write("", withType: .pong, withData: []) + } client?.engineDidReceivePing() } @@ -478,7 +510,7 @@ open class SocketEngine: lastCommunication = Date() - client?.parseEngineBinaryData(data) + client?.parseEngineBinaryData(version.rawValue >= 3 ? data : data.subdata(in: 1..= 3 ? "b" : "b4") { return handleBase64(message: message) } - guard let type = SocketEnginePacketType(rawValue: Int(reader.currentCharacter) ?? -1) else { + guard let type = SocketEnginePacketType(rawValue: message.first?.wholeNumberValue ?? -1) else { checkAndHandleEngineError(message) return @@ -536,6 +566,34 @@ open class SocketEngine: waitingForPost = false } + private func sendPing() { + guard connected, let pingInterval = pingInterval else { + print("not connected \(self.connected) or no ping interval \(self.pingInterval ?? -222)") + return + } + + // Server is not responding + if pongsMissed > pongsMissedMax { + closeOutEngine(reason: "Ping timeout") + return + } + + pongsMissed += 1 + write("", withType: .ping, withData: [], completion: nil) + + engineQueue.asyncAfter(deadline: .now() + .milliseconds(pingInterval)) {[weak self, id = self.sid] in + // Make sure not to ping old connections + guard let this = self, this.sid == id else { + print("wrong ping?") + return + } + + this.sendPing() + } + + client?.engineDidSendPing() + } + /// Called when the engine should set/update its configs from a given configuration. /// /// parameter config: The `SocketIOClientConfiguration` that should be used to set/update configs. @@ -570,6 +628,8 @@ open class SocketEngine: self.compress = true case .enableSOCKSProxy: self.enableSOCKSProxy = true + case let .version(num): + version = num default: continue } diff --git a/Source/SocketIO/Engine/SocketEngineClient.swift b/Source/SocketIO/Engine/SocketEngineClient.swift index bd3a3aa..903fa6d 100644 --- a/Source/SocketIO/Engine/SocketEngineClient.swift +++ b/Source/SocketIO/Engine/SocketEngineClient.swift @@ -44,10 +44,16 @@ import Foundation /// - parameter reason: The reason the engine opened. func engineDidOpen(reason: String) - /// Called when the engine receives a ping message. + /// Called when the engine receives a ping message. Only called in socket.io >3. func engineDidReceivePing() - /// Called when the engine sends a pong to the server. + /// Called when the engine receives a pong message. Only called in socket.io 2. + func engineDidReceivePong() + + /// Called when the engine sends a ping to the server. Only called in socket.io 2. + func engineDidSendPing() + + /// Called when the engine sends a pong to the server. Only called in socket.io >3. func engineDidSendPong() /// Called when the engine has a message that must be parsed. diff --git a/Source/SocketIO/Engine/SocketEnginePollable.swift b/Source/SocketIO/Engine/SocketEnginePollable.swift index 9a00a69..a5ee073 100644 --- a/Source/SocketIO/Engine/SocketEnginePollable.swift +++ b/Source/SocketIO/Engine/SocketEnginePollable.swift @@ -79,7 +79,15 @@ extension SocketEnginePollable { postWait.removeAll(keepingCapacity: true) } - let postStr = postWait.lazy.map({ $0.msg }).joined(separator: "\u{1e}") + var postStr = "" + + if version.rawValue >= 3 { + postStr = postWait.lazy.map({ $0.msg }).joined(separator: "\u{1e}") + } else { + for packet in postWait { + postStr += "\(packet.msg.utf16.count):\(packet.msg)" + } + } DefaultSocketLogger.Logger.log("Created POST string: \(postStr)", type: "SocketEnginePolling") @@ -195,10 +203,29 @@ extension SocketEnginePollable { DefaultSocketLogger.Logger.log("Got poll message: \(str)", type: "SocketEnginePolling") - let records = str.components(separatedBy: "\u{1e}") + if version.rawValue >= 3 { + let records = str.components(separatedBy: "\u{1e}") - for record in records { - parseEngineMessage(record) + for record in records { + parseEngineMessage(record) + } + } else { + guard str.count != 1 else { + parseEngineMessage(str) + + return + } + + var reader = SocketStringReader(message: str) + + while reader.hasNext { + if let n = Int(reader.readUntilOccurence(of: ":")) { + parseEngineMessage(reader.read(count: n)) + } else { + parseEngineMessage(str) + break + } + } } } diff --git a/Source/SocketIO/Engine/SocketEngineSpec.swift b/Source/SocketIO/Engine/SocketEngineSpec.swift index cbbb38c..1eecffd 100644 --- a/Source/SocketIO/Engine/SocketEngineSpec.swift +++ b/Source/SocketIO/Engine/SocketEngineSpec.swift @@ -81,6 +81,9 @@ public protocol SocketEngineSpec: class { /// The url for WebSockets. var urlWebSocket: URL { get } + /// The version of engine.io being used. Default is three. + var version: SocketIOVersion { get } + /// If `true`, then the engine is currently in WebSockets mode. @available(*, deprecated, message: "No longer needed, if we're not polling, then we must be doing websockets") var websocket: Bool { get } @@ -142,10 +145,23 @@ public protocol SocketEngineSpec: class { } extension SocketEngineSpec { + var engineIOParam: String { + switch version { + case .two: + return "&EIO=3" + case .three: + return "&EIO=4" + } + } + var urlPollingWithSid: URL { var com = URLComponents(url: urlPolling, resolvingAgainstBaseURL: false)! com.percentEncodedQuery = com.percentEncodedQuery! + "&sid=\(sid.urlEncode()!)" + if !com.percentEncodedQuery!.contains("EIO") { + com.percentEncodedQuery = com.percentEncodedQuery! + engineIOParam + } + return com.url! } @@ -153,6 +169,11 @@ extension SocketEngineSpec { var com = URLComponents(url: urlWebSocket, resolvingAgainstBaseURL: false)! com.percentEncodedQuery = com.percentEncodedQuery! + (sid == "" ? "" : "&sid=\(sid.urlEncode()!)") + if !com.percentEncodedQuery!.contains("EIO") { + com.percentEncodedQuery = com.percentEncodedQuery! + engineIOParam + } + + return com.url! } @@ -172,10 +193,12 @@ extension SocketEngineSpec { } func createBinaryDataForSend(using data: Data) -> Either { + let prefixB64 = version.rawValue >= 3 ? "b" : "b4" + if polling { - return .right("b" + data.base64EncodedString(options: Data.Base64EncodingOptions(rawValue: 0))) + return .right(prefixB64 + data.base64EncodedString(options: Data.Base64EncodingOptions(rawValue: 0))) } else { - return .left(data) + return .left(version.rawValue >= 3 ? data : Data([0x4]) + data) } } diff --git a/Source/SocketIO/Manager/SocketManager.swift b/Source/SocketIO/Manager/SocketManager.swift index 4b579b4..c45c5f5 100644 --- a/Source/SocketIO/Manager/SocketManager.swift +++ b/Source/SocketIO/Manager/SocketManager.swift @@ -119,6 +119,8 @@ open class SocketManager: NSObject, SocketManagerSpec, SocketParsable, SocketDat } } + public private(set) var version = SocketIOVersion.three + /// A list of packets that are waiting for binary data. /// /// The way that socket.io works all data should be sent directly after each packet. @@ -214,7 +216,7 @@ open class SocketManager: NSObject, SocketManagerSpec, SocketParsable, SocketDat var payloadStr = "" - if payload != nil, + if version.rawValue >= 3 && payload != nil, let payloadData = try? JSONSerialization.data(withJSONObject: payload!, options: .fragmentsAllowed), let jsonString = String(data: payloadData, encoding: .utf8) { payloadStr = jsonString @@ -349,12 +351,20 @@ open class SocketManager: NSObject, SocketManagerSpec, SocketParsable, SocketDat status = .connected - for (_, socket) in nsps where socket.status == .connecting { + if version.rawValue < 3 { + nsps["/"]?.didConnect(toNamespace: "/", payload: nil) + } + + for (nsp, socket) in nsps where socket.status == .connecting { + if version.rawValue < 3 && nsp == "/" { + continue + } + connectSocket(socket, withPayload: socket.connectPayload) } } - /// Called when the engine receives a pong message. + /// Called when the engine receives a ping message. open func engineDidReceivePing() { handleQueue.async { self._engineDidReceivePing() @@ -366,6 +376,28 @@ open class SocketManager: NSObject, SocketManagerSpec, SocketParsable, SocketDat } /// Called when the sends a ping to the server. + open func engineDidSendPing() { + handleQueue.async { + self._engineDidSendPing() + } + } + + private func _engineDidSendPing() { + emitAll(clientEvent: .ping, data: []) + } + + /// Called when the engine receives a pong message. + open func engineDidReceivePong() { + handleQueue.async { + self._engineDidReceivePong() + } + } + + private func _engineDidReceivePong() { + emitAll(clientEvent: .pong, data: []) + } + + /// Called when the sends a pong to the server. open func engineDidSendPong() { handleQueue.async { self._engineDidSendPong() @@ -508,13 +540,13 @@ open class SocketManager: NSObject, SocketManagerSpec, SocketParsable, SocketDat for option in config { switch option { case let .forceNew(new): - self.forceNew = new + forceNew = new case let .handleQueue(queue): - self.handleQueue = queue + handleQueue = queue case let .reconnects(reconnects): self.reconnects = reconnects case let .reconnectAttempts(attempts): - self.reconnectAttempts = attempts + reconnectAttempts = attempts case let .reconnectWait(wait): reconnectWait = abs(wait) case let .reconnectWaitMax(wait): @@ -525,6 +557,8 @@ open class SocketManager: NSObject, SocketManagerSpec, SocketParsable, SocketDat DefaultSocketLogger.Logger.log = log case let .logger(logger): DefaultSocketLogger.Logger = logger + case let .version(num): + version = num case _: continue } diff --git a/Source/SocketIO/Manager/SocketManagerSpec.swift b/Source/SocketIO/Manager/SocketManagerSpec.swift index 01acc32..87be545 100644 --- a/Source/SocketIO/Manager/SocketManagerSpec.swift +++ b/Source/SocketIO/Manager/SocketManagerSpec.swift @@ -83,6 +83,9 @@ public protocol SocketManagerSpec : AnyObject, SocketEngineClient { /// The status of this manager. var status: SocketIOStatus { get } + /// The version of socket.io in use. + var version: SocketIOVersion { get } + // MARK: Methods /// Connects the underlying transport. diff --git a/Tests/TestSocketIO/SocketEngineTest.swift b/Tests/TestSocketIO/SocketEngineTest.swift index a5806d7..9f44cc8 100644 --- a/Tests/TestSocketIO/SocketEngineTest.swift +++ b/Tests/TestSocketIO/SocketEngineTest.swift @@ -10,6 +10,19 @@ import XCTest @testable import SocketIO class SocketEngineTest: XCTestCase { + func testBasicPollingMessageV3() { + let expect = expectation(description: "Basic polling test v3") + + socket.on("blankTest") {data, ack in + expect.fulfill() + } + + engine.setConfigs([.version(.two)]) + engine.parsePollingMessage("15:42[\"blankTest\"]") + + waitForExpectations(timeout: 3, handler: nil) + } + func testBasicPollingMessage() { let expect = expectation(description: "Basic polling test") socket.on("blankTest") {data, ack in @@ -83,15 +96,15 @@ class SocketEngineTest: XCTestCase { "created": "2016-05-04T18:31:15+0200" ] - XCTAssertEqual(engine.urlPolling.query, "transport=polling&b64=1&created=2016-05-04T18%3A31%3A15%2B0200") - XCTAssertEqual(engine.urlWebSocket.query, "transport=websocket&created=2016-05-04T18%3A31%3A15%2B0200") + XCTAssertEqual(engine.urlPolling.query, "transport=polling&b64=1&created=2016-05-04T18%3A31%3A15%2B0200&EIO=4") + XCTAssertEqual(engine.urlWebSocket.query, "transport=websocket&created=2016-05-04T18%3A31%3A15%2B0200&EIO=4") engine.connectParams = [ "forbidden": "!*'();:@&=+$,/?%#[]\" {}^|" ] - XCTAssertEqual(engine.urlPolling.query, "transport=polling&b64=1&forbidden=%21%2A%27%28%29%3B%3A%40%26%3D%2B%24%2C%2F%3F%25%23%5B%5D%22%20%7B%7D%5E%7C") - XCTAssertEqual(engine.urlWebSocket.query, "transport=websocket&forbidden=%21%2A%27%28%29%3B%3A%40%26%3D%2B%24%2C%2F%3F%25%23%5B%5D%22%20%7B%7D%5E%7C") + XCTAssertEqual(engine.urlPolling.query, "transport=polling&b64=1&forbidden=%21%2A%27%28%29%3B%3A%40%26%3D%2B%24%2C%2F%3F%25%23%5B%5D%22%20%7B%7D%5E%7C&EIO=4") + XCTAssertEqual(engine.urlWebSocket.query, "transport=websocket&forbidden=%21%2A%27%28%29%3B%3A%40%26%3D%2B%24%2C%2F%3F%25%23%5B%5D%22%20%7B%7D%5E%7C&EIO=4") } func testBase64Data() { diff --git a/Tests/TestSocketIO/SocketSideEffectTest.swift b/Tests/TestSocketIO/SocketSideEffectTest.swift index 7440924..ecaaee0 100644 --- a/Tests/TestSocketIO/SocketSideEffectTest.swift +++ b/Tests/TestSocketIO/SocketSideEffectTest.swift @@ -485,6 +485,7 @@ class TestEngine: SocketEngineSpec { private(set) var urlWebSocket = URL(string: "http://localhost/")! private(set) var websocket = false private(set) var ws: WebSocket? = nil + private(set) var version = SocketIOVersion.three fileprivate var onConnect: (() -> ())?