diff --git a/Sources/App/API/MessageResult.swift b/Sources/App/API/MessageResult.swift index 2c34a96..b58fef0 100644 --- a/Sources/App/API/MessageResult.swift +++ b/Sources/App/API/MessageResult.swift @@ -86,7 +86,7 @@ enum MessageResult: UInt8 { /// A valid server challenge was received case deviceAvailable = 37 - case invalidSignatureByDevice = 38 + case invalidSignatureFromDevice = 38 case invalidMessageTypeFromDevice = 39 @@ -162,7 +162,7 @@ extension MessageResult: CustomStringConvertible { return "Unexpected server response code" case .deviceAvailable: return "Device available" - case .invalidSignatureByDevice: + case .invalidSignatureFromDevice: return "Invalid device signature" case .invalidMessageTypeFromDevice: return "Message type from device invalid" diff --git a/Sources/App/DeviceManager.swift b/Sources/App/DeviceManager.swift index ea8095b..6aab23a 100644 --- a/Sources/App/DeviceManager.swift +++ b/Sources/App/DeviceManager.swift @@ -4,20 +4,20 @@ import Vapor import Clairvoyant final class DeviceManager { - + /// The connection to the device private var connection: WebSocket? - + /// The authentication token of the device for the socket connection private let deviceKey: Data - + /// The authentication token of the remote private let remoteKey: Data - + private let deviceTimeout: Int64 - + private let deviceConnectedMetric: Metric - + private let messagesToDeviceMetric: Metric let serverStatus: Metric @@ -31,6 +31,7 @@ final class DeviceManager { /// A promise to finish the request once the device responds or times out private var requestInProgress: CheckedContinuation? + private var receivedMessageData: Data? init(deviceKey: Data, remoteKey: Data, deviceTimeout: Int64, serverStatus: Metric) { self.deviceKey = deviceKey @@ -51,11 +52,11 @@ final class DeviceManager { _ = try? await serverStatus.update(deviceIsConnected ? .nominal : .reducedFunctionality) await updateDeviceConnectionMetric() } - + private func updateDeviceConnectionMetric() async { _ = try? await deviceConnectedMetric.update(deviceIsConnected) } - + private func updateMessageCountMetric() async { let lastValue = await messagesToDeviceMetric.lastValue()?.value ?? 0 _ = try? await messagesToDeviceMetric.update(lastValue + 1) @@ -74,9 +75,11 @@ final class DeviceManager { connection = nil throw MessageResult.deviceNotConnected } - guard requestInProgress == nil else { + guard receivedMessageData == nil else { throw MessageResult.tooManyRequests } + // Indicate that a message is in transit + receivedMessageData = Data() do { try await socket.send(Array(message)) } catch { @@ -84,6 +87,12 @@ final class DeviceManager { } startTimeoutForDeviceRequest(on: eventLoop) + // Check if a full message has already been received + if let receivedMessageData, receivedMessageData.count == SignedMessage.size { + self.receivedMessageData = nil + return receivedMessageData + } + // Wait until a fill message is received, or a timeout occurs let result: Data = try await withCheckedThrowingContinuation { continuation in self.requestInProgress = continuation } @@ -98,13 +107,44 @@ final class DeviceManager { } private func resumeDeviceRequest(with data: Data) { - requestInProgress?.resume(returning: data) - requestInProgress = nil + guard let receivedMessageData else { + print("[WARN] Received \(data.count) bytes after message completion") + self.requestInProgress = nil + return + } + let newData = receivedMessageData + data + if newData.count < SignedMessage.size { + // Wait for more data + self.receivedMessageData = newData + return + } + self.receivedMessageData = nil + guard let requestInProgress else { + print("[WARN] Received \(newData.count) bytes, but no continuation to resume") + return + } + self.requestInProgress = nil + guard newData.count == SignedMessage.size else { + print("[WARN] Received \(newData.count) bytes, expected \(SignedMessage.size) for a message.") + requestInProgress.resume(throwing: MessageResult.invalidMessageSizeFromDevice) + return + } + requestInProgress.resume(returning: newData) } private func resumeDeviceRequest(with result: MessageResult) { - requestInProgress?.resume(throwing: result) - requestInProgress = nil + guard let receivedMessageData else { + print("[WARN] Result after message completed: \(result)") + self.requestInProgress = nil + return + } + self.receivedMessageData = nil + guard let requestInProgress else { + print("[WARN] Request in progress (\(receivedMessageData.count) bytes), but no continuation found for result: \(result)") + return + } + self.requestInProgress = nil + requestInProgress.resume(throwing: result) } func authenticateRemote(_ token: Data) -> Bool { @@ -114,7 +154,7 @@ final class DeviceManager { func processDeviceResponse(_ buffer: ByteBuffer) { guard let data = buffer.getData(at: 0, length: buffer.readableBytes) else { - log("Failed to get data buffer received from device") + print("Failed to get data buffer received from device") self.resumeDeviceRequest(with: .invalidMessageSizeFromDevice) return } @@ -140,7 +180,7 @@ final class DeviceManager { func createNewDeviceConnection(socket: WebSocket, auth: String) async { guard let key = Data(fromHexEncodedString: auth), SHA256.hash(data: key) == self.deviceKey else { - log("Invalid device key while opening socket") + log("[WARN] Invalid device key while opening socket") try? await socket.close() return } @@ -150,17 +190,32 @@ final class DeviceManager { socket.eventLoop.execute { socket.pingInterval = .seconds(10) - socket.onText { socket, text in + socket.onText { [weak self] socket, text in print("[WARN] Received text over socket: \(text)") // Close connection to prevent spamming the log try? await socket.close() + + guard let self else { + print("[WARN] No reference to self to handle text over socket") + return + } + self.didCloseDeviceSocket() } socket.onBinary { [weak self] _, data in - self?.processDeviceResponse(data) + guard let self else { + print("[WARN] No reference to self to process binary data on socket") + return + } + self.processDeviceResponse(data) } + socket.onClose.whenComplete { [weak self] _ in - self?.didCloseDeviceSocket() + guard let self else { + print("[WARN] No reference to self to handle socket closing") + return + } + self.didCloseDeviceSocket() } } log("[INFO] Socket connected")