Add device id to support multiple remotes

This commit is contained in:
Christoph Hagen 2023-08-07 15:17:04 +02:00
parent 8b196981ef
commit 360f3a1478
5 changed files with 58 additions and 38 deletions

View File

@ -12,6 +12,9 @@ constexpr uint32_t serialBaudRate = 115200;
/* Keys */ /* Keys */
// The number of remote devices
constexpr int remoteDeviceCount = 1;
// The size of the symmetric keys used for signing and verifying messages // The size of the symmetric keys used for signing and verifying messages
constexpr size_t keySize = 32; constexpr size_t keySize = 32;

View File

@ -1,6 +1,7 @@
#pragma once #pragma once
#include <stdint.h> #include <stdint.h>
#include "config.h"
/** /**
* @brief The size of the message counter in bytes (uint32_t) * @brief The size of the message counter in bytes (uint32_t)
@ -72,13 +73,15 @@ void prepareMessageCounterUsage();
* *
* @return The next counter to use by the remote * @return The next counter to use by the remote
*/ */
uint32_t getNextMessageCounter(); uint32_t getNextMessageCounter(uint8_t deviceId);
/** /**
* @brief Print info about the current message counter to the serial output * @brief Print info about the current message counter to the serial output
* *
*/ */
void printMessageCounter(); void printMessageCounters();
bool isDeviceIdValid(uint8_t deviceId);
/** /**
* @brief Check if a received counter is valid * @brief Check if a received counter is valid
@ -90,7 +93,7 @@ void printMessageCounter();
* @return true The counter is valid * @return true The counter is valid
* @return false The counter belongs to an old message * @return false The counter belongs to an old message
*/ */
bool isMessageCounterValid(uint32_t counter); bool isMessageCounterValid(uint32_t counter, uint8_t deviceId);
/** /**
* @brief Mark a counter of a message as used. * @brief Mark a counter of a message as used.
@ -101,7 +104,7 @@ bool isMessageCounterValid(uint32_t counter);
* *
* @param counter The counter used in the last message. * @param counter The counter used in the last message.
*/ */
void didUseMessageCounter(uint32_t counter); void didUseMessageCounter(uint32_t counter, uint8_t deviceId);
/** /**
* @brief Reset the message counter. * @brief Reset the message counter.
@ -111,4 +114,4 @@ void didUseMessageCounter(uint32_t counter);
* used for replay attacks. * used for replay attacks.
* *
*/ */
void resetMessageCounter(); void resetMessageCounters();

View File

@ -45,6 +45,11 @@ typedef struct {
*/ */
uint32_t id; uint32_t id;
/**
* @brief The id of the device sending the message
*/
uint8_t device;
} Message; } Message;
/** /**
@ -97,6 +102,7 @@ enum class SesameEvent {
MessageTimeMismatch = 5, MessageTimeMismatch = 5,
MessageCounterInvalid = 6, MessageCounterInvalid = 6,
MessageAccepted = 7, MessageAccepted = 7,
MessageDeviceInvalid = 8,
}; };
/** /**

View File

@ -4,11 +4,6 @@
#include <time.h> #include <time.h>
#include <EEPROM.h> #include <EEPROM.h>
/**
* @brief The size of the message counter in bytes (uint32_t)
*/
#define MESSAGE_COUNTER_SIZE sizeof(uint32_t)
/** /**
* @brief The allowed discrepancy between the time of a received message * @brief The allowed discrepancy between the time of a received message
* and the device time (in seconds) * and the device time (in seconds)
@ -63,40 +58,51 @@ bool isMessageTimeAcceptable(uint32_t t) {
} }
void prepareMessageCounterUsage() { void prepareMessageCounterUsage() {
EEPROM.begin(MESSAGE_COUNTER_SIZE); EEPROM.begin(MESSAGE_COUNTER_SIZE * remoteDeviceCount);
} }
uint32_t getNextMessageCounter() { uint32_t getNextMessageCounter(uint8_t deviceId) {
uint32_t counter = (uint32_t) EEPROM.read(0) << 24; int offset = deviceId * MESSAGE_COUNTER_SIZE;
counter += (uint32_t) EEPROM.read(1) << 16; uint32_t counter = (uint32_t) EEPROM.read(offset + 0) << 24;
counter += (uint32_t) EEPROM.read(2) << 8; counter += (uint32_t) EEPROM.read(offset + 1) << 16;
counter += (uint32_t) EEPROM.read(3); counter += (uint32_t) EEPROM.read(offset + 2) << 8;
counter += (uint32_t) EEPROM.read(offset + 3);
return counter; return counter;
} }
void printMessageCounter() { void printMessageCounters() {
Serial.printf("[INFO] Next message number: %u\n", getNextMessageCounter()); Serial.print("[INFO] Next message numbers:");
for (uint8_t i = 0; i < remoteDeviceCount; i += 1) {
Serial.printf(" %u", getNextMessageCounter(i));
}
Serial.println("");
} }
bool isMessageCounterValid(uint32_t counter) { bool isDeviceIdValid(uint8_t deviceId) {
return counter >= getNextMessageCounter(); return deviceId < remoteDeviceCount;
} }
void setMessageCounter(uint32_t counter) { bool isMessageCounterValid(uint32_t counter, uint8_t deviceId) {
EEPROM.write(0, (counter >> 24) & 0xFF); return counter >= getNextMessageCounter(deviceId);
EEPROM.write(1, (counter >> 16) & 0xFF); }
EEPROM.write(2, (counter >> 8) & 0xFF);
EEPROM.write(3, counter & 0xFF);
void setMessageCounter(uint32_t counter, uint8_t deviceId) {
int offset = deviceId * MESSAGE_COUNTER_SIZE;
EEPROM.write(offset + 0, (counter >> 24) & 0xFF);
EEPROM.write(offset + 1, (counter >> 16) & 0xFF);
EEPROM.write(offset + 2, (counter >> 8) & 0xFF);
EEPROM.write(offset + 3, counter & 0xFF);
EEPROM.commit(); EEPROM.commit();
} }
void didUseMessageCounter(uint32_t counter) { void didUseMessageCounter(uint32_t counter, uint8_t deviceId) {
// Store the next counter, so that resetting starts at 0 // Store the next counter, so that resetting starts at 0
setMessageCounter(counter+1); setMessageCounter(counter+1, deviceId);
} }
void resetMessageCounter() { void resetMessageCounters() {
setMessageCounter(0); for (uint8_t i = 0; i < remoteDeviceCount; i += 1) {
Serial.println("[WARN] Message counter reset"); setMessageCounter(0, i);
}
Serial.println("[WARN] Message counters reset");
} }

View File

@ -52,7 +52,7 @@ void setup() {
prepareMessageCounterUsage(); prepareMessageCounterUsage();
//resetMessageCounter(); //resetMessageCounter();
printMessageCounter(); printMessageCounters();
server.onMessage(handleReceivedMessage); server.onMessage(handleReceivedMessage);
@ -72,9 +72,6 @@ void setup() {
}); });
} }
uint32_t nextWifiReconnect = 0;
bool isReconnecting = false;
void loop() { void loop() {
uint32_t time = millis(); uint32_t time = millis();
@ -85,6 +82,9 @@ void loop() {
ensureWebSocketConnection(time); ensureWebSocketConnection(time);
} }
uint32_t nextWifiReconnect = 0;
bool isReconnecting = false;
void ensureWiFiConnection(uint32_t time) { void ensureWiFiConnection(uint32_t time) {
// Reconnect to WiFi // Reconnect to WiFi
if(time > nextWifiReconnect && WiFi.status() != WL_CONNECTED) { if(time > nextWifiReconnect && WiFi.status() != WL_CONNECTED) {
@ -111,7 +111,10 @@ void ensureWebSocketConnection(uint32_t time) {
} }
SesameEvent processMessage(AuthenticatedMessage* message) { SesameEvent processMessage(AuthenticatedMessage* message) {
if (!isMessageCounterValid(message->message.id)) { if (!isDeviceIdValid(message->message.device)) {
return SesameEvent::MessageDeviceInvalid;
}
if (!isMessageCounterValid(message->message.id, message->message.device)) {
return SesameEvent::MessageCounterInvalid; return SesameEvent::MessageCounterInvalid;
} }
if (!isMessageTimeAcceptable(message->message.time)) { if (!isMessageTimeAcceptable(message->message.time)) {
@ -138,7 +141,7 @@ SesameEvent handleReceivedMessage(AuthenticatedMessage* message, AuthenticatedMe
// Only open when message is valid // Only open when message is valid
if (event == SesameEvent::MessageAccepted) { if (event == SesameEvent::MessageAccepted) {
didUseMessageCounter(message->message.id); didUseMessageCounter(message->message.id, message->message.device);
// Move servo // Move servo
servo.pressButton(); servo.pressButton();
Serial.printf("[Info] Accepted message %d\n", message->message.id); Serial.printf("[Info] Accepted message %d\n", message->message.id);
@ -146,14 +149,13 @@ SesameEvent handleReceivedMessage(AuthenticatedMessage* message, AuthenticatedMe
// Create response for all cases // Create response for all cases
response->message.time = getEpochTime(); response->message.time = getEpochTime();
response->message.id = getNextMessageCounter(); response->message.id = getNextMessageCounter(message->message.device);
if (!authenticateMessage(response, localKey, keySize)) { if (!authenticateMessage(response, localKey, keySize)) {
return SesameEvent::MessageAuthenticationFailed; return SesameEvent::MessageAuthenticationFailed;
} }
return event; return event;
} }
void sendFailureResponse(AsyncWebServerRequest *request, SesameEvent event) { void sendFailureResponse(AsyncWebServerRequest *request, SesameEvent event) {
uint8_t response = static_cast<uint8_t>(event); uint8_t response = static_cast<uint8_t>(event);
sendResponse(request, &response, 1); sendResponse(request, &response, 1);