diff --git a/include/example_config.h b/include/example_config.h index 11ae6c4..0507c2b 100644 --- a/include/example_config.h +++ b/include/example_config.h @@ -12,6 +12,9 @@ constexpr uint32_t serialBaudRate = 115200; /* Keys */ +// The number of remote devices +constexpr int remoteDeviceCount = 1; + // The size of the symmetric keys used for signing and verifying messages constexpr size_t keySize = 32; diff --git a/include/fresh.h b/include/fresh.h index de8ab60..cd017cc 100644 --- a/include/fresh.h +++ b/include/fresh.h @@ -1,6 +1,7 @@ #pragma once #include +#include "config.h" /** * @brief The size of the message counter in bytes (uint32_t) @@ -72,13 +73,15 @@ void prepareMessageCounterUsage(); * * @return The next counter to use by the remote */ -uint32_t getNextMessageCounter(); +uint32_t getNextMessageCounter(uint8_t deviceId); /** * @brief Print info about the current message counter to the serial output * */ -void printMessageCounter(); +void printMessageCounters(); + +bool isDeviceIdValid(uint8_t deviceId); /** * @brief Check if a received counter is valid @@ -90,7 +93,7 @@ void printMessageCounter(); * @return true The counter is valid * @return false The counter belongs to an old message */ -bool isMessageCounterValid(uint32_t counter); +bool isMessageCounterValid(uint32_t counter, uint8_t deviceId); /** * @brief Mark a counter of a message as used. @@ -101,7 +104,7 @@ bool isMessageCounterValid(uint32_t counter); * * @param counter The counter used in the last message. */ -void didUseMessageCounter(uint32_t counter); +void didUseMessageCounter(uint32_t counter, uint8_t deviceId); /** * @brief Reset the message counter. @@ -111,4 +114,4 @@ void didUseMessageCounter(uint32_t counter); * used for replay attacks. * */ -void resetMessageCounter(); +void resetMessageCounters(); diff --git a/include/message.h b/include/message.h index 8daf7f5..beb6ae8 100644 --- a/include/message.h +++ b/include/message.h @@ -45,6 +45,11 @@ typedef struct { */ uint32_t id; + /** + * @brief The id of the device sending the message + */ + uint8_t device; + } Message; /** @@ -97,6 +102,7 @@ enum class SesameEvent { MessageTimeMismatch = 5, MessageCounterInvalid = 6, MessageAccepted = 7, + MessageDeviceInvalid = 8, }; /** diff --git a/src/fresh.cpp b/src/fresh.cpp index 8ce1e00..a891a9b 100644 --- a/src/fresh.cpp +++ b/src/fresh.cpp @@ -4,11 +4,6 @@ #include #include -/** - * @brief The size of the message counter in bytes (uint32_t) - */ -#define MESSAGE_COUNTER_SIZE sizeof(uint32_t) - /** * @brief The allowed discrepancy between the time of a received message * and the device time (in seconds) @@ -63,40 +58,51 @@ bool isMessageTimeAcceptable(uint32_t t) { } void prepareMessageCounterUsage() { - EEPROM.begin(MESSAGE_COUNTER_SIZE); + EEPROM.begin(MESSAGE_COUNTER_SIZE * remoteDeviceCount); } -uint32_t getNextMessageCounter() { - uint32_t counter = (uint32_t) EEPROM.read(0) << 24; - counter += (uint32_t) EEPROM.read(1) << 16; - counter += (uint32_t) EEPROM.read(2) << 8; - counter += (uint32_t) EEPROM.read(3); +uint32_t getNextMessageCounter(uint8_t deviceId) { + int offset = deviceId * MESSAGE_COUNTER_SIZE; + uint32_t counter = (uint32_t) EEPROM.read(offset + 0) << 24; + counter += (uint32_t) EEPROM.read(offset + 1) << 16; + counter += (uint32_t) EEPROM.read(offset + 2) << 8; + counter += (uint32_t) EEPROM.read(offset + 3); return counter; } -void printMessageCounter() { - Serial.printf("[INFO] Next message number: %u\n", getNextMessageCounter()); +void printMessageCounters() { + Serial.print("[INFO] Next message numbers:"); + for (uint8_t i = 0; i < remoteDeviceCount; i += 1) { + Serial.printf(" %u", getNextMessageCounter(i)); + } + Serial.println(""); } -bool isMessageCounterValid(uint32_t counter) { - return counter >= getNextMessageCounter(); +bool isDeviceIdValid(uint8_t deviceId) { + return deviceId < remoteDeviceCount; } -void setMessageCounter(uint32_t counter) { - EEPROM.write(0, (counter >> 24) & 0xFF); - EEPROM.write(1, (counter >> 16) & 0xFF); - EEPROM.write(2, (counter >> 8) & 0xFF); - EEPROM.write(3, counter & 0xFF); +bool isMessageCounterValid(uint32_t counter, uint8_t deviceId) { + return counter >= getNextMessageCounter(deviceId); +} +void setMessageCounter(uint32_t counter, uint8_t deviceId) { + int offset = deviceId * MESSAGE_COUNTER_SIZE; + EEPROM.write(offset + 0, (counter >> 24) & 0xFF); + EEPROM.write(offset + 1, (counter >> 16) & 0xFF); + EEPROM.write(offset + 2, (counter >> 8) & 0xFF); + EEPROM.write(offset + 3, counter & 0xFF); EEPROM.commit(); } -void didUseMessageCounter(uint32_t counter) { +void didUseMessageCounter(uint32_t counter, uint8_t deviceId) { // Store the next counter, so that resetting starts at 0 - setMessageCounter(counter+1); + setMessageCounter(counter+1, deviceId); } -void resetMessageCounter() { - setMessageCounter(0); - Serial.println("[WARN] Message counter reset"); +void resetMessageCounters() { + for (uint8_t i = 0; i < remoteDeviceCount; i += 1) { + setMessageCounter(0, i); + } + Serial.println("[WARN] Message counters reset"); } diff --git a/src/main.cpp b/src/main.cpp index a142127..e7405eb 100644 --- a/src/main.cpp +++ b/src/main.cpp @@ -52,7 +52,7 @@ void setup() { prepareMessageCounterUsage(); //resetMessageCounter(); - printMessageCounter(); + printMessageCounters(); server.onMessage(handleReceivedMessage); @@ -72,9 +72,6 @@ void setup() { }); } -uint32_t nextWifiReconnect = 0; -bool isReconnecting = false; - void loop() { uint32_t time = millis(); @@ -85,6 +82,9 @@ void loop() { ensureWebSocketConnection(time); } +uint32_t nextWifiReconnect = 0; +bool isReconnecting = false; + void ensureWiFiConnection(uint32_t time) { // Reconnect to WiFi if(time > nextWifiReconnect && WiFi.status() != WL_CONNECTED) { @@ -111,7 +111,10 @@ void ensureWebSocketConnection(uint32_t time) { } SesameEvent processMessage(AuthenticatedMessage* message) { - if (!isMessageCounterValid(message->message.id)) { + if (!isDeviceIdValid(message->message.device)) { + return SesameEvent::MessageDeviceInvalid; + } + if (!isMessageCounterValid(message->message.id, message->message.device)) { return SesameEvent::MessageCounterInvalid; } if (!isMessageTimeAcceptable(message->message.time)) { @@ -138,7 +141,7 @@ SesameEvent handleReceivedMessage(AuthenticatedMessage* message, AuthenticatedMe // Only open when message is valid if (event == SesameEvent::MessageAccepted) { - didUseMessageCounter(message->message.id); + didUseMessageCounter(message->message.id, message->message.device); // Move servo servo.pressButton(); Serial.printf("[Info] Accepted message %d\n", message->message.id); @@ -146,14 +149,13 @@ SesameEvent handleReceivedMessage(AuthenticatedMessage* message, AuthenticatedMe // Create response for all cases response->message.time = getEpochTime(); - response->message.id = getNextMessageCounter(); + response->message.id = getNextMessageCounter(message->message.device); if (!authenticateMessage(response, localKey, keySize)) { return SesameEvent::MessageAuthenticationFailed; } return event; } - void sendFailureResponse(AsyncWebServerRequest *request, SesameEvent event) { uint8_t response = static_cast(event); sendResponse(request, &response, 1);