Support both v2 and v3

This commit is contained in:
Erik Little 2021-01-27 13:22:14 -05:00
parent 21290f5752
commit fde88c10c5
No known key found for this signature in database
GPG Key ID: 62F837E56F4E9320
10 changed files with 219 additions and 27 deletions

View File

@ -134,6 +134,18 @@ open class SocketIOClient: NSObject, SocketIOClientSpec {
joinNamespace(withPayload: payload) joinNamespace(withPayload: payload)
switch manager.version {
case .three:
break
case .two where manager.status == .connected && nsp == "/":
// We might not get a connect event for the default nsp, fire immediately
didConnect(toNamespace: nsp, payload: nil)
return
case _:
break
}
guard timeoutAfter != 0 else { return } guard timeoutAfter != 0 else { return }
manager.handleQueue.asyncAfter(deadline: DispatchTime.now() + timeoutAfter) {[weak self] in manager.handleQueue.asyncAfter(deadline: DispatchTime.now() + timeoutAfter) {[weak self] in

View File

@ -25,6 +25,12 @@
import Foundation import Foundation
import Starscream import Starscream
/// The socket.io version being used.
public enum SocketIOVersion: Int {
case two = 2
case three = 3
}
protocol ClientOption : CustomStringConvertible, Equatable { protocol ClientOption : CustomStringConvertible, Equatable {
func getSocketIOOptionValue() -> Any func getSocketIOOptionValue() -> Any
} }
@ -99,6 +105,9 @@ public enum SocketIOClientOption : ClientOption {
/// Sets an NSURLSessionDelegate for the underlying engine. Useful if you need to handle self-signed certs. /// Sets an NSURLSessionDelegate for the underlying engine. Useful if you need to handle self-signed certs.
case sessionDelegate(URLSessionDelegate) case sessionDelegate(URLSessionDelegate)
/// The version of socket.io being used. This should match the server version. Default is 3.
case version(SocketIOVersion)
// MARK: Properties // MARK: Properties
/// The description of this option. /// The description of this option.
@ -148,6 +157,8 @@ public enum SocketIOClientOption : ClientOption {
description = "sessionDelegate" description = "sessionDelegate"
case .enableSOCKSProxy: case .enableSOCKSProxy:
description = "enableSOCKSProxy" description = "enableSOCKSProxy"
case .version:
description = "version"
} }
return description return description
@ -199,6 +210,8 @@ public enum SocketIOClientOption : ClientOption {
value = delegate value = delegate
case let .enableSOCKSProxy(enable): case let .enableSOCKSProxy(enable):
value = enable value = enable
case let.version(versionNum):
value = versionNum
} }
return value return value

View File

@ -111,6 +111,9 @@ open class SocketEngine:
/// The url for WebSockets. /// The url for WebSockets.
public private(set) var urlWebSocket = URL(string: "http://localhost/")! public private(set) var urlWebSocket = URL(string: "http://localhost/")!
/// The version of engine.io being used. Default is three.
public private(set) var version: SocketIOVersion = .three
/// If `true`, then the engine is currently in WebSockets mode. /// If `true`, then the engine is currently in WebSockets mode.
@available(*, deprecated, message: "No longer needed, if we're not polling, then we must be doing websockets") @available(*, deprecated, message: "No longer needed, if we're not polling, then we must be doing websockets")
public private(set) var websocket = false public private(set) var websocket = false
@ -133,8 +136,14 @@ open class SocketEngine:
private var lastCommunication: Date? private var lastCommunication: Date?
private var pingInterval: Int? private var pingInterval: Int?
private var pingTimeout = 0 private var pingTimeout = 0 {
didSet {
pongsMissedMax = Int(pingTimeout / (pingInterval ?? 25000))
}
}
private var pongsMissed = 0
private var pongsMissedMax = 0
private var probeWait = ProbeWaitQueue() private var probeWait = ProbeWaitQueue()
private var secure = false private var secure = false
private var certPinner: CertificatePinning? private var certPinner: CertificatePinning?
@ -196,8 +205,9 @@ open class SocketEngine:
} }
private func handleBase64(message: String) { private func handleBase64(message: String) {
let offset = version.rawValue >= 3 ? 1 : 2
// binary in base64 string // binary in base64 string
let noPrefix = String(message[message.index(message.startIndex, offsetBy: 1)..<message.endIndex]) let noPrefix = String(message[message.index(message.startIndex, offsetBy: offset)..<message.endIndex])
if let data = Data(base64Encoded: noPrefix, options: .ignoreUnknownCharacters) { if let data = Data(base64Encoded: noPrefix, options: .ignoreUnknownCharacters) {
client?.parseEngineBinaryData(data) client?.parseEngineBinaryData(data)
@ -278,6 +288,14 @@ open class SocketEngine:
urlWebSocket.percentEncodedQuery = "transport=websocket" + queryString urlWebSocket.percentEncodedQuery = "transport=websocket" + queryString
urlPolling.percentEncodedQuery = "transport=polling&b64=1" + queryString urlPolling.percentEncodedQuery = "transport=polling&b64=1" + queryString
if !urlWebSocket.percentEncodedQuery!.contains("EIO") {
urlWebSocket.percentEncodedQuery = urlWebSocket.percentEncodedQuery! + engineIOParam
}
if !urlPolling.percentEncodedQuery!.contains("EIO") {
urlPolling.percentEncodedQuery = urlPolling.percentEncodedQuery! + engineIOParam
}
return (urlPolling.url!, urlWebSocket.url!) return (urlPolling.url!, urlWebSocket.url!)
} }
@ -289,6 +307,8 @@ open class SocketEngine:
includingCookies: session?.configuration.httpCookieStorage?.cookies(for: urlPollingWithSid) includingCookies: session?.configuration.httpCookieStorage?.cookies(for: urlPollingWithSid)
) )
print("ws req: \(req)")
ws = WebSocket(request: req, certPinner: certPinner, compressionHandler: compress ? WSCompression() : nil) ws = WebSocket(request: req, certPinner: certPinner, compressionHandler: compress ? WSCompression() : nil)
ws?.callbackQueue = engineQueue ws?.callbackQueue = engineQueue
ws?.delegate = self ws?.delegate = self
@ -413,6 +433,7 @@ open class SocketEngine:
self.sid = sid self.sid = sid
connected = true connected = true
pongsMissed = 0
if let upgrades = json["upgrades"] as? [String] { if let upgrades = json["upgrades"] as? [String] {
upgradeWs = upgrades.contains("websocket") upgradeWs = upgrades.contains("websocket")
@ -429,15 +450,22 @@ open class SocketEngine:
createWebSocketAndConnect() createWebSocketAndConnect()
} }
if version.rawValue >= 3 {
checkPings()
} else {
sendPing()
}
if !forceWebsockets { if !forceWebsockets {
doPoll() doPoll()
} }
checkPings()
client?.engineDidOpen(reason: "Connect") client?.engineDidOpen(reason: "Connect")
} }
private func handlePong(with message: String) { private func handlePong(with message: String) {
pongsMissed = 0
// We should upgrade // We should upgrade
if message == "3probe" { if message == "3probe" {
DefaultSocketLogger.Logger.log("Received probe response, should upgrade to WebSockets", DefaultSocketLogger.Logger.log("Received probe response, should upgrade to WebSockets",
@ -445,10 +473,14 @@ open class SocketEngine:
upgradeTransport() upgradeTransport()
} }
client?.engineDidReceivePong()
} }
private func handlePing(with message: String) { private func handlePing(with message: String) {
write("", withType: .pong, withData: []) if version.rawValue >= 3 {
write("", withType: .pong, withData: [])
}
client?.engineDidReceivePing() client?.engineDidReceivePing()
} }
@ -478,7 +510,7 @@ open class SocketEngine:
lastCommunication = Date() lastCommunication = Date()
client?.parseEngineBinaryData(data) client?.parseEngineBinaryData(version.rawValue >= 3 ? data : data.subdata(in: 1..<data.endIndex))
} }
/// Parses a raw engine.io packet. /// Parses a raw engine.io packet.
@ -489,13 +521,11 @@ open class SocketEngine:
DefaultSocketLogger.Logger.log("Got message: \(message)", type: SocketEngine.logType) DefaultSocketLogger.Logger.log("Got message: \(message)", type: SocketEngine.logType)
let reader = SocketStringReader(message: message) if message.hasPrefix(version.rawValue >= 3 ? "b" : "b4") {
if message.hasPrefix("b") {
return handleBase64(message: message) return handleBase64(message: message)
} }
guard let type = SocketEnginePacketType(rawValue: Int(reader.currentCharacter) ?? -1) else { guard let type = SocketEnginePacketType(rawValue: message.first?.wholeNumberValue ?? -1) else {
checkAndHandleEngineError(message) checkAndHandleEngineError(message)
return return
@ -536,6 +566,34 @@ open class SocketEngine:
waitingForPost = false waitingForPost = false
} }
private func sendPing() {
guard connected, let pingInterval = pingInterval else {
print("not connected \(self.connected) or no ping interval \(self.pingInterval ?? -222)")
return
}
// Server is not responding
if pongsMissed > pongsMissedMax {
closeOutEngine(reason: "Ping timeout")
return
}
pongsMissed += 1
write("", withType: .ping, withData: [], completion: nil)
engineQueue.asyncAfter(deadline: .now() + .milliseconds(pingInterval)) {[weak self, id = self.sid] in
// Make sure not to ping old connections
guard let this = self, this.sid == id else {
print("wrong ping?")
return
}
this.sendPing()
}
client?.engineDidSendPing()
}
/// Called when the engine should set/update its configs from a given configuration. /// Called when the engine should set/update its configs from a given configuration.
/// ///
/// parameter config: The `SocketIOClientConfiguration` that should be used to set/update configs. /// parameter config: The `SocketIOClientConfiguration` that should be used to set/update configs.
@ -570,6 +628,8 @@ open class SocketEngine:
self.compress = true self.compress = true
case .enableSOCKSProxy: case .enableSOCKSProxy:
self.enableSOCKSProxy = true self.enableSOCKSProxy = true
case let .version(num):
version = num
default: default:
continue continue
} }

View File

@ -44,10 +44,16 @@ import Foundation
/// - parameter reason: The reason the engine opened. /// - parameter reason: The reason the engine opened.
func engineDidOpen(reason: String) func engineDidOpen(reason: String)
/// Called when the engine receives a ping message. /// Called when the engine receives a ping message. Only called in socket.io >3.
func engineDidReceivePing() func engineDidReceivePing()
/// Called when the engine sends a pong to the server. /// Called when the engine receives a pong message. Only called in socket.io 2.
func engineDidReceivePong()
/// Called when the engine sends a ping to the server. Only called in socket.io 2.
func engineDidSendPing()
/// Called when the engine sends a pong to the server. Only called in socket.io >3.
func engineDidSendPong() func engineDidSendPong()
/// Called when the engine has a message that must be parsed. /// Called when the engine has a message that must be parsed.

View File

@ -79,7 +79,15 @@ extension SocketEnginePollable {
postWait.removeAll(keepingCapacity: true) postWait.removeAll(keepingCapacity: true)
} }
let postStr = postWait.lazy.map({ $0.msg }).joined(separator: "\u{1e}") var postStr = ""
if version.rawValue >= 3 {
postStr = postWait.lazy.map({ $0.msg }).joined(separator: "\u{1e}")
} else {
for packet in postWait {
postStr += "\(packet.msg.utf16.count):\(packet.msg)"
}
}
DefaultSocketLogger.Logger.log("Created POST string: \(postStr)", type: "SocketEnginePolling") DefaultSocketLogger.Logger.log("Created POST string: \(postStr)", type: "SocketEnginePolling")
@ -195,10 +203,29 @@ extension SocketEnginePollable {
DefaultSocketLogger.Logger.log("Got poll message: \(str)", type: "SocketEnginePolling") DefaultSocketLogger.Logger.log("Got poll message: \(str)", type: "SocketEnginePolling")
let records = str.components(separatedBy: "\u{1e}") if version.rawValue >= 3 {
let records = str.components(separatedBy: "\u{1e}")
for record in records { for record in records {
parseEngineMessage(record) parseEngineMessage(record)
}
} else {
guard str.count != 1 else {
parseEngineMessage(str)
return
}
var reader = SocketStringReader(message: str)
while reader.hasNext {
if let n = Int(reader.readUntilOccurence(of: ":")) {
parseEngineMessage(reader.read(count: n))
} else {
parseEngineMessage(str)
break
}
}
} }
} }

View File

@ -81,6 +81,9 @@ public protocol SocketEngineSpec: class {
/// The url for WebSockets. /// The url for WebSockets.
var urlWebSocket: URL { get } var urlWebSocket: URL { get }
/// The version of engine.io being used. Default is three.
var version: SocketIOVersion { get }
/// If `true`, then the engine is currently in WebSockets mode. /// If `true`, then the engine is currently in WebSockets mode.
@available(*, deprecated, message: "No longer needed, if we're not polling, then we must be doing websockets") @available(*, deprecated, message: "No longer needed, if we're not polling, then we must be doing websockets")
var websocket: Bool { get } var websocket: Bool { get }
@ -142,10 +145,23 @@ public protocol SocketEngineSpec: class {
} }
extension SocketEngineSpec { extension SocketEngineSpec {
var engineIOParam: String {
switch version {
case .two:
return "&EIO=3"
case .three:
return "&EIO=4"
}
}
var urlPollingWithSid: URL { var urlPollingWithSid: URL {
var com = URLComponents(url: urlPolling, resolvingAgainstBaseURL: false)! var com = URLComponents(url: urlPolling, resolvingAgainstBaseURL: false)!
com.percentEncodedQuery = com.percentEncodedQuery! + "&sid=\(sid.urlEncode()!)" com.percentEncodedQuery = com.percentEncodedQuery! + "&sid=\(sid.urlEncode()!)"
if !com.percentEncodedQuery!.contains("EIO") {
com.percentEncodedQuery = com.percentEncodedQuery! + engineIOParam
}
return com.url! return com.url!
} }
@ -153,6 +169,11 @@ extension SocketEngineSpec {
var com = URLComponents(url: urlWebSocket, resolvingAgainstBaseURL: false)! var com = URLComponents(url: urlWebSocket, resolvingAgainstBaseURL: false)!
com.percentEncodedQuery = com.percentEncodedQuery! + (sid == "" ? "" : "&sid=\(sid.urlEncode()!)") com.percentEncodedQuery = com.percentEncodedQuery! + (sid == "" ? "" : "&sid=\(sid.urlEncode()!)")
if !com.percentEncodedQuery!.contains("EIO") {
com.percentEncodedQuery = com.percentEncodedQuery! + engineIOParam
}
return com.url! return com.url!
} }
@ -172,10 +193,12 @@ extension SocketEngineSpec {
} }
func createBinaryDataForSend(using data: Data) -> Either<Data, String> { func createBinaryDataForSend(using data: Data) -> Either<Data, String> {
let prefixB64 = version.rawValue >= 3 ? "b" : "b4"
if polling { if polling {
return .right("b" + data.base64EncodedString(options: Data.Base64EncodingOptions(rawValue: 0))) return .right(prefixB64 + data.base64EncodedString(options: Data.Base64EncodingOptions(rawValue: 0)))
} else { } else {
return .left(data) return .left(version.rawValue >= 3 ? data : Data([0x4]) + data)
} }
} }

View File

@ -119,6 +119,8 @@ open class SocketManager: NSObject, SocketManagerSpec, SocketParsable, SocketDat
} }
} }
public private(set) var version = SocketIOVersion.three
/// A list of packets that are waiting for binary data. /// A list of packets that are waiting for binary data.
/// ///
/// The way that socket.io works all data should be sent directly after each packet. /// The way that socket.io works all data should be sent directly after each packet.
@ -214,7 +216,7 @@ open class SocketManager: NSObject, SocketManagerSpec, SocketParsable, SocketDat
var payloadStr = "" var payloadStr = ""
if payload != nil, if version.rawValue >= 3 && payload != nil,
let payloadData = try? JSONSerialization.data(withJSONObject: payload!, options: .fragmentsAllowed), let payloadData = try? JSONSerialization.data(withJSONObject: payload!, options: .fragmentsAllowed),
let jsonString = String(data: payloadData, encoding: .utf8) { let jsonString = String(data: payloadData, encoding: .utf8) {
payloadStr = jsonString payloadStr = jsonString
@ -349,12 +351,20 @@ open class SocketManager: NSObject, SocketManagerSpec, SocketParsable, SocketDat
status = .connected status = .connected
for (_, socket) in nsps where socket.status == .connecting { if version.rawValue < 3 {
nsps["/"]?.didConnect(toNamespace: "/", payload: nil)
}
for (nsp, socket) in nsps where socket.status == .connecting {
if version.rawValue < 3 && nsp == "/" {
continue
}
connectSocket(socket, withPayload: socket.connectPayload) connectSocket(socket, withPayload: socket.connectPayload)
} }
} }
/// Called when the engine receives a pong message. /// Called when the engine receives a ping message.
open func engineDidReceivePing() { open func engineDidReceivePing() {
handleQueue.async { handleQueue.async {
self._engineDidReceivePing() self._engineDidReceivePing()
@ -366,6 +376,28 @@ open class SocketManager: NSObject, SocketManagerSpec, SocketParsable, SocketDat
} }
/// Called when the sends a ping to the server. /// Called when the sends a ping to the server.
open func engineDidSendPing() {
handleQueue.async {
self._engineDidSendPing()
}
}
private func _engineDidSendPing() {
emitAll(clientEvent: .ping, data: [])
}
/// Called when the engine receives a pong message.
open func engineDidReceivePong() {
handleQueue.async {
self._engineDidReceivePong()
}
}
private func _engineDidReceivePong() {
emitAll(clientEvent: .pong, data: [])
}
/// Called when the sends a pong to the server.
open func engineDidSendPong() { open func engineDidSendPong() {
handleQueue.async { handleQueue.async {
self._engineDidSendPong() self._engineDidSendPong()
@ -508,13 +540,13 @@ open class SocketManager: NSObject, SocketManagerSpec, SocketParsable, SocketDat
for option in config { for option in config {
switch option { switch option {
case let .forceNew(new): case let .forceNew(new):
self.forceNew = new forceNew = new
case let .handleQueue(queue): case let .handleQueue(queue):
self.handleQueue = queue handleQueue = queue
case let .reconnects(reconnects): case let .reconnects(reconnects):
self.reconnects = reconnects self.reconnects = reconnects
case let .reconnectAttempts(attempts): case let .reconnectAttempts(attempts):
self.reconnectAttempts = attempts reconnectAttempts = attempts
case let .reconnectWait(wait): case let .reconnectWait(wait):
reconnectWait = abs(wait) reconnectWait = abs(wait)
case let .reconnectWaitMax(wait): case let .reconnectWaitMax(wait):
@ -525,6 +557,8 @@ open class SocketManager: NSObject, SocketManagerSpec, SocketParsable, SocketDat
DefaultSocketLogger.Logger.log = log DefaultSocketLogger.Logger.log = log
case let .logger(logger): case let .logger(logger):
DefaultSocketLogger.Logger = logger DefaultSocketLogger.Logger = logger
case let .version(num):
version = num
case _: case _:
continue continue
} }

View File

@ -83,6 +83,9 @@ public protocol SocketManagerSpec : AnyObject, SocketEngineClient {
/// The status of this manager. /// The status of this manager.
var status: SocketIOStatus { get } var status: SocketIOStatus { get }
/// The version of socket.io in use.
var version: SocketIOVersion { get }
// MARK: Methods // MARK: Methods
/// Connects the underlying transport. /// Connects the underlying transport.

View File

@ -10,6 +10,19 @@ import XCTest
@testable import SocketIO @testable import SocketIO
class SocketEngineTest: XCTestCase { class SocketEngineTest: XCTestCase {
func testBasicPollingMessageV3() {
let expect = expectation(description: "Basic polling test v3")
socket.on("blankTest") {data, ack in
expect.fulfill()
}
engine.setConfigs([.version(.two)])
engine.parsePollingMessage("15:42[\"blankTest\"]")
waitForExpectations(timeout: 3, handler: nil)
}
func testBasicPollingMessage() { func testBasicPollingMessage() {
let expect = expectation(description: "Basic polling test") let expect = expectation(description: "Basic polling test")
socket.on("blankTest") {data, ack in socket.on("blankTest") {data, ack in
@ -83,15 +96,15 @@ class SocketEngineTest: XCTestCase {
"created": "2016-05-04T18:31:15+0200" "created": "2016-05-04T18:31:15+0200"
] ]
XCTAssertEqual(engine.urlPolling.query, "transport=polling&b64=1&created=2016-05-04T18%3A31%3A15%2B0200") XCTAssertEqual(engine.urlPolling.query, "transport=polling&b64=1&created=2016-05-04T18%3A31%3A15%2B0200&EIO=4")
XCTAssertEqual(engine.urlWebSocket.query, "transport=websocket&created=2016-05-04T18%3A31%3A15%2B0200") XCTAssertEqual(engine.urlWebSocket.query, "transport=websocket&created=2016-05-04T18%3A31%3A15%2B0200&EIO=4")
engine.connectParams = [ engine.connectParams = [
"forbidden": "!*'();:@&=+$,/?%#[]\" {}^|" "forbidden": "!*'();:@&=+$,/?%#[]\" {}^|"
] ]
XCTAssertEqual(engine.urlPolling.query, "transport=polling&b64=1&forbidden=%21%2A%27%28%29%3B%3A%40%26%3D%2B%24%2C%2F%3F%25%23%5B%5D%22%20%7B%7D%5E%7C") XCTAssertEqual(engine.urlPolling.query, "transport=polling&b64=1&forbidden=%21%2A%27%28%29%3B%3A%40%26%3D%2B%24%2C%2F%3F%25%23%5B%5D%22%20%7B%7D%5E%7C&EIO=4")
XCTAssertEqual(engine.urlWebSocket.query, "transport=websocket&forbidden=%21%2A%27%28%29%3B%3A%40%26%3D%2B%24%2C%2F%3F%25%23%5B%5D%22%20%7B%7D%5E%7C") XCTAssertEqual(engine.urlWebSocket.query, "transport=websocket&forbidden=%21%2A%27%28%29%3B%3A%40%26%3D%2B%24%2C%2F%3F%25%23%5B%5D%22%20%7B%7D%5E%7C&EIO=4")
} }
func testBase64Data() { func testBase64Data() {

View File

@ -485,6 +485,7 @@ class TestEngine: SocketEngineSpec {
private(set) var urlWebSocket = URL(string: "http://localhost/")! private(set) var urlWebSocket = URL(string: "http://localhost/")!
private(set) var websocket = false private(set) var websocket = false
private(set) var ws: WebSocket? = nil private(set) var ws: WebSocket? = nil
private(set) var version = SocketIOVersion.three
fileprivate var onConnect: (() -> ())? fileprivate var onConnect: (() -> ())?