Compare commits

...

5 Commits

Author SHA1 Message Date
Christoph Hagen
e631ea0a20 Typo 2023-08-07 15:51:59 +02:00
Christoph Hagen
360f3a1478 Add device id to support multiple remotes 2023-08-07 15:17:04 +02:00
Christoph Hagen
8b196981ef Replace base64 with hex 2023-04-11 17:53:47 +02:00
Christoph Hagen
03e8b90b1f Add local server option 2023-04-11 14:33:58 +02:00
Christoph Hagen
15f07464ca Make example config explicit 2022-07-15 15:03:32 +02:00
8 changed files with 186 additions and 65 deletions

3
.gitignore vendored
View File

@@ -1,3 +1,4 @@
.pio
.vscode
.DS_Store
.DS_Store
include/config.h

View File

@@ -12,6 +12,9 @@ constexpr uint32_t serialBaudRate = 115200;
/* Keys */
// The number of remote devices
constexpr int remoteDeviceCount = 1;
// The size of the symmetric keys used for signing and verifying messages
constexpr size_t keySize = 32;
@@ -40,9 +43,20 @@ constexpr const char* wifiSSID = "MyWiFi";
// The WiFi password to connect to the above network
constexpr const char* wifiPassword = "00000000";
// The name of the device on the network
constexpr const char* networkName = "Sesame-Device";
// The interval to reconnect to WiFi if the connection is broken
constexpr uint32_t wifiReconnectInterval = 10000;
/* Local server */
// The port for the local server to directly receive messages over WiFi
constexpr uint16_t localPort = 80;
// The url parameter to send the message to the local server
constexpr char messageUrlParameter[] = "m";
/* Server */

View File

@@ -1,6 +1,7 @@
#pragma once
#include <stdint.h>
#include "config.h"
/**
* @brief The size of the message counter in bytes (uint32_t)
@@ -72,13 +73,15 @@ void prepareMessageCounterUsage();
*
* @return The next counter to use by the remote
*/
uint32_t getNextMessageCounter();
uint32_t getNextMessageCounter(uint8_t deviceId);
/**
* @brief Print info about the current message counter to the serial output
*
*/
void printMessageCounter();
void printMessageCounters();
bool isDeviceIdValid(uint8_t deviceId);
/**
* @brief Check if a received counter is valid
@@ -90,7 +93,7 @@ void printMessageCounter();
* @return true The counter is valid
* @return false The counter belongs to an old message
*/
bool isMessageCounterValid(uint32_t counter);
bool isMessageCounterValid(uint32_t counter, uint8_t deviceId);
/**
* @brief Mark a counter of a message as used.
@@ -101,7 +104,7 @@ bool isMessageCounterValid(uint32_t counter);
*
* @param counter The counter used in the last message.
*/
void didUseMessageCounter(uint32_t counter);
void didUseMessageCounter(uint32_t counter, uint8_t deviceId);
/**
* @brief Reset the message counter.
@@ -111,4 +114,4 @@ void didUseMessageCounter(uint32_t counter);
* used for replay attacks.
*
*/
void resetMessageCounter();
void resetMessageCounters();

View File

@@ -45,6 +45,11 @@ typedef struct {
*/
uint32_t id;
/**
* @brief The id of the device sending the message
*/
uint8_t device;
} Message;
/**
@@ -84,4 +89,27 @@ typedef struct {
#define MESSAGE_CONTENT_SIZE sizeof(Message)
#define AUTHENTICATED_MESSAGE_SIZE sizeof(AuthenticatedMessage)
#define AUTHENTICATED_MESSAGE_SIZE sizeof(AuthenticatedMessage)
/**
* An event signaled from the device
*/
enum class SesameEvent {
TextReceived = 1,
UnexpectedSocketEvent = 2,
InvalidMessageData = 3,
MessageAuthenticationFailed = 4,
MessageTimeMismatch = 5,
MessageCounterInvalid = 6,
MessageAccepted = 7,
MessageDeviceInvalid = 8,
};
/**
* @brief A callback for messages received over the socket
*
* The first parameter is the received message.
* The second parameter is the response to the remote.
* The return value is the type of event to respond with.
*/
typedef SesameEvent (*MessageCallback)(AuthenticatedMessage*, AuthenticatedMessage*);

View File

@@ -6,28 +6,6 @@
#include <WiFiClientSecure.h>
#include <WebSocketsClient.h>
/**
* An event signaled from the device
*/
enum class SesameEvent {
TextReceived = 1,
UnexpectedSocketEvent = 2,
InvalidMessageData = 3,
MessageAuthenticationFailed = 4,
MessageTimeMismatch = 5,
MessageCounterInvalid = 6,
MessageAccepted = 7,
};
/**
* @brief A callback for messages received over the socket
*
* The first parameter is the received message.
* The second parameter is the response to the remote.
* The return value is the type of event to respond with.
*/
typedef SesameEvent (*MessageCallback)(AuthenticatedMessage*, AuthenticatedMessage*);
class ServerConnection {
public:

View File

@@ -14,5 +14,6 @@ board = az-delivery-devkit-v4
framework = arduino
lib_deps =
links2004/WebSockets@^2.3.7
madhephaestus/ESP32Servo@^0.11.0
monitor_speed = 115200
madhephaestus/ESP32Servo@^0.13.0
ottowinter/ESPAsyncWebServer-esphome@^3.0.0
monitor_speed = 115200

View File

@@ -4,11 +4,6 @@
#include <time.h>
#include <EEPROM.h>
/**
* @brief The size of the message counter in bytes (uint32_t)
*/
#define MESSAGE_COUNTER_SIZE sizeof(uint32_t)
/**
* @brief The allowed discrepancy between the time of a received message
* and the device time (in seconds)
@@ -63,40 +58,51 @@ bool isMessageTimeAcceptable(uint32_t t) {
}
void prepareMessageCounterUsage() {
EEPROM.begin(MESSAGE_COUNTER_SIZE);
EEPROM.begin(MESSAGE_COUNTER_SIZE * remoteDeviceCount);
}
uint32_t getNextMessageCounter() {
uint32_t counter = (uint32_t) EEPROM.read(0) << 24;
counter += (uint32_t) EEPROM.read(1) << 16;
counter += (uint32_t) EEPROM.read(2) << 8;
counter += (uint32_t) EEPROM.read(3);
uint32_t getNextMessageCounter(uint8_t deviceId) {
int offset = deviceId * MESSAGE_COUNTER_SIZE;
uint32_t counter = (uint32_t) EEPROM.read(offset + 0) << 24;
counter += (uint32_t) EEPROM.read(offset + 1) << 16;
counter += (uint32_t) EEPROM.read(offset + 2) << 8;
counter += (uint32_t) EEPROM.read(offset + 3);
return counter;
}
void printMessageCounter() {
Serial.printf("[INFO] Next message number: %u\n", getNextMessageCounter());
void printMessageCounters() {
Serial.print("[INFO] Next message numbers:");
for (uint8_t i = 0; i < remoteDeviceCount; i += 1) {
Serial.printf(" %u", getNextMessageCounter(i));
}
Serial.println("");
}
bool isMessageCounterValid(uint32_t counter) {
return counter >= getNextMessageCounter();
bool isDeviceIdValid(uint8_t deviceId) {
return deviceId < remoteDeviceCount;
}
void setMessageCounter(uint32_t counter) {
EEPROM.write(0, (counter >> 24) & 0xFF);
EEPROM.write(1, (counter >> 16) & 0xFF);
EEPROM.write(2, (counter >> 8) & 0xFF);
EEPROM.write(3, counter & 0xFF);
bool isMessageCounterValid(uint32_t counter, uint8_t deviceId) {
return counter >= getNextMessageCounter(deviceId);
}
void setMessageCounter(uint32_t counter, uint8_t deviceId) {
int offset = deviceId * MESSAGE_COUNTER_SIZE;
EEPROM.write(offset + 0, (counter >> 24) & 0xFF);
EEPROM.write(offset + 1, (counter >> 16) & 0xFF);
EEPROM.write(offset + 2, (counter >> 8) & 0xFF);
EEPROM.write(offset + 3, counter & 0xFF);
EEPROM.commit();
}
void didUseMessageCounter(uint32_t counter) {
void didUseMessageCounter(uint32_t counter, uint8_t deviceId) {
// Store the next counter, so that resetting starts at 0
setMessageCounter(counter+1);
setMessageCounter(counter+1, deviceId);
}
void resetMessageCounter() {
setMessageCounter(0);
Serial.println("[WARN] Message counter reset");
void resetMessageCounters() {
for (uint8_t i = 0; i < remoteDeviceCount; i += 1) {
setMessageCounter(0, i);
}
Serial.println("[WARN] Message counters reset");
}

View File

@@ -7,6 +7,7 @@
*/
#include <Arduino.h>
#include <WiFi.h>
#include <ESPAsyncWebServer.h>
#include "crypto.h"
#include "fresh.h"
@@ -21,10 +22,24 @@ ServerConnection server(serverUrl, serverPort, serverPath);
ServoController servo(pwmTimer, servoFrequency, servoPin);
AsyncWebServer local(localPort);
// The buffer to hold a received message while it is read
uint8_t receivedMessageBuffer[AUTHENTICATED_MESSAGE_SIZE];
/* Event callbacks */
SesameEvent handleReceivedMessage(AuthenticatedMessage* payload, AuthenticatedMessage* response);
// Forward declare monitoring functions
void ensureWiFiConnection(uint32_t time);
void ensureWebSocketConnection(uint32_t time);
void sendFailureResponse(AsyncWebServerRequest *request, SesameEvent event);
void sendMessageResponse(AsyncWebServerRequest *request, SesameEvent event, AuthenticatedMessage* message);
void sendResponse(AsyncWebServerRequest *request, uint8_t* buffer, uint8_t size);
void hexToBin(const char * str, uint8_t * bytes, size_t blen);
/* Logic */
void setup() {
@@ -36,39 +51,69 @@ void setup() {
Serial.println("[INFO] Servo configured");
prepareMessageCounterUsage();
//resetMessageCounter();
printMessageCounter();
//resetMessageCounters();
printMessageCounters();
server.onMessage(handleReceivedMessage);
}
uint32_t nextWifiReconnect = 0;
bool isReconnecting = false;
local.on("/message", HTTP_POST, [] (AsyncWebServerRequest *request) {
if (!request->hasParam(messageUrlParameter)) {
Serial.println("Missing url parameter");
sendFailureResponse(request, SesameEvent::InvalidMessageData);
return;
}
String encoded = request->getParam(messageUrlParameter)->value();
hexToBin(encoded.c_str(), receivedMessageBuffer, AUTHENTICATED_MESSAGE_SIZE);
// Process received message
AuthenticatedMessage* message = (AuthenticatedMessage*) receivedMessageBuffer;
AuthenticatedMessage responseMessage;
SesameEvent event = handleReceivedMessage(message, &responseMessage);
sendMessageResponse(request, event, &responseMessage);
});
}
void loop() {
uint32_t time = millis();
server.loop();
servo.loop();
ensureWiFiConnection(time);
ensureWebSocketConnection(time);
}
uint32_t nextWifiReconnect = 0;
bool isReconnecting = false;
void ensureWiFiConnection(uint32_t time) {
// Reconnect to WiFi
if(time > nextWifiReconnect && WiFi.status() != WL_CONNECTED) {
Serial.println("[INFO] Reconnecting WiFi...");
WiFi.setHostname(networkName);
WiFi.begin(wifiSSID, wifiPassword);
isReconnecting = true;
nextWifiReconnect = time + wifiReconnectInterval;
}
}
void ensureWebSocketConnection(uint32_t time) {
if (isReconnecting && WiFi.status() == WL_CONNECTED) {
isReconnecting = false;
Serial.println("IP address: ");
Serial.println(WiFi.localIP());
Serial.println("[INFO] WiFi connected, opening socket");
server.connectSSL(serverAccessKey);
configureNTP(timeOffsetToGMT, timeOffsetDaylightSavings, ntpServerUrl);
printLocalTime();
local.begin();
}
}
SesameEvent processMessage(AuthenticatedMessage* message) {
if (!isMessageCounterValid(message->message.id)) {
if (!isDeviceIdValid(message->message.device)) {
return SesameEvent::MessageDeviceInvalid;
}
if (!isMessageCounterValid(message->message.id, message->message.device)) {
return SesameEvent::MessageCounterInvalid;
}
if (!isMessageTimeAcceptable(message->message.time)) {
@@ -95,7 +140,7 @@ SesameEvent handleReceivedMessage(AuthenticatedMessage* message, AuthenticatedMe
// Only open when message is valid
if (event == SesameEvent::MessageAccepted) {
didUseMessageCounter(message->message.id);
didUseMessageCounter(message->message.id, message->message.device);
// Move servo
servo.pressButton();
Serial.printf("[Info] Accepted message %d\n", message->message.id);
@@ -103,9 +148,54 @@ SesameEvent handleReceivedMessage(AuthenticatedMessage* message, AuthenticatedMe
// Create response for all cases
response->message.time = getEpochTime();
response->message.id = getNextMessageCounter();
response->message.id = getNextMessageCounter(message->message.device);
if (!authenticateMessage(response, localKey, keySize)) {
return SesameEvent::MessageAuthenticationFailed;
}
return event;
}
void sendFailureResponse(AsyncWebServerRequest *request, SesameEvent event) {
uint8_t response = static_cast<uint8_t>(event);
sendResponse(request, &response, 1);
}
void sendMessageResponse(AsyncWebServerRequest *request, SesameEvent event, AuthenticatedMessage* message) {
uint8_t response[AUTHENTICATED_MESSAGE_SIZE+1];
response[0] = static_cast<uint8_t>(event);
memcpy(response+1, (uint8_t*) message, AUTHENTICATED_MESSAGE_SIZE);
sendResponse(request, response, AUTHENTICATED_MESSAGE_SIZE+1);
}
void sendResponse(AsyncWebServerRequest *request, uint8_t* buffer, uint8_t size) {
request->send_P(200, "application/octet-stream", buffer, size);
Serial.printf("[INFO] Local response %d\n", buffer[0]);
}
// Based on https://stackoverflow.com/a/23898449/266720
void hexToBin(const char * str, uint8_t * bytes, size_t blen) {
uint8_t idx0, idx1;
// mapping of ASCII characters to hex values
const uint8_t hashmap[] = {
0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, // 01234567
0x08, 0x09, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, // 89:;<=>?
0x00, 0x0a, 0x0b, 0x0c, 0x0d, 0x0e, 0x0f, 0x00, // @ABCDEFG
0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, // HIJKLMNO
};
memset(bytes, 0, blen);
size_t len = strlen(str);
if (len % 2) {
// Require two chars per byte
return;
}
size_t end = min(blen*2, len);
for (size_t pos = 0; pos < end; pos += 2) {
idx0 = ((uint8_t)str[pos+0] & 0x1F) ^ 0x10;
idx1 = ((uint8_t)str[pos+1] & 0x1F) ^ 0x10;
bytes[pos/2] = (uint8_t)(hashmap[idx0] << 4) | hashmap[idx1];
};
}