summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--.gitignore2
-rw-r--r--Makefile12
-rw-r--r--websocket.c266
-rw-r--r--websocket.h22
-rw-r--r--wsproxy.c269
5 files changed, 571 insertions, 0 deletions
diff --git a/.gitignore b/.gitignore
index 0d20b64..b8e9884 100644
--- a/.gitignore
+++ b/.gitignore
@@ -1 +1,3 @@
*.pyc
+*.o
+wsproxy
diff --git a/Makefile b/Makefile
new file mode 100644
index 0000000..bd824f8
--- /dev/null
+++ b/Makefile
@@ -0,0 +1,12 @@
+wsproxy: wsproxy.o websocket.o
+ $(CC) $^ -l ssl -l resolv -o $@
+
+#websocket.o: websocket.c
+# $(CC) -c $^ -o $@
+#
+#wsproxy.o: wsproxy.c
+# $(CC) -c $^ -o $@
+
+clean:
+ rm -f wsproxy wsproxy.o websocket.o
+
diff --git a/websocket.c b/websocket.c
new file mode 100644
index 0000000..23e5803
--- /dev/null
+++ b/websocket.c
@@ -0,0 +1,266 @@
+/*
+ * WebSocket lib with support for "wss://" encryption.
+ *
+ * You can make a cert/key with openssl using:
+ * openssl req -new -x509 -days 365 -nodes -out self.pem -keyout self.pem
+ * as taken from http://docs.python.org/dev/library/ssl.html#certificates
+ */
+#include <stdio.h>
+#include <stdlib.h>
+#include <errno.h>
+#include <strings.h>
+#include <sys/types.h>
+#include <sys/socket.h>
+#include <netinet/in.h>
+#include <arpa/inet.h>
+#include <openssl/err.h>
+#include <openssl/ssl.h>
+#include "websocket.h"
+
+const char server_handshake[] = "HTTP/1.1 101 Web Socket Protocol Handshake\r\n\
+Upgrade: WebSocket\r\n\
+Connection: Upgrade\r\n\
+WebSocket-Origin: %s\r\n\
+WebSocket-Location: %s://%s%s\r\n\
+WebSocket-Protocol: sample\r\n\
+\r\n";
+
+const char policy_response[] = "<cross-domain-policy><allow-access-from domain=\"*\" to-ports=\"*\" /></cross-domain-policy>\n";
+
+void traffic(char * token) {
+ fprintf(stdout, "%s", token);
+ fflush(stdout);
+}
+
+void error(char *msg)
+{
+ perror(msg);
+}
+
+void fatal(char *msg)
+{
+ perror(msg);
+ exit(1);
+}
+
+/*
+ * SSL Wrapper Code
+ */
+
+/* Warning: not thread safe */
+int ssl_initialized = 0;
+
+ssize_t ws_recv(ws_ctx_t *ctx, void *buf, size_t len) {
+ if (ctx->ssl) {
+ //printf("SSL recv\n");
+ return SSL_read(ctx->ssl, buf, len);
+ } else {
+ return recv(ctx->sockfd, buf, len, 0);
+ }
+}
+
+ssize_t ws_send(ws_ctx_t *ctx, const void *buf, size_t len) {
+ if (ctx->ssl) {
+ //printf("SSL send\n");
+ return SSL_write(ctx->ssl, buf, len);
+ } else {
+ return send(ctx->sockfd, buf, len, 0);
+ }
+}
+
+ws_ctx_t *ws_socket(int socket) {
+ ws_ctx_t *ctx;
+ ctx = malloc(sizeof(ws_ctx_t));
+ ctx->sockfd = socket;
+ ctx->ssl = NULL;
+ ctx->ssl_ctx = NULL;
+ return ctx;
+}
+
+ws_ctx_t *ws_socket_ssl(int socket, char * certfile) {
+ int ret;
+ char msg[1024];
+ ws_ctx_t *ctx;
+ ctx = ws_socket(socket);
+
+ // Initialize the library
+ if (! ssl_initialized) {
+ SSL_library_init();
+ OpenSSL_add_all_algorithms();
+ SSL_load_error_strings();
+ ssl_initialized = 1;
+
+ }
+
+ ctx->ssl_ctx = SSL_CTX_new(TLSv1_server_method());
+ if (ctx->ssl_ctx == NULL) {
+ ERR_print_errors_fp(stderr);
+ fatal("Failed to configure SSL context");
+ }
+
+ if (SSL_CTX_use_PrivateKey_file(ctx->ssl_ctx, certfile,
+ SSL_FILETYPE_PEM) <= 0) {
+ sprintf(msg, "Unable to load private key file %s\n", certfile);
+ fatal(msg);
+ }
+
+ if (SSL_CTX_use_certificate_file(ctx->ssl_ctx, certfile,
+ SSL_FILETYPE_PEM) <= 0) {
+ sprintf(msg, "Unable to load certificate file %s\n", certfile);
+ fatal(msg);
+ }
+
+// if (SSL_CTX_set_cipher_list(ctx->ssl_ctx, "DEFAULT") != 1) {
+// sprintf(msg, "Unable to set cipher\n");
+// fatal(msg);
+// }
+
+ // Associate socket and ssl object
+ ctx->ssl = SSL_new(ctx->ssl_ctx);
+ SSL_set_fd(ctx->ssl, socket);
+
+ ret = SSL_accept(ctx->ssl);
+ if (ret < 0) {
+ ERR_print_errors_fp(stderr);
+ return NULL;
+ }
+
+ return ctx;
+}
+
+int ws_socket_free(ws_ctx_t *ctx) {
+ if (ctx->ssl) {
+ SSL_free(ctx->ssl);
+ ctx->ssl = NULL;
+ }
+ if (ctx->ssl_ctx) {
+ SSL_CTX_free(ctx->ssl_ctx);
+ ctx->ssl_ctx = NULL;
+ }
+ if (ctx->sockfd) {
+ close(ctx->sockfd);
+ ctx->sockfd = 0;
+ }
+ free(ctx);
+}
+
+/* ------------------------------------------------------- */
+
+
+ws_ctx_t *do_handshake(int sock, client_settings_t *client_settings) {
+ char handshake[4096], response[4096];
+ char *scheme, *line, *path, *host, *origin;
+ char *args_start, *args_end, *arg_idx;
+ int len;
+ ws_ctx_t * ws_ctx;
+
+ // Reset settings
+ client_settings->b64encode = 0;
+ client_settings->seq_num = 0;
+
+ len = recv(sock, handshake, 1024, MSG_PEEK);
+ handshake[len] = 0;
+ if (bcmp(handshake, "<policy-file-request/>", 22) == 0) {
+ len = recv(sock, handshake, 1024, 0);
+ handshake[len] = 0;
+ printf("Sending flash policy response\n");
+ send(sock, policy_response, sizeof(policy_response), 0);
+ close(sock);
+ return NULL;
+ } else if (bcmp(handshake, "\x16", 1) == 0) {
+ // SSL
+ ws_ctx = ws_socket_ssl(sock, "self.pem");
+ if (! ws_ctx) { return NULL; }
+ scheme = "wss";
+ printf("Using SSL socket\n");
+ } else {
+ ws_ctx = ws_socket(sock);
+ if (! ws_ctx) { return NULL; }
+ scheme = "ws";
+ printf("Using plain (not SSL) socket\n");
+ }
+ len = ws_recv(ws_ctx, handshake, 4096);
+ handshake[len] = 0;
+ //printf("handshake: %s\n", handshake);
+ if ((len < 92) || (bcmp(handshake, "GET ", 4) != 0)) {
+ fprintf(stderr, "Invalid WS request\n");
+ return NULL;
+ }
+ strtok(handshake, " "); // Skip "GET "
+ path = strtok(NULL, " "); // Extract path
+ strtok(NULL, "\n"); // Skip to Upgrade line
+ strtok(NULL, "\n"); // Skip to Connection line
+ strtok(NULL, "\n"); // Skip to Host line
+ strtok(NULL, " "); // Skip "Host: "
+ host = strtok(NULL, "\r"); // Extract host
+ strtok(NULL, " "); // Skip "Origin: "
+ origin = strtok(NULL, "\r"); // Extract origin
+
+ //printf("path: %s\n", path);
+ //printf("host: %s\n", host);
+ //printf("origin: %s\n", origin);
+
+ // TODO: parse out client settings
+ args_start = strstr(path, "?");
+ if (args_start) {
+ if (strstr(args_start, "#")) {
+ args_end = strstr(args_start, "#");
+ } else {
+ args_end = args_start + strlen(args_start);
+ }
+ arg_idx = strstr(args_start, "b64encode");
+ if (arg_idx && arg_idx < args_end) {
+ //printf("setting b64encode\n");
+ client_settings->b64encode = 1;
+ }
+ arg_idx = strstr(args_start, "seq_num");
+ if (arg_idx && arg_idx < args_end) {
+ //printf("setting seq_num\n");
+ client_settings->seq_num = 1;
+ }
+ }
+
+ sprintf(response, server_handshake, origin, scheme, host, path);
+ printf("response: %s\n", response);
+ ws_send(ws_ctx, response, strlen(response));
+
+ return ws_ctx;
+}
+
+void start_server(int listen_port,
+ void (*handler)(ws_ctx_t*),
+ client_settings_t *client_settings) {
+ int lsock, csock, clilen, sopt = 1;
+ struct sockaddr_in serv_addr, cli_addr;
+ ws_ctx_t *ws_ctx;
+
+ lsock = socket(AF_INET, SOCK_STREAM, 0);
+ if (lsock < 0) { error("ERROR creating listener socket"); }
+ bzero((char *) &serv_addr, sizeof(serv_addr));
+ serv_addr.sin_family = AF_INET;
+ serv_addr.sin_addr.s_addr = INADDR_ANY;
+ serv_addr.sin_port = htons(listen_port);
+ setsockopt(lsock, SOL_SOCKET, SO_REUSEADDR, (char *)&sopt, sizeof(sopt));
+ if (bind(lsock, (struct sockaddr *) &serv_addr, sizeof(serv_addr)) < 0) {
+ error("ERROR on binding listener socket");
+ }
+ listen(lsock,100);
+
+ while (1) {
+ clilen = sizeof(cli_addr);
+ printf("waiting for connection on port %d\n", listen_port);
+ csock = accept(lsock,
+ (struct sockaddr *) &cli_addr,
+ &clilen);
+ if (csock < 0) {
+ error("ERROR on accept");
+ }
+ printf("Got client connection from %s\n", inet_ntoa(cli_addr.sin_addr));
+ ws_ctx = do_handshake(csock, client_settings);
+ if (ws_ctx == NULL) { continue; }
+ handler(ws_ctx);
+ close(csock);
+ }
+
+}
+
diff --git a/websocket.h b/websocket.h
new file mode 100644
index 0000000..f9512d9
--- /dev/null
+++ b/websocket.h
@@ -0,0 +1,22 @@
+#include <openssl/ssl.h>
+
+typedef struct {
+ int sockfd;
+ SSL_CTX *ssl_ctx;
+ SSL *ssl;
+} ws_ctx_t;
+
+typedef struct {
+ int b64encode;
+ int seq_num;
+} client_settings_t;
+
+
+ssize_t ws_recv(ws_ctx_t *ctx, void *buf, size_t len);
+
+ssize_t ws_send(ws_ctx_t *ctx, const void *buf, size_t len);
+
+/* base64.c declarations */
+//int b64_ntop(u_char const *src, size_t srclength, char *target, size_t targsize);
+//int b64_pton(char const *src, u_char *target, size_t targsize);
+
diff --git a/wsproxy.c b/wsproxy.c
new file mode 100644
index 0000000..4a1c17b
--- /dev/null
+++ b/wsproxy.c
@@ -0,0 +1,269 @@
+/*
+ * A WebSocket to TCP socket proxy with support for "wss://" encryption.
+ *
+ * You can make a cert/key with openssl using:
+ * openssl req -new -x509 -days 365 -nodes -out self.pem -keyout self.pem
+ * as taken from http://docs.python.org/dev/library/ssl.html#certificates
+ */
+#include <stdio.h>
+#include <errno.h>
+#include <sys/socket.h>
+#include <netinet/in.h>
+#include <netdb.h>
+#include <sys/select.h>
+#include <resolv.h>
+#include <fcntl.h>
+#include <sys/stat.h>
+#include "websocket.h"
+
+char traffic_legend[] = "\n\
+Traffic Legend:\n\
+ } - Client receive\n\
+ }. - Client receive partial\n\
+ { - Target receive\n\
+\n\
+ > - Target send\n\
+ >. - Target send partial\n\
+ < - Client send\n\
+ <. - Client send partial\n\
+";
+
+void usage() {
+ fprintf(stderr,"Usage: <listen_port> <target_host> <target_port>\n");
+ exit(1);
+}
+
+char *target_host;
+int target_port;
+client_settings_t client_settings;
+char *record_filename = NULL;
+int recordfd = 0;
+char *tbuf, *cbuf, *tbuf_tmp, *cbuf_tmp;
+unsigned int bufsize, dbufsize;
+
+void do_proxy(ws_ctx_t *ws_ctx, int target) {
+ fd_set rlist, wlist, elist;
+ struct timeval tv;
+ int maxfd, client = ws_ctx->sockfd;
+ unsigned int tstart, tend, cstart, cend, ret;
+ ssize_t len, bytes;
+
+ tstart = tend = cstart = cend = 0;
+ maxfd = client > target ? client+1 : target+1;
+ // Account for base64 encoding and WebSocket delims:
+ // 49150 = 65536 * 3/4 + 2 - 1
+
+ while (1) {
+ tv.tv_sec = 1;
+ tv.tv_usec = 0;
+
+ FD_ZERO(&rlist);
+ FD_ZERO(&wlist);
+ FD_ZERO(&elist);
+
+ FD_SET(client, &elist);
+ FD_SET(target, &elist);
+
+ if (tend == tstart) {
+ // Nothing queued for target, so read from client
+ FD_SET(client, &rlist);
+ } else {
+ // Data queued for target, so write to it
+ FD_SET(target, &wlist);
+ }
+ if (cend == cstart) {
+ // Nothing queued for client, so read from target
+ FD_SET(target, &rlist);
+ } else {
+ // Data queued for client, so write to it
+ FD_SET(client, &wlist);
+ }
+
+ ret = select(maxfd, &rlist, &wlist, &elist, &tv);
+
+ if (FD_ISSET(target, &elist)) {
+ fprintf(stderr, "target exception\n");
+ break;
+ }
+ if (FD_ISSET(client, &elist)) {
+ fprintf(stderr, "client exception\n");
+ break;
+ }
+
+ if (ret == -1) {
+ error("select()");
+ break;
+ } else if (ret == 0) {
+ //fprintf(stderr, "select timeout\n");
+ continue;
+ }
+
+ if (FD_ISSET(target, &wlist)) {
+ len = tend-tstart;
+ bytes = send(target, tbuf + tstart, len, 0);
+ if (bytes < 0) {
+ error("target connection error");
+ break;
+ }
+ tstart += bytes;
+ if (tstart >= tend) {
+ tstart = tend = 0;
+ traffic(">");
+ } else {
+ traffic(">.");
+ }
+ }
+
+ if (FD_ISSET(client, &wlist)) {
+ len = cend-cstart;
+ bytes = ws_send(ws_ctx, cbuf + cstart, len);
+ if (len < 3) {
+ fprintf(stderr, "len: %d, bytes: %d: %d\n", len, bytes, *(cbuf + cstart));
+ }
+ cstart += bytes;
+ if (cstart >= cend) {
+ cstart = cend = 0;
+ traffic("<");
+ if (recordfd) {
+ write(recordfd, "'>", 2);
+ write(recordfd, cbuf + cstart + 1, bytes - 2);
+ write(recordfd, "',\n", 3);
+ }
+ } else {
+ traffic("<.");
+ }
+ }
+
+ if (FD_ISSET(target, &rlist)) {
+ bytes = recv(target, cbuf_tmp, dbufsize , 0);
+ if (bytes <= 0) {
+ error("target closed connection");
+ break;
+ }
+ cbuf[0] = '\x00';
+ cstart = 0;
+ len = b64_ntop(cbuf_tmp, bytes, cbuf+1, bufsize-1);
+ if (len < 0) {
+ fprintf(stderr, "base64 encoding error\n");
+ break;
+ }
+ cbuf[len+1] = '\xff';
+ cend = len+1+1;
+ traffic("{");
+ }
+
+ if (FD_ISSET(client, &rlist)) {
+ bytes = ws_recv(ws_ctx, tbuf_tmp, bufsize-1);
+ if (bytes <= 0) {
+ fprintf(stderr, "client closed connection\n");
+ break;
+ }
+ if (tbuf_tmp[bytes-1] != '\xff') {
+ //traffic(".}");
+ fprintf(stderr, "Malformed packet\n");
+ break;
+ }
+ if (recordfd) {
+ write(recordfd, "'", 1);
+ write(recordfd, tbuf_tmp + 1, bytes - 2);
+ write(recordfd, "',\n", 3);
+ }
+ tbuf_tmp[bytes-1] = '\0';
+ len = b64_pton(tbuf_tmp+1, tbuf, bufsize-1);
+ if (len < 0) {
+ fprintf(stderr, "base64 decoding error\n");
+ break;
+ }
+ traffic("}");
+ tstart = 0;
+ tend = len;
+ }
+ }
+}
+
+void proxy_handler(ws_ctx_t *ws_ctx) {
+ int tsock = 0;
+ struct sockaddr_in taddr;
+ struct hostent *thost;
+
+ printf("Connecting to: %s:%d\n", target_host, target_port);
+
+ if (client_settings.b64encode) {
+ dbufsize = (bufsize * 3)/4 + 2 - 10; // padding and for good measure
+ } else {
+ }
+
+ tsock = socket(AF_INET, SOCK_STREAM, 0);
+ if (tsock < 0) {
+ error("Could not create target socket");
+ return;
+ }
+ thost = gethostbyname(target_host);
+ if (thost == NULL) {
+ error("Could not resolve server");
+ close(tsock);
+ return;
+ }
+ bzero((char *) &taddr, sizeof(taddr));
+ taddr.sin_family = AF_INET;
+ bcopy((char *) thost->h_addr,
+ (char *) &taddr.sin_addr.s_addr,
+ thost->h_length);
+ taddr.sin_port = htons(target_port);
+
+ if (connect(tsock, (struct sockaddr *) &taddr, sizeof(taddr)) < 0) {
+ error("Could not connect to target");
+ close(tsock);
+ return;
+ }
+
+ if (record_filename) {
+ recordfd = open(record_filename, O_WRONLY | O_CREAT | O_TRUNC,
+ S_IRUSR | S_IWUSR | S_IRGRP | S_IROTH);
+ }
+
+ printf("%s", traffic_legend);
+
+ do_proxy(ws_ctx, tsock);
+
+ close(tsock);
+ if (recordfd) {
+ close(recordfd);
+ recordfd = 0;
+ }
+}
+
+int main(int argc, char *argv[])
+{
+ int listen_port, idx=1;
+
+ if (strcmp(argv[idx], "--record") == 0) {
+ idx++;
+ record_filename = argv[idx++];
+ }
+
+ if ((argc-idx) != 3) { usage(); }
+ listen_port = strtol(argv[idx++], NULL, 10);
+ if (errno != 0) { usage(); }
+ target_host = argv[idx++];
+ target_port = strtol(argv[idx++], NULL, 10);
+ if (errno != 0) { usage(); }
+
+ /* Initialize buffers */
+ bufsize = 65536;
+ if (! (tbuf = malloc(bufsize)) )
+ { fatal("malloc()"); }
+ if (! (cbuf = malloc(bufsize)) )
+ { fatal("malloc()"); }
+ if (! (tbuf_tmp = malloc(bufsize)) )
+ { fatal("malloc()"); }
+ if (! (cbuf_tmp = malloc(bufsize)) )
+ { fatal("malloc()"); }
+
+ start_server(listen_port, &proxy_handler, &client_settings);
+
+ free(tbuf);
+ free(cbuf);
+ free(tbuf_tmp);
+ free(cbuf_tmp);
+}