diff options
-rw-r--r-- | .gitignore | 2 | ||||
-rw-r--r-- | Makefile | 12 | ||||
-rw-r--r-- | websocket.c | 266 | ||||
-rw-r--r-- | websocket.h | 22 | ||||
-rw-r--r-- | wsproxy.c | 269 |
5 files changed, 571 insertions, 0 deletions
@@ -1 +1,3 @@ *.pyc +*.o +wsproxy diff --git a/Makefile b/Makefile new file mode 100644 index 0000000..bd824f8 --- /dev/null +++ b/Makefile @@ -0,0 +1,12 @@ +wsproxy: wsproxy.o websocket.o + $(CC) $^ -l ssl -l resolv -o $@ + +#websocket.o: websocket.c +# $(CC) -c $^ -o $@ +# +#wsproxy.o: wsproxy.c +# $(CC) -c $^ -o $@ + +clean: + rm -f wsproxy wsproxy.o websocket.o + diff --git a/websocket.c b/websocket.c new file mode 100644 index 0000000..23e5803 --- /dev/null +++ b/websocket.c @@ -0,0 +1,266 @@ +/* + * WebSocket lib with support for "wss://" encryption. + * + * You can make a cert/key with openssl using: + * openssl req -new -x509 -days 365 -nodes -out self.pem -keyout self.pem + * as taken from http://docs.python.org/dev/library/ssl.html#certificates + */ +#include <stdio.h> +#include <stdlib.h> +#include <errno.h> +#include <strings.h> +#include <sys/types.h> +#include <sys/socket.h> +#include <netinet/in.h> +#include <arpa/inet.h> +#include <openssl/err.h> +#include <openssl/ssl.h> +#include "websocket.h" + +const char server_handshake[] = "HTTP/1.1 101 Web Socket Protocol Handshake\r\n\ +Upgrade: WebSocket\r\n\ +Connection: Upgrade\r\n\ +WebSocket-Origin: %s\r\n\ +WebSocket-Location: %s://%s%s\r\n\ +WebSocket-Protocol: sample\r\n\ +\r\n"; + +const char policy_response[] = "<cross-domain-policy><allow-access-from domain=\"*\" to-ports=\"*\" /></cross-domain-policy>\n"; + +void traffic(char * token) { + fprintf(stdout, "%s", token); + fflush(stdout); +} + +void error(char *msg) +{ + perror(msg); +} + +void fatal(char *msg) +{ + perror(msg); + exit(1); +} + +/* + * SSL Wrapper Code + */ + +/* Warning: not thread safe */ +int ssl_initialized = 0; + +ssize_t ws_recv(ws_ctx_t *ctx, void *buf, size_t len) { + if (ctx->ssl) { + //printf("SSL recv\n"); + return SSL_read(ctx->ssl, buf, len); + } else { + return recv(ctx->sockfd, buf, len, 0); + } +} + +ssize_t ws_send(ws_ctx_t *ctx, const void *buf, size_t len) { + if (ctx->ssl) { + //printf("SSL send\n"); + return SSL_write(ctx->ssl, buf, len); + } else { + return send(ctx->sockfd, buf, len, 0); + } +} + +ws_ctx_t *ws_socket(int socket) { + ws_ctx_t *ctx; + ctx = malloc(sizeof(ws_ctx_t)); + ctx->sockfd = socket; + ctx->ssl = NULL; + ctx->ssl_ctx = NULL; + return ctx; +} + +ws_ctx_t *ws_socket_ssl(int socket, char * certfile) { + int ret; + char msg[1024]; + ws_ctx_t *ctx; + ctx = ws_socket(socket); + + // Initialize the library + if (! ssl_initialized) { + SSL_library_init(); + OpenSSL_add_all_algorithms(); + SSL_load_error_strings(); + ssl_initialized = 1; + + } + + ctx->ssl_ctx = SSL_CTX_new(TLSv1_server_method()); + if (ctx->ssl_ctx == NULL) { + ERR_print_errors_fp(stderr); + fatal("Failed to configure SSL context"); + } + + if (SSL_CTX_use_PrivateKey_file(ctx->ssl_ctx, certfile, + SSL_FILETYPE_PEM) <= 0) { + sprintf(msg, "Unable to load private key file %s\n", certfile); + fatal(msg); + } + + if (SSL_CTX_use_certificate_file(ctx->ssl_ctx, certfile, + SSL_FILETYPE_PEM) <= 0) { + sprintf(msg, "Unable to load certificate file %s\n", certfile); + fatal(msg); + } + +// if (SSL_CTX_set_cipher_list(ctx->ssl_ctx, "DEFAULT") != 1) { +// sprintf(msg, "Unable to set cipher\n"); +// fatal(msg); +// } + + // Associate socket and ssl object + ctx->ssl = SSL_new(ctx->ssl_ctx); + SSL_set_fd(ctx->ssl, socket); + + ret = SSL_accept(ctx->ssl); + if (ret < 0) { + ERR_print_errors_fp(stderr); + return NULL; + } + + return ctx; +} + +int ws_socket_free(ws_ctx_t *ctx) { + if (ctx->ssl) { + SSL_free(ctx->ssl); + ctx->ssl = NULL; + } + if (ctx->ssl_ctx) { + SSL_CTX_free(ctx->ssl_ctx); + ctx->ssl_ctx = NULL; + } + if (ctx->sockfd) { + close(ctx->sockfd); + ctx->sockfd = 0; + } + free(ctx); +} + +/* ------------------------------------------------------- */ + + +ws_ctx_t *do_handshake(int sock, client_settings_t *client_settings) { + char handshake[4096], response[4096]; + char *scheme, *line, *path, *host, *origin; + char *args_start, *args_end, *arg_idx; + int len; + ws_ctx_t * ws_ctx; + + // Reset settings + client_settings->b64encode = 0; + client_settings->seq_num = 0; + + len = recv(sock, handshake, 1024, MSG_PEEK); + handshake[len] = 0; + if (bcmp(handshake, "<policy-file-request/>", 22) == 0) { + len = recv(sock, handshake, 1024, 0); + handshake[len] = 0; + printf("Sending flash policy response\n"); + send(sock, policy_response, sizeof(policy_response), 0); + close(sock); + return NULL; + } else if (bcmp(handshake, "\x16", 1) == 0) { + // SSL + ws_ctx = ws_socket_ssl(sock, "self.pem"); + if (! ws_ctx) { return NULL; } + scheme = "wss"; + printf("Using SSL socket\n"); + } else { + ws_ctx = ws_socket(sock); + if (! ws_ctx) { return NULL; } + scheme = "ws"; + printf("Using plain (not SSL) socket\n"); + } + len = ws_recv(ws_ctx, handshake, 4096); + handshake[len] = 0; + //printf("handshake: %s\n", handshake); + if ((len < 92) || (bcmp(handshake, "GET ", 4) != 0)) { + fprintf(stderr, "Invalid WS request\n"); + return NULL; + } + strtok(handshake, " "); // Skip "GET " + path = strtok(NULL, " "); // Extract path + strtok(NULL, "\n"); // Skip to Upgrade line + strtok(NULL, "\n"); // Skip to Connection line + strtok(NULL, "\n"); // Skip to Host line + strtok(NULL, " "); // Skip "Host: " + host = strtok(NULL, "\r"); // Extract host + strtok(NULL, " "); // Skip "Origin: " + origin = strtok(NULL, "\r"); // Extract origin + + //printf("path: %s\n", path); + //printf("host: %s\n", host); + //printf("origin: %s\n", origin); + + // TODO: parse out client settings + args_start = strstr(path, "?"); + if (args_start) { + if (strstr(args_start, "#")) { + args_end = strstr(args_start, "#"); + } else { + args_end = args_start + strlen(args_start); + } + arg_idx = strstr(args_start, "b64encode"); + if (arg_idx && arg_idx < args_end) { + //printf("setting b64encode\n"); + client_settings->b64encode = 1; + } + arg_idx = strstr(args_start, "seq_num"); + if (arg_idx && arg_idx < args_end) { + //printf("setting seq_num\n"); + client_settings->seq_num = 1; + } + } + + sprintf(response, server_handshake, origin, scheme, host, path); + printf("response: %s\n", response); + ws_send(ws_ctx, response, strlen(response)); + + return ws_ctx; +} + +void start_server(int listen_port, + void (*handler)(ws_ctx_t*), + client_settings_t *client_settings) { + int lsock, csock, clilen, sopt = 1; + struct sockaddr_in serv_addr, cli_addr; + ws_ctx_t *ws_ctx; + + lsock = socket(AF_INET, SOCK_STREAM, 0); + if (lsock < 0) { error("ERROR creating listener socket"); } + bzero((char *) &serv_addr, sizeof(serv_addr)); + serv_addr.sin_family = AF_INET; + serv_addr.sin_addr.s_addr = INADDR_ANY; + serv_addr.sin_port = htons(listen_port); + setsockopt(lsock, SOL_SOCKET, SO_REUSEADDR, (char *)&sopt, sizeof(sopt)); + if (bind(lsock, (struct sockaddr *) &serv_addr, sizeof(serv_addr)) < 0) { + error("ERROR on binding listener socket"); + } + listen(lsock,100); + + while (1) { + clilen = sizeof(cli_addr); + printf("waiting for connection on port %d\n", listen_port); + csock = accept(lsock, + (struct sockaddr *) &cli_addr, + &clilen); + if (csock < 0) { + error("ERROR on accept"); + } + printf("Got client connection from %s\n", inet_ntoa(cli_addr.sin_addr)); + ws_ctx = do_handshake(csock, client_settings); + if (ws_ctx == NULL) { continue; } + handler(ws_ctx); + close(csock); + } + +} + diff --git a/websocket.h b/websocket.h new file mode 100644 index 0000000..f9512d9 --- /dev/null +++ b/websocket.h @@ -0,0 +1,22 @@ +#include <openssl/ssl.h> + +typedef struct { + int sockfd; + SSL_CTX *ssl_ctx; + SSL *ssl; +} ws_ctx_t; + +typedef struct { + int b64encode; + int seq_num; +} client_settings_t; + + +ssize_t ws_recv(ws_ctx_t *ctx, void *buf, size_t len); + +ssize_t ws_send(ws_ctx_t *ctx, const void *buf, size_t len); + +/* base64.c declarations */ +//int b64_ntop(u_char const *src, size_t srclength, char *target, size_t targsize); +//int b64_pton(char const *src, u_char *target, size_t targsize); + diff --git a/wsproxy.c b/wsproxy.c new file mode 100644 index 0000000..4a1c17b --- /dev/null +++ b/wsproxy.c @@ -0,0 +1,269 @@ +/* + * A WebSocket to TCP socket proxy with support for "wss://" encryption. + * + * You can make a cert/key with openssl using: + * openssl req -new -x509 -days 365 -nodes -out self.pem -keyout self.pem + * as taken from http://docs.python.org/dev/library/ssl.html#certificates + */ +#include <stdio.h> +#include <errno.h> +#include <sys/socket.h> +#include <netinet/in.h> +#include <netdb.h> +#include <sys/select.h> +#include <resolv.h> +#include <fcntl.h> +#include <sys/stat.h> +#include "websocket.h" + +char traffic_legend[] = "\n\ +Traffic Legend:\n\ + } - Client receive\n\ + }. - Client receive partial\n\ + { - Target receive\n\ +\n\ + > - Target send\n\ + >. - Target send partial\n\ + < - Client send\n\ + <. - Client send partial\n\ +"; + +void usage() { + fprintf(stderr,"Usage: <listen_port> <target_host> <target_port>\n"); + exit(1); +} + +char *target_host; +int target_port; +client_settings_t client_settings; +char *record_filename = NULL; +int recordfd = 0; +char *tbuf, *cbuf, *tbuf_tmp, *cbuf_tmp; +unsigned int bufsize, dbufsize; + +void do_proxy(ws_ctx_t *ws_ctx, int target) { + fd_set rlist, wlist, elist; + struct timeval tv; + int maxfd, client = ws_ctx->sockfd; + unsigned int tstart, tend, cstart, cend, ret; + ssize_t len, bytes; + + tstart = tend = cstart = cend = 0; + maxfd = client > target ? client+1 : target+1; + // Account for base64 encoding and WebSocket delims: + // 49150 = 65536 * 3/4 + 2 - 1 + + while (1) { + tv.tv_sec = 1; + tv.tv_usec = 0; + + FD_ZERO(&rlist); + FD_ZERO(&wlist); + FD_ZERO(&elist); + + FD_SET(client, &elist); + FD_SET(target, &elist); + + if (tend == tstart) { + // Nothing queued for target, so read from client + FD_SET(client, &rlist); + } else { + // Data queued for target, so write to it + FD_SET(target, &wlist); + } + if (cend == cstart) { + // Nothing queued for client, so read from target + FD_SET(target, &rlist); + } else { + // Data queued for client, so write to it + FD_SET(client, &wlist); + } + + ret = select(maxfd, &rlist, &wlist, &elist, &tv); + + if (FD_ISSET(target, &elist)) { + fprintf(stderr, "target exception\n"); + break; + } + if (FD_ISSET(client, &elist)) { + fprintf(stderr, "client exception\n"); + break; + } + + if (ret == -1) { + error("select()"); + break; + } else if (ret == 0) { + //fprintf(stderr, "select timeout\n"); + continue; + } + + if (FD_ISSET(target, &wlist)) { + len = tend-tstart; + bytes = send(target, tbuf + tstart, len, 0); + if (bytes < 0) { + error("target connection error"); + break; + } + tstart += bytes; + if (tstart >= tend) { + tstart = tend = 0; + traffic(">"); + } else { + traffic(">."); + } + } + + if (FD_ISSET(client, &wlist)) { + len = cend-cstart; + bytes = ws_send(ws_ctx, cbuf + cstart, len); + if (len < 3) { + fprintf(stderr, "len: %d, bytes: %d: %d\n", len, bytes, *(cbuf + cstart)); + } + cstart += bytes; + if (cstart >= cend) { + cstart = cend = 0; + traffic("<"); + if (recordfd) { + write(recordfd, "'>", 2); + write(recordfd, cbuf + cstart + 1, bytes - 2); + write(recordfd, "',\n", 3); + } + } else { + traffic("<."); + } + } + + if (FD_ISSET(target, &rlist)) { + bytes = recv(target, cbuf_tmp, dbufsize , 0); + if (bytes <= 0) { + error("target closed connection"); + break; + } + cbuf[0] = '\x00'; + cstart = 0; + len = b64_ntop(cbuf_tmp, bytes, cbuf+1, bufsize-1); + if (len < 0) { + fprintf(stderr, "base64 encoding error\n"); + break; + } + cbuf[len+1] = '\xff'; + cend = len+1+1; + traffic("{"); + } + + if (FD_ISSET(client, &rlist)) { + bytes = ws_recv(ws_ctx, tbuf_tmp, bufsize-1); + if (bytes <= 0) { + fprintf(stderr, "client closed connection\n"); + break; + } + if (tbuf_tmp[bytes-1] != '\xff') { + //traffic(".}"); + fprintf(stderr, "Malformed packet\n"); + break; + } + if (recordfd) { + write(recordfd, "'", 1); + write(recordfd, tbuf_tmp + 1, bytes - 2); + write(recordfd, "',\n", 3); + } + tbuf_tmp[bytes-1] = '\0'; + len = b64_pton(tbuf_tmp+1, tbuf, bufsize-1); + if (len < 0) { + fprintf(stderr, "base64 decoding error\n"); + break; + } + traffic("}"); + tstart = 0; + tend = len; + } + } +} + +void proxy_handler(ws_ctx_t *ws_ctx) { + int tsock = 0; + struct sockaddr_in taddr; + struct hostent *thost; + + printf("Connecting to: %s:%d\n", target_host, target_port); + + if (client_settings.b64encode) { + dbufsize = (bufsize * 3)/4 + 2 - 10; // padding and for good measure + } else { + } + + tsock = socket(AF_INET, SOCK_STREAM, 0); + if (tsock < 0) { + error("Could not create target socket"); + return; + } + thost = gethostbyname(target_host); + if (thost == NULL) { + error("Could not resolve server"); + close(tsock); + return; + } + bzero((char *) &taddr, sizeof(taddr)); + taddr.sin_family = AF_INET; + bcopy((char *) thost->h_addr, + (char *) &taddr.sin_addr.s_addr, + thost->h_length); + taddr.sin_port = htons(target_port); + + if (connect(tsock, (struct sockaddr *) &taddr, sizeof(taddr)) < 0) { + error("Could not connect to target"); + close(tsock); + return; + } + + if (record_filename) { + recordfd = open(record_filename, O_WRONLY | O_CREAT | O_TRUNC, + S_IRUSR | S_IWUSR | S_IRGRP | S_IROTH); + } + + printf("%s", traffic_legend); + + do_proxy(ws_ctx, tsock); + + close(tsock); + if (recordfd) { + close(recordfd); + recordfd = 0; + } +} + +int main(int argc, char *argv[]) +{ + int listen_port, idx=1; + + if (strcmp(argv[idx], "--record") == 0) { + idx++; + record_filename = argv[idx++]; + } + + if ((argc-idx) != 3) { usage(); } + listen_port = strtol(argv[idx++], NULL, 10); + if (errno != 0) { usage(); } + target_host = argv[idx++]; + target_port = strtol(argv[idx++], NULL, 10); + if (errno != 0) { usage(); } + + /* Initialize buffers */ + bufsize = 65536; + if (! (tbuf = malloc(bufsize)) ) + { fatal("malloc()"); } + if (! (cbuf = malloc(bufsize)) ) + { fatal("malloc()"); } + if (! (tbuf_tmp = malloc(bufsize)) ) + { fatal("malloc()"); } + if (! (cbuf_tmp = malloc(bufsize)) ) + { fatal("malloc()"); } + + start_server(listen_port, &proxy_handler, &client_settings); + + free(tbuf); + free(cbuf); + free(tbuf_tmp); + free(cbuf_tmp); +} |