OTA: Add TFTP client mode, expand ota_basic example.

This commit is contained in:
Angus Gratton 2016-03-23 17:33:08 +11:00
parent a3956af4ca
commit e671927bd0
3 changed files with 162 additions and 6 deletions

View file

@ -18,6 +18,45 @@
#include "rboot.h" #include "rboot.h"
#include "rboot-api.h" #include "rboot-api.h"
#define TFTP_IMAGE_SERVER "192.168.1.23"
#define TFTP_IMAGE_FILENAME1 "firmware1.bin"
#define TFTP_IMAGE_FILENAME2 "firmware2.bin"
void tftp_client_task(void *pvParameters)
{
printf("TFTP client task starting...\n");
rboot_config conf;
conf = rboot_get_config();
int slot = (conf.current_rom + 1) % conf.count;
printf("Image will be saved in OTA slot %d.\n", slot);
if(slot == conf.current_rom) {
printf("FATAL ERROR: Only one OTA slot is configured!\n");
while(1) {}
}
/* Alternate between trying two different filenames. Probalby want to change this if making a practical
example!
Note: example will reboot into FILENAME1 if it is successfully downloaded, but FILENAME2 is ignored.
*/
while(1) {
printf("Downloading %s to slot %d...\n", TFTP_IMAGE_FILENAME1, slot);
int res = ota_tftp_download(TFTP_IMAGE_SERVER, TFTP_PORT, TFTP_IMAGE_FILENAME1, 1000, slot);
printf("ota_tftp_download %s result %d\n", TFTP_IMAGE_FILENAME1, res);
if(res == 0) {
printf("Rebooting into slot %d...\n", slot);
rboot_set_current_rom(slot);
sdk_system_restart();
}
vTaskDelay(5000 / portTICK_RATE_MS);
printf("Downloading %s to slot %d...\n", TFTP_IMAGE_FILENAME2, slot);
res = ota_tftp_download(TFTP_IMAGE_SERVER, TFTP_PORT, TFTP_IMAGE_FILENAME2, 1000, slot);
printf("ota_tftp_download %s result %d\n", TFTP_IMAGE_FILENAME2, res);
vTaskDelay(5000 / portTICK_RATE_MS);
}
}
void user_init(void) void user_init(void)
{ {
uart_set_baud(0, 115200); uart_set_baud(0, 115200);
@ -39,4 +78,5 @@ void user_init(void)
sdk_wifi_station_set_config(&config); sdk_wifi_station_set_config(&config);
ota_tftp_init_server(TFTP_PORT); ota_tftp_init_server(TFTP_PORT);
xTaskCreate(&tftp_client_task, (signed char *)"tftp_client", 1024, NULL, 2, NULL);
} }

View file

@ -28,6 +28,7 @@
#define TFTP_FIRMWARE_FILE "firmware.bin" #define TFTP_FIRMWARE_FILE "firmware.bin"
#define TFTP_OCTET_MODE "octet" /* non-case-sensitive */ #define TFTP_OCTET_MODE "octet" /* non-case-sensitive */
#define TFTP_OP_RRQ 1
#define TFTP_OP_WRQ 2 #define TFTP_OP_WRQ 2
#define TFTP_OP_DATA 3 #define TFTP_OP_DATA 3
#define TFTP_OP_ACK 4 #define TFTP_OP_ACK 4
@ -43,8 +44,9 @@
static void tftp_task(void *port_p); static void tftp_task(void *port_p);
static char *tftp_get_field(int field, struct netbuf *netbuf); static char *tftp_get_field(int field, struct netbuf *netbuf);
static err_t tftp_receive_data(struct netconn *nc, size_t write_offs, size_t limit_offs, size_t *received_len); static err_t tftp_receive_data(struct netconn *nc, size_t write_offs, size_t limit_offs, size_t *received_len, ip_addr_t *peer_addr, int peer_port);
static err_t tftp_send_ack(struct netconn *nc, int block); static err_t tftp_send_ack(struct netconn *nc, int block);
static err_t tftp_send_rrq(struct netconn *nc, const char *filename);
static void tftp_send_error(struct netconn *nc, int err_code, const char *err_msg); static void tftp_send_error(struct netconn *nc, int err_code, const char *err_msg);
void ota_tftp_init_server(int listen_port) void ota_tftp_init_server(int listen_port)
@ -52,6 +54,60 @@ void ota_tftp_init_server(int listen_port)
xTaskCreate(tftp_task, (signed char *)"tftpOTATask", 512, (void *)listen_port, 2, NULL); xTaskCreate(tftp_task, (signed char *)"tftpOTATask", 512, (void *)listen_port, 2, NULL);
} }
err_t ota_tftp_download(const char *server, int port, const char *filename, int timeout, int ota_slot)
{
rboot_config rboot_config = rboot_get_config();
/* Validate the OTA slot parameter */
if(rboot_config.current_rom == ota_slot || rboot_config.count <= ota_slot)
{
return ERR_VAL;
}
/* This is all we need to know from the rboot config - where we need
to write data to.
*/
uint32_t flash_offset = rboot_config.roms[ota_slot];
struct netconn *nc = netconn_new (NETCONN_UDP);
err_t err;
if(!nc) {
return ERR_IF;
}
netconn_set_recvtimeout(nc, timeout);
/* try to bind our client port as our local port,
or keep trying the next 10 ports after it */
int local_port = port-1;
do {
err = netconn_bind(nc, IP_ADDR_ANY, ++local_port);
} while(err == ERR_USE && local_port < port + 10);
if(err) {
netconn_delete(nc);
return err;
}
ip_addr_t addr;
err = netconn_gethostbyname(server, &addr);
if(err) {
netconn_delete(nc);
return err;
}
netconn_connect(nc, &addr, port);
err = tftp_send_rrq(nc, filename);
if(err) {
netconn_delete(nc);
return err;
}
size_t received_len;
err = tftp_receive_data(nc, flash_offset, flash_offset+MAX_IMAGE_SIZE, &received_len, &addr, port);
netconn_delete(nc);
return err;
}
static void tftp_task(void *listen_port) static void tftp_task(void *listen_port)
{ {
struct netconn *nc = netconn_new (NETCONN_UDP); struct netconn *nc = netconn_new (NETCONN_UDP);
@ -100,7 +156,7 @@ static void tftp_task(void *listen_port)
/* check mode */ /* check mode */
char *mode = tftp_get_field(1, netbuf); char *mode = tftp_get_field(1, netbuf);
if(!mode || strcmp("octet", mode)) { if(!mode || strcmp(TFTP_OCTET_MODE, mode)) {
tftp_send_error(nc, TFTP_ERR_ILLEGAL, "Mode must be octet/binary"); tftp_send_error(nc, TFTP_ERR_ILLEGAL, "Mode must be octet/binary");
free(mode); free(mode);
netbuf_delete(netbuf); netbuf_delete(netbuf);
@ -133,7 +189,8 @@ static void tftp_task(void *listen_port)
/* Finished WRQ phase, start TFTP data transfer */ /* Finished WRQ phase, start TFTP data transfer */
size_t received_len; size_t received_len;
int recv_err = tftp_receive_data(nc, conf.roms[slot], conf.roms[slot]+MAX_IMAGE_SIZE, &received_len); netconn_set_recvtimeout(nc, 10000);
int recv_err = tftp_receive_data(nc, conf.roms[slot], conf.roms[slot]+MAX_IMAGE_SIZE, &received_len, NULL, 0);
netconn_disconnect(nc); netconn_disconnect(nc);
printf("OTA TFTP receive data result %d bytes %d\r\n", recv_err, received_len); printf("OTA TFTP receive data result %d bytes %d\r\n", recv_err, received_len);
@ -182,20 +239,48 @@ static char *tftp_get_field(int field, struct netbuf *netbuf)
return result; return result;
} }
static err_t tftp_receive_data(struct netconn *nc, size_t write_offs, size_t limit_offs, size_t *received_len) #define TFTP_TIMEOUT_RETRANSMITS 10
static err_t tftp_receive_data(struct netconn *nc, size_t write_offs, size_t limit_offs, size_t *received_len, ip_addr_t *peer_addr, int peer_port)
{ {
*received_len = 0; *received_len = 0;
const int DATA_PACKET_SZ = 512 + 4; /*( packet size plus header */ const int DATA_PACKET_SZ = 512 + 4; /*( packet size plus header */
uint32_t start_offs = write_offs; uint32_t start_offs = write_offs;
int block = 1; int block = 1;
struct netbuf *netbuf; struct netbuf *netbuf = 0;
int retries = TFTP_TIMEOUT_RETRANSMITS;
while(1) while(1)
{ {
netconn_set_recvtimeout(nc, 10000); if(peer_addr) {
netconn_disconnect(nc);
}
err_t err = netconn_recv(nc, &netbuf); err_t err = netconn_recv(nc, &netbuf);
if(peer_addr) {
if(netbuf) {
/* For TFTP server, the UDP connection is already established. But for client,
we don't know what port the server is using until we see the first data
packet - so we connect here.
*/
netconn_connect(nc, netbuf_fromaddr(netbuf), netbuf_fromport(netbuf));
peer_addr = 0;
} else {
/* Otherwise, temporarily re-connect so we can send errors */
netconn_connect(nc, peer_addr, peer_port);
}
}
if(err == ERR_TIMEOUT) { if(err == ERR_TIMEOUT) {
if(retries-- > 0 && block > 1) {
/* Retransmit the last ACK, wait for repeat data block.
This doesn't work for the first block, have to time out and start again. */
tftp_send_ack(nc, block-1);
continue;
}
tftp_send_error(nc, TFTP_ERR_ILLEGAL, "Timeout"); tftp_send_error(nc, TFTP_ERR_ILLEGAL, "Timeout");
return ERR_TIMEOUT; return ERR_TIMEOUT;
} }
@ -225,6 +310,9 @@ static err_t tftp_receive_data(struct netconn *nc, size_t write_offs, size_t lim
} }
} }
/* Reset retry count if we got valid data */
retries = TFTP_TIMEOUT_RETRANSMITS;
if(write_offs % SECTOR_SIZE == 0) { if(write_offs % SECTOR_SIZE == 0) {
sdk_spi_flash_erase_sector(write_offs / SECTOR_SIZE); sdk_spi_flash_erase_sector(write_offs / SECTOR_SIZE);
} }
@ -325,3 +413,17 @@ static void tftp_send_error(struct netconn *nc, int err_code, const char *err_ms
netconn_send(nc, err); netconn_send(nc, err);
netbuf_delete(err); netbuf_delete(err);
} }
static err_t tftp_send_rrq(struct netconn *nc, const char *filename)
{
struct netbuf *rrqbuf = netbuf_new();
uint16_t *rrqdata = (uint16_t *)netbuf_alloc(rrqbuf, 4 + strlen(filename) + strlen(TFTP_OCTET_MODE));
rrqdata[0] = htons(TFTP_OP_RRQ);
char *rrq_filename = (char *)&rrqdata[1];
strcpy(rrq_filename, filename);
strcpy(rrq_filename + strlen(filename) + 1, TFTP_OCTET_MODE);
err_t err = netconn_send(nc, rrqbuf);
netbuf_delete(rrqbuf);
return err;
}

View file

@ -1,5 +1,8 @@
#ifndef _OTA_TFTP_H #ifndef _OTA_TFTP_H
#define _OTA_TFTP_H #define _OTA_TFTP_H
#include "lwip/err.h"
/* TFTP Server OTA Support /* TFTP Server OTA Support
* *
* To use, call ota_tftp_init_server() which will start the TFTP server task * To use, call ota_tftp_init_server() which will start the TFTP server task
@ -30,6 +33,17 @@
*/ */
void ota_tftp_init_server(int listen_port); void ota_tftp_init_server(int listen_port);
/* Attempt to make a TFTP client connection and download the specified filename.
'timeout' is in milliseconds, and is timeout for any UDP exchange
_not_ the entire download.
Returns 0 on success, LWIP err.h values for errors.
Does not change the current firmware slot, or reboot.
*/
err_t ota_tftp_download(const char *server, int port, const char *filename, int timeout, int ota_slot);
#define TFTP_PORT 69 #define TFTP_PORT 69
#endif #endif