diff --git a/Source/WebSocket.swift b/Source/WebSocket.swift index 4af04fa..d8006c3 100644 --- a/Source/WebSocket.swift +++ b/Source/WebSocket.swift @@ -57,7 +57,7 @@ public class WebSocket : NSObject, NSStreamDelegate { //Where the callback is executed. It defaults to the main UI thread queue. public var queue = dispatch_get_main_queue() - var optionalProtocols : Array? + var optionalProtocols : [String]? //Constant Values. let headerWSUpgradeName = "Upgrade" let headerWSUpgradeValue = "websocket" @@ -93,7 +93,7 @@ public class WebSocket : NSObject, NSStreamDelegate { public var onText: ((String) -> Void)? public var onData: ((NSData) -> Void)? public var onPong: ((Void) -> Void)? - public var headers = Dictionary() + public var headers = [String: String]() public var voipEnabled = false public var selfSignedSSL = false private var security: SSLSecurity? @@ -108,36 +108,31 @@ public class WebSocket : NSObject, NSStreamDelegate { private var connected = false private var isCreated = false private var writeQueue = NSOperationQueue() - private var readStack = Array() - private var inputQueue = Array() + private var readStack = [WSResponse]() + private var inputQueue = [NSData]() private var fragBuffer: NSData? private var certValidated = false private var didDisconnect = false - //init the websocket with a url - public init(url: NSURL) { + //used for setting protocols. + public init(url: NSURL, protocols: [String]? = nil) { self.url = url writeQueue.maxConcurrentOperationCount = 1 - } - //used for setting protocols. - public convenience init(url: NSURL, protocols: Array) { - self.init(url: url) optionalProtocols = protocols } ///Connect to the websocket server on a background thread public func connect() { - if isCreated { - return - } - dispatch_async(queue,{ [weak self] in + guard !isCreated else { return } + + dispatch_async(queue) { [weak self] in self?.didDisconnect = false - }) - dispatch_async(dispatch_get_global_queue(DISPATCH_QUEUE_PRIORITY_DEFAULT,0), { [weak self] in + } + dispatch_async(dispatch_get_global_queue(DISPATCH_QUEUE_PRIORITY_DEFAULT,0)) { [weak self] in self?.isCreated = true self?.createHTTPRequest() self?.isCreated = false - }) + } } /** @@ -152,9 +147,9 @@ public class WebSocket : NSObject, NSStreamDelegate { public func disconnect(forceTimeout forceTimeout: NSTimeInterval? = nil) { switch forceTimeout { case .Some(let seconds) where seconds > 0: - dispatch_after(dispatch_time(DISPATCH_TIME_NOW, Int64(seconds * Double(NSEC_PER_SEC))), queue, { [unowned self] in + dispatch_after(dispatch_time(DISPATCH_TIME_NOW, Int64(seconds * Double(NSEC_PER_SEC))), queue) { [unowned self] in self.disconnectStream(nil) - }) + } fallthrough case .None: writeError(CloseCode.Normal.rawValue) @@ -190,7 +185,7 @@ public class WebSocket : NSObject, NSStreamDelegate { var port = url.port if port == nil { - if url.scheme == "wss" || url.scheme == "https" { + if ["wss", "https"].contains(url.scheme) { port = 443 } else { port = 80 @@ -214,18 +209,14 @@ public class WebSocket : NSObject, NSStreamDelegate { } } //Add a header to the CFHTTPMessage by using the NSString bridges to CFString - private func addHeader(urlRequest: CFHTTPMessage,key: String, val: String) { - let nsKey: NSString = key - let nsVal: NSString = val - CFHTTPMessageSetHeaderFieldValue(urlRequest, - nsKey, - nsVal) + private func addHeader(urlRequest: CFHTTPMessage, key: NSString, val: NSString) { + CFHTTPMessageSetHeaderFieldValue(urlRequest, key, val) } //generate a websocket key as needed in rfc private func generateWebSocketKey() -> String { var key = "" let seed = 16 - for (var i = 0; i < seed; i++) { + for _ in 0.. = [kCFStreamSSLValidatesCertificateChain: NSNumber(bool:false), kCFStreamSSLPeerName: kCFNull] + let settings: [NSObject: NSObject] = [kCFStreamSSLValidatesCertificateChain: NSNumber(bool:false), kCFStreamSSLPeerName: kCFNull] inStream.setProperty(settings, forKey: kCFStreamPropertySSLSettings as String) outStream.setProperty(settings, forKey: kCFStreamPropertySSLSettings as String) } @@ -267,12 +258,12 @@ public class WebSocket : NSObject, NSStreamDelegate { sslContextOut = CFWriteStreamCopyProperty(outputStream, kCFStreamPropertySSLContext) as! SSLContextRef? { let resIn = SSLSetEnabledCiphers(sslContextIn, cipherSuites, cipherSuites.count) let resOut = SSLSetEnabledCiphers(sslContextOut, cipherSuites, cipherSuites.count) - if (resIn != errSecSuccess) { + if resIn != errSecSuccess { let error = self.errorWithDetail("Error setting ingoing cypher suites", code: UInt16(resIn)) disconnectStream(error) return } - if (resOut != errSecSuccess) { + if resOut != errSecSuccess { let error = self.errorWithDetail("Error setting outgoing cypher suites", code: UInt16(resOut)) disconnectStream(error) return @@ -293,7 +284,7 @@ public class WebSocket : NSObject, NSStreamDelegate { //delegate for the stream methods. Processes incoming bytes public func stream(aStream: NSStream, handleEvent eventCode: NSStreamEvent) { - if let sec = security where !certValidated && (eventCode == .HasBytesAvailable || eventCode == .HasSpaceAvailable) { + if let sec = security where !certValidated && [.HasBytesAvailable, .HasSpaceAvailable].contains(eventCode) { let possibleTrust: AnyObject? = aStream.propertyForKey(kCFStreamPropertySSLPeerTrust as String) if let trust: AnyObject = possibleTrust { let domain: AnyObject? = aStream.propertyForKey(kCFStreamSSLPeerName as String) @@ -307,7 +298,7 @@ public class WebSocket : NSObject, NSStreamDelegate { } } if eventCode == .HasBytesAvailable { - if(aStream == inputStream) { + if aStream == inputStream { processInputStream() } } else if eventCode == .ErrorOccurred { @@ -339,50 +330,51 @@ public class WebSocket : NSObject, NSStreamDelegate { let buf = NSMutableData(capacity: BUFFER_MAX) let buffer = UnsafeMutablePointer(buf!.bytes) let length = inputStream!.read(buffer, maxLength: BUFFER_MAX) - if length > 0 { + + guard length > 0 else { return } + + if !connected { + connected = processHTTP(buffer, bufferLen: length) if !connected { - connected = processHTTP(buffer, bufferLen: length) - if !connected { - let response = CFHTTPMessageCreateEmpty(kCFAllocatorDefault, false).takeRetainedValue() - CFHTTPMessageAppendBytes(response, buffer, length) - let code = CFHTTPMessageGetResponseStatusCode(response) - doDisconnect(errorWithDetail("Invalid HTTP upgrade", code: UInt16(code))) - } - } else { - var process = false - if inputQueue.count == 0 { - process = true - } - inputQueue.append(NSData(bytes: buffer, length: length)) - if process { - dequeueInput() - } + let response = CFHTTPMessageCreateEmpty(kCFAllocatorDefault, false).takeRetainedValue() + CFHTTPMessageAppendBytes(response, buffer, length) + let code = CFHTTPMessageGetResponseStatusCode(response) + doDisconnect(errorWithDetail("Invalid HTTP upgrade", code: UInt16(code))) + } + } else { + var process = false + if inputQueue.count == 0 { + process = true + } + inputQueue.append(NSData(bytes: buffer, length: length)) + if process { + dequeueInput() } } } ///dequeue the incoming input so it is processed in order private func dequeueInput() { - if inputQueue.count > 0 { - let data = inputQueue[0] - var work = data - if fragBuffer != nil { - let combine = NSMutableData(data: fragBuffer!) - combine.appendData(data) - work = combine - fragBuffer = nil - } - let buffer = UnsafePointer(work.bytes) - processRawMessage(buffer, bufferLen: work.length) - inputQueue = inputQueue.filter{$0 != data} - dequeueInput() + guard !inputQueue.isEmpty else { return } + + let data = inputQueue[0] + var work = data + if let fragBuffer = fragBuffer { + let combine = NSMutableData(data: fragBuffer) + combine.appendData(data) + work = combine + self.fragBuffer = nil } + let buffer = UnsafePointer(work.bytes) + processRawMessage(buffer, bufferLen: work.length) + inputQueue = inputQueue.filter{$0 != data} + dequeueInput() } ///Finds the HTTP Packet in the TCP stream, by looking for the CRLF. private func processHTTP(buffer: UnsafePointer, bufferLen: Int) -> Bool { let CRLFBytes = [UInt8(ascii: "\r"), UInt8(ascii: "\n"), UInt8(ascii: "\r"), UInt8(ascii: "\n")] var k = 0 var totalSize = 0 - for var i = 0; i < bufferLen; i++ { + for i in 0.. 0 { if validateResponse(buffer, bufferLen: totalSize) { - dispatch_async(queue,{ [weak self] in + dispatch_async(queue) { [weak self] in guard let s = self else { return } - if let connectBlock = s.onConnect { - connectBlock() - } + s.onConnect?() s.delegate?.websocketDidConnect(s) - }) + } totalSize += 1 //skip the last \n let restSize = bufferLen - totalSize if restSize > 0 { @@ -438,17 +428,16 @@ public class WebSocket : NSObject, NSStreamDelegate { fragBuffer = NSData(bytes: buffer, length: bufferLen) return } - if response != nil && response!.bytesLeft > 0 { - let resp = response! - var len = resp.bytesLeft - var extra = bufferLen - resp.bytesLeft - if resp.bytesLeft > bufferLen { + if let response = response where response.bytesLeft > 0 { + var len = response.bytesLeft + var extra = bufferLen - response.bytesLeft + if response.bytesLeft > bufferLen { len = bufferLen extra = 0 } - resp.bytesLeft -= len - resp.buffer?.appendData(NSData(bytes: buffer, length: len)) - processResponse(resp) + response.bytesLeft -= len + response.buffer?.appendData(NSData(bytes: buffer, length: len)) + processResponse(response) let offset = bufferLen - extra if extra > 0 { processExtra((buffer+offset), bufferLen: extra) @@ -456,19 +445,19 @@ public class WebSocket : NSObject, NSStreamDelegate { return } else { let isFin = (FinMask & buffer[0]) - let receivedOpcode = (OpCodeMask & buffer[0]) + let receivedOpcode = OpCode(rawValue: (OpCodeMask & buffer[0])) let isMasked = (MaskMask & buffer[1]) let payloadLen = (PayloadLenMask & buffer[1]) var offset = 2 - if((isMasked > 0 || (RSVMask & buffer[0]) > 0) && receivedOpcode != OpCode.Pong.rawValue) { + if (isMasked > 0 || (RSVMask & buffer[0]) > 0) && receivedOpcode != .Pong { let errCode = CloseCode.ProtocolError.rawValue doDisconnect(errorWithDetail("masked and rsv data is not currently supported", code: errCode)) writeError(errCode) return } - let isControlFrame = (receivedOpcode == OpCode.ConnectionClose.rawValue || receivedOpcode == OpCode.Ping.rawValue) - if !isControlFrame && (receivedOpcode != OpCode.BinaryFrame.rawValue && receivedOpcode != OpCode.ContinueFrame.rawValue && - receivedOpcode != OpCode.TextFrame.rawValue && receivedOpcode != OpCode.Pong.rawValue) { + let isControlFrame = (receivedOpcode == .ConnectionClose || receivedOpcode == .Ping) + if !isControlFrame && (receivedOpcode != .BinaryFrame && receivedOpcode != .ContinueFrame && + receivedOpcode != .TextFrame && receivedOpcode != .Pong) { let errCode = CloseCode.ProtocolError.rawValue doDisconnect(errorWithDetail("unknown opcode: \(receivedOpcode)", code: errCode)) writeError(errCode) @@ -480,7 +469,7 @@ public class WebSocket : NSObject, NSStreamDelegate { writeError(errCode) return } - if receivedOpcode == OpCode.ConnectionClose.rawValue { + if receivedOpcode == .ConnectionClose { var code = CloseCode.Normal.rawValue if payloadLen == 1 { code = CloseCode.ProtocolError.rawValue @@ -535,14 +524,12 @@ public class WebSocket : NSObject, NSStreamDelegate { } else { data = NSData(bytes: UnsafePointer((buffer+offset)), length: Int(len)) } - if receivedOpcode == OpCode.Pong.rawValue { - dispatch_async(queue,{ [weak self] in + if receivedOpcode == .Pong { + dispatch_async(queue) { [weak self] in guard let s = self else { return } - if let pongBlock = s.onPong { - pongBlock() - } + s.onPong?() s.pongDelegate?.websocketDidReceivePong(s) - }) + } let step = Int(offset+numericCast(len)) let extra = bufferLen-step @@ -555,15 +542,15 @@ public class WebSocket : NSObject, NSStreamDelegate { if isControlFrame { response = nil //don't append pings } - if isFin == 0 && receivedOpcode == OpCode.ContinueFrame.rawValue && response == nil { + if isFin == 0 && receivedOpcode == .ContinueFrame && response == nil { let errCode = CloseCode.ProtocolError.rawValue doDisconnect(errorWithDetail("continue frame before a binary or text frame", code: errCode)) writeError(errCode) return } var isNew = false - if(response == nil) { - if receivedOpcode == OpCode.ContinueFrame.rawValue { + if response == nil { + if receivedOpcode == .ContinueFrame { let errCode = CloseCode.ProtocolError.rawValue doDisconnect(errorWithDetail("first frame can't be a continue frame", code: errCode)) @@ -572,11 +559,11 @@ public class WebSocket : NSObject, NSStreamDelegate { } isNew = true response = WSResponse() - response!.code = OpCode(rawValue: receivedOpcode)! + response!.code = receivedOpcode! response!.bytesLeft = Int(dataLength) response!.buffer = NSMutableData(data: data) } else { - if receivedOpcode == OpCode.ContinueFrame.rawValue { + if receivedOpcode == .ContinueFrame { response!.bytesLeft = Int(dataLength) } else { let errCode = CloseCode.ProtocolError.rawValue @@ -587,19 +574,19 @@ public class WebSocket : NSObject, NSStreamDelegate { } response!.buffer!.appendData(data) } - if response != nil { - response!.bytesLeft -= Int(len) - response!.frameCount++ - response!.isFin = isFin > 0 ? true : false - if(isNew) { - readStack.append(response!) + if let response = response { + response.bytesLeft -= Int(len) + response.frameCount++ + response.isFin = isFin > 0 ? true : false + if isNew { + readStack.append(response) } - processResponse(response!) + processResponse(response) } let step = Int(offset+numericCast(len)) let extra = bufferLen-step - if(extra > 0) { + if extra > 0 { processExtra((buffer+step), bufferLen: extra) } } @@ -628,22 +615,18 @@ public class WebSocket : NSObject, NSStreamDelegate { return false } - dispatch_async(queue,{ [weak self] in + dispatch_async(queue) { [weak self] in guard let s = self else { return } - if let textBlock = s.onText { - textBlock(str! as String) - } + s.onText?(str! as String) s.delegate?.websocketDidReceiveMessage(s, text: str! as String) - }) + } } else if response.code == .BinaryFrame { let data = response.buffer! //local copy so it is perverse for writing - dispatch_async(queue,{ [weak self] in + dispatch_async(queue) { [weak self] in guard let s = self else { return } - if let dataBlock = s.onData { - dataBlock(data) - } + s.onData?(data) s.delegate?.websocketDidReceiveData(s, data: data) - }) + } } readStack.removeLast() return true @@ -653,7 +636,7 @@ public class WebSocket : NSObject, NSStreamDelegate { ///Create an error private func errorWithDetail(detail: String, code: UInt16) -> NSError { - var details = Dictionary() + var details = [String: String]() details[NSLocalizedDescriptionKey] = detail return NSError(domain: WebSocket.ErrorDomain, code: Int(code), userInfo: details) } @@ -667,9 +650,8 @@ public class WebSocket : NSObject, NSStreamDelegate { } ///used to write things to the stream private func dequeueWrite(data: NSData, code: OpCode) { - if !isConnected { - return - } + guard isConnected else { return } + writeQueue.addOperationWithBlock { [weak self] in //stream isn't ready, let's wait guard let s = self else { return } @@ -697,7 +679,7 @@ public class WebSocket : NSObject, NSStreamDelegate { SecRandomCopyBytes(kSecRandomDefault, Int(sizeof(UInt32)), maskKey) offset += sizeof(UInt32) - for (var i = 0; i < dataLength; i++) { + for i in 0..() - for path in paths { - if let d = NSData(contentsOfFile: path as String) { - collect.append(SSLCert(data: d)) + + let certs = paths.reduce([SSLCert]()) { (var certs: [SSLCert], path: String) -> [SSLCert] in + if let data = NSData(contentsOfFile: path) { + certs.append(SSLCert(data: data)) } + return certs } - self.init(certs:collect, usePublicKeys: usePublicKeys) + + self.init(certs: certs, usePublicKeys: usePublicKeys) } /** - Designated init - - - parameter keys: is the certificates or public keys to use - - parameter usePublicKeys: is to specific if the publicKeys or certificates should be used for SSL pinning validation - - - returns: a representation security object to be used with - */ + Designated init + + - parameter keys: is the certificates or public keys to use + - parameter usePublicKeys: is to specific if the publicKeys or certificates should be used for SSL pinning validation + + - returns: a representation security object to be used with + */ init(certs: [SSLCert], usePublicKeys: Bool) { self.usePublicKeys = usePublicKeys if self.usePublicKeys { - dispatch_async(dispatch_get_global_queue(DISPATCH_QUEUE_PRIORITY_DEFAULT,0), { - var collect = Array() - for cert in certs { - if let data = cert.certData where cert.key == nil { + dispatch_async(dispatch_get_global_queue(DISPATCH_QUEUE_PRIORITY_DEFAULT,0)) { + let pubKeys = certs.reduce([SecKeyRef]()) { (var pubKeys: [SecKeyRef], cert: SSLCert) -> [SecKeyRef] in + if let data = cert.certData where cert.key == nil { cert.key = self.extractPublicKey(data) } - if let k = cert.key { - collect.append(k) + if let key = cert.key { + pubKeys.append(key) } + return pubKeys } - self.pubKeys = collect + + self.pubKeys = pubKeys self.isReady = true - }) - } else { - var collect = Array() - for cert in certs { - if let d = cert.certData { - collect.append(d) - } } - self.certificates = collect + } else { + let certificates = certs.reduce([NSData]()) { (var certificates: [NSData], cert: SSLCert) -> [NSData] in + if let data = cert.certData { + certificates.append(data) + } + return certificates + } + self.certificates = certificates self.isReady = true } } /** - Valid the trust and domain name. - - - parameter trust: is the serverTrust to validate - - parameter domain: is the CN domain to validate - - - returns: if the key was successfully validated - */ + Valid the trust and domain name. + + - parameter trust: is the serverTrust to validate + - parameter domain: is the CN domain to validate + + - returns: if the key was successfully validated + */ func isValid(trust: SecTrustRef, domain: String?) -> Bool { var tries = 0 @@ -863,23 +845,18 @@ private class SSLSecurity { SecTrustSetPolicies(trust,policy) if self.usePublicKeys { if let keys = self.pubKeys { - var trustedCount = 0 let serverPubKeys = publicKeyChainForTrust(trust) for serverKey in serverPubKeys as [AnyObject] { for key in keys as [AnyObject] { if serverKey.isEqual(key) { - trustedCount++ - break + return true } } } - if trustedCount == serverPubKeys.count { - return true - } } } else if let certs = self.certificates { let serverCerts = certificateChainForTrust(trust) - var collect = Array() + var collect = [SecCertificate]() for cert in certs { collect.append(SecCertificateCreateWithData(nil,cert)!) } @@ -906,71 +883,72 @@ private class SSLSecurity { } /** - Get the public key from a certificate data - - - parameter data: is the certificate to pull the public key from - - - returns: a public key - */ + Get the public key from a certificate data + + - parameter data: is the certificate to pull the public key from + + - returns: a public key + */ func extractPublicKey(data: NSData) -> SecKeyRef? { - let possibleCert = SecCertificateCreateWithData(nil,data) - if let cert = possibleCert { - return extractPublicKeyFromCert(cert, policy: SecPolicyCreateBasicX509()) - } - return nil + guard let cert = SecCertificateCreateWithData(nil, data) else { return nil } + + return extractPublicKeyFromCert(cert, policy: SecPolicyCreateBasicX509()) } /** - Get the public key from a certificate - - - parameter data: is the certificate to pull the public key from - - - returns: a public key - */ + Get the public key from a certificate + + - parameter data: is the certificate to pull the public key from + + - returns: a public key + */ func extractPublicKeyFromCert(cert: SecCertificate, policy: SecPolicy) -> SecKeyRef? { var possibleTrust: SecTrust? SecTrustCreateWithCertificates(cert, policy, &possibleTrust) - if let trust = possibleTrust { - var result: SecTrustResultType = 0 - SecTrustEvaluate(trust, &result) - return SecTrustCopyPublicKey(trust) - } - return nil + + guard let trust = possibleTrust else { return nil } + + var result: SecTrustResultType = 0 + SecTrustEvaluate(trust, &result) + return SecTrustCopyPublicKey(trust) } /** - Get the certificate chain for the trust - - - parameter trust: is the trust to lookup the certificate chain for - - - returns: the certificate chain for the trust - */ - func certificateChainForTrust(trust: SecTrustRef) -> Array { - var collect = Array() - for var i = 0; i < SecTrustGetCertificateCount(trust); i++ { - let cert = SecTrustGetCertificateAtIndex(trust,i) - collect.append(SecCertificateCopyData(cert!)) + Get the certificate chain for the trust + + - parameter trust: is the trust to lookup the certificate chain for + + - returns: the certificate chain for the trust + */ + func certificateChainForTrust(trust: SecTrustRef) -> [NSData] { + let certificates = (0.. [NSData] in + let cert = SecTrustGetCertificateAtIndex(trust, index) + certificates.append(SecCertificateCopyData(cert!)) + return certificates } - return collect + + return certificates } /** - Get the public key chain for the trust - - - parameter trust: is the trust to lookup the certificate chain and extract the public keys - - - returns: the public keys from the certifcate chain for the trust - */ - func publicKeyChainForTrust(trust: SecTrustRef) -> Array { - var collect = Array() + Get the public key chain for the trust + + - parameter trust: is the trust to lookup the certificate chain and extract the public keys + + - returns: the public keys from the certifcate chain for the trust + */ + func publicKeyChainForTrust(trust: SecTrustRef) -> [SecKeyRef] { let policy = SecPolicyCreateBasicX509() - for var i = 0; i < SecTrustGetCertificateCount(trust); i++ { - let cert = SecTrustGetCertificateAtIndex(trust,i) + let keys = (0.. [SecKeyRef] in + let cert = SecTrustGetCertificateAtIndex(trust, index) if let key = extractPublicKeyFromCert(cert!, policy: policy) { - collect.append(key) + keys.append(key) } + + return keys } - return collect + + return keys }