diff --git a/include/controller.h b/include/controller.h new file mode 100644 index 0000000..b5ef989 --- /dev/null +++ b/include/controller.h @@ -0,0 +1,42 @@ +#pragma once + +#include "server.h" +#include "servo.h" +#include "message.h" +#include + +class SesameController: public ServerConnectionCallbacks { + +public: + SesameController(ServerConnection* server, ServoController* servo, AsyncWebServer* local); + +private: + + ServerConnection* server; + ServoController* servo; + AsyncWebServer* local; + + // The buffer to hold a received message while it is read + uint8_t receivedMessageBuffer[AUTHENTICATED_MESSAGE_SIZE]; + + // The buffer to hold a response while it is sent + uint8_t responseBuffer[AUTHENTICATED_MESSAGE_SIZE+1]; + SesameEvent* responseStatus; + AuthenticatedMessage* responseMessage; + uint16_t responseSize = 0; + + void handleLocalMessage(AsyncWebServerRequest *request); + // Based on https://stackoverflow.com/a/23898449/266720 + bool convertHexMessageToBinary(const char* str); + + void handleServerMessage(uint8_t* payload, size_t length); + void sendServerError(SesameEvent event); + + void processMessage(AuthenticatedMessage* message); + SesameEvent verifyAndProcessReceivedMessage(AuthenticatedMessage* message); + + uint16_t prepareResponseBuffer(SesameEvent event, uint8_t deviceId = 0); + void sendPreparedLocalResponse(AsyncWebServerRequest *request); + void sendLocalResponse(AsyncWebServerRequest *request, SesameEvent event, uint8_t deviceId = 0); + void sendPreparedServerResponse(); +}; \ No newline at end of file diff --git a/include/message.h b/include/message.h index beb6ae8..58ef742 100644 --- a/include/message.h +++ b/include/message.h @@ -1,6 +1,7 @@ #pragma once -#include "stdint.h" +#include +#include /** * @brief The size of a message authentication code @@ -87,9 +88,9 @@ typedef struct { } AuthenticatedMessage; #pragma pack(pop) -#define MESSAGE_CONTENT_SIZE sizeof(Message) +constexpr int MESSAGE_CONTENT_SIZE = sizeof(Message); -#define AUTHENTICATED_MESSAGE_SIZE sizeof(AuthenticatedMessage) +constexpr int AUTHENTICATED_MESSAGE_SIZE = sizeof(AuthenticatedMessage); /** * An event signaled from the device @@ -97,19 +98,27 @@ typedef struct { enum class SesameEvent { TextReceived = 1, UnexpectedSocketEvent = 2, - InvalidMessageData = 3, + InvalidMessageSize = 3, MessageAuthenticationFailed = 4, MessageTimeMismatch = 5, MessageCounterInvalid = 6, MessageAccepted = 7, MessageDeviceInvalid = 8, + InvalidUrlParameter = 9, + InvalidResponseAuthentication = 10, + DeviceSetupIncomplete = 11, }; /** * @brief A callback for messages received over the socket * - * The first parameter is the received message. - * The second parameter is the response to the remote. - * The return value is the type of event to respond with. + * The first parameter is a pointer to the byte buffer. + * The second parameter indicates the number of received bytes. */ -typedef SesameEvent (*MessageCallback)(AuthenticatedMessage*, AuthenticatedMessage*); \ No newline at end of file +typedef void (*MessageCallback)(uint8_t* payload, size_t length); + +/** + * @brief A callback for socket errors + */ +typedef void (*ErrorCallback)(SesameEvent event); + diff --git a/include/server.h b/include/server.h index e24609d..9e3699c 100644 --- a/include/server.h +++ b/include/server.h @@ -6,6 +6,14 @@ #include #include +class ServerConnectionCallbacks { + +public: + + virtual void sendServerError(SesameEvent event) = 0; + virtual void handleServerMessage(uint8_t* payload, size_t length) = 0; +}; + class ServerConnection { public: @@ -14,14 +22,26 @@ public: void connect(const char* key, uint32_t reconnectTime = 5000); - void connectSSL(const char* key, uint32_t reconnectTime = 5000); - void loop(); - void onMessage(MessageCallback callback); + /** + * @brief Set the handler + * + * @param callback The handler to handle messages and errors + */ + void setCallbackHandler(ServerConnectionCallbacks* callbacks); - // Indicator that the socket is connected. - bool socketIsConnected = false; + /** + * @brief Send a response message over the socket + * + * @param buffer The data buffer + * @param length The number of bytes to send + */ + void sendResponse(uint8_t* buffer, uint16_t length); + + bool isSocketConnected() { + return socketIsConnected; + } private: @@ -33,15 +53,14 @@ private: const char* key = NULL; - MessageCallback messageCallback = NULL; + // Indicator that the socket is connected. + bool socketIsConnected = false; + + ServerConnectionCallbacks* controller = NULL; // WebSocket to connect to the control server WebSocketsClient webSocket; - void reconnectAfter(uint32_t reconnectTime); - - void registerEventCallback(); - /** * Callback for WebSocket events. * @@ -51,35 +70,4 @@ private: * @param length The number of bytes received */ void webSocketEventHandler(WStype_t type, uint8_t * payload, size_t length); - - /** - * Process received binary data. - * - * Checks whether the received data is a valid and unused key, - * and then signals that the motor should move. - * Sends the event id to the server as a response to the request. - * - * If the key is valid, then `shouldStartOpening` is set to true. - * - * @param payload The pointer to the received data. - * @param length The number of bytes received. - */ - void processReceivedBytes(uint8_t* payload, size_t length); - - /** - * Send a response event to the server and include the next key index. - * - * Sends the event type as three byte. - * @param event The event type - */ - void sendFailureResponse(SesameEvent event); - - /** - * Send a response event to the server and include the next key index. - * - * Sends the event type as three byte. - * @param event The event type - */ - void sendResponse(SesameEvent event, AuthenticatedMessage* message); - }; \ No newline at end of file diff --git a/src/controller.cpp b/src/controller.cpp new file mode 100644 index 0000000..36f3685 --- /dev/null +++ b/src/controller.cpp @@ -0,0 +1,158 @@ +#include "controller.h" +#include "config.h" +#include "fresh.h" +#include "crypto.h" + +SesameController::SesameController(ServerConnection* server, ServoController* servo, AsyncWebServer* local) : +server(server), servo(servo), local(local) { + // Set up response buffer + responseStatus = (SesameEvent*) responseBuffer; + responseMessage = (AuthenticatedMessage*) (responseBuffer + 1); + + // Direct messages and errors over the websocket to the controller + server->setCallbackHandler(this); + + // Direct messages from the local web server to the controller + local->on("/message", HTTP_POST, [this] (AsyncWebServerRequest *request) { + this->handleLocalMessage(request); + this->sendPreparedLocalResponse(request); + }); +} + +// MARK: Local + +void +SesameController::handleLocalMessage(AsyncWebServerRequest *request) { + if (!request->hasParam(messageUrlParameter)) { + Serial.println("Missing url parameter"); + prepareResponseBuffer(SesameEvent::InvalidUrlParameter); + return; + } + String encoded = request->getParam(messageUrlParameter)->value(); + if (!convertHexMessageToBinary(encoded.c_str())) { + Serial.println("Invalid hex encoding"); + prepareResponseBuffer(SesameEvent::InvalidMessageSize); + return; + } + processMessage((AuthenticatedMessage*) receivedMessageBuffer); +} + +void SesameController::sendPreparedLocalResponse(AsyncWebServerRequest *request) { + request->send_P(200, "application/octet-stream", responseBuffer, responseSize); + Serial.printf("[INFO] Local response %u (%u bytes)\n", responseBuffer[0], responseSize); +} + +// MARK: Server + +void SesameController::sendServerError(SesameEvent event) { + prepareResponseBuffer(event); + sendPreparedServerResponse(); +} + +void SesameController::handleServerMessage(uint8_t* payload, size_t length) { + if (length != AUTHENTICATED_MESSAGE_SIZE) { + prepareResponseBuffer(SesameEvent::InvalidMessageSize); + return; + } + + processMessage((AuthenticatedMessage*) payload); + sendPreparedServerResponse(); +} + +void SesameController::sendPreparedServerResponse() { + server->sendResponse(responseBuffer, responseSize); + Serial.printf("[INFO] Server response %u (%u bytes)\n", responseBuffer[0], responseSize); +} + +// MARK: Message handling + +void SesameController::processMessage(AuthenticatedMessage* message) { + SesameEvent event = verifyAndProcessReceivedMessage(message); + prepareResponseBuffer(event, message->message.device); +} + +/** + * Process a received message. + * + * Checks whether the received data is a valid, + * and then signals that the motor should move. + * + * @param message The message received from the remote + * @return The response to signal to the server. + */ +SesameEvent SesameController::verifyAndProcessReceivedMessage(AuthenticatedMessage* message) { + if (!isAuthenticMessage(message, remoteKey, keySize)) { + return SesameEvent::MessageAuthenticationFailed; + } + if (!isDeviceIdValid(message->message.device)) { + return SesameEvent::MessageDeviceInvalid; + } + if (!isMessageCounterValid(message->message.id, message->message.device)) { + return SesameEvent::MessageCounterInvalid; + } + if (!isMessageTimeAcceptable(message->message.time)) { + return SesameEvent::MessageTimeMismatch; + } + + didUseMessageCounter(message->message.id, message->message.device); + // Move servo + servo->pressButton(); + Serial.printf("[Info] Accepted message %d\n", message->message.id); + return SesameEvent::MessageAccepted; +} + +bool allowMessageResponse(SesameEvent event) { + switch (event) { + case SesameEvent::MessageTimeMismatch: + case SesameEvent::MessageCounterInvalid: + case SesameEvent::MessageAccepted: + case SesameEvent::MessageDeviceInvalid: + return true; + default: + return false; + } +} + +uint16_t SesameController::prepareResponseBuffer(SesameEvent event, uint8_t deviceId) { + *responseStatus = event; + if (!allowMessageResponse(event)) { + return 1; + } + responseMessage->message.time = getEpochTime(); + responseMessage->message.id = getNextMessageCounter(deviceId); + responseMessage->message.device = deviceId; + if (!authenticateMessage(responseMessage, localKey, keySize)) { + *responseStatus = SesameEvent::InvalidResponseAuthentication; + return 1; + } + return 1 + AUTHENTICATED_MESSAGE_SIZE; +} + +// MARK: Helper + +// Based on https://stackoverflow.com/a/23898449/266720 +bool SesameController::convertHexMessageToBinary(const char* str) { + // TODO: Fail if invalid hex values are used + uint8_t idx0, idx1; + + // mapping of ASCII characters to hex values + const uint8_t hashmap[] = { + 0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, // 01234567 + 0x08, 0x09, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, // 89:;<=>? + 0x00, 0x0a, 0x0b, 0x0c, 0x0d, 0x0e, 0x0f, 0x00, // @ABCDEFG + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, // HIJKLMNO + }; + + size_t len = strlen(str); + if (len != AUTHENTICATED_MESSAGE_SIZE * 2) { + // Require exact message size + return false; + } + + for (size_t pos = 0; pos < len; pos += 2) { + idx0 = ((uint8_t)str[pos+0] & 0x1F) ^ 0x10; + idx1 = ((uint8_t)str[pos+1] & 0x1F) ^ 0x10; + receivedMessageBuffer[pos/2] = (uint8_t)(hashmap[idx0] << 4) | hashmap[idx1]; + }; + return true; +} \ No newline at end of file diff --git a/src/main.cpp b/src/main.cpp index dbdfeb0..f340f45 100644 --- a/src/main.cpp +++ b/src/main.cpp @@ -15,6 +15,7 @@ #include "server.h" #include "servo.h" #include "config.h" +#include "controller.h" /* Global variables */ @@ -24,22 +25,12 @@ ServoController servo(pwmTimer, servoFrequency, servoPin); AsyncWebServer local(localPort); -// The buffer to hold a received message while it is read -uint8_t receivedMessageBuffer[AUTHENTICATED_MESSAGE_SIZE]; - -/* Event callbacks */ - -SesameEvent handleReceivedMessage(AuthenticatedMessage* payload, AuthenticatedMessage* response); +SesameController controller(&server, &servo, &local); // Forward declare monitoring functions void ensureWiFiConnection(uint32_t time); void ensureWebSocketConnection(uint32_t time); -void sendFailureResponse(AsyncWebServerRequest *request, SesameEvent event); -void sendMessageResponse(AsyncWebServerRequest *request, SesameEvent event, AuthenticatedMessage* message); -void sendResponse(AsyncWebServerRequest *request, uint8_t* buffer, uint8_t size); -void hexToBin(const char * str, uint8_t * bytes, size_t blen); - /* Logic */ void setup() { @@ -53,23 +44,6 @@ void setup() { prepareMessageCounterUsage(); //resetMessageCounters(); printMessageCounters(); - - server.onMessage(handleReceivedMessage); - - local.on("/message", HTTP_POST, [] (AsyncWebServerRequest *request) { - if (!request->hasParam(messageUrlParameter)) { - Serial.println("Missing url parameter"); - sendFailureResponse(request, SesameEvent::InvalidMessageData); - return; - } - String encoded = request->getParam(messageUrlParameter)->value(); - hexToBin(encoded.c_str(), receivedMessageBuffer, AUTHENTICATED_MESSAGE_SIZE); - // Process received message - AuthenticatedMessage* message = (AuthenticatedMessage*) receivedMessageBuffer; - AuthenticatedMessage responseMessage; - SesameEvent event = handleReceivedMessage(message, &responseMessage); - sendMessageResponse(request, event, &responseMessage); - }); } void loop() { @@ -102,100 +76,9 @@ void ensureWebSocketConnection(uint32_t time) { Serial.println("IP address: "); Serial.println(WiFi.localIP()); Serial.println("[INFO] WiFi connected, opening socket"); - server.connectSSL(serverAccessKey); + server.connect(serverAccessKey); configureNTP(timeOffsetToGMT, timeOffsetDaylightSavings, ntpServerUrl); printLocalTime(); local.begin(); } -} - -SesameEvent processMessage(AuthenticatedMessage* message) { - if (!isDeviceIdValid(message->message.device)) { - return SesameEvent::MessageDeviceInvalid; - } - if (!isMessageCounterValid(message->message.id, message->message.device)) { - return SesameEvent::MessageCounterInvalid; - } - if (!isMessageTimeAcceptable(message->message.time)) { - return SesameEvent::MessageTimeMismatch; - } - if (!isAuthenticMessage(message, remoteKey, keySize)) { - return SesameEvent::MessageAuthenticationFailed; - } - return SesameEvent::MessageAccepted; -} - -/** - * Process received binary data. - * - * Checks whether the received data is a valid and unused key, - * and then signals that the motor should move. - * - * @param payload The pointer to the received data. - * @param length The number of bytes received. - * @return The event to signal to the server. - */ -SesameEvent handleReceivedMessage(AuthenticatedMessage* message, AuthenticatedMessage* response) { - SesameEvent event = processMessage(message); - - // Only open when message is valid - if (event == SesameEvent::MessageAccepted) { - didUseMessageCounter(message->message.id, message->message.device); - // Move servo - servo.pressButton(); - Serial.printf("[Info] Accepted message %d\n", message->message.id); - } - - // Create response for all cases - response->message.time = getEpochTime(); - response->message.id = getNextMessageCounter(message->message.device); - if (!authenticateMessage(response, localKey, keySize)) { - return SesameEvent::MessageAuthenticationFailed; - } - return event; -} - -void sendFailureResponse(AsyncWebServerRequest *request, SesameEvent event) { - uint8_t response = static_cast(event); - sendResponse(request, &response, 1); -} - -void sendMessageResponse(AsyncWebServerRequest *request, SesameEvent event, AuthenticatedMessage* message) { - uint8_t response[AUTHENTICATED_MESSAGE_SIZE+1]; - response[0] = static_cast(event); - memcpy(response+1, (uint8_t*) message, AUTHENTICATED_MESSAGE_SIZE); - sendResponse(request, response, AUTHENTICATED_MESSAGE_SIZE+1); -} - -void sendResponse(AsyncWebServerRequest *request, uint8_t* buffer, uint8_t size) { - request->send_P(200, "application/octet-stream", buffer, size); - Serial.printf("[INFO] Local response %d\n", buffer[0]); -} - -// Based on https://stackoverflow.com/a/23898449/266720 -void hexToBin(const char * str, uint8_t * bytes, size_t blen) { - uint8_t idx0, idx1; - - // mapping of ASCII characters to hex values - const uint8_t hashmap[] = { - 0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, // 01234567 - 0x08, 0x09, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, // 89:;<=>? - 0x00, 0x0a, 0x0b, 0x0c, 0x0d, 0x0e, 0x0f, 0x00, // @ABCDEFG - 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, // HIJKLMNO - }; - - memset(bytes, 0, blen); - - size_t len = strlen(str); - if (len % 2) { - // Require two chars per byte - return; - } - size_t end = min(blen*2, len); - - for (size_t pos = 0; pos < end; pos += 2) { - idx0 = ((uint8_t)str[pos+0] & 0x1F) ^ 0x10; - idx1 = ((uint8_t)str[pos+1] & 0x1F) ^ 0x10; - bytes[pos/2] = (uint8_t)(hashmap[idx0] << 4) | hashmap[idx1]; - }; } \ No newline at end of file diff --git a/src/server.cpp b/src/server.cpp index 6989fe5..db74aa1 100644 --- a/src/server.cpp +++ b/src/server.cpp @@ -10,38 +10,30 @@ ServerConnection::ServerConnection(const char* url, int port, const char* path) } void ServerConnection::connect(const char* key, uint32_t reconnectTime) { - webSocket.begin(url, port, path); - registerEventCallback(); - reconnectAfter(reconnectTime); -} - -void ServerConnection::connectSSL(const char* key, uint32_t reconnectTime) { if (socketIsConnected) { return; } + if (controller == NULL) { + Serial.println("[ERROR] No callbacks set for server"); + return; + } this->key = key; + webSocket.beginSSL(url, port, path); - registerEventCallback(); - reconnectAfter(reconnectTime); + + std::function f = [this](WStype_t type, uint8_t *payload, size_t length) { + this->webSocketEventHandler(type, payload, length); + }; + webSocket.onEvent(f); + webSocket.setReconnectInterval(reconnectTime); } void ServerConnection::loop() { webSocket.loop(); } -void ServerConnection::onMessage(MessageCallback callback) { - messageCallback = callback; -} - -void ServerConnection::reconnectAfter(uint32_t reconnectTime) { - webSocket.setReconnectInterval(reconnectTime); -} - -void ServerConnection::registerEventCallback() { - std::function f = [this](WStype_t type, uint8_t *payload, size_t length) { - this->webSocketEventHandler(type, payload, length); - }; - webSocket.onEvent(f); +void ServerConnection::setCallbackHandler(ServerConnectionCallbacks* callbacks) { + controller = callbacks; } void ServerConnection::webSocketEventHandler(WStype_t type, uint8_t * payload, size_t length) { @@ -57,10 +49,10 @@ switch(type) { webSocket.enableHeartbeat(pingInterval, pongTimeout, disconnectTimeoutCount); break; case WStype_TEXT: - sendFailureResponse(SesameEvent::TextReceived); + controller->sendServerError(SesameEvent::TextReceived); break; case WStype_BIN: - processReceivedBytes(payload, length); + controller->handleServerMessage(payload, length); break; case WStype_PONG: break; @@ -70,36 +62,11 @@ switch(type) { case WStype_FRAGMENT_BIN_START: case WStype_FRAGMENT: case WStype_FRAGMENT_FIN: - sendFailureResponse(SesameEvent::UnexpectedSocketEvent); + controller->sendServerError(SesameEvent::UnexpectedSocketEvent); break; } } -void ServerConnection::processReceivedBytes(uint8_t* payload, size_t length) { - if (length != AUTHENTICATED_MESSAGE_SIZE) { - sendFailureResponse(SesameEvent::InvalidMessageData); - return; - } - AuthenticatedMessage* message = (AuthenticatedMessage*) payload; - if (messageCallback == NULL) { - sendFailureResponse(SesameEvent::MessageAuthenticationFailed); - return; - } - AuthenticatedMessage responseMessage; - SesameEvent event = messageCallback(message, &responseMessage); - sendResponse(event, &responseMessage); -} - -void ServerConnection::sendFailureResponse(SesameEvent event) { - uint8_t response = static_cast(event); - webSocket.sendBIN(&response, 1); - Serial.printf("[INFO] Socket failure %d\n", response); -} - -void ServerConnection::sendResponse(SesameEvent event, AuthenticatedMessage* message) { - uint8_t response[AUTHENTICATED_MESSAGE_SIZE+1]; - response[0] = static_cast(event); - memcpy(response+1, (uint8_t*) message, AUTHENTICATED_MESSAGE_SIZE); - webSocket.sendBIN(response, AUTHENTICATED_MESSAGE_SIZE+1); - Serial.printf("[INFO] Socket response %d\n", response[0]); +void ServerConnection::sendResponse(uint8_t* buffer, uint16_t length) { + webSocket.sendBIN(buffer, length); } \ No newline at end of file