sdk-ameba-v4.0c_180328/component/common/network/websocket/wsclient_tls.c

245 lines
6 KiB
C
Raw Permalink Normal View History

2019-04-02 08:34:25 +00:00
#include "platform_opts.h"
#include <websocket/libwsclient.h>
#if (WSCLIENT_USE_TLS == WSCLIENT_TLS_POLARSSL)
#include "polarssl/net.h"
#include "polarssl/ssl.h"
#include <polarssl/memory.h>
struct wss_tls{
ssl_context ctx;
};
#elif (WSCLIENT_USE_TLS == WSCLIENT_TLS_MBEDTLS)
#include "mbedTLS/ssl.h"
#include "mbedtls/net_sockets.h"
struct wss_tls{
mbedtls_ssl_context ctx;
mbedtls_ssl_config conf;
mbedtls_net_context socket;
};
static void* my_calloc(size_t nelements, size_t elementSize){
size_t size;
void *ptr = NULL;
size = nelements * elementSize;
ptr = pvPortMalloc(size);
if(ptr)
memset(ptr, 0, size);
return ptr;
}
static char *ws_itoa(int value){
char *val_str;
int tmp = value, len = 1;
while((tmp /= 10) > 0)
len ++;
val_str = (char *) pvPortMalloc(len + 1);
sprintf(val_str, "%d", value);
return val_str;
}
#endif /* WSCLIENT_USE_TLS */
int ws_random(void *p_rng, unsigned char *output, size_t output_len);
void *wss_tls_connect(int *sock , char *host, int port){
#if (WSCLIENT_USE_TLS == WSCLIENT_TLS_POLARSSL)
int ret;
struct wss_tls *tls =NULL;
memory_set_own(pvPortMalloc, vPortFree);
tls = (struct wss_tls *) malloc(sizeof(struct wss_tls));
if(tls){
ssl_context *ssl = &tls->ctx;
memset(tls, 0, sizeof(struct wss_tls));
if((ret = net_connect(sock, host, port)) != 0){
printf("\n[WSCLIENT] ERROR: net_connect %d\n", ret);
goto exit;
}
if((ret = ssl_init(ssl)) != 0){
printf("\n[WSCLIENT] ERROR: ssl_init %d\n", ret);
goto exit;
}
ssl_set_endpoint(ssl, 0);
ssl_set_authmode(ssl, 0);
ssl_set_rng(ssl, ws_random, NULL);
ssl_set_bio(ssl, net_recv, sock, net_send, sock);
}
else{
printf("\n[WSCLIENT] ERROR: malloc\n");
ret = -1;
goto exit;
}
exit:
if(ret && tls) {
net_close(*sock);
ssl_free(&tls->ctx);
free(tls);
tls = NULL;
}
return (void *) tls;
#elif (WSCLIENT_USE_TLS == WSCLIENT_TLS_MBEDTLS)
int ret;
struct wss_tls *tls =NULL;
mbedtls_platform_set_calloc_free(my_calloc, vPortFree);
tls = (struct wss_tls *) malloc(sizeof(struct wss_tls));
if(tls){
mbedtls_ssl_context *ssl = &tls->ctx;
mbedtls_ssl_config *conf = &tls->conf;
mbedtls_net_context *server_fd = &tls->socket;
memset(tls, 0, sizeof(struct wss_tls));
server_fd->fd = *sock;
char *port_str = ws_itoa (port);
if((ret = mbedtls_net_connect(server_fd, host, port_str, MBEDTLS_NET_PROTO_TCP)) != 0){
printf("\n[WSCLIENT] ERROR: net_connect %d\n", ret);
goto exit;
}
free(port_str);
*sock = server_fd->fd;
mbedtls_ssl_init(ssl);
mbedtls_ssl_config_init(conf);
mbedtls_ssl_set_bio(ssl, server_fd, mbedtls_net_send, mbedtls_net_recv, NULL);
if((ret = mbedtls_ssl_config_defaults(conf,
MBEDTLS_SSL_IS_CLIENT,
MBEDTLS_SSL_TRANSPORT_STREAM,
MBEDTLS_SSL_PRESET_DEFAULT)) != 0) {
printf("\n[WSCLIENT] ERROR: ssl_config %d\n", ret);
goto exit;
}
mbedtls_ssl_conf_authmode(conf, MBEDTLS_SSL_VERIFY_NONE);
mbedtls_ssl_conf_rng(conf, ws_random, NULL);
if((ret = mbedtls_ssl_setup(ssl, conf)) != 0) {
printf("\n[WSCLIENT] ERROR: ssl_setup %d\n", ret);
goto exit;
}
}
else{
printf("\n[WSCLIENT] ERROR: malloc\n");
ret = -1;
goto exit;
}
exit:
if(ret && tls){
mbedtls_net_free(&tls->socket);
mbedtls_ssl_free(&tls->ctx);
mbedtls_ssl_config_free(&tls->conf);
free(tls);
tls = NULL;
}
return (void *) tls;
#endif /* WSCLIENT_USE_TLS */
}
int wss_tls_handshake(void *tls_in){
struct wss_tls *tls = (struct wss_tls *) tls_in;
#if (WSCLIENT_USE_TLS == WSCLIENT_TLS_POLARSSL)
int ret;
if((ret = ssl_handshake(&tls->ctx)) != 0) {
printf("\n[WSCLIENT] ERROR: ssl_handshake %d\n", ret);
ret = -1;
}
else {
printf("\n[WSCLIENT] Use ciphersuite %s\n", ssl_get_ciphersuite(&tls->ctx));
}
return ret;
#elif (WSCLIENT_USE_TLS == WSCLIENT_TLS_MBEDTLS)
int ret;
if((ret = mbedtls_ssl_handshake(&tls->ctx)) != 0) {
printf("\n[WSCLIENT] ERROR: ssl_handshake -0x%x\n", -ret);
ret = -1;
}
else {
printf("\n[WSCLIENT] Use ciphersuite %s\n", mbedtls_ssl_get_ciphersuite(&tls->ctx));
}
return ret;
#endif /* WSCLIENT_USE_TLS */
}
void wss_tls_close(void *tls_in,int *sock){
struct wss_tls *tls = (struct wss_tls *) tls_in;
#if (WSCLIENT_USE_TLS == WSCLIENT_TLS_POLARSSL)
if(tls)
ssl_close_notify(&tls->ctx);
if(*sock != -1){
net_close(*sock);
*sock = -1;
}
ssl_free(&tls->ctx);
free(tls);
tls = NULL;
#elif (WSCLIENT_USE_TLS == WSCLIENT_TLS_MBEDTLS)
if(tls)
mbedtls_ssl_close_notify(&tls->ctx);
if(*sock != -1){
mbedtls_net_free(&tls->socket);
*sock = -1;
}
mbedtls_ssl_free(&tls->ctx);
mbedtls_ssl_config_free(&tls->conf);
free(tls);
tls = NULL;
#endif /* WSCLIENT_USE_TLS */
}
int wss_tls_write(void *tls_in, char *request, int request_len){
int ret;
struct wss_tls *tls = (struct wss_tls *) tls_in;
#if (WSCLIENT_USE_TLS == WSCLIENT_TLS_POLARSSL)
ret = ssl_write(&tls->ctx, request, request_len);
if(ret == POLARSSL_ERR_NET_WANT_READ || ret == POLARSSL_ERR_NET_WANT_WRITE)
ret = 0;
#elif (WSCLIENT_USE_TLS == WSCLIENT_TLS_MBEDTLS)
ret = mbedtls_ssl_write(&tls->ctx, request, request_len);
if(ret == MBEDTLS_ERR_SSL_WANT_READ || ret == MBEDTLS_ERR_SSL_WANT_WRITE)
ret = 0;
#endif /* WSCLIENT_USE_TLS */
return ret;
}
int wss_tls_read(void *tls_in, char *buffer, int buf_len){
int ret;
struct wss_tls *tls = (struct wss_tls *) tls_in;
#if (WSCLIENT_USE_TLS == WSCLIENT_TLS_POLARSSL)
ret = ssl_read(&tls->ctx, buffer, buf_len);
if(ret == POLARSSL_ERR_NET_WANT_READ || ret == POLARSSL_ERR_NET_WANT_WRITE
|| ret == POLARSSL_ERR_NET_RECV_FAILED)
ret =0;
#elif (WSCLIENT_USE_TLS == WSCLIENT_TLS_MBEDTLS)
ret = mbedtls_ssl_read(&tls->ctx, buffer, buf_len);
if(ret == MBEDTLS_ERR_SSL_WANT_READ || ret == MBEDTLS_ERR_SSL_WANT_WRITE
|| ret == MBEDTLS_ERR_NET_RECV_FAILED)
ret =0;
#endif /* WSCLIENT_USE_TLS */
return ret;
}