diff --git a/include/base64.h b/include/base64.h new file mode 100644 index 0000000..e7530ee --- /dev/null +++ b/include/base64.h @@ -0,0 +1,30 @@ +#ifndef UTILITY_BASE64_H_ +#define UTILITY_BASE64_H_ + +#include + +/** + * @brief Base64 encode function. + * + * @param[in] *data Pointer to data that will be encoded. + * @param[in] data_length Data length. + * @param[out] *result Pointer to result, encoded data. + * @param[in] max_result_length Maximum result length. + * + * @return 0 = success, in case of fail value different than 0 is returned. + */ +int32_t Base64_encode(const char* data, size_t data_length, char* result, size_t max_result_length); + +/** + * @brief Base64 decode function. + * + * @param[in] *in Pointer to data that will be decoded. + * @param[in] in_len Input data length. + * @param[out] *out Pointer to result data, decoded data. + * @param[in] max_out_len Maximum output length. + * + * @return 0 = success, in case of fail value different than 0 is returned. + */ +int32_t Base64_decode(const char* in, size_t in_len, uint8_t* out, size_t max_out_len); + +#endif /* UTILITY_BASE64_H_ */ \ No newline at end of file diff --git a/include/example_config.h b/include/example_config.h index 3973062..11ae6c4 100644 --- a/include/example_config.h +++ b/include/example_config.h @@ -40,9 +40,20 @@ constexpr const char* wifiSSID = "MyWiFi"; // The WiFi password to connect to the above network constexpr const char* wifiPassword = "00000000"; +// The name of the device on the network +constexpr const char* networkName = "Sesame-Device"; + // The interval to reconnect to WiFi if the connection is broken constexpr uint32_t wifiReconnectInterval = 10000; +/* Local server */ + +// The port for the local server to directly receive messages over WiFi +constexpr uint16_t localPort = 80; + +// The url parameter to send the message to the local server +constexpr char messageUrlParameter[] = "m"; + /* Server */ diff --git a/include/message.h b/include/message.h index d112204..8daf7f5 100644 --- a/include/message.h +++ b/include/message.h @@ -84,4 +84,26 @@ typedef struct { #define MESSAGE_CONTENT_SIZE sizeof(Message) -#define AUTHENTICATED_MESSAGE_SIZE sizeof(AuthenticatedMessage) \ No newline at end of file +#define AUTHENTICATED_MESSAGE_SIZE sizeof(AuthenticatedMessage) + +/** + * An event signaled from the device + */ +enum class SesameEvent { + TextReceived = 1, + UnexpectedSocketEvent = 2, + InvalidMessageData = 3, + MessageAuthenticationFailed = 4, + MessageTimeMismatch = 5, + MessageCounterInvalid = 6, + MessageAccepted = 7, +}; + +/** + * @brief A callback for messages received over the socket + * + * The first parameter is the received message. + * The second parameter is the response to the remote. + * The return value is the type of event to respond with. + */ +typedef SesameEvent (*MessageCallback)(AuthenticatedMessage*, AuthenticatedMessage*); \ No newline at end of file diff --git a/include/server.h b/include/server.h index c48654a..e24609d 100644 --- a/include/server.h +++ b/include/server.h @@ -6,28 +6,6 @@ #include #include -/** - * An event signaled from the device - */ -enum class SesameEvent { - TextReceived = 1, - UnexpectedSocketEvent = 2, - InvalidMessageData = 3, - MessageAuthenticationFailed = 4, - MessageTimeMismatch = 5, - MessageCounterInvalid = 6, - MessageAccepted = 7, -}; - -/** - * @brief A callback for messages received over the socket - * - * The first parameter is the received message. - * The second parameter is the response to the remote. - * The return value is the type of event to respond with. - */ -typedef SesameEvent (*MessageCallback)(AuthenticatedMessage*, AuthenticatedMessage*); - class ServerConnection { public: diff --git a/platformio.ini b/platformio.ini index 7565249..e95b3b3 100644 --- a/platformio.ini +++ b/platformio.ini @@ -14,5 +14,6 @@ board = az-delivery-devkit-v4 framework = arduino lib_deps = links2004/WebSockets@^2.3.7 - madhephaestus/ESP32Servo@^0.11.0 -monitor_speed = 115200 \ No newline at end of file + madhephaestus/ESP32Servo@^0.13.0 + ottowinter/ESPAsyncWebServer-esphome@^3.0.0 +monitor_speed = 115200 diff --git a/src/base64.cpp b/src/base64.cpp new file mode 100644 index 0000000..f7c1861 --- /dev/null +++ b/src/base64.cpp @@ -0,0 +1,172 @@ +/* + * Base64 Decode + * Polfosol + * + * Base64 encoding/decoding (RFC1341) + * Copyright (c) 2005-2011, Jouni Malinen + * + * This software may be distributed under the terms of the BSD license. + * See README for more details. + * + */ + +// Source code from Polfosol: https://stackoverflow.com/questions/180947/base64-decode-snippet-in-c/13935718 +// Source code from Jouni Malinen: https://web.mit.edu/freebsd/head/contrib/wpa/src/utils/base64.c + + +// Encode/Decode functions are modified by Juraj Ciberlin (jciberlin1@gmail.com) to be MISRA C 2012 compliant + +#include "base64.h" + +int32_t +Base64_encode(const char* data, size_t data_length, char* result, size_t max_result_length) { + int32_t success = 0; + const uint8_t base64_table[65] = { + (uint8_t)'A', (uint8_t)'B', (uint8_t)'C', (uint8_t)'D', + (uint8_t)'E', (uint8_t)'F', (uint8_t)'G', (uint8_t)'H', + (uint8_t)'I', (uint8_t)'J', (uint8_t)'K', (uint8_t)'L', + (uint8_t)'M', (uint8_t)'N', (uint8_t)'O', (uint8_t)'P', + (uint8_t)'Q', (uint8_t)'R', (uint8_t)'S', (uint8_t)'T', + (uint8_t)'U', (uint8_t)'V', (uint8_t)'W', (uint8_t)'X', + (uint8_t)'Y', (uint8_t)'Z', (uint8_t)'a', (uint8_t)'b', + (uint8_t)'c', (uint8_t)'d', (uint8_t)'e', (uint8_t)'f', + (uint8_t)'g', (uint8_t)'h', (uint8_t)'i', (uint8_t)'j', + (uint8_t)'k', (uint8_t)'l', (uint8_t)'m', (uint8_t)'n', + (uint8_t)'o', (uint8_t)'p', (uint8_t)'q', (uint8_t)'r', + (uint8_t)'s', (uint8_t)'t', (uint8_t)'u', (uint8_t)'v', + (uint8_t)'w', (uint8_t)'x', (uint8_t)'y', (uint8_t)'z', + (uint8_t)'0', (uint8_t)'1', (uint8_t)'2', (uint8_t)'3', + (uint8_t)'4', (uint8_t)'5', (uint8_t)'6', (uint8_t)'7', + (uint8_t)'8', (uint8_t)'9', (uint8_t)'+', (uint8_t)'/', + (uint8_t)'\0' + }; + uint8_t* out; + const uint8_t* in = (const uint8_t*) data; + + size_t len = 4U * ((data_length + 2U) / 3U); + + if (len < data_length) { + success = 1; + } + + if (success == 0) { + size_t current_length = 0U; + size_t in_position = 0U; + out = (uint8_t*)&result[0]; + uint8_t* pos = out; + while ((data_length - in_position) >= 3U) { + current_length += 4U; + if (current_length > max_result_length) { + success = 1; + break; + } + *pos = base64_table[in[0] >> 2]; + ++pos; + *pos = base64_table[((in[0] & 0x03U) << 4) | (in[1] >> 4)]; + ++pos; + *pos = base64_table[((in[1] & 0x0FU) << 2) | (in[2] >> 6)]; + ++pos; + *pos = base64_table[in[2] & 0x3FU]; + ++pos; + ++in; + ++in; + ++in; + in_position += 3U; + } + + if ((success == 0) && ((data_length - in_position) != 0U)) { + current_length += 4U; + if (current_length > max_result_length) { + success = 1; + } + + if (success == 0) { + *pos = base64_table[in[0] >> 2]; + ++pos; + if ((data_length - in_position) == 1U) { + *pos = base64_table[(in[0] & 0x03U) << 4]; + ++pos; + *pos = (uint8_t)'='; + ++pos; + } else { + *pos = base64_table[((in[0] & 0x03U) << 4) | (in[1] >> 4)]; + ++pos; + *pos = base64_table[(in[1] & 0x0FU) << 2]; + ++pos; + } + *pos = (uint8_t)'='; + ++pos; + } + } + + *pos = (uint8_t)'\0'; + } + + return success; +} + +int32_t +Base64_decode(const char* in, size_t in_len, uint8_t* out, size_t max_out_len) { + int32_t success = 0; + const uint32_t base64_index[256] = { + 0U, 0U, 0U, 0U, 0U, 0U, 0U, 0U, 0U, 0U, 0U, 0U, 0U, 0U, 0U, 0U, 0U, 0U, 0U, 0U, + 0U, 0U, 0U, 0U, 0U, 0U, 0U, 0U, 0U, 0U, 0U, 0U, 0U, 0U, 0U, 0U, 0U, 0U, 0U, 0U, + 0U, 0U, 0U, 62U, 63U, 62U, 62U, 63U, 52U, 53U, 54U, 55U, 56U, 57U, 58U, 59U, 60U, + 61U, 0U, 0U, 0U, 0U, 0U, 0U, 0U, 0U, 1U, 2U, 3U, 4U, 5U, 6U, 7U, 8U, 9U, 10U, 11U, + 12U, 13U, 14U, 15U, 16U, 17U, 18U, 19U, 20U, 21U, 22U, 23U, 24U, 25U, 0U, 0U, 0U, + 0U, 63U, 0U, 26U, 27U, 28U, 29U, 30U, 31U, 32U, 33U, 34U, 35U, 36U, 37U, 38U, 39U, + 40U, 41U, 42U, 43U, 44U, 45U, 46U, 47U, 48U, 49U, 50U, 51U, 0U, 0U, 0U, 0U, 0U, 0U, + 0U, 0U, 0U, 0U, 0U, 0U, 0U, 0U, 0U, 0U, 0U, 0U, 0U, 0U, 0U, 0U, 0U, 0U, 0U, 0U, 0U, + 0U, 0U, 0U, 0U, 0U, 0U, 0U, 0U, 0U, 0U, 0U, 0U, 0U, 0U, 0U, 0U, 0U, 0U, 0U, 0U, 0U, + 0U, 0U, 0U, 0U, 0U, 0U, 0U, 0U, 0U, 0U, 0U, 0U, 0U, 0U, 0U, 0U, 0U, 0U, 0U, 0U, 0U, + 0U, 0U, 0U, 0U, 0U, 0U, 0U, 0U, 0U, 0U, 0U, 0U, 0U, 0U, 0U, 0U, 0U, 0U, 0U, 0U, 0U, + 0U, 0U, 0U, 0U, 0U, 0U, 0U, 0U, 0U, 0U, 0U, 0U, 0U, 0U, 0U, 0U, 0U, 0U, 0U, 0U, 0U, + 0U, 0U, 0U, 0U, 0U, 0U, 0U, 0U, 0U, 0U, 0U, 0U, 0U, 0U, 0U, 0U, 0U, 0U, 0U, 0U, 0U, + 0U + }; + const uint8_t* in_data_uchar = (const uint8_t*)in; + bool pad_bool = (in_len > 0U) && (((in_len % 4U) != 0U) || (in_data_uchar[in_len - 1U] == (uint8_t)'=')); + uint32_t pad_uint = 0U; + if (pad_bool) { + pad_uint = 1U; + } + const size_t len = (((in_len + 3U) / 4U) - pad_uint) * 4U; + const size_t out_len = ((len / 4U) * 3U) + pad_uint; + + if (out_len > max_out_len) { + success = 1; + } + + if (len == 0U) { + success = 1; + } + + if (success == 0) { + size_t j = 0U; + for (size_t i = 0U; i < len; i += 4U) { + uint32_t n = (base64_index[in_data_uchar[i]] << 18U) | (base64_index[in_data_uchar[i + 1U]] << 12U) | + (base64_index[in_data_uchar[i + 2U]] << 6U) | (base64_index[in_data_uchar[i + 3U]]); + out[j] = (uint8_t)(n >> 16U); + ++j; + out[j] = (uint8_t)((n >> 8U) & 0xFFU); + ++j; + out[j] = (uint8_t)(n & 0xFFU); + ++j; + } + if (pad_bool) { + uint32_t n = (base64_index[in_data_uchar[len]] << 18U) | (base64_index[in_data_uchar[len + 1U]] << 12U); + out[out_len - 1U] = (uint8_t)(n >> 16U); + + if ((in_len > (len + 2U)) && (in_data_uchar[len + 2U] != (uint8_t)'=')) { + if ((out_len + 1U) > max_out_len) { + success = 1; + } else { + n |= base64_index[in_data_uchar[len + 2U]] << 6U; + out[out_len] = (uint8_t)((n >> 8U) & 0xFFU); + } + } + } + } + + return success; +} \ No newline at end of file diff --git a/src/main.cpp b/src/main.cpp index e1f5dfe..092bafc 100644 --- a/src/main.cpp +++ b/src/main.cpp @@ -7,6 +7,7 @@ */ #include #include +#include #include "crypto.h" #include "fresh.h" @@ -14,6 +15,7 @@ #include "server.h" #include "servo.h" #include "config.h" +#include "base64.h" /* Global variables */ @@ -21,10 +23,23 @@ ServerConnection server(serverUrl, serverPort, serverPath); ServoController servo(pwmTimer, servoFrequency, servoPin); +AsyncWebServer local(localPort); + +// The buffer to hold a received message while it is read +uint8_t receivedMessageBuffer[AUTHENTICATED_MESSAGE_SIZE]; + /* Event callbacks */ SesameEvent handleReceivedMessage(AuthenticatedMessage* payload, AuthenticatedMessage* response); +// Forward declare monitoring functions +void ensureWiFiConnection(uint32_t time); +void ensureWebSocketConnection(uint32_t time); + +void sendFailureResponse(AsyncWebServerRequest *request, SesameEvent event); +void sendMessageResponse(AsyncWebServerRequest *request, SesameEvent event, AuthenticatedMessage* message); +void sendResponse(AsyncWebServerRequest *request, uint8_t* buffer, uint8_t size); + /* Logic */ void setup() { @@ -40,6 +55,26 @@ void setup() { printMessageCounter(); server.onMessage(handleReceivedMessage); + + local.on("/message", HTTP_POST, [] (AsyncWebServerRequest *request) { + if (!request->hasParam(messageUrlParameter)) { + Serial.println("Missing url parameter"); + sendFailureResponse(request, SesameEvent::InvalidMessageData); + return; + } + String encoded = request->getParam(messageUrlParameter)->value(); + int res = Base64_decode(encoded.c_str(),encoded.length(), receivedMessageBuffer, AUTHENTICATED_MESSAGE_SIZE); + if (res) { + Serial.printf("Invalid message length (%d)\n", res); + sendFailureResponse(request, SesameEvent::InvalidMessageData); + return; + } + // Process received message + AuthenticatedMessage* message = (AuthenticatedMessage*) receivedMessageBuffer; + AuthenticatedMessage responseMessage; + SesameEvent event = handleReceivedMessage(message, &responseMessage); + sendMessageResponse(request, event, &responseMessage); + }); } uint32_t nextWifiReconnect = 0; @@ -50,20 +85,33 @@ void loop() { server.loop(); servo.loop(); - + + ensureWiFiConnection(time); + ensureWebSocketConnection(time); +} + +void ensureWiFiConnection(uint32_t time) { // Reconnect to WiFi if(time > nextWifiReconnect && WiFi.status() != WL_CONNECTED) { Serial.println("[INFO] Reconnecting WiFi..."); + WiFi.setHostname(networkName); WiFi.begin(wifiSSID, wifiPassword); isReconnecting = true; nextWifiReconnect = time + wifiReconnectInterval; } + +} + +void ensureWebSocketConnection(uint32_t time) { if (isReconnecting && WiFi.status() == WL_CONNECTED) { isReconnecting = false; + Serial.println("IP address: "); + Serial.println(WiFi.localIP()); Serial.println("[INFO] WiFi connected, opening socket"); server.connectSSL(serverAccessKey); configureNTP(timeOffsetToGMT, timeOffsetDaylightSavings, ntpServerUrl); printLocalTime(); + local.begin(); } } @@ -109,3 +157,21 @@ SesameEvent handleReceivedMessage(AuthenticatedMessage* message, AuthenticatedMe } return event; } + + +void sendFailureResponse(AsyncWebServerRequest *request, SesameEvent event) { + uint8_t response = static_cast(event); + sendResponse(request, &response, 1); +} + +void sendMessageResponse(AsyncWebServerRequest *request, SesameEvent event, AuthenticatedMessage* message) { + uint8_t response[AUTHENTICATED_MESSAGE_SIZE+1]; + response[0] = static_cast(event); + memcpy(response+1, (uint8_t*) message, AUTHENTICATED_MESSAGE_SIZE); + sendResponse(request, response, AUTHENTICATED_MESSAGE_SIZE+1); +} + +void sendResponse(AsyncWebServerRequest *request, uint8_t* buffer, uint8_t size) { + request->send_P(200, "application/octet-stream", buffer, size); + Serial.printf("[INFO] Local response %d\n", buffer[0]); +} \ No newline at end of file