From 67169240f925f08ac297d9439f2ac00acb1662a5 Mon Sep 17 00:00:00 2001 From: Christoph Hagen Date: Wed, 9 Aug 2023 13:13:38 +0200 Subject: [PATCH] Create storage class --- include/controller.h | 7 ++-- include/fresh.h | 67 +---------------------------------- include/message.h | 2 ++ include/storage.h | 83 ++++++++++++++++++++++++++++++++++++++++++++ src/controller.cpp | 21 +++++++---- src/fresh.cpp | 53 +--------------------------- src/main.cpp | 6 ++-- src/storage.cpp | 53 ++++++++++++++++++++++++++++ 8 files changed, 162 insertions(+), 130 deletions(-) create mode 100644 include/storage.h create mode 100644 src/storage.cpp diff --git a/include/controller.h b/include/controller.h index b5ef989..946aff8 100644 --- a/include/controller.h +++ b/include/controller.h @@ -3,18 +3,22 @@ #include "server.h" #include "servo.h" #include "message.h" +#include "storage.h" #include class SesameController: public ServerConnectionCallbacks { public: - SesameController(ServerConnection* server, ServoController* servo, AsyncWebServer* local); + SesameController(ServerConnection* server, ServoController* servo, AsyncWebServer* local, uint8_t remoteDeviceCount); + + void configure(); private: ServerConnection* server; ServoController* servo; AsyncWebServer* local; + Storage storage; // The buffer to hold a received message while it is read uint8_t receivedMessageBuffer[AUTHENTICATED_MESSAGE_SIZE]; @@ -37,6 +41,5 @@ private: uint16_t prepareResponseBuffer(SesameEvent event, uint8_t deviceId = 0); void sendPreparedLocalResponse(AsyncWebServerRequest *request); - void sendLocalResponse(AsyncWebServerRequest *request, SesameEvent event, uint8_t deviceId = 0); void sendPreparedServerResponse(); }; \ No newline at end of file diff --git a/include/fresh.h b/include/fresh.h index cd017cc..6eaa9da 100644 --- a/include/fresh.h +++ b/include/fresh.h @@ -3,11 +3,6 @@ #include #include "config.h" -/** - * @brief The size of the message counter in bytes (uint32_t) - */ -#define MESSAGE_COUNTER_SIZE sizeof(uint32_t) - /** * @brief Configure an NTP server to get the current time * @@ -54,64 +49,4 @@ void setMessageTimeAllowedOffset(uint32_t offset); * @return true The time is within the acceptable offset of the local time * @return false The message time is invalid */ -bool isMessageTimeAcceptable(uint32_t messageTime); - -/** - * @brief Initialize the use of the message counter API - * - * The message counter is stored in EEPROM, which must be initialized before use. - * - * @note The ESP32 does not have a true EEPROM, - * which is emulated using a section of the flash memory. - */ -void prepareMessageCounterUsage(); - -/** - * @brief Get the expected count for the next message. - * - * The counter is stored in EEPROM to persist across restarts - * - * @return The next counter to use by the remote - */ -uint32_t getNextMessageCounter(uint8_t deviceId); - -/** - * @brief Print info about the current message counter to the serial output - * - */ -void printMessageCounters(); - -bool isDeviceIdValid(uint8_t deviceId); - -/** - * @brief Check if a received counter is valid - * - * The counter is valid if it is larger than the previous counter - * (larger or equal to the next expected counter). - * - * @param counter The counter to check - * @return true The counter is valid - * @return false The counter belongs to an old message - */ -bool isMessageCounterValid(uint32_t counter, uint8_t deviceId); - -/** - * @brief Mark a counter of a message as used. - * - * The counter value is stored in EEPROM to persist across restarts. - * - * All messages with counters lower than the given one will become invalid. - * - * @param counter The counter used in the last message. - */ -void didUseMessageCounter(uint32_t counter, uint8_t deviceId); - -/** - * @brief Reset the message counter. - * - * @warning The counter should never be reset in production environments, - * and only together with a new secret key. Otherwise old messages may be - * used for replay attacks. - * - */ -void resetMessageCounters(); +bool isMessageTimeAcceptable(uint32_t messageTime); \ No newline at end of file diff --git a/include/message.h b/include/message.h index 58ef742..97b7872 100644 --- a/include/message.h +++ b/include/message.h @@ -53,6 +53,8 @@ typedef struct { } Message; +constexpr size_t messageCounterSize = sizeof(uint32_t); + /** * @brief An authenticated message by the mobile device to command unlocking. * diff --git a/include/storage.h b/include/storage.h new file mode 100644 index 0000000..530e4bd --- /dev/null +++ b/include/storage.h @@ -0,0 +1,83 @@ +#pragma once + +#include + +class Storage { + +public: + + Storage(uint8_t remoteDeviceCount) : remoteDeviceCount(remoteDeviceCount) { }; + + /** + * @brief Initialize the use of the message counter API + * + * The message counter is stored in EEPROM, which must be initialized before use. + * + * @note The ESP32 does not have a true EEPROM, + * which is emulated using a section of the flash memory. + */ + void configure(); + + /** + * @brief Check if a device ID is allowed + * + * @param deviceId The ID to check + * @return true The id is valid + * @return false The id is invalid + */ + bool isDeviceIdValid(uint8_t deviceId); + + /** + * @brief Check if a received counter is valid + * + * The counter is valid if it is larger than the previous counter + * (larger or equal to the next expected counter). + * + * @param counter The counter to check + * @return true The counter is valid + * @return false The counter belongs to an old message + */ + bool isMessageCounterValid(uint32_t counter, uint8_t deviceId); + + /** + * @brief Mark a counter of a message as used. + * + * The counter value is stored in EEPROM to persist across restarts. + * + * All messages with counters lower than the given one will become invalid. + * + * @param counter The counter used in the last message. + */ + void didUseMessageCounter(uint32_t counter, uint8_t deviceId); + + /** + * @brief Get the expected count for the next message. + * + * The counter is stored in EEPROM to persist across restarts + * + * @return The next counter to use by the remote + */ + uint32_t getNextMessageCounter(uint8_t deviceId); + + /** + * @brief Print info about the current message counter to the serial output + * + */ + void printMessageCounters(); + + /** + * @brief Reset the message counter. + * + * @warning The counter should never be reset in production environments, + * and only together with a new secret key. Otherwise old messages may be + * used for replay attacks. + * + */ + void resetMessageCounters(); + +private: + + uint8_t remoteDeviceCount; + + void setMessageCounter(uint32_t counter, uint8_t deviceId); +}; \ No newline at end of file diff --git a/src/controller.cpp b/src/controller.cpp index 36f3685..9697a0f 100644 --- a/src/controller.cpp +++ b/src/controller.cpp @@ -3,12 +3,18 @@ #include "fresh.h" #include "crypto.h" -SesameController::SesameController(ServerConnection* server, ServoController* servo, AsyncWebServer* local) : -server(server), servo(servo), local(local) { +SesameController::SesameController(ServerConnection* server, ServoController* servo, AsyncWebServer* local, uint8_t remoteDeviceCount) : + server(server), servo(servo), local(local), storage(remoteDeviceCount) { + // Set up response buffer responseStatus = (SesameEvent*) responseBuffer; responseMessage = (AuthenticatedMessage*) (responseBuffer + 1); +} +void SesameController::configure() { + // Prepare EEPROM for reading and writing + storage.configure(); + // Direct messages and errors over the websocket to the controller server->setCallbackHandler(this); @@ -17,6 +23,9 @@ server(server), servo(servo), local(local) { this->handleLocalMessage(request); this->sendPreparedLocalResponse(request); }); + + //storage.resetMessageCounters(); + storage.printMessageCounters(); } // MARK: Local @@ -84,17 +93,17 @@ SesameEvent SesameController::verifyAndProcessReceivedMessage(AuthenticatedMessa if (!isAuthenticMessage(message, remoteKey, keySize)) { return SesameEvent::MessageAuthenticationFailed; } - if (!isDeviceIdValid(message->message.device)) { + if (!storage.isDeviceIdValid(message->message.device)) { return SesameEvent::MessageDeviceInvalid; } - if (!isMessageCounterValid(message->message.id, message->message.device)) { + if (!storage.isMessageCounterValid(message->message.id, message->message.device)) { return SesameEvent::MessageCounterInvalid; } if (!isMessageTimeAcceptable(message->message.time)) { return SesameEvent::MessageTimeMismatch; } - didUseMessageCounter(message->message.id, message->message.device); + storage.didUseMessageCounter(message->message.id, message->message.device); // Move servo servo->pressButton(); Serial.printf("[Info] Accepted message %d\n", message->message.id); @@ -119,7 +128,7 @@ uint16_t SesameController::prepareResponseBuffer(SesameEvent event, uint8_t devi return 1; } responseMessage->message.time = getEpochTime(); - responseMessage->message.id = getNextMessageCounter(deviceId); + responseMessage->message.id = storage.getNextMessageCounter(deviceId); responseMessage->message.device = deviceId; if (!authenticateMessage(responseMessage, localKey, keySize)) { *responseStatus = SesameEvent::InvalidResponseAuthentication; diff --git a/src/fresh.cpp b/src/fresh.cpp index a891a9b..8b770f2 100644 --- a/src/fresh.cpp +++ b/src/fresh.cpp @@ -2,7 +2,6 @@ #include #include -#include /** * @brief The allowed discrepancy between the time of a received message @@ -55,54 +54,4 @@ bool isMessageTimeAcceptable(uint32_t t) { return false; } return true; -} - -void prepareMessageCounterUsage() { - EEPROM.begin(MESSAGE_COUNTER_SIZE * remoteDeviceCount); -} - -uint32_t getNextMessageCounter(uint8_t deviceId) { - int offset = deviceId * MESSAGE_COUNTER_SIZE; - uint32_t counter = (uint32_t) EEPROM.read(offset + 0) << 24; - counter += (uint32_t) EEPROM.read(offset + 1) << 16; - counter += (uint32_t) EEPROM.read(offset + 2) << 8; - counter += (uint32_t) EEPROM.read(offset + 3); - return counter; -} - -void printMessageCounters() { - Serial.print("[INFO] Next message numbers:"); - for (uint8_t i = 0; i < remoteDeviceCount; i += 1) { - Serial.printf(" %u", getNextMessageCounter(i)); - } - Serial.println(""); -} - -bool isDeviceIdValid(uint8_t deviceId) { - return deviceId < remoteDeviceCount; -} - -bool isMessageCounterValid(uint32_t counter, uint8_t deviceId) { - return counter >= getNextMessageCounter(deviceId); -} - -void setMessageCounter(uint32_t counter, uint8_t deviceId) { - int offset = deviceId * MESSAGE_COUNTER_SIZE; - EEPROM.write(offset + 0, (counter >> 24) & 0xFF); - EEPROM.write(offset + 1, (counter >> 16) & 0xFF); - EEPROM.write(offset + 2, (counter >> 8) & 0xFF); - EEPROM.write(offset + 3, counter & 0xFF); - EEPROM.commit(); -} - -void didUseMessageCounter(uint32_t counter, uint8_t deviceId) { - // Store the next counter, so that resetting starts at 0 - setMessageCounter(counter+1, deviceId); -} - -void resetMessageCounters() { - for (uint8_t i = 0; i < remoteDeviceCount; i += 1) { - setMessageCounter(0, i); - } - Serial.println("[WARN] Message counters reset"); -} +} \ No newline at end of file diff --git a/src/main.cpp b/src/main.cpp index f340f45..fc2ba50 100644 --- a/src/main.cpp +++ b/src/main.cpp @@ -25,7 +25,7 @@ ServoController servo(pwmTimer, servoFrequency, servoPin); AsyncWebServer local(localPort); -SesameController controller(&server, &servo, &local); +SesameController controller(&server, &servo, &local, remoteDeviceCount); // Forward declare monitoring functions void ensureWiFiConnection(uint32_t time); @@ -41,9 +41,7 @@ void setup() { servo.configure(lockOpeningDuration, servoPressedState, servoReleasedState); Serial.println("[INFO] Servo configured"); - prepareMessageCounterUsage(); - //resetMessageCounters(); - printMessageCounters(); + controller.configure(); } void loop() { diff --git a/src/storage.cpp b/src/storage.cpp new file mode 100644 index 0000000..6218453 --- /dev/null +++ b/src/storage.cpp @@ -0,0 +1,53 @@ +#include "storage.h" +#include "message.h" +#include + +void Storage::configure() { + EEPROM.begin(messageCounterSize * remoteDeviceCount); +} + +bool Storage::isDeviceIdValid(uint8_t deviceId) { + return deviceId < remoteDeviceCount; +} + +bool Storage::isMessageCounterValid(uint32_t counter, uint8_t deviceId) { + return counter >= getNextMessageCounter(deviceId); +} + +void Storage::didUseMessageCounter(uint32_t counter, uint8_t deviceId) { + // Store the next counter, so that resetting starts at 0 + setMessageCounter(counter+1, deviceId); +} + +void Storage::setMessageCounter(uint32_t counter, uint8_t deviceId) { + int offset = deviceId * messageCounterSize; + EEPROM.write(offset + 0, (counter >> 24) & 0xFF); + EEPROM.write(offset + 1, (counter >> 16) & 0xFF); + EEPROM.write(offset + 2, (counter >> 8) & 0xFF); + EEPROM.write(offset + 3, counter & 0xFF); + EEPROM.commit(); +} + +uint32_t Storage::getNextMessageCounter(uint8_t deviceId) { + int offset = deviceId * messageCounterSize; + uint32_t counter = (uint32_t) EEPROM.read(offset + 0) << 24; + counter += (uint32_t) EEPROM.read(offset + 1) << 16; + counter += (uint32_t) EEPROM.read(offset + 2) << 8; + counter += (uint32_t) EEPROM.read(offset + 3); + return counter; +} + +void Storage::printMessageCounters() { + Serial.print("[INFO] Next message numbers:"); + for (uint8_t i = 0; i < remoteDeviceCount; i += 1) { + Serial.printf(" %u", getNextMessageCounter(i)); + } + Serial.println(""); +} + +void Storage::resetMessageCounters() { + for (uint8_t i = 0; i < remoteDeviceCount; i += 1) { + setMessageCounter(0, i); + } + Serial.println("[WARN] Message counters reset"); +}