diff --git a/app/modules/mqtt.c b/app/modules/mqtt.c index 16e6ce56..41c1c99f 100644 --- a/app/modules/mqtt.c +++ b/app/modules/mqtt.c @@ -1154,51 +1154,49 @@ static int mqtt_socket_subscribe( lua_State* L ) { NODE_DBG("subscribe table\n"); lua_pushnil( L ); /* first key */ - uint8_t temp_buf[MQTT_BUF_SIZE]; - uint32_t temp_pos = 0; + int topic_count = 0; uint8_t overflow = 0; while( lua_next( L, stack ) != 0 ) { topic = luaL_checkstring( L, -2 ); qos = luaL_checkinteger( L, -1 ); - temp_msg = mqtt_msg_subscribe( &mud->mqtt_state.mqtt_connection, topic, qos, &msg_id ); + if (topic_count == 0) { + temp_msg = mqtt_msg_subscribe_init( &mud->mqtt_state.mqtt_connection, &msg_id ); + } + temp_msg = mqtt_msg_subscribe_topic( &mud->mqtt_state.mqtt_connection, topic, qos ); + topic_count++; + NODE_DBG("topic: %s - qos: %d, length: %d\n", topic, qos, temp_msg->length); - if (temp_pos + temp_msg->length > MQTT_BUF_SIZE){ + if (temp_msg->length == 0) { lua_pop(L, 1); overflow = 1; break; // too long message for the outbuffer. } - c_memcpy( temp_buf + temp_pos, temp_msg->data, temp_msg->length ); - temp_pos += temp_msg->length; lua_pop( L, 1 ); } - if (temp_pos == 0){ - luaL_error( L, "invalid data" ); - lua_pushboolean(L, 0); - return 1; + if (topic_count == 0){ + return luaL_error( L, "no topics found" ); } if (overflow != 0){ - luaL_error( L, "buffer overflow, can't enqueue all subscriptions" ); - lua_pushboolean(L, 0); - return 1; + return luaL_error( L, "buffer overflow, can't enqueue all subscriptions" ); + } + + temp_msg = mqtt_msg_subscribe_fini( &mud->mqtt_state.mqtt_connection ); + if (temp_msg->length == 0) { + return luaL_error( L, "buffer overflow, can't enqueue all subscriptions" ); } - c_memcpy( temp_buffer, temp_buf, temp_pos ); - temp_msg->data = temp_buffer; - temp_msg->length = temp_pos; stack++; } else { NODE_DBG("subscribe string\n"); topic = luaL_checklstring( L, stack, &il ); stack++; if( topic == NULL ){ - luaL_error( L, "need topic name" ); - lua_pushboolean(L, 0); - return 1; + return luaL_error( L, "need topic name" ); } qos = luaL_checkinteger( L, stack ); temp_msg = mqtt_msg_subscribe( &mud->mqtt_state.mqtt_connection, topic, qos, &msg_id ); @@ -1257,17 +1255,13 @@ static int mqtt_socket_publish( lua_State* L ) } if(!mud->connected){ - luaL_error( L, "not connected" ); - lua_pushboolean(L, 0); - return 1; + return luaL_error( L, "not connected" ); } const char *topic = luaL_checklstring( L, stack, &l ); stack ++; if (topic == NULL){ - luaL_error( L, "need topic" ); - lua_pushboolean(L, 0); - return 1; + return luaL_error( L, "need topic" ); } const char *payload = luaL_checklstring( L, stack, &l ); diff --git a/app/mqtt/mqtt_msg.c b/app/mqtt/mqtt_msg.c index 9c405a7c..48c80e09 100644 --- a/app/mqtt/mqtt_msg.c +++ b/app/mqtt/mqtt_msg.c @@ -402,14 +402,19 @@ mqtt_message_t* mqtt_msg_pubcomp(mqtt_connection_t* connection, uint16_t message return fini_message(connection, MQTT_MSG_TYPE_PUBCOMP, 0, 0, 0); } -mqtt_message_t* mqtt_msg_subscribe(mqtt_connection_t* connection, const char* topic, int qos, uint16_t* message_id) +mqtt_message_t* mqtt_msg_subscribe_init(mqtt_connection_t* connection, uint16_t *message_id) { init_message(connection); - if(topic == NULL || topic[0] == '\0') + if((*message_id = append_message_id(connection, 0)) == 0) return fail_message(connection); - if((*message_id = append_message_id(connection, 0)) == 0) + return &connection->message; +} + +mqtt_message_t* mqtt_msg_subscribe_topic(mqtt_connection_t* connection, const char* topic, int qos) +{ + if(topic == NULL || topic[0] == '\0') return fail_message(connection); if(append_string(connection, topic, c_strlen(topic)) < 0) @@ -419,9 +424,29 @@ mqtt_message_t* mqtt_msg_subscribe(mqtt_connection_t* connection, const char* to return fail_message(connection); connection->buffer[connection->message.length++] = qos; + return &connection->message; +} + +mqtt_message_t* mqtt_msg_subscribe_fini(mqtt_connection_t* connection) +{ return fini_message(connection, MQTT_MSG_TYPE_SUBSCRIBE, 0, 1, 0); } +mqtt_message_t* mqtt_msg_subscribe(mqtt_connection_t* connection, const char* topic, int qos, uint16_t* message_id) +{ + mqtt_message_t* result; + + result = mqtt_msg_subscribe_init(connection, message_id); + if (result->length != 0) { + result = mqtt_msg_subscribe_topic(connection, topic, qos); + } + if (result->length != 0) { + result = mqtt_msg_subscribe_fini(connection); + } + + return result; +} + mqtt_message_t* mqtt_msg_unsubscribe(mqtt_connection_t* connection, const char* topic, uint16_t* message_id) { init_message(connection); diff --git a/app/mqtt/mqtt_msg.h b/app/mqtt/mqtt_msg.h index 225ba642..e1f372b9 100644 --- a/app/mqtt/mqtt_msg.h +++ b/app/mqtt/mqtt_msg.h @@ -120,6 +120,10 @@ mqtt_message_t* mqtt_msg_pingreq(mqtt_connection_t* connection); mqtt_message_t* mqtt_msg_pingresp(mqtt_connection_t* connection); mqtt_message_t* mqtt_msg_disconnect(mqtt_connection_t* connection); +mqtt_message_t* mqtt_msg_subscribe_init(mqtt_connection_t* connection, uint16_t* message_id); +mqtt_message_t* mqtt_msg_subscribe_topic(mqtt_connection_t* connection, const char* topic, int qos); +mqtt_message_t* mqtt_msg_subscribe_fini(mqtt_connection_t* connection); + #ifdef __cplusplus }