Add device id to support multiple remotes

This commit is contained in:
Christoph Hagen 2023-08-07 15:17:04 +02:00
parent 8b196981ef
commit 360f3a1478
5 changed files with 58 additions and 38 deletions

View File

@ -12,6 +12,9 @@ constexpr uint32_t serialBaudRate = 115200;
/* Keys */
// The number of remote devices
constexpr int remoteDeviceCount = 1;
// The size of the symmetric keys used for signing and verifying messages
constexpr size_t keySize = 32;

View File

@ -1,6 +1,7 @@
#pragma once
#include <stdint.h>
#include "config.h"
/**
* @brief The size of the message counter in bytes (uint32_t)
@ -72,13 +73,15 @@ void prepareMessageCounterUsage();
*
* @return The next counter to use by the remote
*/
uint32_t getNextMessageCounter();
uint32_t getNextMessageCounter(uint8_t deviceId);
/**
* @brief Print info about the current message counter to the serial output
*
*/
void printMessageCounter();
void printMessageCounters();
bool isDeviceIdValid(uint8_t deviceId);
/**
* @brief Check if a received counter is valid
@ -90,7 +93,7 @@ void printMessageCounter();
* @return true The counter is valid
* @return false The counter belongs to an old message
*/
bool isMessageCounterValid(uint32_t counter);
bool isMessageCounterValid(uint32_t counter, uint8_t deviceId);
/**
* @brief Mark a counter of a message as used.
@ -101,7 +104,7 @@ bool isMessageCounterValid(uint32_t counter);
*
* @param counter The counter used in the last message.
*/
void didUseMessageCounter(uint32_t counter);
void didUseMessageCounter(uint32_t counter, uint8_t deviceId);
/**
* @brief Reset the message counter.
@ -111,4 +114,4 @@ void didUseMessageCounter(uint32_t counter);
* used for replay attacks.
*
*/
void resetMessageCounter();
void resetMessageCounters();

View File

@ -45,6 +45,11 @@ typedef struct {
*/
uint32_t id;
/**
* @brief The id of the device sending the message
*/
uint8_t device;
} Message;
/**
@ -97,6 +102,7 @@ enum class SesameEvent {
MessageTimeMismatch = 5,
MessageCounterInvalid = 6,
MessageAccepted = 7,
MessageDeviceInvalid = 8,
};
/**

View File

@ -4,11 +4,6 @@
#include <time.h>
#include <EEPROM.h>
/**
* @brief The size of the message counter in bytes (uint32_t)
*/
#define MESSAGE_COUNTER_SIZE sizeof(uint32_t)
/**
* @brief The allowed discrepancy between the time of a received message
* and the device time (in seconds)
@ -63,40 +58,51 @@ bool isMessageTimeAcceptable(uint32_t t) {
}
void prepareMessageCounterUsage() {
EEPROM.begin(MESSAGE_COUNTER_SIZE);
EEPROM.begin(MESSAGE_COUNTER_SIZE * remoteDeviceCount);
}
uint32_t getNextMessageCounter() {
uint32_t counter = (uint32_t) EEPROM.read(0) << 24;
counter += (uint32_t) EEPROM.read(1) << 16;
counter += (uint32_t) EEPROM.read(2) << 8;
counter += (uint32_t) EEPROM.read(3);
uint32_t getNextMessageCounter(uint8_t deviceId) {
int offset = deviceId * MESSAGE_COUNTER_SIZE;
uint32_t counter = (uint32_t) EEPROM.read(offset + 0) << 24;
counter += (uint32_t) EEPROM.read(offset + 1) << 16;
counter += (uint32_t) EEPROM.read(offset + 2) << 8;
counter += (uint32_t) EEPROM.read(offset + 3);
return counter;
}
void printMessageCounter() {
Serial.printf("[INFO] Next message number: %u\n", getNextMessageCounter());
void printMessageCounters() {
Serial.print("[INFO] Next message numbers:");
for (uint8_t i = 0; i < remoteDeviceCount; i += 1) {
Serial.printf(" %u", getNextMessageCounter(i));
}
Serial.println("");
}
bool isMessageCounterValid(uint32_t counter) {
return counter >= getNextMessageCounter();
bool isDeviceIdValid(uint8_t deviceId) {
return deviceId < remoteDeviceCount;
}
void setMessageCounter(uint32_t counter) {
EEPROM.write(0, (counter >> 24) & 0xFF);
EEPROM.write(1, (counter >> 16) & 0xFF);
EEPROM.write(2, (counter >> 8) & 0xFF);
EEPROM.write(3, counter & 0xFF);
bool isMessageCounterValid(uint32_t counter, uint8_t deviceId) {
return counter >= getNextMessageCounter(deviceId);
}
void setMessageCounter(uint32_t counter, uint8_t deviceId) {
int offset = deviceId * MESSAGE_COUNTER_SIZE;
EEPROM.write(offset + 0, (counter >> 24) & 0xFF);
EEPROM.write(offset + 1, (counter >> 16) & 0xFF);
EEPROM.write(offset + 2, (counter >> 8) & 0xFF);
EEPROM.write(offset + 3, counter & 0xFF);
EEPROM.commit();
}
void didUseMessageCounter(uint32_t counter) {
void didUseMessageCounter(uint32_t counter, uint8_t deviceId) {
// Store the next counter, so that resetting starts at 0
setMessageCounter(counter+1);
setMessageCounter(counter+1, deviceId);
}
void resetMessageCounter() {
setMessageCounter(0);
Serial.println("[WARN] Message counter reset");
void resetMessageCounters() {
for (uint8_t i = 0; i < remoteDeviceCount; i += 1) {
setMessageCounter(0, i);
}
Serial.println("[WARN] Message counters reset");
}

View File

@ -52,7 +52,7 @@ void setup() {
prepareMessageCounterUsage();
//resetMessageCounter();
printMessageCounter();
printMessageCounters();
server.onMessage(handleReceivedMessage);
@ -72,9 +72,6 @@ void setup() {
});
}
uint32_t nextWifiReconnect = 0;
bool isReconnecting = false;
void loop() {
uint32_t time = millis();
@ -85,6 +82,9 @@ void loop() {
ensureWebSocketConnection(time);
}
uint32_t nextWifiReconnect = 0;
bool isReconnecting = false;
void ensureWiFiConnection(uint32_t time) {
// Reconnect to WiFi
if(time > nextWifiReconnect && WiFi.status() != WL_CONNECTED) {
@ -111,7 +111,10 @@ void ensureWebSocketConnection(uint32_t time) {
}
SesameEvent processMessage(AuthenticatedMessage* message) {
if (!isMessageCounterValid(message->message.id)) {
if (!isDeviceIdValid(message->message.device)) {
return SesameEvent::MessageDeviceInvalid;
}
if (!isMessageCounterValid(message->message.id, message->message.device)) {
return SesameEvent::MessageCounterInvalid;
}
if (!isMessageTimeAcceptable(message->message.time)) {
@ -138,7 +141,7 @@ SesameEvent handleReceivedMessage(AuthenticatedMessage* message, AuthenticatedMe
// Only open when message is valid
if (event == SesameEvent::MessageAccepted) {
didUseMessageCounter(message->message.id);
didUseMessageCounter(message->message.id, message->message.device);
// Move servo
servo.pressButton();
Serial.printf("[Info] Accepted message %d\n", message->message.id);
@ -146,14 +149,13 @@ SesameEvent handleReceivedMessage(AuthenticatedMessage* message, AuthenticatedMe
// Create response for all cases
response->message.time = getEpochTime();
response->message.id = getNextMessageCounter();
response->message.id = getNextMessageCounter(message->message.device);
if (!authenticateMessage(response, localKey, keySize)) {
return SesameEvent::MessageAuthenticationFailed;
}
return event;
}
void sendFailureResponse(AsyncWebServerRequest *request, SesameEvent event) {
uint8_t response = static_cast<uint8_t>(event);
sendResponse(request, &response, 1);