support for custom websocket headers (#1573)

Looks good to me. Thank you.

Also:
 - allow for '\0's in received messages

* add client:config for setting websocket headers

Also:
 - headers are case-insensitive now

* fix docs

* fix typo

* remove unnecessary luaL_argcheck calls

* replace os_sprintf with simple string copy
This commit is contained in:
Mariusz Kryński 2016-11-19 16:35:20 +01:00 committed by Philip Gladstone
parent 6331e0868c
commit 59b9b3e26f
4 changed files with 160 additions and 29 deletions

View File

@ -15,6 +15,7 @@
#include "c_types.h" #include "c_types.h"
#include "c_string.h" #include "c_string.h"
#include "c_stdlib.h"
#include "websocketclient.h" #include "websocketclient.h"
@ -45,7 +46,7 @@ static void websocketclient_onConnectionCallback(ws_info *ws) {
} }
} }
static void websocketclient_onReceiveCallback(ws_info *ws, char *message, int opCode) { static void websocketclient_onReceiveCallback(ws_info *ws, int len, char *message, int opCode) {
NODE_DBG("websocketclient_onReceiveCallback\n"); NODE_DBG("websocketclient_onReceiveCallback\n");
lua_State *L = lua_getstate(); lua_State *L = lua_getstate();
@ -59,7 +60,7 @@ static void websocketclient_onReceiveCallback(ws_info *ws, char *message, int op
if (data->onReceive != LUA_NOREF) { if (data->onReceive != LUA_NOREF) {
lua_rawgeti(L, LUA_REGISTRYINDEX, data->onReceive); // load the callback function lua_rawgeti(L, LUA_REGISTRYINDEX, data->onReceive); // load the callback function
lua_rawgeti(L, LUA_REGISTRYINDEX, data->self_ref); // pass itself, #1 callback argument lua_rawgeti(L, LUA_REGISTRYINDEX, data->self_ref); // pass itself, #1 callback argument
lua_pushstring(L, message); // #2 callback argument lua_pushlstring(L, message, len); // #2 callback argument
lua_pushnumber(L, opCode); // #3 callback argument lua_pushnumber(L, opCode); // #3 callback argument
lua_call(L, 3, 0); lua_call(L, 3, 0);
} }
@ -102,6 +103,7 @@ static int websocket_createClient(lua_State *L) {
ws_info *ws = (ws_info *) lua_newuserdata(L, sizeof(ws_info)); ws_info *ws = (ws_info *) lua_newuserdata(L, sizeof(ws_info));
ws->connectionState = 0; ws->connectionState = 0;
ws->extraHeaders = NULL;
ws->onConnection = &websocketclient_onConnectionCallback; ws->onConnection = &websocketclient_onConnectionCallback;
ws->onReceive = &websocketclient_onReceiveCallback; ws->onReceive = &websocketclient_onReceiveCallback;
ws->onFailure = &websocketclient_onCloseCallback; ws->onFailure = &websocketclient_onCloseCallback;
@ -118,7 +120,6 @@ static int websocketclient_on(lua_State *L) {
NODE_DBG("websocketclient_on\n"); NODE_DBG("websocketclient_on\n");
ws_info *ws = (ws_info *) luaL_checkudata(L, 1, METATABLE_WSCLIENT); ws_info *ws = (ws_info *) luaL_checkudata(L, 1, METATABLE_WSCLIENT);
luaL_argcheck(L, ws, 1, "Client websocket expected");
ws_data *data = (ws_data *) ws->reservedData; ws_data *data = (ws_data *) ws->reservedData;
@ -170,7 +171,6 @@ static int websocketclient_connect(lua_State *L) {
NODE_DBG("websocketclient_connect is called.\n"); NODE_DBG("websocketclient_connect is called.\n");
ws_info *ws = (ws_info *) luaL_checkudata(L, 1, METATABLE_WSCLIENT); ws_info *ws = (ws_info *) luaL_checkudata(L, 1, METATABLE_WSCLIENT);
luaL_argcheck(L, ws, 1, "Client websocket expected");
ws_data *data = (ws_data *) ws->reservedData; ws_data *data = (ws_data *) ws->reservedData;
@ -188,11 +188,61 @@ static int websocketclient_connect(lua_State *L) {
return 0; return 0;
} }
static header_t *realloc_headers(header_t *headers, int new_size) {
if(headers) {
for(header_t *header = headers; header->key; header++) {
c_free(header->value);
c_free(header->key);
}
c_free(headers);
}
if(!new_size)
return NULL;
return (header_t *)c_malloc(sizeof(header_t) * (new_size + 1));
}
static int websocketclient_config(lua_State *L) {
NODE_DBG("websocketclient_config is called.\n");
ws_info *ws = (ws_info *) luaL_checkudata(L, 1, METATABLE_WSCLIENT);
ws_data *data = (ws_data *) ws->reservedData;
luaL_checktype(L, 2, LUA_TTABLE);
lua_getfield(L, 2, "headers");
if(lua_istable(L, -1)) {
lua_pushnil(L);
int size = 0;
while(lua_next(L, -2)) {
size++;
lua_pop(L, 1);
}
ws->extraHeaders = realloc_headers(ws->extraHeaders, size);
if(ws->extraHeaders) {
header_t *header = ws->extraHeaders;
lua_pushnil(L);
while(lua_next(L, -2)) {
header->key = c_strdup(lua_tostring(L, -2));
header->value = c_strdup(lua_tostring(L, -1));
header++;
lua_pop(L, 1);
}
header->key = header->value = NULL;
}
}
lua_pop(L, 1); // pop headers
return 0;
}
static int websocketclient_send(lua_State *L) { static int websocketclient_send(lua_State *L) {
NODE_DBG("websocketclient_send is called.\n"); NODE_DBG("websocketclient_send is called.\n");
ws_info *ws = (ws_info *) luaL_checkudata(L, 1, METATABLE_WSCLIENT); ws_info *ws = (ws_info *) luaL_checkudata(L, 1, METATABLE_WSCLIENT);
luaL_argcheck(L, ws, 1, "Client websocket expected");
ws_data *data = (ws_data *) ws->reservedData; ws_data *data = (ws_data *) ws->reservedData;
@ -225,7 +275,8 @@ static int websocketclient_gc(lua_State *L) {
NODE_DBG("websocketclient_gc\n"); NODE_DBG("websocketclient_gc\n");
ws_info *ws = (ws_info *) luaL_checkudata(L, 1, METATABLE_WSCLIENT); ws_info *ws = (ws_info *) luaL_checkudata(L, 1, METATABLE_WSCLIENT);
luaL_argcheck(L, ws, 1, "Client websocket expected");
ws->extraHeaders = realloc_headers(ws->extraHeaders, 0);
ws_data *data = (ws_data *) ws->reservedData; ws_data *data = (ws_data *) ws->reservedData;
@ -265,6 +316,7 @@ static const LUA_REG_TYPE websocket_map[] =
static const LUA_REG_TYPE websocketclient_map[] = static const LUA_REG_TYPE websocketclient_map[] =
{ {
{ LSTRKEY("on"), LFUNCVAL(websocketclient_on) }, { LSTRKEY("on"), LFUNCVAL(websocketclient_on) },
{ LSTRKEY("config"), LFUNCVAL(websocketclient_config) },
{ LSTRKEY("connect"), LFUNCVAL(websocketclient_connect) }, { LSTRKEY("connect"), LFUNCVAL(websocketclient_connect) },
{ LSTRKEY("send"), LFUNCVAL(websocketclient_send) }, { LSTRKEY("send"), LFUNCVAL(websocketclient_send) },
{ LSTRKEY("close"), LFUNCVAL(websocketclient_close) }, { LSTRKEY("close"), LFUNCVAL(websocketclient_close) },

View File

@ -47,18 +47,10 @@
#define PORT_INSECURE 80 #define PORT_INSECURE 80
#define PORT_MAX_VALUE 65535 #define PORT_MAX_VALUE 65535
// TODO: user agent configurable #define WS_INIT_REQUEST "GET %s HTTP/1.1\r\n"\
#define WS_INIT_HEADERS "GET %s HTTP/1.1\r\n"\ "Host: %s:%d\r\n"
"Host: %s:%d\r\n"\
"Upgrade: websocket\r\n"\
"Connection: Upgrade\r\n"\
"User-Agent: ESP8266\r\n"\
"Sec-Websocket-Key: %s\r\n"\
"Sec-WebSocket-Protocol: chat\r\n"\
"Sec-WebSocket-Version: 13\r\n"\
"\r\n"
#define WS_INIT_HEADERS_LENGTH 169 #define WS_INIT_REQUEST_LENGTH 30
#define WS_GUID "258EAFA5-E914-47DA-95CA-C5AB0DC85B11" #define WS_GUID "258EAFA5-E914-47DA-95CA-C5AB0DC85B11"
#define WS_GUID_LENGTH 36 #define WS_GUID_LENGTH 36
@ -77,6 +69,13 @@
#define WS_OPCODE_PING 0x9 #define WS_OPCODE_PING 0x9
#define WS_OPCODE_PONG 0xA #define WS_OPCODE_PONG 0xA
header_t DEFAULT_HEADERS[] = {
{"User-Agent", "ESP8266"},
{"Sec-WebSocket-Protocol", "chat"},
{0}
};
header_t *EMPTY_HEADERS = DEFAULT_HEADERS + sizeof(DEFAULT_HEADERS) / sizeof(header_t) - 1;
static char *cryptoSha1(char *data, unsigned int len) { static char *cryptoSha1(char *data, unsigned int len) {
SHA1_CTX ctx; SHA1_CTX ctx;
SHA1Init(&ctx); SHA1Init(&ctx);
@ -128,6 +127,44 @@ static void generateSecKeys(char **key, char **expectedKey) {
os_free(keyEncrypted); os_free(keyEncrypted);
} }
static char *_strcpy(char *dst, char *src) {
while(*dst++ = *src++);
return dst - 1;
}
static int headers_length(header_t *headers) {
int length = 0;
for(; headers->key; headers++)
length += strlen(headers->key) + strlen(headers->value) + 4;
return length;
}
static char *sprintf_headers(char *buf, ...) {
char *dst = buf;
va_list args;
va_start(args, buf);
for(header_t *header_set = va_arg(args, header_t *); header_set; header_set = va_arg(args, header_t *))
for(header_t *header = header_set; header->key; header++) {
va_list args2;
va_start(args2, buf);
for(header_t *header_set2 = va_arg(args2, header_t *); header_set2; header_set2 = va_arg(args2, header_t *))
for(header_t *header2 = header_set2; header2->key; header2++) {
if(header == header2)
goto ok;
if(!strcasecmp(header->key, header2->key))
goto skip;
}
ok:
dst = _strcpy(dst, header->key);
dst = _strcpy(dst, ": ");
dst = _strcpy(dst, header->value);
dst = _strcpy(dst, "\r\n");
skip:;
}
dst = _strcpy(dst, "\r\n");
return dst;
}
static void ws_closeSentCallback(void *arg) { static void ws_closeSentCallback(void *arg) {
NODE_DBG("ws_closeSentCallback \n"); NODE_DBG("ws_closeSentCallback \n");
struct espconn *conn = (struct espconn *) arg; struct espconn *conn = (struct espconn *) arg;
@ -452,7 +489,7 @@ static void ws_receiveCallback(void *arg, char *buf, unsigned short len) {
} else if (opCode == WS_OPCODE_PONG) { } else if (opCode == WS_OPCODE_PONG) {
// ping alarm was already reset... // ping alarm was already reset...
} else { } else {
if (ws->onReceive) ws->onReceive(ws, payload, opCode); if (ws->onReceive) ws->onReceive(ws, payloadLength, payload, opCode);
} }
os_free(payload); os_free(payload);
} }
@ -509,7 +546,7 @@ static void ws_initReceiveCallback(void *arg, char *buf, unsigned short len) {
} }
// Check server has valid sec key // Check server has valid sec key
if (strstr(buf, WS_HTTP_SEC_WEBSOCKET_ACCEPT) == NULL || strstr(buf, ws->expectedSecKey) == NULL) { if (strstr(buf, ws->expectedSecKey) == NULL) {
NODE_DBG("Server has invalid response\n"); NODE_DBG("Server has invalid response\n");
ws->knownFailureCode = -7; ws->knownFailureCode = -7;
if (ws->isSecure) if (ws->isSecure)
@ -550,12 +587,31 @@ static void connect_callback(void *arg) {
char *key; char *key;
generateSecKeys(&key, &ws->expectedSecKey); generateSecKeys(&key, &ws->expectedSecKey);
char buf[WS_INIT_HEADERS_LENGTH + strlen(ws->path) + strlen(ws->hostname) + strlen(key)]; header_t headers[] = {
int len = os_sprintf(buf, WS_INIT_HEADERS, ws->path, ws->hostname, ws->port, key); {"Upgrade", "websocket"},
{"Connection", "Upgrade"},
{"Sec-WebSocket-Key", key},
{"Sec-WebSocket-Version", "13"},
{0}
};
header_t *extraHeaders = ws->extraHeaders ? ws->extraHeaders : EMPTY_HEADERS;
char buf[WS_INIT_REQUEST_LENGTH + strlen(ws->path) + strlen(ws->hostname) +
headers_length(DEFAULT_HEADERS) + headers_length(headers) + headers_length(extraHeaders) + 2];
int len = os_sprintf(
buf,
WS_INIT_REQUEST,
ws->path,
ws->hostname,
ws->port
);
len = sprintf_headers(buf + len, headers, extraHeaders, DEFAULT_HEADERS, 0) - buf;
os_free(key); os_free(key);
NODE_DBG("request: %s", buf);
NODE_DBG("connecting\n");
if (ws->isSecure) if (ws->isSecure)
espconn_secure_send(conn, (uint8_t *) buf, len); espconn_secure_send(conn, (uint8_t *) buf, len);
else else
@ -630,7 +686,7 @@ static void dns_callback(const char *hostname, ip_addr_t *addr, void *arg) {
ws_info *ws = (ws_info *) conn->reverse; ws_info *ws = (ws_info *) conn->reverse;
if (ws->conn == NULL || ws->connectionState == 4) { if (ws->conn == NULL || ws->connectionState == 4) {
return; return;
} }
if (addr == NULL) { if (addr == NULL) {

View File

@ -40,9 +40,14 @@
struct ws_info; struct ws_info;
typedef void (*ws_onConnectionCallback)(struct ws_info *wsInfo); typedef void (*ws_onConnectionCallback)(struct ws_info *wsInfo);
typedef void (*ws_onReceiveCallback)(struct ws_info *wsInfo, char *message, int opCode); typedef void (*ws_onReceiveCallback)(struct ws_info *wsInfo, int len, char *message, int opCode);
typedef void (*ws_onFailureCallback)(struct ws_info *wsInfo, int errorCode); typedef void (*ws_onFailureCallback)(struct ws_info *wsInfo, int errorCode);
typedef struct {
char *key;
char *value;
} header_t;
typedef struct ws_info { typedef struct ws_info {
int connectionState; int connectionState;
@ -51,6 +56,7 @@ typedef struct ws_info {
int port; int port;
char *path; char *path;
char *expectedSecKey; char *expectedSecKey;
header_t *extraHeaders;
struct espconn *conn; struct espconn *conn;
void *reservedData; void *reservedData;

View File

@ -7,10 +7,6 @@ A websocket *client* module that implements [RFC6455](https://tools.ietf.org/htm
The implementation supports fragmented messages, automatically respondes to ping requests and periodically pings if the server isn't communicating. The implementation supports fragmented messages, automatically respondes to ping requests and periodically pings if the server isn't communicating.
!!! note
Currently, it is **not** possible to change the request headers, most notably the user agent.
**SSL/TLS support** **SSL/TLS support**
Take note of constraints documented in the [net module](net.md). Take note of constraints documented in the [net module](net.md).
@ -70,6 +66,27 @@ ws = nil -- fully dispose the client as lua will now gc it
``` ```
## websocket.client:config(params)
Configures websocket client instance.
#### Syntax
`websocket:config(params)`
#### Parameters
- `params` table with configuration parameters. Following keys are recognized:
- `headers` table of extra request headers affecting every request
#### Returns
`nil`
#### Example
```lua
ws = websocket.createClient()
ws:config({headers={['User-Agent']='NodeMCU'}})
```
## websocket.client:connect() ## websocket.client:connect()
Attempts to estabilish a websocket connection to the given URL. Attempts to estabilish a websocket connection to the given URL.