Improve configuration, refactoring

This commit is contained in:
Christoph Hagen 2023-08-09 15:02:24 +02:00
parent b99245085e
commit e84e388521
10 changed files with 235 additions and 165 deletions

View File

@ -7,21 +7,50 @@
#include "fresh.h" #include "fresh.h"
#include <ESPAsyncWebServer.h> #include <ESPAsyncWebServer.h>
struct WifiConfiguration {
// The WiFi network to connect to
const char* ssid;
// The WiFi password to connect to the above network
const char* password;
// The name of the device on the network
const char* networkName;
// The interval to reconnect to WiFi if the connection is broken
uint32_t reconnectInterval;
};
struct KeyConfiguration {
const uint8_t* remoteKey;
const uint8_t* localKey;
};
class SesameController: public ServerConnectionCallbacks { class SesameController: public ServerConnectionCallbacks {
public: public:
SesameController(ServerConnection* server, ServoController* servo, AsyncWebServer* local, TimeCheck* timeCheck, uint8_t remoteDeviceCount); SesameController(uint16_t localWebServerPort, uint8_t remoteDeviceCount);
void configure(); void configure(ServoConfiguration servoConfig, ServerConfiguration serverConfig, TimeConfiguration timeConfig, WifiConfiguration wifiConfig, KeyConfiguration keyConfig);
void loop(uint32_t millis);
private: private:
ServerConnection* server; ServerConnection server;
ServoController* servo; ServoController servo;
AsyncWebServer* local; AsyncWebServer localWebServer;
TimeCheck* timeCheck; TimeCheck timeCheck;
Storage storage; Storage storage;
WifiConfiguration wifiConfig;
KeyConfiguration keyConfig;
bool isReconnecting = false;
// The buffer to hold a received message while it is read // The buffer to hold a received message while it is read
uint8_t receivedMessageBuffer[AUTHENTICATED_MESSAGE_SIZE]; uint8_t receivedMessageBuffer[AUTHENTICATED_MESSAGE_SIZE];
@ -31,6 +60,9 @@ private:
AuthenticatedMessage* responseMessage; AuthenticatedMessage* responseMessage;
uint16_t responseSize = 0; uint16_t responseSize = 0;
void ensureWiFiConnection(uint32_t time);
void ensureWebSocketConnection();
void handleLocalMessage(AsyncWebServerRequest *request); void handleLocalMessage(AsyncWebServerRequest *request);
// Based on https://stackoverflow.com/a/23898449/266720 // Based on https://stackoverflow.com/a/23898449/266720
bool convertHexMessageToBinary(const char* str); bool convertHexMessageToBinary(const char* str);

View File

@ -52,11 +52,7 @@ constexpr uint32_t wifiReconnectInterval = 10000;
/* Local server */ /* Local server */
// The port for the local server to directly receive messages over WiFi // The port for the local server to directly receive messages over WiFi
constexpr uint16_t localPort = 80; constexpr uint16_t localWebServerPort = 80;
// The url parameter to send the message to the local server
constexpr char messageUrlParameter[] = "m";
/* Server */ /* Server */

View File

@ -1,7 +1,34 @@
#pragma once #pragma once
#include <stdint.h> #include <stdint.h>
#include "config.h"
struct TimeConfiguration {
/**
* @brief The timezone offset in seconds
*/
int32_t offsetToGMT;
/**
* @brief The daylight savings offset in seconds
*/
int32_t offsetDaylightSavings;
/**
* @brief The url of the NTP server
*/
const char* ntpServerUrl;
/**
* @brief The allowed discrepancy between the time of a received message
* and the device time (in seconds)
*
* A stricter (lower) value better prevents against replay attacks,
* but may lead to issues when dealing with slow networks and other
* routing delays.
*/
uint32_t allowedTimeOffset;
};
class TimeCheck { class TimeCheck {
@ -9,26 +36,18 @@ public:
/** /**
* @brief Create a time checker instance * @brief Create a time checker instance
*
* Specify the allowed discrepancy between the time of a received message
* and the device time (in seconds).
*
* A stricter (lower) value better prevents against replay attacks,
* but may lead to issues when dealing with slow networks and other
* routing delays.
*
* @param offset The allowed time discrepancy in both directions (seconds)
*/ */
TimeCheck(uint32_t allowedTimeOffset = 60); TimeCheck();
/** /**
* @brief Configure an NTP server to get the current time * @brief Set the configuration
*
* @param offsetToGMT The timezone offset in seconds
* @param offsetDaylightSavings The daylight savings offset in seconds
* @param serverUrl The url of the NTP server
*/ */
void configureNTP(int32_t offsetToGMT, int32_t offsetDaylightSavings, const char* serverUrl); void configure(TimeConfiguration configuration);
/**
* @brief Configure the NTP server to get the current time
*/
void startNTP();
/** /**
* @brief Print the current time to the serial output * @brief Print the current time to the serial output
@ -42,20 +61,6 @@ public:
*/ */
uint32_t getEpochTime(); uint32_t getEpochTime();
/**
* @brief The allowed time discrepancy (in seconds)
*
* Specifies the allowed discrepancy between the time of a received message
* and the device time (in seconds).
*
* A stricter (lower) value better prevents against replay attacks,
* but may lead to issues when dealing with slow networks and other
* routing delays.
*
* @param offset The offset in both directions (seconds)
*/
void setMessageTimeAllowedOffset(uint32_t offset);
/** /**
* @brief Check wether the time of a message is within the allowed bounds regarding freshness. * @brief Check wether the time of a message is within the allowed bounds regarding freshness.
* *
@ -70,14 +75,6 @@ public:
bool isMessageTimeAcceptable(uint32_t messageTime); bool isMessageTimeAcceptable(uint32_t messageTime);
private: private:
/** TimeConfiguration config;
* @brief The allowed discrepancy between the time of a received message
* and the device time (in seconds)
*
* A stricter (lower) value better prevents against replay attacks,
* but may lead to issues when dealing with slow networks and other
* routing delays.
*/
uint32_t allowedOffset;
}; };

View File

@ -6,6 +6,32 @@
#include <WiFiClientSecure.h> #include <WiFiClientSecure.h>
#include <WebSocketsClient.h> #include <WebSocketsClient.h>
struct ServerConfiguration {
/**
* @brief The url of the remote server to connect to
*/
const char* url;
/**
* @brief The server port
*/
int port;
/**
* @brief The path on the server
*/
const char* path;
/**
* @brief The authentication key for the server
*/
const char* key;
uint32_t reconnectTime;
};
class ServerConnectionCallbacks { class ServerConnectionCallbacks {
public: public:
@ -18,18 +44,18 @@ class ServerConnection {
public: public:
ServerConnection(const char* url, int port, const char* path); ServerConnection();
void connect(const char* key, uint32_t reconnectTime = 5000);
void loop();
/** /**
* @brief Set the handler * @brief Set the configuration and the callback handler
* *
* @param callback The handler to handle messages and errors * @param callback The handler to handle messages and errors
*/ */
void setCallbackHandler(ServerConnectionCallbacks* callbacks); void configure(ServerConfiguration configuration, ServerConnectionCallbacks* callbacks);
void connect();
void loop();
/** /**
* @brief Send a response message over the socket * @brief Send a response message over the socket
@ -45,13 +71,7 @@ public:
private: private:
const char* url; ServerConfiguration configuration;
int port;
const char* path;
const char* key = NULL;
// Indicator that the socket is connected. // Indicator that the socket is connected.
bool socketIsConnected = false; bool socketIsConnected = false;

View File

@ -3,7 +3,7 @@
#include <stdint.h> #include <stdint.h>
#include <ESP32Servo.h> // To control the servo #include <ESP32Servo.h> // To control the servo
typedef struct ServoConfiguration { struct ServoConfiguration {
/** /**
* @brief The timer to use for the servo control * @brief The timer to use for the servo control
* number 0-3 indicating which timer to allocate in this library * number 0-3 indicating which timer to allocate in this library
@ -34,6 +34,7 @@ typedef struct ServoConfiguration {
* @brief The servo value (in µs) that specifies the 'released' state * @brief The servo value (in µs) that specifies the 'released' state
*/ */
int releasedValue; int releasedValue;
}; };
/** /**
@ -70,7 +71,7 @@ public:
* There is no required interval to call this function, but the accuracy of * There is no required interval to call this function, but the accuracy of
* the opening interval is dependent on the calling frequency. * the opening interval is dependent on the calling frequency.
*/ */
void loop(); void loop(uint32_t millis);
/** /**
* Push the door opener button down by moving the servo arm. * Push the door opener button down by moving the servo arm.

View File

@ -1,33 +1,58 @@
#include "controller.h" #include "controller.h"
#include "config.h"
#include "fresh.h"
#include "crypto.h" #include "crypto.h"
#include "config.h"
SesameController::SesameController(ServerConnection* server, ServoController* servo, AsyncWebServer* local, TimeCheck* timeCheck, uint8_t remoteDeviceCount) : #include <WiFi.h>
server(server), servo(servo), local(local), timeCheck(timeCheck), storage(remoteDeviceCount) {
// The url parameter to send the message to the local server
constexpr char messageUrlParameter[] = "m";
SesameController::SesameController(uint16_t localWebServerPort, uint8_t remoteDeviceCount) :
storage(remoteDeviceCount), localWebServer(localWebServerPort) {
// Set up response buffer // Set up response buffer
responseStatus = (SesameEvent*) responseBuffer; responseStatus = (SesameEvent*) responseBuffer;
responseMessage = (AuthenticatedMessage*) (responseBuffer + 1); responseMessage = (AuthenticatedMessage*) (responseBuffer + 1);
} }
void SesameController::configure() { void SesameController::configure(ServoConfiguration servoConfig, ServerConfiguration serverConfig, TimeConfiguration timeConfig, WifiConfiguration wifiConfig, KeyConfiguration keyConfig) {
this->wifiConfig = wifiConfig;
this->keyConfig = keyConfig;
// Prepare EEPROM for reading and writing // Prepare EEPROM for reading and writing
storage.configure(); storage.configure();
Serial.println("[INFO] Storage configured");
// Direct messages and errors over the websocket to the controller servo.configure(servoConfig);
server->setCallbackHandler(this); Serial.println("[INFO] Servo configured");
// Direct messages and errors over the websocket to the controller
server.configure(serverConfig, this);
Serial.println("[INFO] Server connection configured");
timeCheck.configure(timeConfig);
// Direct messages from the local web server to the controller // Direct messages from the local web server to the controller
local->on("/message", HTTP_POST, [this] (AsyncWebServerRequest *request) { localWebServer.on("/message", HTTP_POST, [this] (AsyncWebServerRequest *request) {
this->handleLocalMessage(request); this->handleLocalMessage(request);
this->sendPreparedLocalResponse(request); this->sendPreparedLocalResponse(request);
}); });
Serial.println("[INFO] Local web server configured");
//storage.resetMessageCounters(); //storage.resetMessageCounters();
storage.printMessageCounters(); storage.printMessageCounters();
} }
void SesameController::loop(uint32_t millis) {
server.loop();
servo.loop(millis);
ensureWiFiConnection(millis);
ensureWebSocketConnection();
}
// MARK: Local // MARK: Local
void void
@ -69,7 +94,7 @@ void SesameController::handleServerMessage(uint8_t* payload, size_t length) {
} }
void SesameController::sendPreparedServerResponse() { void SesameController::sendPreparedServerResponse() {
server->sendResponse(responseBuffer, responseSize); server.sendResponse(responseBuffer, responseSize);
Serial.printf("[INFO] Server response %u (%u bytes)\n", responseBuffer[0], responseSize); Serial.printf("[INFO] Server response %u (%u bytes)\n", responseBuffer[0], responseSize);
} }
@ -90,7 +115,7 @@ void SesameController::processMessage(AuthenticatedMessage* message) {
* @return The response to signal to the server. * @return The response to signal to the server.
*/ */
SesameEvent SesameController::verifyAndProcessReceivedMessage(AuthenticatedMessage* message) { SesameEvent SesameController::verifyAndProcessReceivedMessage(AuthenticatedMessage* message) {
if (!isAuthenticMessage(message, remoteKey, keySize)) { if (!isAuthenticMessage(message, keyConfig.remoteKey, keySize)) {
return SesameEvent::MessageAuthenticationFailed; return SesameEvent::MessageAuthenticationFailed;
} }
if (!storage.isDeviceIdValid(message->message.device)) { if (!storage.isDeviceIdValid(message->message.device)) {
@ -99,13 +124,13 @@ SesameEvent SesameController::verifyAndProcessReceivedMessage(AuthenticatedMessa
if (!storage.isMessageCounterValid(message->message.id, message->message.device)) { if (!storage.isMessageCounterValid(message->message.id, message->message.device)) {
return SesameEvent::MessageCounterInvalid; return SesameEvent::MessageCounterInvalid;
} }
if (!timeCheck->isMessageTimeAcceptable(message->message.time)) { if (!timeCheck.isMessageTimeAcceptable(message->message.time)) {
return SesameEvent::MessageTimeMismatch; return SesameEvent::MessageTimeMismatch;
} }
storage.didUseMessageCounter(message->message.id, message->message.device); storage.didUseMessageCounter(message->message.id, message->message.device);
// Move servo // Move servo
servo->pressButton(); servo.pressButton();
Serial.printf("[Info] Accepted message %d\n", message->message.id); Serial.printf("[Info] Accepted message %d\n", message->message.id);
return SesameEvent::MessageAccepted; return SesameEvent::MessageAccepted;
} }
@ -127,16 +152,42 @@ uint16_t SesameController::prepareResponseBuffer(SesameEvent event, uint8_t devi
if (!allowMessageResponse(event)) { if (!allowMessageResponse(event)) {
return 1; return 1;
} }
responseMessage->message.time = timeCheck->getEpochTime(); responseMessage->message.time = timeCheck.getEpochTime();
responseMessage->message.id = storage.getNextMessageCounter(deviceId); responseMessage->message.id = storage.getNextMessageCounter(deviceId);
responseMessage->message.device = deviceId; responseMessage->message.device = deviceId;
if (!authenticateMessage(responseMessage, localKey, keySize)) { if (!authenticateMessage(responseMessage, keyConfig.localKey, keySize)) {
*responseStatus = SesameEvent::InvalidResponseAuthentication; *responseStatus = SesameEvent::InvalidResponseAuthentication;
return 1; return 1;
} }
return 1 + AUTHENTICATED_MESSAGE_SIZE; return 1 + AUTHENTICATED_MESSAGE_SIZE;
} }
// MARK: Reconnecting
void SesameController::ensureWiFiConnection(uint32_t millis) {
static uint32_t nextWifiReconnect = 0;
// Reconnect to WiFi
if(millis > nextWifiReconnect && WiFi.status() != WL_CONNECTED) {
Serial.println("[INFO] Reconnecting WiFi...");
WiFi.setHostname(wifiConfig.networkName);
WiFi.begin(wifiConfig.ssid, wifiConfig.password);
isReconnecting = true;
nextWifiReconnect = millis + wifiConfig.reconnectInterval;
}
}
void SesameController::ensureWebSocketConnection() {
if (isReconnecting && WiFi.status() == WL_CONNECTED) {
isReconnecting = false;
Serial.print("WiFi IP address: ");
Serial.println(WiFi.localIP());
server.connect();
timeCheck.startNTP();
timeCheck.printLocalTime();
localWebServer.begin();
}
}
// MARK: Helper // MARK: Helper
// Based on https://stackoverflow.com/a/23898449/266720 // Based on https://stackoverflow.com/a/23898449/266720

View File

@ -3,12 +3,14 @@
#include <Arduino.h> // configTime() #include <Arduino.h> // configTime()
#include <time.h> #include <time.h>
TimeCheck::TimeCheck(uint32_t allowedTimeOffset) { TimeCheck::TimeCheck() { }
allowedOffset = allowedTimeOffset;
void TimeCheck::configure(TimeConfiguration configuration) {
config = configuration;
} }
void TimeCheck::configureNTP(int32_t offsetToGMT, int32_t offsetDaylightSavings, const char* serverUrl) { void TimeCheck::startNTP() {
configTime(offsetToGMT, offsetDaylightSavings, serverUrl); configTime(config.offsetToGMT, config.offsetDaylightSavings, config.ntpServerUrl);
} }
void TimeCheck::printLocalTime() { void TimeCheck::printLocalTime() {
@ -37,10 +39,10 @@ bool TimeCheck::isMessageTimeAcceptable(uint32_t t) {
Serial.println("No epoch time available"); Serial.println("No epoch time available");
return false; return false;
} }
if (t > localTime + allowedOffset) { if (t > localTime + config.allowedTimeOffset) {
return false; return false;
} }
if (t < localTime - allowedOffset) { if (t < localTime - config.allowedTimeOffset) {
return false; return false;
} }
return true; return true;

View File

@ -6,87 +6,61 @@
* physical button. * physical button.
*/ */
#include <Arduino.h> #include <Arduino.h>
#include <WiFi.h>
#include <ESPAsyncWebServer.h>
#include "crypto.h"
#include "fresh.h"
#include "message.h" #include "message.h"
#include "server.h" #include "server.h"
#include "servo.h" #include "servo.h"
#include "config.h"
#include "controller.h" #include "controller.h"
#include "config.h"
/* Global variables */ SesameController controller(localWebServerPort, remoteDeviceCount);
TimeCheck timeCheck{};
ServerConnection server(serverUrl, serverPort, serverPath);
ServoController servo{};
AsyncWebServer local(localPort);
SesameController controller(&server, &servo, &local, &timeCheck, remoteDeviceCount);
// Forward declare monitoring functions
void ensureWiFiConnection(uint32_t time);
void ensureWebSocketConnection(uint32_t time);
/* Logic */
void setup() { void setup() {
Serial.begin(serialBaudRate); Serial.begin(serialBaudRate);
Serial.setDebugOutput(true); Serial.setDebugOutput(true);
Serial.println("[INFO] Device started"); Serial.println("[INFO] Device started");
ServoConfiguration servoConfig; ServoConfiguration servoConfig {
servoConfig.pwmTimer = pwmTimer; .pwmTimer = pwmTimer,
servoConfig.pwmFrequency = servoFrequency; .pwmFrequency = servoFrequency,
servoConfig.pin = servoPin; .pin = servoPin,
servoConfig.openDuration = lockOpeningDuration; .openDuration = lockOpeningDuration,
servoConfig.pressedValue = servoPressedState; .pressedValue = servoPressedState,
servoConfig.releasedValue = servoReleasedState; .releasedValue = servoReleasedState,
};
servo.configure(servoConfig); ServerConfiguration serverConfig {
Serial.println("[INFO] Servo configured"); .url = serverUrl,
.port = serverPort,
.path = serverPath,
.key = serverAccessKey,
.reconnectTime = 5000,
};
controller.configure(); TimeConfiguration timeConfig {
.offsetToGMT = timeOffsetToGMT,
.offsetDaylightSavings = timeOffsetDaylightSavings,
.ntpServerUrl = ntpServerUrl,
.allowedTimeOffset = 60,
};
WifiConfiguration wifiConfig {
.ssid = wifiSSID,
.password = wifiPassword,
.networkName = networkName,
.reconnectInterval = wifiReconnectInterval,
};
KeyConfiguration keyConfig {
.remoteKey = remoteKey,
.localKey = localKey,
};
controller.configure(servoConfig, serverConfig, timeConfig, wifiConfig, keyConfig);
} }
void loop() { void loop() {
uint32_t time = millis(); uint32_t time = millis();
server.loop(); controller.loop(time);
servo.loop();
ensureWiFiConnection(time);
ensureWebSocketConnection(time);
}
uint32_t nextWifiReconnect = 0;
bool isReconnecting = false;
void ensureWiFiConnection(uint32_t time) {
// Reconnect to WiFi
if(time > nextWifiReconnect && WiFi.status() != WL_CONNECTED) {
Serial.println("[INFO] Reconnecting WiFi...");
WiFi.setHostname(networkName);
WiFi.begin(wifiSSID, wifiPassword);
isReconnecting = true;
nextWifiReconnect = time + wifiReconnectInterval;
}
}
void ensureWebSocketConnection(uint32_t time) {
if (isReconnecting && WiFi.status() == WL_CONNECTED) {
isReconnecting = false;
Serial.println("IP address: ");
Serial.println(WiFi.localIP());
Serial.println("[INFO] WiFi connected, opening socket");
server.connect(serverAccessKey);
timeCheck.configureNTP(timeOffsetToGMT, timeOffsetDaylightSavings, ntpServerUrl);
timeCheck.printLocalTime();
local.begin();
}
} }

View File

@ -4,12 +4,14 @@ constexpr int32_t pingInterval = 10000;
constexpr uint32_t pongTimeout = 5000; constexpr uint32_t pongTimeout = 5000;
uint8_t disconnectTimeoutCount = 3; uint8_t disconnectTimeoutCount = 3;
ServerConnection::ServerConnection(const char* url, int port, const char* path) : ServerConnection::ServerConnection() { }
url(url), port(port), path(path) {
void ServerConnection::configure(ServerConfiguration configuration, ServerConnectionCallbacks *callbacks) {
controller = callbacks;
this->configuration = configuration;
} }
void ServerConnection::connect(const char* key, uint32_t reconnectTime) { void ServerConnection::connect() {
if (socketIsConnected) { if (socketIsConnected) {
return; return;
} }
@ -17,25 +19,20 @@ void ServerConnection::connect(const char* key, uint32_t reconnectTime) {
Serial.println("[ERROR] No callbacks set for server"); Serial.println("[ERROR] No callbacks set for server");
return; return;
} }
this->key = key;
webSocket.beginSSL(url, port, path); webSocket.beginSSL(configuration.url, configuration.port, configuration.path);
std::function<void(WStype_t, uint8_t *, size_t)> f = [this](WStype_t type, uint8_t *payload, size_t length) { std::function<void(WStype_t, uint8_t *, size_t)> f = [this](WStype_t type, uint8_t *payload, size_t length) {
this->webSocketEventHandler(type, payload, length); this->webSocketEventHandler(type, payload, length);
}; };
webSocket.onEvent(f); webSocket.onEvent(f);
webSocket.setReconnectInterval(reconnectTime); webSocket.setReconnectInterval(configuration.reconnectTime);
} }
void ServerConnection::loop() { void ServerConnection::loop() {
webSocket.loop(); webSocket.loop();
} }
void ServerConnection::setCallbackHandler(ServerConnectionCallbacks* callbacks) {
controller = callbacks;
}
void ServerConnection::webSocketEventHandler(WStype_t type, uint8_t * payload, size_t length) { void ServerConnection::webSocketEventHandler(WStype_t type, uint8_t * payload, size_t length) {
switch(type) { switch(type) {
case WStype_DISCONNECTED: case WStype_DISCONNECTED:
@ -44,7 +41,7 @@ switch(type) {
break; break;
case WStype_CONNECTED: case WStype_CONNECTED:
socketIsConnected = true; socketIsConnected = true;
webSocket.sendTXT(key); webSocket.sendTXT(configuration.key);
Serial.printf("[INFO] Socket connected to url: %s\n", payload); Serial.printf("[INFO] Socket connected to url: %s\n", payload);
webSocket.enableHeartbeat(pingInterval, pongTimeout, disconnectTimeoutCount); webSocket.enableHeartbeat(pingInterval, pongTimeout, disconnectTimeoutCount);
break; break;

View File

@ -25,8 +25,8 @@ void ServoController::releaseButton() {
buttonIsPressed = false; buttonIsPressed = false;
} }
void ServoController::loop() { void ServoController::loop(uint32_t millis) {
if (buttonIsPressed && millis() > openingEndTime) { if (buttonIsPressed && millis > openingEndTime) {
releaseButton(); releaseButton();
} }
} }