diff --git a/include/controller.h b/include/controller.h index 9d5acd8..62f2a06 100644 --- a/include/controller.h +++ b/include/controller.h @@ -3,25 +3,37 @@ #include "server.h" #include "servo.h" #include "message.h" -#include "storage.h" -#include "fresh.h" #include -struct WifiConfiguration { +struct EthernetConfiguration { - // The WiFi network to connect to - const char* ssid; + // The MAC address of the ethernet connection + uint8_t macAddress[6]; - // The WiFi password to connect to the above network - const char* password; + // The master-in slave-out pin of the SPI connection for the Ethernet module + int8_t spiPinMiso; - // The name of the device on the network - const char* networkName; + // The master-out slave-in pin of the SPI connection for the Ethernet module + int8_t spiPinMosi; - // The interval to reconnect to WiFi if the connection is broken - uint32_t reconnectInterval; + // The slave clock pin of the SPI connection for the Ethernet module + int8_t spiPinSclk; - uint32_t periodicReconnectInterval; + // The slave-select pin of the SPI connection for the Ethernet module + int8_t spiPinSS; + + unsigned long dhcpLeaseTimeoutMs; + unsigned long dhcpLeaseResponseTimeoutMs; + + // The static IP address to assign if DHCP fails + uint8_t manualIp[4]; + + // The IP address of the DNS server, if DHCP fails + uint8_t manualDnsAddress[4]; + + uint32_t socketHeartbeatIntervalMs; + uint32_t socketHeartbeatTimeoutMs; + uint8_t socketHeartbeatFailureReconnectCount; }; struct KeyConfiguration { @@ -29,55 +41,87 @@ struct KeyConfiguration { const uint8_t* remoteKey; const uint8_t* localKey; + + uint32_t challengeExpiryMs; }; class SesameController: public ServerConnectionCallbacks { public: - SesameController(uint16_t localWebServerPort, uint8_t remoteDeviceCount); + SesameController(uint16_t localWebServerPort); - void configure(ServoConfiguration servoConfig, ServerConfiguration serverConfig, TimeConfiguration timeConfig, WifiConfiguration wifiConfig, KeyConfiguration keyConfig); + void configure(ServoConfiguration servoConfig, ServerConfiguration serverConfig, EthernetConfiguration ethernetConfig, KeyConfiguration keyConfig); void loop(uint32_t millis); private: + uint32_t currentTime = 0; + ServerConnection server; ServoController servo; AsyncWebServer localWebServer; - TimeCheck timeCheck; - Storage storage; - WifiConfiguration wifiConfig; + EthernetConfiguration ethernetConfig; + bool ethernetIsConfigured = false; + KeyConfiguration keyConfig; bool isReconnecting = false; - // The buffer to hold a received message while it is read - uint8_t receivedMessageBuffer[AUTHENTICATED_MESSAGE_SIZE]; + // Buffer to get local message + SignedMessage receivedLocalMessage; - // The buffer to hold a response while it is sent - uint8_t responseBuffer[AUTHENTICATED_MESSAGE_SIZE+1]; - SesameEvent* responseStatus; - AuthenticatedMessage* responseMessage; - uint16_t responseSize = 0; + uint32_t currentClientChallenge; + uint32_t currentChallengeExpiry = 0; + uint32_t currentServerChallenge; + + SignedMessage outgoingMessage; + + bool hasCurrentChallenge() { + return currentChallengeExpiry > currentTime; + } + + void clearCurrentChallenge() { + currentClientChallenge = 0; + currentServerChallenge = 0; + currentChallengeExpiry = 0; + } + + /** + * @brief Send an error Response over the web socket. + * + * @param result The error result to send + * @param discardMessage Indicate if the stored message should be cleared. + * + * Note: Only clear the message if no other operation is in progress. + */ + void sendErrorResponseToServer(MessageResult result, bool discardMessage = true); - void ensureWiFiConnection(uint32_t time); void ensureWebSocketConnection(); void handleLocalMessage(AsyncWebServerRequest *request); - // Based on https://stackoverflow.com/a/23898449/266720 + bool convertHexMessageToBinary(const char* str); void handleServerMessage(uint8_t* payload, size_t length); - void sendServerError(SesameEvent event); - void processMessage(AuthenticatedMessage* message); - SesameEvent verifyAndProcessReceivedMessage(AuthenticatedMessage* message); + /** + * @brief Callback to send an error back to the server via the web socket. + * + * This function is called when the socket get's an error. + * + * @param event The error to report back + */ + void sendServerError(MessageResult event); - void prepareResponseBuffer(SesameEvent event, uint8_t deviceId = 0); + void processMessage(SignedMessage* message); + MessageResult verifyAndProcessReceivedMessage(SignedMessage* message); + + void prepareResponseBuffer(MessageResult event, Message* message = NULL); void sendPreparedLocalResponse(AsyncWebServerRequest *request); - void sendPreparedServerResponse(); + void sendPreparedResponseToServer(); - void periodicallyReconnectWifiAndSocket(uint32_t millis); + void prepareChallenge(Message* message); + void completeUnlockRequest(Message* message); }; \ No newline at end of file diff --git a/include/crypto.h b/include/crypto.h index 202d1e5..7e35e3b 100644 --- a/include/crypto.h +++ b/include/crypto.h @@ -3,6 +3,15 @@ #include "message.h" #include +void enableCrypto(); + +/** + * @brief Create a random server challenge. + * + * @return uint32_t + */ +uint32_t randomChallenge(); + /** * @brief Create a message authentication code (MAC) for some data. * @@ -10,11 +19,10 @@ * @param dataLength The number of bytes to authenticate * @param mac The output to store the MAC (must be at least 32 bytes) * @param key The secret key used for authentication - * @param keyLength The length of the secret key * @return true The MAC was successfully written * @return false The MAC could not be created */ -bool authenticateData(const uint8_t* data, size_t dataLength, uint8_t* mac, const uint8_t* key, size_t keyLength); +bool authenticateData(const uint8_t* data, size_t dataLength, uint8_t* mac, const uint8_t* key); /** * @brief Calculate a MAC for message content. @@ -22,22 +30,20 @@ bool authenticateData(const uint8_t* data, size_t dataLength, uint8_t* mac, cons * @param message The message for which to calculate the MAC. * @param mac The output where the computed MAC is stored * @param key The secret key used for authentication - * @param keyLength The length of the secret key * @return true The MAC was successfully computed * @return false The MAC could not be created */ -bool authenticateMessage(Message* message, uint8_t* mac, const uint8_t* key, size_t keyLength); +bool authenticateMessage(Message* message, uint8_t* mac, const uint8_t* key); /** * @brief Create a message authentication code (MAC) for a message. * * @param message The message to authenticate * @param key The secret key used for authentication - * @param keyLength The length of the secret key * @return true The MAC was successfully added to the message * @return false The MAC could not be created */ -bool authenticateMessage(AuthenticatedMessage* message, const uint8_t* key, size_t keyLength); +bool authenticateMessage(SignedMessage* message, const uint8_t* key); /** * @brief Check if a received unlock message is authentic @@ -48,8 +54,7 @@ bool authenticateMessage(AuthenticatedMessage* message, const uint8_t* key, size * * @param message The message to authenticate * @param key The secret key used for authentication - * @param keyLength The length of the key in bytes * @return true The message is authentic * @return false The message is invalid, or the MAC could not be calculated */ -bool isAuthenticMessage(AuthenticatedMessage* message, const uint8_t* key, size_t keyLength); \ No newline at end of file +bool isAuthenticMessage(SignedMessage* message, const uint8_t* key); \ No newline at end of file diff --git a/include/fresh.h b/include/fresh.h deleted file mode 100644 index 4604ada..0000000 --- a/include/fresh.h +++ /dev/null @@ -1,80 +0,0 @@ -#pragma once - -#include - -struct TimeConfiguration { - - /** - * @brief The timezone offset in seconds - */ - int32_t offsetToGMT; - - /** - * @brief The daylight savings offset in seconds - */ - int32_t offsetDaylightSavings; - - /** - * @brief The url of the NTP server - */ - const char* ntpServerUrl; - - /** - * @brief The allowed discrepancy between the time of a received message - * and the device time (in seconds) - * - * A stricter (lower) value better prevents against replay attacks, - * but may lead to issues when dealing with slow networks and other - * routing delays. - */ - uint32_t allowedTimeOffset; -}; - -class TimeCheck { - -public: - - /** - * @brief Create a time checker instance - */ - TimeCheck(); - - /** - * @brief Set the configuration - */ - void configure(TimeConfiguration configuration); - - /** - * @brief Configure the NTP server to get the current time - */ - void startNTP(); - - /** - * @brief Print the current time to the serial output - * - * The time must be initialized by calling `configureNTP()` before use. - */ - void printLocalTime(); - - /** - * Gets the current epoch time - */ - uint32_t getEpochTime(); - - /** - * @brief Check wether the time of a message is within the allowed bounds regarding freshness. - * - * The timestamp is used to ensure 'freshness' of the messages, - * i.e. that they are not unreasonably delayed or captured and - * later replayed by an attacker. - * - * @param messageTime The timestamp of the message (seconds since epoch) - * @return true The time is within the acceptable offset of the local time - * @return false The message time is invalid - */ - bool isMessageTimeAcceptable(uint32_t messageTime); - -private: - - TimeConfiguration config; -}; \ No newline at end of file diff --git a/include/message.h b/include/message.h index 074f094..fa8c944 100644 --- a/include/message.h +++ b/include/message.h @@ -14,58 +14,93 @@ #pragma pack(push, 1) +typedef enum { + + /// @brief The initial message from remote to device to request a challenge. + initial = 0, + + /// @brief The second message in an unlock with the challenge from the device to the remote + challenge = 1, + + /// @brief The third message with the signed challenge from the remote to the device + request = 2, + + /// @brief The final message with the unlock result from the device to the remote + response = 3, + +} MessageType; + +enum class MessageResult: uint8_t { + + /// @brief The message was accepted. + MessageAccepted = 0, + + /// @brief The web socket received text while waiting for binary data. + TextReceived = 1, + + /// @brief An unexpected socket event occured while performing the exchange. + UnexpectedSocketEvent = 2, + + /// @brief The received message size is invalid. + InvalidMessageSize = 3, + + /// @brief The message signature was incorrect. + MessageAuthenticationFailed = 4, + + /// @brief The server challenge of the message did not match previous messages + ServerChallengeMismatch = 5, + + /// @brief The client challenge of the message did not match previous messages + ClientChallengeInvalid = 6, + + /// @brief An unexpected or unsupported message type was received + InvalidMessageType = 7, + + /// @brief A message is already being processed + TooManyRequests = 8, + + InvalidUrlParameter = 10, + InvalidResponseAuthentication = 11, + +}; + /** - * @brief The content of an unlock message. - * - * The content is necessary to ensure freshness of the message - * by requiring a recent time and a monotonously increasing counter. - * This prevents messages from being delayed or being blocked and - * replayed later. + * @brief A generic message to exchange during challenge-response authentication. */ typedef struct { - /** - * The timestamp of message creation - * - * The timestamp is encoded as the epoch time, i.e. seconds since 1970 (GMT). - * - * The timestamp is used to ensure 'freshness' of the messages, - * i.e. that they are not unreasonably delayed or captured and - * later replayed by an attacker. - */ - uint32_t time; - - /** - * The counter of unlock messages - * - * This counter must always increase with each message from the remote - * in order for the messages to be deemed valid. Transfering the counters - * back and forth also gives information about lost messages and potential - * attacks. Both the remote and the device keep a record of at least the - * last used counter. - */ - uint32_t id; + /// @brief The type of message being sent. + MessageType messageType; /** - * @brief The id of the device sending the message + * @brief The random nonce created by the remote + * + * This nonce is a random number created by the remote, different for each unlock request. + * It is set for all message types. */ - uint8_t device; + uint32_t clientChallenge; + + /** + * @brief A random number to sign by the remote + * + * This nonce is set by the server after receiving an initial message. + * It is set for the message types `challenge`, `request`, and `response`. + */ + uint32_t serverChallenge; + + /** + * @brief The response status for the previous message. + * + * It is set only for messages from the server, e.g. the `challenge` and `response` message types. + * Must be set to `MessageAccepted` for other messages. + */ + MessageResult result; } Message; -constexpr size_t messageCounterSize = sizeof(uint32_t); - /** - * @brief An authenticated message by the mobile device to command unlocking. + * @brief The signed version of a message. * - * The message is protected by a message authentication code (MAC) based on - * a symmetric key shared by the device and the remote. This code ensures - * that the contents of the request were not altered. The message further - * contains a timestamp to ensure that the message is recent, and not replayed - * by an attacker. An additional counter is also included for this purpose, - * which must continously increase for a message to be valid. This increases - * security a bit, since the timestamp validation must be tolerant to some - * inaccuracy due to mismatching clocks. */ typedef struct { @@ -77,38 +112,18 @@ typedef struct { */ uint8_t mac[SHA256_MAC_SIZE]; - /** - * @brief The message content. - * - * The content is necessary to ensure freshness of the message - * by requiring a recent time and a monotonously increasing counter. - * This prevents messages from being delayed or being blocked and - * replayed later. - */ + /// @brief The message Message message; -} AuthenticatedMessage; +} SignedMessage; + +constexpr size_t messageCounterSize = sizeof(uint32_t); + #pragma pack(pop) constexpr int MESSAGE_CONTENT_SIZE = sizeof(Message); -constexpr int AUTHENTICATED_MESSAGE_SIZE = sizeof(AuthenticatedMessage); - -/** - * An event signaled from the device - */ -enum class SesameEvent { - TextReceived = 1, - UnexpectedSocketEvent = 2, - InvalidMessageSize = 3, - MessageAuthenticationFailed = 4, - MessageTimeMismatch = 5, - MessageCounterInvalid = 6, - MessageAccepted = 7, - MessageDeviceInvalid = 8, - InvalidUrlParameter = 20, - InvalidResponseAuthentication = 21, -}; +constexpr int SIGNED_MESSAGE_SIZE = sizeof(SignedMessage); /** * @brief A callback for messages received over the socket @@ -121,5 +136,5 @@ typedef void (*MessageCallback)(uint8_t* payload, size_t length); /** * @brief A callback for socket errors */ -typedef void (*ErrorCallback)(SesameEvent event); +typedef void (*ErrorCallback)(MessageResult event); diff --git a/include/server.h b/include/server.h index 3767262..a2dc177 100644 --- a/include/server.h +++ b/include/server.h @@ -2,8 +2,6 @@ #include "message.h" #include "crypto.h" -#include -#include #include struct ServerConfiguration { @@ -36,7 +34,7 @@ class ServerConnectionCallbacks { public: - virtual void sendServerError(SesameEvent event) = 0; + virtual void sendServerError(MessageResult event) = 0; virtual void handleServerMessage(uint8_t* payload, size_t length) = 0; }; diff --git a/include/storage.h b/include/storage.h deleted file mode 100644 index 530e4bd..0000000 --- a/include/storage.h +++ /dev/null @@ -1,83 +0,0 @@ -#pragma once - -#include - -class Storage { - -public: - - Storage(uint8_t remoteDeviceCount) : remoteDeviceCount(remoteDeviceCount) { }; - - /** - * @brief Initialize the use of the message counter API - * - * The message counter is stored in EEPROM, which must be initialized before use. - * - * @note The ESP32 does not have a true EEPROM, - * which is emulated using a section of the flash memory. - */ - void configure(); - - /** - * @brief Check if a device ID is allowed - * - * @param deviceId The ID to check - * @return true The id is valid - * @return false The id is invalid - */ - bool isDeviceIdValid(uint8_t deviceId); - - /** - * @brief Check if a received counter is valid - * - * The counter is valid if it is larger than the previous counter - * (larger or equal to the next expected counter). - * - * @param counter The counter to check - * @return true The counter is valid - * @return false The counter belongs to an old message - */ - bool isMessageCounterValid(uint32_t counter, uint8_t deviceId); - - /** - * @brief Mark a counter of a message as used. - * - * The counter value is stored in EEPROM to persist across restarts. - * - * All messages with counters lower than the given one will become invalid. - * - * @param counter The counter used in the last message. - */ - void didUseMessageCounter(uint32_t counter, uint8_t deviceId); - - /** - * @brief Get the expected count for the next message. - * - * The counter is stored in EEPROM to persist across restarts - * - * @return The next counter to use by the remote - */ - uint32_t getNextMessageCounter(uint8_t deviceId); - - /** - * @brief Print info about the current message counter to the serial output - * - */ - void printMessageCounters(); - - /** - * @brief Reset the message counter. - * - * @warning The counter should never be reset in production environments, - * and only together with a new secret key. Otherwise old messages may be - * used for replay attacks. - * - */ - void resetMessageCounters(); - -private: - - uint8_t remoteDeviceCount; - - void setMessageCounter(uint32_t counter, uint8_t deviceId); -}; \ No newline at end of file diff --git a/platformio.ini b/platformio.ini index 6696185..e2fa657 100644 --- a/platformio.ini +++ b/platformio.ini @@ -13,7 +13,10 @@ platform = espressif32 board = az-delivery-devkit-v4 framework = arduino lib_deps = - links2004/WebSockets@^2.3.7 + ; links2004/WebSockets@^2.4.0 madhephaestus/ESP32Servo@^1.1.0 ottowinter/ESPAsyncWebServer-esphome@^3.0.0 + arduino-libraries/Ethernet@^2.0.2 + https://github.com/christophhagen/arduinoWebSockets#master + monitor_speed = 115200 diff --git a/src/controller.cpp b/src/controller.cpp index ee7c9ef..27f76d5 100644 --- a/src/controller.cpp +++ b/src/controller.cpp @@ -2,23 +2,47 @@ #include "crypto.h" #include "config.h" -#include +#include +#include + +SesameController::SesameController(uint16_t localWebServerPort) : localWebServer(localWebServerPort) { -SesameController::SesameController(uint16_t localWebServerPort, uint8_t remoteDeviceCount) : - storage(remoteDeviceCount), localWebServer(localWebServerPort) { - - // Set up response buffer - responseStatus = (SesameEvent*) responseBuffer; - responseMessage = (AuthenticatedMessage*) (responseBuffer + 1); } -void SesameController::configure(ServoConfiguration servoConfig, ServerConfiguration serverConfig, TimeConfiguration timeConfig, WifiConfiguration wifiConfig, KeyConfiguration keyConfig) { - this->wifiConfig = wifiConfig; +void SesameController::configure(ServoConfiguration servoConfig, ServerConfiguration serverConfig, EthernetConfiguration ethernetConfig, KeyConfiguration keyConfig) { + this->ethernetConfig = ethernetConfig; this->keyConfig = keyConfig; - // Prepare EEPROM for reading and writing - storage.configure(); - Serial.println("[INFO] Storage configured"); + // Ensure source of random numbers without WiFi and Bluetooth + enableCrypto(); + + // Initialize SPI interface to Ethernet module + SPI.begin(ethernetConfig.spiPinSclk, ethernetConfig.spiPinMiso, ethernetConfig.spiPinMosi, ethernetConfig.spiPinSS); //SCLK, MISO, MOSI, SS + pinMode(ethernetConfig.spiPinSS, OUTPUT); + + Ethernet.init(ethernetConfig.spiPinSS); + + // Check for Ethernet hardware present + if (Ethernet.hardwareStatus() == EthernetNoHardware) { + Serial.println("[ERROR] Ethernet shield not found."); + } else if (Ethernet.linkStatus() == LinkOFF) { + Serial.println("[ERROR] Ethernet cable is not connected."); + } else if (Ethernet.linkStatus() == Unknown) { + Serial.println("[ERROR] Ethernet cable status unknown."); + } else if (Ethernet.linkStatus() == LinkON) { + Serial.println("[INFO] Ethernet cable is connected."); + if (Ethernet.begin(ethernetConfig.macAddress, ethernetConfig.dhcpLeaseTimeoutMs, ethernetConfig.dhcpLeaseResponseTimeoutMs) == 1) { + Serial.print("[INFO] DHCP assigned IP "); + Serial.println(Ethernet.localIP()); + ethernetIsConfigured = true; + } else { + // Try to configure using IP address instead of DHCP + Ethernet.begin(ethernetConfig.macAddress, ethernetConfig.manualIp, ethernetConfig.manualDnsAddress); + Serial.print("[WARNING] DHCP failed, using self-assigned IP "); + Serial.println(Ethernet.localIP()); + ethernetIsConfigured = true; + } + } servo.configure(servoConfig); Serial.println("[INFO] Servo configured"); @@ -26,9 +50,6 @@ void SesameController::configure(ServoConfiguration servoConfig, ServerConfigura // Direct messages and errors over the websocket to the controller server.configure(serverConfig, this); Serial.println("[INFO] Server connection configured"); - - timeCheck.configure(timeConfig); - // Direct messages from the local web server to the controller localWebServer.on("/message", HTTP_POST, [this] (AsyncWebServerRequest *request) { @@ -37,17 +58,13 @@ void SesameController::configure(ServoConfiguration servoConfig, ServerConfigura }); Serial.println("[INFO] Local web server configured"); - - //storage.resetMessageCounters(); - storage.printMessageCounters(); } void SesameController::loop(uint32_t millis) { + currentTime = millis; server.loop(); servo.loop(millis); - periodicallyReconnectWifiAndSocket(millis); - ensureWiFiConnection(millis); ensureWebSocketConnection(); } @@ -56,126 +73,134 @@ void SesameController::loop(uint32_t millis) { void SesameController::handleLocalMessage(AsyncWebServerRequest *request) { if (!request->hasParam(messageUrlParameter)) { Serial.println("Missing url parameter"); - prepareResponseBuffer(SesameEvent::InvalidUrlParameter); + prepareResponseBuffer(MessageResult::InvalidUrlParameter); return; } String encoded = request->getParam(messageUrlParameter)->value(); if (!convertHexMessageToBinary(encoded.c_str())) { Serial.println("Invalid hex encoding"); - prepareResponseBuffer(SesameEvent::InvalidMessageSize); + prepareResponseBuffer(MessageResult::InvalidMessageSize); return; } - processMessage((AuthenticatedMessage*) receivedMessageBuffer); + processMessage(&receivedLocalMessage); } void SesameController::sendPreparedLocalResponse(AsyncWebServerRequest *request) { - request->send_P(200, "application/octet-stream", responseBuffer, responseSize); - Serial.printf("[INFO] Local response %u (%u bytes)\n", responseBuffer[0], responseSize); + request->send_P(200, "application/octet-stream", (uint8_t*) &outgoingMessage, SIGNED_MESSAGE_SIZE); + Serial.printf("[INFO] Local response %u\n", outgoingMessage.message.messageType); } // MARK: Server -void SesameController::sendServerError(SesameEvent event) { - prepareResponseBuffer(event); - sendPreparedServerResponse(); +void SesameController::sendServerError(MessageResult result) { + prepareResponseBuffer(result); // No message to echo + sendPreparedResponseToServer(); } void SesameController::handleServerMessage(uint8_t* payload, size_t length) { - if (length != AUTHENTICATED_MESSAGE_SIZE) { - prepareResponseBuffer(SesameEvent::InvalidMessageSize); + if (length != SIGNED_MESSAGE_SIZE) { + // No message saved to discard, don't accidentally delete for other operation + sendServerError(MessageResult::InvalidMessageSize); return; } - - processMessage((AuthenticatedMessage*) payload); - sendPreparedServerResponse(); + processMessage((SignedMessage*) payload); + sendPreparedResponseToServer(); } -void SesameController::sendPreparedServerResponse() { - server.sendResponse(responseBuffer, responseSize); - Serial.printf("[INFO] Server response %u (%u bytes)\n", responseBuffer[0], responseSize); +void SesameController::sendPreparedResponseToServer() { + server.sendResponse((uint8_t*) &outgoingMessage, SIGNED_MESSAGE_SIZE); + Serial.printf("[INFO] Server response %u\n", outgoingMessage.message.messageType); } // MARK: Message handling -void SesameController::processMessage(AuthenticatedMessage* message) { - SesameEvent event = verifyAndProcessReceivedMessage(message); - prepareResponseBuffer(event, message->message.device); +void SesameController::processMessage(SignedMessage* message) { + // Result must be empty + if (message->message.result != MessageResult::MessageAccepted) { + prepareResponseBuffer(MessageResult::ClientChallengeInvalid); + return; + } + if (!isAuthenticMessage(message, keyConfig.remoteKey)) { + prepareResponseBuffer(MessageResult::MessageAuthenticationFailed); + return; + } + switch (message->message.messageType) { + case MessageType::initial: + prepareChallenge(&message->message); + return; + case MessageType::request: + completeUnlockRequest(&message->message); + return; + default: + prepareResponseBuffer(MessageResult::InvalidMessageType); + return; + } } -/** - * Process a received message. - * - * Checks whether the received data is a valid, - * and then signals that the motor should move. - * - * @param message The message received from the remote - * @return The response to signal to the server. - */ -SesameEvent SesameController::verifyAndProcessReceivedMessage(AuthenticatedMessage* message) { - if (!isAuthenticMessage(message, keyConfig.remoteKey, keySize)) { - return SesameEvent::MessageAuthenticationFailed; +void SesameController::prepareChallenge(Message* message) { + // Server challenge must be empty + if (message->serverChallenge != 0) { + prepareResponseBuffer(MessageResult::ClientChallengeInvalid); + return; } - if (!storage.isDeviceIdValid(message->message.device)) { - return SesameEvent::MessageDeviceInvalid; - } - if (!storage.isMessageCounterValid(message->message.id, message->message.device)) { - return SesameEvent::MessageCounterInvalid; - } - if (!timeCheck.isMessageTimeAcceptable(message->message.time)) { - return SesameEvent::MessageTimeMismatch; + if (hasCurrentChallenge()) { + Serial.println("[INFO] Overwriting old challenge"); } - storage.didUseMessageCounter(message->message.id, message->message.device); + // Set challenge and respond + currentClientChallenge = message->clientChallenge; + currentServerChallenge = randomChallenge(); + currentChallengeExpiry = currentTime + keyConfig.challengeExpiryMs; + + prepareResponseBuffer(MessageResult::MessageAccepted, message); +} + +void SesameController::completeUnlockRequest(Message* message) { + if (!hasCurrentChallenge()) { + prepareResponseBuffer(MessageResult::ClientChallengeInvalid, message); + return; + } + // Client and server challenge must match + if (message->clientChallenge != currentClientChallenge) { + prepareResponseBuffer(MessageResult::ClientChallengeInvalid, message); + return; + } + if (message->serverChallenge != currentServerChallenge) { + prepareResponseBuffer(MessageResult::ServerChallengeMismatch, message); + return; + } + clearCurrentChallenge(); + // Move servo servo.pressButton(); - Serial.printf("[Info] Accepted message %d\n", message->message.id); - return SesameEvent::MessageAccepted; + prepareResponseBuffer(MessageResult::MessageAccepted, message); + Serial.println("[INFO] Accepted message"); } -bool allowMessageResponse(SesameEvent event) { - switch (event) { - case SesameEvent::MessageTimeMismatch: - case SesameEvent::MessageCounterInvalid: - case SesameEvent::MessageAccepted: - case SesameEvent::MessageDeviceInvalid: - return true; - default: - return false; +void SesameController::prepareResponseBuffer(MessageResult result, Message* message) { + outgoingMessage.message.result = result; + if (message != NULL) { + outgoingMessage.message.clientChallenge = message->clientChallenge; + outgoingMessage.message.serverChallenge = message->serverChallenge; + // All outgoing messages are responses, except if an initial message is accepted + if (message->messageType == MessageType::initial && result == MessageResult::MessageAccepted) { + outgoingMessage.message.messageType = MessageType::challenge; + } else { + outgoingMessage.message.messageType = MessageType::response; + } + } else { + outgoingMessage.message.clientChallenge = message->clientChallenge; + outgoingMessage.message.serverChallenge = message->serverChallenge; + outgoingMessage.message.messageType = MessageType::response; } -} -void SesameController::prepareResponseBuffer(SesameEvent event, uint8_t deviceId) { - *responseStatus = event; - responseSize = 1; - if (!allowMessageResponse(event)) { - return; - } - responseMessage->message.time = timeCheck.getEpochTime(); - responseMessage->message.id = storage.getNextMessageCounter(deviceId); - responseMessage->message.device = deviceId; - - if (!authenticateMessage(responseMessage, keyConfig.localKey, keySize)) { - *responseStatus = SesameEvent::InvalidResponseAuthentication; - return; - } - responseSize += AUTHENTICATED_MESSAGE_SIZE; -} - -// MARK: Reconnecting - -void SesameController::ensureWiFiConnection(uint32_t millis) { - static uint32_t nextWifiReconnect = 0; - // Reconnect to WiFi - if(millis > nextWifiReconnect && WiFi.status() != WL_CONNECTED) { - Serial.println("[INFO] Reconnecting WiFi..."); - WiFi.setHostname(wifiConfig.networkName); - WiFi.begin(wifiConfig.ssid, wifiConfig.password); - isReconnecting = true; - nextWifiReconnect = millis + wifiConfig.reconnectInterval; + if (!authenticateMessage(&outgoingMessage, keyConfig.localKey)) { + Serial.println("[ERROR] Failed to sign message"); } } void SesameController::ensureWebSocketConnection() { + /* if (isReconnecting && WiFi.status() == WL_CONNECTED) { isReconnecting = false; Serial.print("WiFi IP address: "); @@ -185,12 +210,22 @@ void SesameController::ensureWebSocketConnection() { timeCheck.printLocalTime(); localWebServer.begin(); } + */ } // MARK: Helper -// Based on https://stackoverflow.com/a/23898449/266720 +/** + * @brief + * + * Based on https://stackoverflow.com/a/23898449/266720 + * + * @param str + * @return true + * @return false + */ bool SesameController::convertHexMessageToBinary(const char* str) { + uint8_t* buffer = (uint8_t*) &receivedLocalMessage; // TODO: Fail if invalid hex values are used uint8_t idx0, idx1; @@ -203,7 +238,7 @@ bool SesameController::convertHexMessageToBinary(const char* str) { }; size_t len = strlen(str); - if (len != AUTHENTICATED_MESSAGE_SIZE * 2) { + if (len != SIGNED_MESSAGE_SIZE * 2) { // Require exact message size return false; } @@ -211,17 +246,7 @@ bool SesameController::convertHexMessageToBinary(const char* str) { for (size_t pos = 0; pos < len; pos += 2) { idx0 = ((uint8_t)str[pos+0] & 0x1F) ^ 0x10; idx1 = ((uint8_t)str[pos+1] & 0x1F) ^ 0x10; - receivedMessageBuffer[pos/2] = (uint8_t)(hashmap[idx0] << 4) | hashmap[idx1]; + buffer[pos/2] = (uint8_t)(hashmap[idx0] << 4) | hashmap[idx1]; }; return true; -} - -void SesameController::periodicallyReconnectWifiAndSocket(uint32_t millis) { - static uint32_t nextWifiReconnect = wifiConfig.periodicReconnectInterval; - if (millis > nextWifiReconnect) { - nextWifiReconnect += wifiConfig.periodicReconnectInterval; - - server.disconnect(); - WiFi.disconnect(); - } } \ No newline at end of file diff --git a/src/crypto.cpp b/src/crypto.cpp index 6c162ed..34b8cfc 100644 --- a/src/crypto.cpp +++ b/src/crypto.cpp @@ -1,8 +1,19 @@ #include "crypto.h" +#include "config.h" #include #include +#include +#include -bool authenticateData(const uint8_t* data, size_t dataLength, uint8_t* mac, const uint8_t* key, size_t keyLength) { +void enableCrypto() { + bootloader_random_enable(); +} + +uint32_t randomChallenge() { + return esp_random(); +} + +bool authenticateData(const uint8_t* data, size_t dataLength, uint8_t* mac, const uint8_t* key) { mbedtls_md_context_t ctx; mbedtls_md_type_t md_type = MBEDTLS_MD_SHA256; int result; @@ -12,7 +23,7 @@ bool authenticateData(const uint8_t* data, size_t dataLength, uint8_t* mac, cons if (result) { return false; } - result = mbedtls_md_hmac_starts(&ctx, key, keyLength); + result = mbedtls_md_hmac_starts(&ctx, key, keySize); if (result) { return false; } @@ -28,17 +39,17 @@ bool authenticateData(const uint8_t* data, size_t dataLength, uint8_t* mac, cons return true; } -bool authenticateMessage(Message* message, uint8_t* mac, const uint8_t* key, size_t keyLength) { - return authenticateData((const uint8_t*) message, MESSAGE_CONTENT_SIZE, mac, key, keyLength); +bool authenticateMessage(Message* message, uint8_t* mac, const uint8_t* key) { + return authenticateData((const uint8_t*) message, MESSAGE_CONTENT_SIZE, mac, key); } -bool authenticateMessage(AuthenticatedMessage* message, const uint8_t* key, size_t keyLength) { - return authenticateMessage(&message->message, message->mac, key, keyLength); +bool authenticateMessage(SignedMessage* message, const uint8_t* key) { + return authenticateMessage(&message->message, message->mac, key); } -bool isAuthenticMessage(AuthenticatedMessage* message, const uint8_t* key, size_t keyLength) { +bool isAuthenticMessage(SignedMessage* message, const uint8_t* key) { uint8_t mac[SHA256_MAC_SIZE]; - if (!authenticateMessage(&message->message, mac, key, keyLength)) { + if (!authenticateMessage(&message->message, mac, key)) { return false; } return memcmp(mac, message->mac, SHA256_MAC_SIZE) == 0; diff --git a/src/fresh.cpp b/src/fresh.cpp deleted file mode 100644 index db5270d..0000000 --- a/src/fresh.cpp +++ /dev/null @@ -1,49 +0,0 @@ -#include "fresh.h" - -#include // configTime() -#include - -TimeCheck::TimeCheck() { } - -void TimeCheck::configure(TimeConfiguration configuration) { - config = configuration; -} - -void TimeCheck::startNTP() { - configTime(config.offsetToGMT, config.offsetDaylightSavings, config.ntpServerUrl); -} - -void TimeCheck::printLocalTime() { - struct tm timeinfo; - if (getLocalTime(&timeinfo)) { - Serial.println(&timeinfo, "[INFO] Time is %A, %d. %B %Y %H:%M:%S"); - } else { - Serial.println("[WARN] No local time available"); - } -} - -uint32_t TimeCheck::getEpochTime() { - time_t now; - struct tm timeinfo; - if (!getLocalTime(&timeinfo)) { - Serial.println("[WARN] Failed to obtain local time"); - return(0); - } - time(&now); - return now; -} - -bool TimeCheck::isMessageTimeAcceptable(uint32_t t) { - uint32_t localTime = getEpochTime(); - if (localTime == 0) { - Serial.println("No epoch time available"); - return false; - } - if (t > localTime + config.allowedTimeOffset) { - return false; - } - if (t < localTime - config.allowedTimeOffset) { - return false; - } - return true; -} \ No newline at end of file diff --git a/src/main.cpp b/src/main.cpp index 2386f7e..e07658a 100644 --- a/src/main.cpp +++ b/src/main.cpp @@ -13,7 +13,7 @@ #include "controller.h" #include "config.h" -SesameController controller(localPort, remoteDeviceCount); +SesameController controller(localPort); void setup() { Serial.begin(serialBaudRate); @@ -37,27 +37,28 @@ void setup() { .reconnectTime = 5000, }; - TimeConfiguration timeConfig { - .offsetToGMT = timeOffsetToGMT, - .offsetDaylightSavings = timeOffsetDaylightSavings, - .ntpServerUrl = ntpServerUrl, - .allowedTimeOffset = 60, - }; - - WifiConfiguration wifiConfig { - .ssid = wifiSSID, - .password = wifiPassword, - .networkName = networkName, - .reconnectInterval = wifiReconnectInterval, - .periodicReconnectInterval = wifiPeriodicReconnectInterval, + EthernetConfiguration ethernetConfig { + .macAddress = ethernetMacAddress, + .spiPinMiso = spiPinMiso, + .spiPinMosi = spiPinMosi, + .spiPinSclk = spiPinSclk, + .spiPinSS = spiPinSS, + .dhcpLeaseTimeoutMs = dhcpLeaseTimeoutMs, + .dhcpLeaseResponseTimeoutMs = dhcpLeaseResponseTimeoutMs, + .manualIp = manualIpAddress, + .manualDnsAddress = manualDnsServerAddress, + .socketHeartbeatIntervalMs = socketHeartbeatIntervalMs, + .socketHeartbeatTimeoutMs = socketHeartbeatTimeoutMs, + .socketHeartbeatFailureReconnectCount = socketHeartbeatFailureReconnectCount, }; KeyConfiguration keyConfig { .remoteKey = remoteKey, .localKey = localKey, + .challengeExpiryMs = challengeExpiryMs, }; - controller.configure(servoConfig, serverConfig, timeConfig, wifiConfig, keyConfig); + controller.configure(servoConfig, serverConfig, ethernetConfig, keyConfig); } void loop() { diff --git a/src/server.cpp b/src/server.cpp index 1009932..882a53a 100644 --- a/src/server.cpp +++ b/src/server.cpp @@ -48,7 +48,7 @@ switch(type) { webSocket.enableHeartbeat(pingInterval, pongTimeout, disconnectTimeoutCount); break; case WStype_TEXT: - controller->sendServerError(SesameEvent::TextReceived); + controller->sendServerError(MessageResult::TextReceived); break; case WStype_BIN: controller->handleServerMessage(payload, length); @@ -61,7 +61,7 @@ switch(type) { case WStype_FRAGMENT_BIN_START: case WStype_FRAGMENT: case WStype_FRAGMENT_FIN: - controller->sendServerError(SesameEvent::UnexpectedSocketEvent); + controller->sendServerError(MessageResult::UnexpectedSocketEvent); break; } } diff --git a/src/storage.cpp b/src/storage.cpp deleted file mode 100644 index 6218453..0000000 --- a/src/storage.cpp +++ /dev/null @@ -1,53 +0,0 @@ -#include "storage.h" -#include "message.h" -#include - -void Storage::configure() { - EEPROM.begin(messageCounterSize * remoteDeviceCount); -} - -bool Storage::isDeviceIdValid(uint8_t deviceId) { - return deviceId < remoteDeviceCount; -} - -bool Storage::isMessageCounterValid(uint32_t counter, uint8_t deviceId) { - return counter >= getNextMessageCounter(deviceId); -} - -void Storage::didUseMessageCounter(uint32_t counter, uint8_t deviceId) { - // Store the next counter, so that resetting starts at 0 - setMessageCounter(counter+1, deviceId); -} - -void Storage::setMessageCounter(uint32_t counter, uint8_t deviceId) { - int offset = deviceId * messageCounterSize; - EEPROM.write(offset + 0, (counter >> 24) & 0xFF); - EEPROM.write(offset + 1, (counter >> 16) & 0xFF); - EEPROM.write(offset + 2, (counter >> 8) & 0xFF); - EEPROM.write(offset + 3, counter & 0xFF); - EEPROM.commit(); -} - -uint32_t Storage::getNextMessageCounter(uint8_t deviceId) { - int offset = deviceId * messageCounterSize; - uint32_t counter = (uint32_t) EEPROM.read(offset + 0) << 24; - counter += (uint32_t) EEPROM.read(offset + 1) << 16; - counter += (uint32_t) EEPROM.read(offset + 2) << 8; - counter += (uint32_t) EEPROM.read(offset + 3); - return counter; -} - -void Storage::printMessageCounters() { - Serial.print("[INFO] Next message numbers:"); - for (uint8_t i = 0; i < remoteDeviceCount; i += 1) { - Serial.printf(" %u", getNextMessageCounter(i)); - } - Serial.println(""); -} - -void Storage::resetMessageCounters() { - for (uint8_t i = 0; i < remoteDeviceCount; i += 1) { - setMessageCounter(0, i); - } - Serial.println("[WARN] Message counters reset"); -}