Sesame-Server/Sources/App/DeviceManager.swift
2023-12-10 19:32:09 +01:00

234 lines
8.3 KiB
Swift

import Foundation
import WebSocketKit
import Vapor
import Clairvoyant
final class DeviceManager {
/// The connection to the device
private var connection: WebSocket?
/// The authentication token of the device for the socket connection
private let deviceKey: Data
/// The authentication token of the remote
private let remoteKey: Data
private let deviceTimeout: Int64
private let deviceConnectedMetric: Metric<Bool>
private let messagesToDeviceMetric: Metric<Int>
let serverStatus: Metric<ServerStatus>
var deviceIsConnected: Bool {
guard let connection, !connection.isClosed else {
return false
}
return true
}
/// A promise to finish the request once the device responds or times out
private var requestInProgress: CheckedContinuation<Data, Error>?
private var receivedMessageData: Data?
var logger: Logger?
private func printAndFlush(_ message: String) {
logger?.notice(.init(stringLiteral: message))
}
init(deviceKey: Data, remoteKey: Data, deviceTimeout: Int64, serverStatus: Metric<ServerStatus>) {
self.deviceKey = deviceKey
self.remoteKey = remoteKey
self.deviceTimeout = deviceTimeout
self.deviceConnectedMetric = .init(
"sesame.connected",
name: "Device connection",
description: "Shows if the device is connected via WebSocket")
self.messagesToDeviceMetric = .init(
"sesame.messages",
name: "Forwarded Messages",
description: "The number of messages transmitted to the device")
self.serverStatus = serverStatus
}
func updateDeviceConnectionMetrics() async {
let isConnected = deviceIsConnected
_ = try? await serverStatus.update(isConnected ? .nominal : .reducedFunctionality)
_ = try? await deviceConnectedMetric.update(isConnected)
}
private func updateMessageCountMetric() async {
let lastValue = await messagesToDeviceMetric.lastValue()?.value ?? 0
_ = try? await messagesToDeviceMetric.update(lastValue + 1)
}
// MARK: API
func sendMessageToDevice(_ message: Data, authToken: Data, on eventLoop: EventLoop) async throws -> Data {
guard message.count == SignedMessage.size else {
throw MessageResult.invalidMessageSizeFromRemote
}
guard SHA256.hash(data: authToken) == remoteKey else {
throw MessageResult.invalidServerAuthenticationFromRemote
}
guard let socket = connection, !socket.isClosed else {
connection = nil
throw MessageResult.deviceNotConnected
}
guard receivedMessageData == nil else {
throw MessageResult.tooManyRequests
}
// Indicate that a message is in transit
receivedMessageData = Data()
do {
try await socket.send(Array(message))
} catch {
throw MessageResult.deviceNotConnected
}
startTimeoutForDeviceRequest(on: eventLoop)
// Check if a full message has already been received
if let receivedMessageData, receivedMessageData.count == SignedMessage.size {
printAndFlush("[INFO] Full message from socket collected early")
self.receivedMessageData = nil
return receivedMessageData
}
// Wait until a fill message is received, or a timeout occurs
let result: Data = try await withCheckedThrowingContinuation { continuation in
self.requestInProgress = continuation
}
await updateMessageCountMetric()
printAndFlush("[INFO] Message completed")
return result
}
private func startTimeoutForDeviceRequest(on eventLoop: EventLoop) {
eventLoop.scheduleTask(in: .seconds(deviceTimeout)) { [weak self] in
guard let self else {
print("[INFO] Message completed")
return
}
self.resumeDeviceRequest(with: .deviceTimedOut)
}
}
private func resumeDeviceRequest(with data: Data) {
guard let receivedMessageData else {
printAndFlush("[WARN] Received \(data.count) bytes after message completion")
self.requestInProgress = nil
return
}
let newData = receivedMessageData + data
if newData.count < SignedMessage.size {
// Wait for more data
self.receivedMessageData = newData
return
}
self.receivedMessageData = nil
guard let requestInProgress else {
printAndFlush("[WARN] Received \(newData.count) bytes, but no continuation to resume")
return
}
self.requestInProgress = nil
guard newData.count == SignedMessage.size else {
printAndFlush("[WARN] Received \(newData.count) bytes, expected \(SignedMessage.size) for a message.")
requestInProgress.resume(throwing: MessageResult.invalidMessageSizeFromDevice)
return
}
requestInProgress.resume(returning: newData)
}
private func resumeDeviceRequest(with result: MessageResult) {
guard let receivedMessageData else {
printAndFlush("[WARN] Result after message completed: \(result)")
self.requestInProgress = nil
return
}
self.receivedMessageData = nil
guard let requestInProgress else {
printAndFlush("[WARN] Request in progress (\(receivedMessageData.count) bytes), but no continuation found for result: \(result)")
return
}
self.requestInProgress = nil
requestInProgress.resume(throwing: result)
}
func authenticateRemote(_ token: Data) -> Bool {
let hash = SHA256.hash(data: token)
return hash == remoteKey
}
func processDeviceResponse(_ buffer: ByteBuffer) {
guard let data = buffer.getData(at: 0, length: buffer.readableBytes) else {
print("Failed to get data buffer received from device")
self.resumeDeviceRequest(with: .invalidMessageSizeFromDevice)
return
}
guard data.count == SignedMessage.size else {
print("Invalid size of device message: \(data.count)")
self.resumeDeviceRequest(with: .invalidMessageSizeFromDevice)
return
}
self.resumeDeviceRequest(with: data)
}
func didCloseDeviceSocket() {
connection = nil
}
func removeDeviceConnection() async {
try? await connection?.close()
connection = nil
await updateDeviceConnectionMetrics()
}
func createNewDeviceConnection(socket: WebSocket, auth: String) async {
guard let key = Data(fromHexEncodedString: auth),
SHA256.hash(data: key) == self.deviceKey else {
log("[WARN] Invalid device key while opening socket")
try? await socket.close()
return
}
await removeDeviceConnection()
connection = socket
socket.eventLoop.execute {
socket.pingInterval = .seconds(10)
socket.onText { [weak self] socket, text in
self?.printAndFlush("[WARN] Received text over socket: \(text)")
// Close connection to prevent spamming the log
try? await socket.close()
guard let self else {
print("[WARN] No reference to self to handle text over socket")
return
}
self.didCloseDeviceSocket()
}
socket.onBinary { [weak self] _, data in
guard let self else {
print("[WARN] No reference to self to process binary data on socket")
return
}
self.processDeviceResponse(data)
}
socket.onClose.whenComplete { [weak self] _ in
guard let self else {
print("[WARN] No reference to self to handle socket closing")
return
}
self.didCloseDeviceSocket()
}
}
log("[INFO] Socket connected")
await updateDeviceConnectionMetrics()
}
}