bump websocket

This commit is contained in:
Erik 2015-12-07 14:46:27 -05:00
parent c642017f4b
commit 2ccaa4351a

View File

@ -57,7 +57,7 @@ public class WebSocket : NSObject, NSStreamDelegate {
//Where the callback is executed. It defaults to the main UI thread queue. //Where the callback is executed. It defaults to the main UI thread queue.
public var queue = dispatch_get_main_queue() public var queue = dispatch_get_main_queue()
var optionalProtocols : Array<String>? var optionalProtocols : [String]?
//Constant Values. //Constant Values.
let headerWSUpgradeName = "Upgrade" let headerWSUpgradeName = "Upgrade"
let headerWSUpgradeValue = "websocket" let headerWSUpgradeValue = "websocket"
@ -93,7 +93,7 @@ public class WebSocket : NSObject, NSStreamDelegate {
public var onText: ((String) -> Void)? public var onText: ((String) -> Void)?
public var onData: ((NSData) -> Void)? public var onData: ((NSData) -> Void)?
public var onPong: ((Void) -> Void)? public var onPong: ((Void) -> Void)?
public var headers = Dictionary<String,String>() public var headers = [String: String]()
public var voipEnabled = false public var voipEnabled = false
public var selfSignedSSL = false public var selfSignedSSL = false
private var security: SSLSecurity? private var security: SSLSecurity?
@ -108,36 +108,31 @@ public class WebSocket : NSObject, NSStreamDelegate {
private var connected = false private var connected = false
private var isCreated = false private var isCreated = false
private var writeQueue = NSOperationQueue() private var writeQueue = NSOperationQueue()
private var readStack = Array<WSResponse>() private var readStack = [WSResponse]()
private var inputQueue = Array<NSData>() private var inputQueue = [NSData]()
private var fragBuffer: NSData? private var fragBuffer: NSData?
private var certValidated = false private var certValidated = false
private var didDisconnect = false private var didDisconnect = false
//init the websocket with a url //used for setting protocols.
public init(url: NSURL) { public init(url: NSURL, protocols: [String]? = nil) {
self.url = url self.url = url
writeQueue.maxConcurrentOperationCount = 1 writeQueue.maxConcurrentOperationCount = 1
}
//used for setting protocols.
public convenience init(url: NSURL, protocols: Array<String>) {
self.init(url: url)
optionalProtocols = protocols optionalProtocols = protocols
} }
///Connect to the websocket server on a background thread ///Connect to the websocket server on a background thread
public func connect() { public func connect() {
if isCreated { guard !isCreated else { return }
return
} dispatch_async(queue) { [weak self] in
dispatch_async(queue,{ [weak self] in
self?.didDisconnect = false self?.didDisconnect = false
}) }
dispatch_async(dispatch_get_global_queue(DISPATCH_QUEUE_PRIORITY_DEFAULT,0), { [weak self] in dispatch_async(dispatch_get_global_queue(DISPATCH_QUEUE_PRIORITY_DEFAULT,0)) { [weak self] in
self?.isCreated = true self?.isCreated = true
self?.createHTTPRequest() self?.createHTTPRequest()
self?.isCreated = false self?.isCreated = false
}) }
} }
/** /**
@ -152,9 +147,9 @@ public class WebSocket : NSObject, NSStreamDelegate {
public func disconnect(forceTimeout forceTimeout: NSTimeInterval? = nil) { public func disconnect(forceTimeout forceTimeout: NSTimeInterval? = nil) {
switch forceTimeout { switch forceTimeout {
case .Some(let seconds) where seconds > 0: case .Some(let seconds) where seconds > 0:
dispatch_after(dispatch_time(DISPATCH_TIME_NOW, Int64(seconds * Double(NSEC_PER_SEC))), queue, { [unowned self] in dispatch_after(dispatch_time(DISPATCH_TIME_NOW, Int64(seconds * Double(NSEC_PER_SEC))), queue) { [unowned self] in
self.disconnectStream(nil) self.disconnectStream(nil)
}) }
fallthrough fallthrough
case .None: case .None:
writeError(CloseCode.Normal.rawValue) writeError(CloseCode.Normal.rawValue)
@ -190,7 +185,7 @@ public class WebSocket : NSObject, NSStreamDelegate {
var port = url.port var port = url.port
if port == nil { if port == nil {
if url.scheme == "wss" || url.scheme == "https" { if ["wss", "https"].contains(url.scheme) {
port = 443 port = 443
} else { } else {
port = 80 port = 80
@ -214,18 +209,14 @@ public class WebSocket : NSObject, NSStreamDelegate {
} }
} }
//Add a header to the CFHTTPMessage by using the NSString bridges to CFString //Add a header to the CFHTTPMessage by using the NSString bridges to CFString
private func addHeader(urlRequest: CFHTTPMessage,key: String, val: String) { private func addHeader(urlRequest: CFHTTPMessage, key: NSString, val: NSString) {
let nsKey: NSString = key CFHTTPMessageSetHeaderFieldValue(urlRequest, key, val)
let nsVal: NSString = val
CFHTTPMessageSetHeaderFieldValue(urlRequest,
nsKey,
nsVal)
} }
//generate a websocket key as needed in rfc //generate a websocket key as needed in rfc
private func generateWebSocketKey() -> String { private func generateWebSocketKey() -> String {
var key = "" var key = ""
let seed = 16 let seed = 16
for (var i = 0; i < seed; i++) { for _ in 0..<seed {
let uni = UnicodeScalar(UInt32(97 + arc4random_uniform(25))) let uni = UnicodeScalar(UInt32(97 + arc4random_uniform(25)))
key += "\(Character(uni))" key += "\(Character(uni))"
} }
@ -247,7 +238,7 @@ public class WebSocket : NSObject, NSStreamDelegate {
guard let inStream = inputStream, let outStream = outputStream else { return } guard let inStream = inputStream, let outStream = outputStream else { return }
inStream.delegate = self inStream.delegate = self
outStream.delegate = self outStream.delegate = self
if url.scheme == "wss" || url.scheme == "https" { if ["wss", "https"].contains(url.scheme) {
inStream.setProperty(NSStreamSocketSecurityLevelNegotiatedSSL, forKey: NSStreamSocketSecurityLevelKey) inStream.setProperty(NSStreamSocketSecurityLevelNegotiatedSSL, forKey: NSStreamSocketSecurityLevelKey)
outStream.setProperty(NSStreamSocketSecurityLevelNegotiatedSSL, forKey: NSStreamSocketSecurityLevelKey) outStream.setProperty(NSStreamSocketSecurityLevelNegotiatedSSL, forKey: NSStreamSocketSecurityLevelKey)
} else { } else {
@ -258,7 +249,7 @@ public class WebSocket : NSObject, NSStreamDelegate {
outStream.setProperty(NSStreamNetworkServiceTypeVoIP, forKey: NSStreamNetworkServiceType) outStream.setProperty(NSStreamNetworkServiceTypeVoIP, forKey: NSStreamNetworkServiceType)
} }
if selfSignedSSL { if selfSignedSSL {
let settings: Dictionary<NSObject, NSObject> = [kCFStreamSSLValidatesCertificateChain: NSNumber(bool:false), kCFStreamSSLPeerName: kCFNull] let settings: [NSObject: NSObject] = [kCFStreamSSLValidatesCertificateChain: NSNumber(bool:false), kCFStreamSSLPeerName: kCFNull]
inStream.setProperty(settings, forKey: kCFStreamPropertySSLSettings as String) inStream.setProperty(settings, forKey: kCFStreamPropertySSLSettings as String)
outStream.setProperty(settings, forKey: kCFStreamPropertySSLSettings as String) outStream.setProperty(settings, forKey: kCFStreamPropertySSLSettings as String)
} }
@ -267,12 +258,12 @@ public class WebSocket : NSObject, NSStreamDelegate {
sslContextOut = CFWriteStreamCopyProperty(outputStream, kCFStreamPropertySSLContext) as! SSLContextRef? { sslContextOut = CFWriteStreamCopyProperty(outputStream, kCFStreamPropertySSLContext) as! SSLContextRef? {
let resIn = SSLSetEnabledCiphers(sslContextIn, cipherSuites, cipherSuites.count) let resIn = SSLSetEnabledCiphers(sslContextIn, cipherSuites, cipherSuites.count)
let resOut = SSLSetEnabledCiphers(sslContextOut, cipherSuites, cipherSuites.count) let resOut = SSLSetEnabledCiphers(sslContextOut, cipherSuites, cipherSuites.count)
if (resIn != errSecSuccess) { if resIn != errSecSuccess {
let error = self.errorWithDetail("Error setting ingoing cypher suites", code: UInt16(resIn)) let error = self.errorWithDetail("Error setting ingoing cypher suites", code: UInt16(resIn))
disconnectStream(error) disconnectStream(error)
return return
} }
if (resOut != errSecSuccess) { if resOut != errSecSuccess {
let error = self.errorWithDetail("Error setting outgoing cypher suites", code: UInt16(resOut)) let error = self.errorWithDetail("Error setting outgoing cypher suites", code: UInt16(resOut))
disconnectStream(error) disconnectStream(error)
return return
@ -293,7 +284,7 @@ public class WebSocket : NSObject, NSStreamDelegate {
//delegate for the stream methods. Processes incoming bytes //delegate for the stream methods. Processes incoming bytes
public func stream(aStream: NSStream, handleEvent eventCode: NSStreamEvent) { public func stream(aStream: NSStream, handleEvent eventCode: NSStreamEvent) {
if let sec = security where !certValidated && (eventCode == .HasBytesAvailable || eventCode == .HasSpaceAvailable) { if let sec = security where !certValidated && [.HasBytesAvailable, .HasSpaceAvailable].contains(eventCode) {
let possibleTrust: AnyObject? = aStream.propertyForKey(kCFStreamPropertySSLPeerTrust as String) let possibleTrust: AnyObject? = aStream.propertyForKey(kCFStreamPropertySSLPeerTrust as String)
if let trust: AnyObject = possibleTrust { if let trust: AnyObject = possibleTrust {
let domain: AnyObject? = aStream.propertyForKey(kCFStreamSSLPeerName as String) let domain: AnyObject? = aStream.propertyForKey(kCFStreamSSLPeerName as String)
@ -307,7 +298,7 @@ public class WebSocket : NSObject, NSStreamDelegate {
} }
} }
if eventCode == .HasBytesAvailable { if eventCode == .HasBytesAvailable {
if(aStream == inputStream) { if aStream == inputStream {
processInputStream() processInputStream()
} }
} else if eventCode == .ErrorOccurred { } else if eventCode == .ErrorOccurred {
@ -339,7 +330,9 @@ public class WebSocket : NSObject, NSStreamDelegate {
let buf = NSMutableData(capacity: BUFFER_MAX) let buf = NSMutableData(capacity: BUFFER_MAX)
let buffer = UnsafeMutablePointer<UInt8>(buf!.bytes) let buffer = UnsafeMutablePointer<UInt8>(buf!.bytes)
let length = inputStream!.read(buffer, maxLength: BUFFER_MAX) let length = inputStream!.read(buffer, maxLength: BUFFER_MAX)
if length > 0 {
guard length > 0 else { return }
if !connected { if !connected {
connected = processHTTP(buffer, bufferLen: length) connected = processHTTP(buffer, bufferLen: length)
if !connected { if !connected {
@ -359,30 +352,29 @@ public class WebSocket : NSObject, NSStreamDelegate {
} }
} }
} }
}
///dequeue the incoming input so it is processed in order ///dequeue the incoming input so it is processed in order
private func dequeueInput() { private func dequeueInput() {
if inputQueue.count > 0 { guard !inputQueue.isEmpty else { return }
let data = inputQueue[0] let data = inputQueue[0]
var work = data var work = data
if fragBuffer != nil { if let fragBuffer = fragBuffer {
let combine = NSMutableData(data: fragBuffer!) let combine = NSMutableData(data: fragBuffer)
combine.appendData(data) combine.appendData(data)
work = combine work = combine
fragBuffer = nil self.fragBuffer = nil
} }
let buffer = UnsafePointer<UInt8>(work.bytes) let buffer = UnsafePointer<UInt8>(work.bytes)
processRawMessage(buffer, bufferLen: work.length) processRawMessage(buffer, bufferLen: work.length)
inputQueue = inputQueue.filter{$0 != data} inputQueue = inputQueue.filter{$0 != data}
dequeueInput() dequeueInput()
} }
}
///Finds the HTTP Packet in the TCP stream, by looking for the CRLF. ///Finds the HTTP Packet in the TCP stream, by looking for the CRLF.
private func processHTTP(buffer: UnsafePointer<UInt8>, bufferLen: Int) -> Bool { private func processHTTP(buffer: UnsafePointer<UInt8>, bufferLen: Int) -> Bool {
let CRLFBytes = [UInt8(ascii: "\r"), UInt8(ascii: "\n"), UInt8(ascii: "\r"), UInt8(ascii: "\n")] let CRLFBytes = [UInt8(ascii: "\r"), UInt8(ascii: "\n"), UInt8(ascii: "\r"), UInt8(ascii: "\n")]
var k = 0 var k = 0
var totalSize = 0 var totalSize = 0
for var i = 0; i < bufferLen; i++ { for i in 0..<bufferLen {
if buffer[i] == CRLFBytes[k] { if buffer[i] == CRLFBytes[k] {
k++ k++
if k == 3 { if k == 3 {
@ -395,13 +387,11 @@ public class WebSocket : NSObject, NSStreamDelegate {
} }
if totalSize > 0 { if totalSize > 0 {
if validateResponse(buffer, bufferLen: totalSize) { if validateResponse(buffer, bufferLen: totalSize) {
dispatch_async(queue,{ [weak self] in dispatch_async(queue) { [weak self] in
guard let s = self else { return } guard let s = self else { return }
if let connectBlock = s.onConnect { s.onConnect?()
connectBlock()
}
s.delegate?.websocketDidConnect(s) s.delegate?.websocketDidConnect(s)
}) }
totalSize += 1 //skip the last \n totalSize += 1 //skip the last \n
let restSize = bufferLen - totalSize let restSize = bufferLen - totalSize
if restSize > 0 { if restSize > 0 {
@ -438,17 +428,16 @@ public class WebSocket : NSObject, NSStreamDelegate {
fragBuffer = NSData(bytes: buffer, length: bufferLen) fragBuffer = NSData(bytes: buffer, length: bufferLen)
return return
} }
if response != nil && response!.bytesLeft > 0 { if let response = response where response.bytesLeft > 0 {
let resp = response! var len = response.bytesLeft
var len = resp.bytesLeft var extra = bufferLen - response.bytesLeft
var extra = bufferLen - resp.bytesLeft if response.bytesLeft > bufferLen {
if resp.bytesLeft > bufferLen {
len = bufferLen len = bufferLen
extra = 0 extra = 0
} }
resp.bytesLeft -= len response.bytesLeft -= len
resp.buffer?.appendData(NSData(bytes: buffer, length: len)) response.buffer?.appendData(NSData(bytes: buffer, length: len))
processResponse(resp) processResponse(response)
let offset = bufferLen - extra let offset = bufferLen - extra
if extra > 0 { if extra > 0 {
processExtra((buffer+offset), bufferLen: extra) processExtra((buffer+offset), bufferLen: extra)
@ -456,19 +445,19 @@ public class WebSocket : NSObject, NSStreamDelegate {
return return
} else { } else {
let isFin = (FinMask & buffer[0]) let isFin = (FinMask & buffer[0])
let receivedOpcode = (OpCodeMask & buffer[0]) let receivedOpcode = OpCode(rawValue: (OpCodeMask & buffer[0]))
let isMasked = (MaskMask & buffer[1]) let isMasked = (MaskMask & buffer[1])
let payloadLen = (PayloadLenMask & buffer[1]) let payloadLen = (PayloadLenMask & buffer[1])
var offset = 2 var offset = 2
if((isMasked > 0 || (RSVMask & buffer[0]) > 0) && receivedOpcode != OpCode.Pong.rawValue) { if (isMasked > 0 || (RSVMask & buffer[0]) > 0) && receivedOpcode != .Pong {
let errCode = CloseCode.ProtocolError.rawValue let errCode = CloseCode.ProtocolError.rawValue
doDisconnect(errorWithDetail("masked and rsv data is not currently supported", code: errCode)) doDisconnect(errorWithDetail("masked and rsv data is not currently supported", code: errCode))
writeError(errCode) writeError(errCode)
return return
} }
let isControlFrame = (receivedOpcode == OpCode.ConnectionClose.rawValue || receivedOpcode == OpCode.Ping.rawValue) let isControlFrame = (receivedOpcode == .ConnectionClose || receivedOpcode == .Ping)
if !isControlFrame && (receivedOpcode != OpCode.BinaryFrame.rawValue && receivedOpcode != OpCode.ContinueFrame.rawValue && if !isControlFrame && (receivedOpcode != .BinaryFrame && receivedOpcode != .ContinueFrame &&
receivedOpcode != OpCode.TextFrame.rawValue && receivedOpcode != OpCode.Pong.rawValue) { receivedOpcode != .TextFrame && receivedOpcode != .Pong) {
let errCode = CloseCode.ProtocolError.rawValue let errCode = CloseCode.ProtocolError.rawValue
doDisconnect(errorWithDetail("unknown opcode: \(receivedOpcode)", code: errCode)) doDisconnect(errorWithDetail("unknown opcode: \(receivedOpcode)", code: errCode))
writeError(errCode) writeError(errCode)
@ -480,7 +469,7 @@ public class WebSocket : NSObject, NSStreamDelegate {
writeError(errCode) writeError(errCode)
return return
} }
if receivedOpcode == OpCode.ConnectionClose.rawValue { if receivedOpcode == .ConnectionClose {
var code = CloseCode.Normal.rawValue var code = CloseCode.Normal.rawValue
if payloadLen == 1 { if payloadLen == 1 {
code = CloseCode.ProtocolError.rawValue code = CloseCode.ProtocolError.rawValue
@ -535,14 +524,12 @@ public class WebSocket : NSObject, NSStreamDelegate {
} else { } else {
data = NSData(bytes: UnsafePointer<UInt8>((buffer+offset)), length: Int(len)) data = NSData(bytes: UnsafePointer<UInt8>((buffer+offset)), length: Int(len))
} }
if receivedOpcode == OpCode.Pong.rawValue { if receivedOpcode == .Pong {
dispatch_async(queue,{ [weak self] in dispatch_async(queue) { [weak self] in
guard let s = self else { return } guard let s = self else { return }
if let pongBlock = s.onPong { s.onPong?()
pongBlock()
}
s.pongDelegate?.websocketDidReceivePong(s) s.pongDelegate?.websocketDidReceivePong(s)
}) }
let step = Int(offset+numericCast(len)) let step = Int(offset+numericCast(len))
let extra = bufferLen-step let extra = bufferLen-step
@ -555,15 +542,15 @@ public class WebSocket : NSObject, NSStreamDelegate {
if isControlFrame { if isControlFrame {
response = nil //don't append pings response = nil //don't append pings
} }
if isFin == 0 && receivedOpcode == OpCode.ContinueFrame.rawValue && response == nil { if isFin == 0 && receivedOpcode == .ContinueFrame && response == nil {
let errCode = CloseCode.ProtocolError.rawValue let errCode = CloseCode.ProtocolError.rawValue
doDisconnect(errorWithDetail("continue frame before a binary or text frame", code: errCode)) doDisconnect(errorWithDetail("continue frame before a binary or text frame", code: errCode))
writeError(errCode) writeError(errCode)
return return
} }
var isNew = false var isNew = false
if(response == nil) { if response == nil {
if receivedOpcode == OpCode.ContinueFrame.rawValue { if receivedOpcode == .ContinueFrame {
let errCode = CloseCode.ProtocolError.rawValue let errCode = CloseCode.ProtocolError.rawValue
doDisconnect(errorWithDetail("first frame can't be a continue frame", doDisconnect(errorWithDetail("first frame can't be a continue frame",
code: errCode)) code: errCode))
@ -572,11 +559,11 @@ public class WebSocket : NSObject, NSStreamDelegate {
} }
isNew = true isNew = true
response = WSResponse() response = WSResponse()
response!.code = OpCode(rawValue: receivedOpcode)! response!.code = receivedOpcode!
response!.bytesLeft = Int(dataLength) response!.bytesLeft = Int(dataLength)
response!.buffer = NSMutableData(data: data) response!.buffer = NSMutableData(data: data)
} else { } else {
if receivedOpcode == OpCode.ContinueFrame.rawValue { if receivedOpcode == .ContinueFrame {
response!.bytesLeft = Int(dataLength) response!.bytesLeft = Int(dataLength)
} else { } else {
let errCode = CloseCode.ProtocolError.rawValue let errCode = CloseCode.ProtocolError.rawValue
@ -587,19 +574,19 @@ public class WebSocket : NSObject, NSStreamDelegate {
} }
response!.buffer!.appendData(data) response!.buffer!.appendData(data)
} }
if response != nil { if let response = response {
response!.bytesLeft -= Int(len) response.bytesLeft -= Int(len)
response!.frameCount++ response.frameCount++
response!.isFin = isFin > 0 ? true : false response.isFin = isFin > 0 ? true : false
if(isNew) { if isNew {
readStack.append(response!) readStack.append(response)
} }
processResponse(response!) processResponse(response)
} }
let step = Int(offset+numericCast(len)) let step = Int(offset+numericCast(len))
let extra = bufferLen-step let extra = bufferLen-step
if(extra > 0) { if extra > 0 {
processExtra((buffer+step), bufferLen: extra) processExtra((buffer+step), bufferLen: extra)
} }
} }
@ -628,22 +615,18 @@ public class WebSocket : NSObject, NSStreamDelegate {
return false return false
} }
dispatch_async(queue,{ [weak self] in dispatch_async(queue) { [weak self] in
guard let s = self else { return } guard let s = self else { return }
if let textBlock = s.onText { s.onText?(str! as String)
textBlock(str! as String)
}
s.delegate?.websocketDidReceiveMessage(s, text: str! as String) s.delegate?.websocketDidReceiveMessage(s, text: str! as String)
}) }
} else if response.code == .BinaryFrame { } else if response.code == .BinaryFrame {
let data = response.buffer! //local copy so it is perverse for writing let data = response.buffer! //local copy so it is perverse for writing
dispatch_async(queue,{ [weak self] in dispatch_async(queue) { [weak self] in
guard let s = self else { return } guard let s = self else { return }
if let dataBlock = s.onData { s.onData?(data)
dataBlock(data)
}
s.delegate?.websocketDidReceiveData(s, data: data) s.delegate?.websocketDidReceiveData(s, data: data)
}) }
} }
readStack.removeLast() readStack.removeLast()
return true return true
@ -653,7 +636,7 @@ public class WebSocket : NSObject, NSStreamDelegate {
///Create an error ///Create an error
private func errorWithDetail(detail: String, code: UInt16) -> NSError { private func errorWithDetail(detail: String, code: UInt16) -> NSError {
var details = Dictionary<String,String>() var details = [String: String]()
details[NSLocalizedDescriptionKey] = detail details[NSLocalizedDescriptionKey] = detail
return NSError(domain: WebSocket.ErrorDomain, code: Int(code), userInfo: details) return NSError(domain: WebSocket.ErrorDomain, code: Int(code), userInfo: details)
} }
@ -667,9 +650,8 @@ public class WebSocket : NSObject, NSStreamDelegate {
} }
///used to write things to the stream ///used to write things to the stream
private func dequeueWrite(data: NSData, code: OpCode) { private func dequeueWrite(data: NSData, code: OpCode) {
if !isConnected { guard isConnected else { return }
return
}
writeQueue.addOperationWithBlock { [weak self] in writeQueue.addOperationWithBlock { [weak self] in
//stream isn't ready, let's wait //stream isn't ready, let's wait
guard let s = self else { return } guard let s = self else { return }
@ -697,7 +679,7 @@ public class WebSocket : NSObject, NSStreamDelegate {
SecRandomCopyBytes(kSecRandomDefault, Int(sizeof(UInt32)), maskKey) SecRandomCopyBytes(kSecRandomDefault, Int(sizeof(UInt32)), maskKey)
offset += sizeof(UInt32) offset += sizeof(UInt32)
for (var i = 0; i < dataLength; i++) { for i in 0..<dataLength {
buffer[offset] = bytes[i] ^ maskKey[i % sizeof(UInt32)] buffer[offset] = bytes[i] ^ maskKey[i % sizeof(UInt32)]
offset += 1 offset += 1
} }
@ -732,20 +714,17 @@ public class WebSocket : NSObject, NSStreamDelegate {
///used to preform the disconnect delegate ///used to preform the disconnect delegate
private func doDisconnect(error: NSError?) { private func doDisconnect(error: NSError?) {
if !didDisconnect { guard !didDisconnect else { return }
dispatch_async(queue,{ [weak self] in
dispatch_async(queue) { [weak self] in
guard let s = self else { return } guard let s = self else { return }
s.didDisconnect = true s.didDisconnect = true
if let disconnect = s.onDisconnect { s.onDisconnect?(error)
disconnect(error)
}
s.delegate?.websocketDidDisconnect(s, error: error) s.delegate?.websocketDidDisconnect(s, error: error)
})
} }
} }
} }
private class SSLCert { private class SSLCert {
var certData: NSData? var certData: NSData?
var key: SecKeyRef? var key: SecKeyRef?
@ -790,13 +769,15 @@ private class SSLSecurity {
*/ */
convenience init(usePublicKeys: Bool = false) { convenience init(usePublicKeys: Bool = false) {
let paths = NSBundle.mainBundle().pathsForResourcesOfType("cer", inDirectory: ".") let paths = NSBundle.mainBundle().pathsForResourcesOfType("cer", inDirectory: ".")
var collect = Array<SSLCert>()
for path in paths { let certs = paths.reduce([SSLCert]()) { (var certs: [SSLCert], path: String) -> [SSLCert] in
if let d = NSData(contentsOfFile: path as String) { if let data = NSData(contentsOfFile: path) {
collect.append(SSLCert(data: d)) certs.append(SSLCert(data: data))
} }
return certs
} }
self.init(certs:collect, usePublicKeys: usePublicKeys)
self.init(certs: certs, usePublicKeys: usePublicKeys)
} }
/** /**
@ -811,27 +792,28 @@ private class SSLSecurity {
self.usePublicKeys = usePublicKeys self.usePublicKeys = usePublicKeys
if self.usePublicKeys { if self.usePublicKeys {
dispatch_async(dispatch_get_global_queue(DISPATCH_QUEUE_PRIORITY_DEFAULT,0), { dispatch_async(dispatch_get_global_queue(DISPATCH_QUEUE_PRIORITY_DEFAULT,0)) {
var collect = Array<SecKeyRef>() let pubKeys = certs.reduce([SecKeyRef]()) { (var pubKeys: [SecKeyRef], cert: SSLCert) -> [SecKeyRef] in
for cert in certs {
if let data = cert.certData where cert.key == nil { if let data = cert.certData where cert.key == nil {
cert.key = self.extractPublicKey(data) cert.key = self.extractPublicKey(data)
} }
if let k = cert.key { if let key = cert.key {
collect.append(k) pubKeys.append(key)
} }
return pubKeys
} }
self.pubKeys = collect
self.pubKeys = pubKeys
self.isReady = true self.isReady = true
}) }
} else { } else {
var collect = Array<NSData>() let certificates = certs.reduce([NSData]()) { (var certificates: [NSData], cert: SSLCert) -> [NSData] in
for cert in certs { if let data = cert.certData {
if let d = cert.certData { certificates.append(data)
collect.append(d)
} }
return certificates
} }
self.certificates = collect self.certificates = certificates
self.isReady = true self.isReady = true
} }
} }
@ -863,23 +845,18 @@ private class SSLSecurity {
SecTrustSetPolicies(trust,policy) SecTrustSetPolicies(trust,policy)
if self.usePublicKeys { if self.usePublicKeys {
if let keys = self.pubKeys { if let keys = self.pubKeys {
var trustedCount = 0
let serverPubKeys = publicKeyChainForTrust(trust) let serverPubKeys = publicKeyChainForTrust(trust)
for serverKey in serverPubKeys as [AnyObject] { for serverKey in serverPubKeys as [AnyObject] {
for key in keys as [AnyObject] { for key in keys as [AnyObject] {
if serverKey.isEqual(key) { if serverKey.isEqual(key) {
trustedCount++
break
}
}
}
if trustedCount == serverPubKeys.count {
return true return true
} }
} }
}
}
} else if let certs = self.certificates { } else if let certs = self.certificates {
let serverCerts = certificateChainForTrust(trust) let serverCerts = certificateChainForTrust(trust)
var collect = Array<SecCertificate>() var collect = [SecCertificate]()
for cert in certs { for cert in certs {
collect.append(SecCertificateCreateWithData(nil,cert)!) collect.append(SecCertificateCreateWithData(nil,cert)!)
} }
@ -913,12 +890,10 @@ private class SSLSecurity {
- returns: a public key - returns: a public key
*/ */
func extractPublicKey(data: NSData) -> SecKeyRef? { func extractPublicKey(data: NSData) -> SecKeyRef? {
let possibleCert = SecCertificateCreateWithData(nil,data) guard let cert = SecCertificateCreateWithData(nil, data) else { return nil }
if let cert = possibleCert {
return extractPublicKeyFromCert(cert, policy: SecPolicyCreateBasicX509()) return extractPublicKeyFromCert(cert, policy: SecPolicyCreateBasicX509())
} }
return nil
}
/** /**
Get the public key from a certificate Get the public key from a certificate
@ -930,13 +905,13 @@ private class SSLSecurity {
func extractPublicKeyFromCert(cert: SecCertificate, policy: SecPolicy) -> SecKeyRef? { func extractPublicKeyFromCert(cert: SecCertificate, policy: SecPolicy) -> SecKeyRef? {
var possibleTrust: SecTrust? var possibleTrust: SecTrust?
SecTrustCreateWithCertificates(cert, policy, &possibleTrust) SecTrustCreateWithCertificates(cert, policy, &possibleTrust)
if let trust = possibleTrust {
guard let trust = possibleTrust else { return nil }
var result: SecTrustResultType = 0 var result: SecTrustResultType = 0
SecTrustEvaluate(trust, &result) SecTrustEvaluate(trust, &result)
return SecTrustCopyPublicKey(trust) return SecTrustCopyPublicKey(trust)
} }
return nil
}
/** /**
Get the certificate chain for the trust Get the certificate chain for the trust
@ -945,13 +920,14 @@ private class SSLSecurity {
- returns: the certificate chain for the trust - returns: the certificate chain for the trust
*/ */
func certificateChainForTrust(trust: SecTrustRef) -> Array<NSData> { func certificateChainForTrust(trust: SecTrustRef) -> [NSData] {
var collect = Array<NSData>() let certificates = (0..<SecTrustGetCertificateCount(trust)).reduce([NSData]()) { (var certificates: [NSData], index: Int) -> [NSData] in
for var i = 0; i < SecTrustGetCertificateCount(trust); i++ { let cert = SecTrustGetCertificateAtIndex(trust, index)
let cert = SecTrustGetCertificateAtIndex(trust,i) certificates.append(SecCertificateCopyData(cert!))
collect.append(SecCertificateCopyData(cert!)) return certificates
} }
return collect
return certificates
} }
/** /**
@ -961,16 +937,18 @@ private class SSLSecurity {
- returns: the public keys from the certifcate chain for the trust - returns: the public keys from the certifcate chain for the trust
*/ */
func publicKeyChainForTrust(trust: SecTrustRef) -> Array<SecKeyRef> { func publicKeyChainForTrust(trust: SecTrustRef) -> [SecKeyRef] {
var collect = Array<SecKeyRef>()
let policy = SecPolicyCreateBasicX509() let policy = SecPolicyCreateBasicX509()
for var i = 0; i < SecTrustGetCertificateCount(trust); i++ { let keys = (0..<SecTrustGetCertificateCount(trust)).reduce([SecKeyRef]()) { (var keys: [SecKeyRef], index: Int) -> [SecKeyRef] in
let cert = SecTrustGetCertificateAtIndex(trust,i) let cert = SecTrustGetCertificateAtIndex(trust, index)
if let key = extractPublicKeyFromCert(cert!, policy: policy) { if let key = extractPublicKeyFromCert(cert!, policy: policy) {
collect.append(key) keys.append(key)
} }
return keys
} }
return collect
return keys
} }