1#if defined(__linux__) || defined(__bsd__) || defined(__sunos__)
2#include <netinet/in.h>
3#include <sys/socket.h>
4#endif
5
6#if defined(__windows__)
7#define _WIN32_WINNT 0x0601 // Windows 7 or later
8#include <winsock2.h>
9#include <ws2tcpip.h>
10#endif
11
12#include <openssl/rand.h>
13
14#include "http_message.h"
15#include "websocket.h"
16#include "url.h"
17
18#define MAX_BUFFER_SIZE 1024
19
20/** @brief Defines the various states of a WebSocket connection */
21typedef enum
22{
23 /** The connection with the client is closed. */
24 CNX_CLOSED = 0,
25
26 /** The connection with the client is established and open. */
27 CNX_CONNECTED = (1 << 1),
28
29 /** The connection with the client is in the process of being closed. */
30 CNX_CLOSING = (1 << 2),
31
32 /** The connection with the client is in the initial SSL handshake phase. */
33 CNX_SSL_INIT = (1 << 3),
34
35 /** The connection is in server mode. */
36 CNX_SERVER = (1 << 4)
37
38} cnx_flags_t;
39
40/** @brief Defines WebSocket close reason codes for CLOSE frames */
41typedef enum
42{
43 /** Normal Closure. This means that the purpose for which the connection was
44 * established has been fulfilled. */
45 WS_CLOSE_NORMAL = 1000,
46
47 /** Going Away. A server is going down, a browser has navigated away from a
48 * page, etc. */
49 WS_CLOSE_GOING_AWAY = 1001,
50
51 /** Protocol Error. The endpoint is terminating the connection due to a
52 * protocol error. */
53 WS_CLOSE_PROTOCOL_ERROR = 1002,
54
55 /** Unsupported Data. The connection is being terminated because an endpoint
56 * received a type of data it cannot accept. */
57 WS_CLOSE_UNSUPPORTED = 1003,
58
59 /** Reserved. The specific meaning might be defined in the future. */
60 WS_CLOSE_RESERVED = 1004,
61
62 /** No Status Received. Reserved value. The connection is closed with no
63 * status code. */
64 WS_CLOSE_NO_STATUS = 1005,
65
66 /** Abnormal Closure. Reserved value. The connection is closed with no
67 * status code. */
68 WS_CLOSE_ABNORMAL = 1006,
69
70 /** Invalid frame payload data. The endpoint is terminating the connection
71 * because a message was received that contains inconsistent data. */
72 WS_CLOSE_INVALID_PAYLOAD = 1007,
73
74 /** Policy Violation. The endpoint is terminating the connection because it
75 * received a message that violates its policy. */
76 WS_CLOSE_POLICY_VIOLATION = 1008,
77
78 /** Message Too Big. The endpoint is terminating the connection because a
79 * data frame was received that is too large. */
80 WS_CLOSE_TOO_BIG = 1009,
81
82 /** Missing Extension. The client is terminating the connection because it
83 * wanted the server to negotiate one or more extension, but the server
84 * didn't. */
85 WS_CLOSE_MISSING_EXTENSION = 1010,
86
87 /** Internal Error. The server is terminating the connection because it
88 * encountered an unexpected condition that prevented it from fulfilling the
89 * request. */
90 WS_CLOSE_INTERNAL_ERROR = 1011,
91
92 /** Service Restart. The server is terminating the connection because it is
93 * restarting. */
94 WS_CLOSE_SERVICE_RESTART = 1012,
95
96 /** Try Again Later. The server is terminating the connection due to a
97 * temporary condition, e.g. it is overloaded and is casting off some of its
98 * clients. */
99 WS_CLOSE_TRY_AGAIN_LATER = 1013,
100
101 /** Bad Gateway. The server was acting as a gateway or proxy and received an
102 * invalid response from the upstream server. This is similar to 502 HTTP
103 * Status Code. */
104 WS_CLOSE_BAD_GATEWAY = 1014,
105
106 /** TLS handshake. Reserved value. The connection is closed due to a failure
107 * to perform a TLS handshake. */
108 WS_CLOSE_TLS_HANDSHAKE = 1015
109
110} ws_close_code_t;
111
112//------------------------------------------------------------------------------
113// Internal functions
114//------------------------------------------------------------------------------
115
116/**
117 * @defgroup ConnectionFunctions
118 *
119 * @brief Functions that manage WebSocket connections
120 *
121 */
122
123/**
124 * @brief Attempts connection based on previously parsed URL
125 *
126 * @param c The websocket connection.
127 * @return Returns true if the connection is successful, false otherwise.
128 */
129static bool cnx_connect();
130
131/**
132 * @brief Generates a new, random WebSocket key for the handshake process.
133 *
134 * @return A pointer to the generated WebSocket key.
135 *
136 * @ingroup ConnectionFunctions
137 */
138static char* generate_websocket_key();
139
140/**
141 * @brief Extracts the WebSocket accept key from a server's handshake response.
142 *
143 * @param response The server's handshake response.
144 * @return A pointer to the extracted WebSocket accept key.
145 *
146 * @ingroup ConnectionFunctions
147 */
148static char* extract_websocket_accept_key(const char* response);
149
150/**
151 * @brief Verifies the server's handshake response using the client's original
152 * key.
153 *
154 * @param key The client's original key.
155 * @param response The server's handshake response.
156 * @return 0 if the handshake is verified successfully, an error code otherwise.
157 *
158 * @ingroup ConnectionFunctions
159 */
160static int verify_handshake(const char* key, const char* response);
161
162
163
164
165/**
166 * @defgroup SocketFunctions
167 *
168 * @brief Functions that manage sockets
169 *
170 */
171
172/**
173 * @brief Performs WebSocket connection. This is integrated as callback into
174 * vws_connect().
175 *
176 * @param s The WebSocket connection
177 * @return Returns true if handshake succeeded, false otherwise
178 *
179 * @ingroup ConnectionFunctions
180 */
181
182static bool socket_handshake(vws_socket* s);
183
184/**
185 * @brief Waits for a complete frame to be available from a WebSocket
186 * connection.
187 *
188 * @param c The vws_cnx representing the WebSocket connection.
189 * @return The number of bytes read and processed, or an error code if an error
190 * occurred.
191 *
192 * @ingroup SocketFunctions
193 */
194static ssize_t socket_wait_for_frame(vws_cnx* c);
195
196/**
197 * @defgroup FrameFunctions
198 *
199 * @brief Functions that manage websocket frames
200 *
201 */
202
203/**
204 * @brief Processes a single frame from a WebSocket connection.
205 *
206 * @param c The vws_cnx representing the WebSocket connection.
207 * @param frame The vws_frame to process.
208 * @return void
209 *
210 * @ingroup FrameFunctions
211 */
212static void process_frame(vws_cnx* c, vws_frame* frame);
213
214
215
216
217/**
218 * @defgroup MessageFunctions
219 *
220 * @brief Functions that manage websocket messages
221 *
222 */
223
224/**
225 * @brief Checks if a complete message is available in the connection's message
226 * queue.
227 *
228 * @param c The vws_cnx representing the WebSocket connection.
229 * @return True if a complete message is available, false otherwise.
230 *
231 * @ingroup MessageFunctions
232 */
233static bool has_complete_message(vws_cnx* c);
234
235
236
237
238/**
239 * @defgroup TraceFunctions
240 *
241 * @brief Functions that provide tracing and debugging
242 *
243 */
244
245/**
246 * @brief Structure representing a WebSocket header.
247 *
248 * @ingroup TraceFunctions
249 */
250typedef struct
251{
252 uint8_t fin; /**< Indicates the final frame of a message */
253 uint8_t opcode; /**< Identifies the frame type */
254 uint8_t mask; /**< Indicates if the frame payload is masked */
255 uint64_t payload_len; /**< Length of the payload data */
256 uint32_t masking_key; /**< Key used for payload data masking */
257} ws_header;
258
259/**
260 * @brief Dumps the contents of a WebSocket header for debugging purposes.
261 *
262 * @param header The WebSocket header to dump.
263 *
264 * @ingroup TraceFunctions
265 */
266static void dump_websocket_header(const ws_header* header);
267
268//------------------------------------------------------------------------------
269//> Connection API
270//------------------------------------------------------------------------------
271
272bool cnx_connect(vws_cnx* c)
273{
274 if (c->url->host == NULL)
275 {
276 vws.error(VE_MEM, "Invalid or missing host");
277 return false;
278 }
279
280 // Connect to the server
281 cstr default_port = strcmp(c->url->protocol, "wss") == 0 ? "443" : "80";
282 cstr port = c->url->port != NULL ? c->url->port : default_port;
283
284 bool ssl = false;
285 if (strcmp(c->url->protocol, "wss") == 0)
286 {
287 ssl = true;
288 }
289
290 return vws_socket_connect((vws_socket*)c, c->url->host, atoi(port), ssl);
291}
292
293vws_cnx* vws_cnx_new()
294{
295 vws_cnx* c = (vws_cnx*)vws.malloc(sizeof(vws_cnx));
296 memset(c, 0, sizeof(vws_cnx));
297
298 // Call base constructor
299 vws_socket_ctor((vws_socket*)c);
300
301 c->base.hs = socket_handshake;
302 c->flags = CNX_CLOSED;
303 c->url = NULL;
304 c->key = generate_websocket_key();
305 c->process = process_frame;
306 c->disconnect = NULL;
307 c->data = NULL;
308
309 sc_queue_init(&c->queue);
310
311 return c;
312}
313
314void vws_cnx_free(vws_cnx* c)
315{
316 if (c == NULL)
317 {
318 return;
319 }
320
321 vws_disconnect(c);
322
323 // Free receive queue contents
324 vws_frame* f;
325 sc_queue_foreach (&c->queue, f)
326 {
327 vws_frame_free(f);
328 }
329
330 // Free receive queue
331 sc_queue_term(&c->queue);
332
333 // Free URL
334 if (c->url != NULL)
335 {
336 url_free((url_data_t*)c->url);
337 c->url = NULL;
338 }
339
340 // Free websocket key
341 vws.free(c->key);
342
343 // Call base constructor
344 vws_socket_dtor((vws_socket*)c);
345}
346
347void vws_cnx_set_server_mode(vws_cnx* c)
348{
349 vws_set_flag(&c->flags, CNX_SERVER);
350}
351
352bool vws_connect(vws_cnx* c, cstr uri)
353{
354 if (c == NULL)
355 {
356 // Return early if failed to create a connection.
357 vws.error(VE_RT, "Invalid connection pointer()");
358 return false;
359 }
360
361 if (c->url != NULL)
362 {
363 url_free((url_data_t*)c->url);
364 }
365
366 c->url = (vws_url_data*)url_parse(uri);
367
368 return cnx_connect(c);
369}
370
371bool vws_reconnect(vws_cnx* c)
372{
373 if (vws_cnx_is_connected(c) == true)
374 {
375 return true;
376 }
377
378 if (c->url != NULL)
379 {
380 cnx_connect(c);
381 }
382
383 return false;
384}
385
386bool vws_cnx_is_connected(vws_cnx* c)
387{
388 if (vws_socket_is_connected((vws_socket*)c) == false)
389 {
390 vws.error(VE_SOCKET, "vws_cnx_is_connected()");
391 return false;
392 }
393
394 return true;
395}
396
397bool socket_handshake(vws_socket* s)
398{
399 vws_cnx* c = (vws_cnx*)s;
400
401 // Send the WebSocket handshake request
402 const char* rt =
403 "GET %s HTTP/1.1\r\n"
404 "Host: %s\r\n"
405 "Cache-Control: no-cache\r\n"
406 "Origin: %s\r\n"
407 "Upgrade: websocket\r\n"
408 "Connection: Upgrade\r\n"
409 "Sec-WebSocket-Key: %s\r\n"
410 "Sec-WebSocket-Version: 13\r\n"
411 "\r\n";
412
413 char req[MAX_BUFFER_SIZE];
414 snprintf(req, sizeof(req), rt, c->url->path, c->url->host, c->url->href, c->key);
415
416 ssize_t n;
417 size_t total = 0;
418 size_t size = strlen(req);
419
420 while (true)
421 {
422 n = vws_socket_write(s, (ucstr)req, size);
423
424 if (vws_cnx_is_connected(c) == false)
425 {
426 return false;
427 }
428
429 if (n > 0)
430 {
431 total += n;
432
433 if (total == size)
434 {
435 break;
436 }
437 }
438
439 if (n == 0)
440 {
441 if (vws.e.code == VE_TIMEOUT)
442 {
443 return false;
444 }
445 }
446
447 if (n < 0)
448 {
449 return false;
450 }
451 }
452
453 // Create HTTP response message
454 vws_http_msg* http = vws_http_msg_new(HTTP_RESPONSE);
455
456 // Read until full HTTP response is received
457 while (true)
458 {
459 n = vws_socket_read(s);
460
461 if (vws_cnx_is_connected(c) == false)
462 {
463 // Clear the socket buffer of anything that did arrive,
464 // otherwise it will possibly be in inconsistent state.
465 vws_buffer_clear(c->base.buffer);
466
467 // Free HTTP response
468 vws_http_msg_free(http);
469
470 return false;
471 }
472
473 // If there was a timeout
474 if (n == 0)
475 {
476 // Fail
477 if (vws.e.code == VE_TIMEOUT)
478 {
479 // Clear the socket buffer of anything that did arrive,
480 // otherwise it will possibly be in inconsistent state.
481 vws_buffer_clear(c->base.buffer);
482
483 // Free HTTP response
484 vws_http_msg_free(http);
485
486 return false;
487 }
488 }
489
490 if (n < 0)
491 {
492 // Clear the socket buffer of anything that did arrive,
493 // otherwise it will possibly be in inconsistent state.
494 vws_buffer_clear(c->base.buffer);
495
496 // Free HTTP response
497 vws_http_msg_free(http);
498
499 return false;
500 }
501
502 if (s->buffer->size > 0)
503 {
504 cstr data = (cstr)s->buffer->data;
505 size_t size = s->buffer->size;
506 ssize_t n = vws_http_msg_parse(http, data, size);
507
508 if (http->headers_complete == true)
509 {
510 // Drain HTTP request data from socket buffer
511 vws_buffer_drain(c->base.buffer, n);
512
513 break;
514 }
515 }
516 }
517
518 struct sc_map_str* headers = &http->headers;
519 cstr accept_key = vws_map_get(headers, "sec-websocket-accept");
520
521 if (accept_key == NULL)
522 {
523 vws.error(VE_SYS, "connect failed: no accept key returned");
524
525 return false;
526 }
527
528 if (verify_handshake(c->key, accept_key) == false)
529 {
530 vws.error(VE_RT, "Handshake verification failed");
531 vws.free(accept_key);
532 return false;
533 }
534
535 vws_http_msg_free(http);
536
537 return true;
538}
539
540void vws_disconnect(vws_cnx* c)
541{
542 vws_socket* s = (vws_socket*)c;
543
544 if (vws_cnx_is_connected(c) == false)
545 {
546 return;
547 }
548
549 // If disconnect callback is registered
550 if (c->disconnect != NULL)
551 {
552 // Call it
553 c->disconnect(c);
554 }
555
556 c->flags = CNX_CLOSED;
557
558 vws_buffer* buffer = vws_generate_close_frame();
559
560 for (size_t i = 0; i < buffer->size;)
561 {
562 int n = vws_socket_write(s, buffer->data + i, buffer->size - i);
563
564 if (n < 0)
565 {
566 break;
567 }
568
569 i += n;
570 }
571
572 vws_buffer_free(buffer);
573
574 vws_socket_disconnect(s);
575}
576
577//------------------------------------------------------------------------------
578//> Messaging API
579//------------------------------------------------------------------------------
580
581ssize_t vws_frame_send_text(vws_cnx* c, cstr data)
582{
583 return vws_frame_send_data(c, (ucstr)data, strlen(data), 0x1);
584}
585
586ssize_t vws_frame_send_binary(vws_cnx* c, ucstr data, size_t size)
587{
588 return vws_frame_send_data(c, data, size, 0x2);
589}
590
591ssize_t vws_frame_send_data(vws_cnx* c, ucstr data, size_t size, int oc)
592{
593 return vws_frame_send(c, vws_frame_new(data, size, oc));
594}
595
596ssize_t vws_msg_send_text(vws_cnx* c, cstr data)
597{
598 return vws_frame_send_data(c, (ucstr)data, strlen(data), 0x1);
599}
600
601ssize_t vws_msg_send_binary(vws_cnx* c, ucstr data, size_t size)
602{
603 return vws_frame_send_data(c, data, size, 0x2);
604}
605
606ssize_t vws_msg_send_data(vws_cnx* c, ucstr data, size_t size, int oc)
607{
608 return vws_frame_send(c, vws_frame_new(data, size, oc));
609}
610
611ssize_t vws_frame_send(vws_cnx* c, vws_frame* frame)
612{
613 if (vws_cnx_is_connected(c) == false)
614 {
615 return -1;
616 }
617
618 vws_buffer* binary = vws_serialize(frame);
619
620 if (vws.tracelevel >= VT_PROTOCOL)
621 {
622 vws_trace_lock();
623 printf("\n\n");
624 printf("+----------------------------------------------------+\n");
625 printf("| Frame Sent |\n");
626 printf("+----------------------------------------------------+\n");
627
628 vws_dump_websocket_frame(binary->data, binary->size);
629 printf("------------------------------------------------------\n");
630 vws_trace_unlock();
631 }
632
633 ssize_t n = 0;
634
635 if (binary->data != NULL)
636 {
637 n = vws_socket_write((vws_socket*)c, binary->data, binary->size);
638 vws_buffer_free(binary);
639
640 if (vws_cnx_is_connected(c) == false)
641 {
642 return -1;
643 }
644 }
645
646 vws.success();
647
648 return n;
649}
650
651//------------------------------------------------------------------------------
652//> Message API
653//------------------------------------------------------------------------------
654
655vws_msg* vws_msg_new()
656{
657 vws_msg* m = vws.malloc(sizeof(vws_msg));
658 m->opcode = 0;
659 m->data = vws_buffer_new();
660
661 return m;
662}
663
664void vws_msg_free(vws_msg* m)
665{
666 if (m != NULL)
667 {
668 vws_buffer_free(m->data);
669 vws.free(m);
670 }
671}
672
673vws_msg* vws_msg_recv(vws_cnx* c)
674{
675 // Default success unless error
676 vws.success();
677
678 if (vws_cnx_is_connected(c) == false)
679 {
680 return NULL;
681 }
682
683 while (true)
684 {
685 vws_msg* msg = vws_msg_pop(c);
686
687 if (msg != NULL)
688 {
689 return msg;
690 }
691
692 if (socket_wait_for_frame(c) <= 0)
693 {
694 break;
695 }
696 }
697
698 return NULL;
699}
700
701//------------------------------------------------------------------------------
702//> Frame API
703//------------------------------------------------------------------------------
704
705vws_frame* vws_frame_new(ucstr data, size_t s, unsigned char oc)
706{
707 vws_frame* f = vws.malloc(sizeof(vws_frame));
708
709 // We must make our own copy of the data for deterministic memory management
710
711 f->fin = 1;
712 f->opcode = oc;
713 f->mask = 1;
714 f->offset = 0;
715 f->size = s;
716 f->data = NULL;
717
718 if (f->size > 0)
719 {
720 f->data = vws.malloc(f->size);
721 memcpy(f->data, data, f->size);
722 }
723
724 return f;
725}
726
727void vws_frame_free(vws_frame* f)
728{
729 if (f != NULL)
730 {
731 if (f->data != NULL)
732 {
733 vws.free(f->data);
734 f->data = NULL;
735 }
736
737 f->size = 0;
738
739 vws.free(f);
740 }
741}
742
743vws_frame* vws_frame_recv(vws_cnx* c)
744{
745 // Default success unless error
746 vws.success();
747
748 if (vws_cnx_is_connected(c) == false)
749 {
750 return NULL;
751 }
752
753 while (true)
754 {
755 if (sc_queue_size(&c->queue) > 0)
756 {
757 return sc_queue_del_last(&c->queue);
758 }
759
760 if (socket_wait_for_frame(c) <= 0)
761 {
762 break;
763 }
764 }
765
766 return NULL;
767}
768
769vws_buffer* vws_serialize(vws_frame* f)
770{
771 if (f == NULL)
772 {
773 vws.error(VE_RT, "empty frame");
774
775 return NULL;
776 }
777
778 //> Section 1: Size calculation
779
780 // Calculate the frame size
781 uint64_t payload_length = f->size;
782
783 // Set the mask bit and payload length. Maximum frame size with extended
784 // payload length and masking key
785 unsigned char header[14];
786
787 // Minimum frame size
788 size_t header_size = 2;
789
790 // Set the FIN bit and opcode
791 header[0] = f->fin << 7 | f->opcode;
792
793 if (payload_length <= 125)
794 {
795 header[1] = payload_length;
796 }
797 else if (payload_length <= 65535)
798 {
799 header[1] = 126;
800 header[2] = (payload_length >> 8) & 0xFF;
801 header[3] = payload_length & 0xFF;
802
803 // Additional bytes for payload length
804 header_size += 2;
805 }
806 else
807 {
808 header[1] = 127;
809 header[2] = (payload_length >> 56) & 0xFF;
810 header[3] = (payload_length >> 48) & 0xFF;
811 header[4] = (payload_length >> 40) & 0xFF;
812 header[5] = (payload_length >> 32) & 0xFF;
813 header[6] = (payload_length >> 24) & 0xFF;
814 header[7] = (payload_length >> 16) & 0xFF;
815 header[8] = (payload_length >> 8) & 0xFF;
816 header[9] = payload_length & 0xFF;
817
818 // Additional bytes for payload length
819 header_size += 8;
820 }
821
822 //> Section 2: Frame allocation
823
824 size_t frame_size = header_size + payload_length;
825
826 if (f->mask)
827 {
828 // Set the masking bit
829 header[1] |= 0x80;
830
831 // Additional bytes for masking key
832 frame_size += 4;
833 }
834
835 // Allocate memory for the frame
836 unsigned char* frame_data = (unsigned char*)vws.malloc(frame_size);
837
838 // Copy the header to the frame
839 memcpy(frame_data, header, header_size);
840
841 //> Section 3: Masking
842
843 if (f->mask)
844 {
845 // Generate a random masking key
846
847 unsigned char masking_key[4];
848
849 if (RAND_bytes(masking_key, sizeof(masking_key)) != 1)
850 {
851 vws.error(VE_RT, "RAND_bytes() failed");
852 vws.free(frame_data);
853 vws_frame_free(f);
854
855 return NULL;
856 }
857
858 // Copy the masking key to the frame
859 memcpy(frame_data + header_size, masking_key, 4);
860
861 // Apply masking to the payload data
862 size_t payload_start = header_size + 4;
863 for (size_t i = 0; i < payload_length; i++)
864 {
865 frame_data[payload_start + i] = f->data[i] ^ masking_key[i % 4];
866 }
867 }
868 else
869 {
870 // Copy the payload data without masking
871 memcpy(frame_data + header_size, f->data, payload_length);
872 }
873
874 //> Section 4: Finalizing
875
876 // Free the frame
877 vws_frame_free(f);
878
879 // Create the vws_buffer to hold the frame data
880 vws_buffer* buffer = vws_buffer_new();
881
882 // Have buffer take ownership of data
883 buffer->data = frame_data;
884 buffer->size = frame_size;
885
886 vws.success();
887
888 return buffer;
889}
890
891fs_t vws_deserialize(ucstr data, size_t size, vws_frame* f, size_t* consumed)
892{
893 // Check if the data contains the minimum required frame header bytes
894 if (size < 2)
895 {
896 return FRAME_INCOMPLETE;
897 }
898
899 // Read the first byte (FIN bit and opcode)
900 f->fin = (data[0] >> 7) & 0x01;
901 f->opcode = data[0] & 0x0F;
902
903 // Read the second byte (mask bit and payload length)
904 f->mask = (data[1] >> 7) & 0x01;
905 f->size = data[1] & 0x7F;
906
907 // Check if the payload length requires additional bytes
908 size_t size_bytes = 0;
909 if (f->size == 126)
910 {
911 size_bytes = 2;
912 }
913 else if (f->size == 127)
914 {
915 size_bytes = 8;
916 }
917
918 // Check if the data contains complete frame header and payload
919 size_t required_bytes = 2 + size_bytes;
920
921 if (size < required_bytes)
922 {
923 return FRAME_INCOMPLETE;
924 }
925
926 // Read the payload
927 if (size_bytes > 0)
928 {
929 f->size = 0;
930 for (size_t i = 0; i < size_bytes; i++)
931 {
932 f->size = (f->size << 8) | data[2 + i];
933 }
934 }
935
936 // Check if the frame has masking key and payload data
937 if (f->mask)
938 {
939 // Check if the data contains the masking key and payload data
940 required_bytes += 4 + f->size;
941
942 if (size < required_bytes)
943 {
944 return FRAME_INCOMPLETE;
945 }
946
947 // Store the payload offset
948 f->offset = 2 + size_bytes + 4;
949
950 // Allocate the frame data
951 f->data = vws.malloc(f->size);
952
953 // Create a temp variable for the masking key
954 unsigned char mask[4];
955 memcpy(mask, data + 2 + size_bytes, 4);
956
957 // Read the payload data and apply the masking
958 for (size_t i = 0; i < f->size; i++)
959 {
960 f->data[i] = data[f->offset + i] ^ mask[i % 4];
961 }
962 }
963 else
964 {
965 // Check if the data contains the payload data
966
967 required_bytes += f->size;
968
969 if (size < required_bytes)
970 {
971 return FRAME_INCOMPLETE;
972 }
973
974 // Store the payload offset
975 f->offset = 2 + size_bytes;
976
977 // Allocate the frame data
978 f->data = vws.malloc(f->size);
979
980 // Copy the payload data
981 memcpy(f->data, data + f->offset, f->size);
982 }
983
984 // Update the bytes consumed
985 *consumed = required_bytes;
986
987 return FRAME_COMPLETE;
988}
989
990//------------------------------------------------------------------------------
991// Utility functions
992//------------------------------------------------------------------------------
993
994void process_frame(vws_cnx* c, vws_frame* f)
995{
996 switch (f->opcode)
997 {
998 case CLOSE_FRAME:
999 {
1000 // Set closing state
1001 vws_set_flag(&c->flags, CNX_CLOSING);
1002
1003 // Build the response frame
1004 vws_buffer* buffer = vws_generate_close_frame();
1005
1006 // Send the response frame
1007 vws_socket_write((vws_socket*)c, buffer->data, buffer->size);
1008
1009 // Clean up
1010 vws_buffer_free(buffer);
1011 vws_frame_free(f);
1012
1013 break;
1014 }
1015
1016 case TEXT_FRAME:
1017 case BINARY_FRAME:
1018 case CONTINUATION_FRAME:
1019 {
1020 // Add to queue
1021 sc_queue_add_first(&c->queue, f);
1022
1023 break;
1024 }
1025
1026 case PING_FRAME:
1027 {
1028 // Generate the PONG response
1029 vws_buffer* buffer = vws_generate_pong_frame(f->data, f->size);
1030
1031 // Send the PONG response
1032 vws_socket_write((vws_socket*)c, buffer->data, buffer->size);
1033
1034 // Clean up
1035 vws_buffer_free(buffer);
1036 vws_frame_free(f);
1037
1038 break;
1039 }
1040
1041 case PONG_FRAME:
1042 {
1043 // No need to send a response
1044
1045 vws_frame_free(f);
1046
1047 break;
1048 }
1049
1050 default:
1051 {
1052 // Invalid frame type
1053 vws_frame_free(f);
1054 }
1055 }
1056
1057 vws.success();
1058}
1059
1060char* generate_websocket_key()
1061{
1062 // Generate a random 16-byte value
1063 unsigned char random_bytes[16];
1064 if (RAND_bytes(random_bytes, sizeof(random_bytes)) != 1)
1065 {
1066 return NULL;
1067 }
1068
1069 // Base64-encode the random bytes
1070 char* encoded_key = vws_base64_encode(random_bytes, sizeof(random_bytes));
1071
1072 if (encoded_key == NULL)
1073 {
1074 return NULL;
1075 }
1076
1077 return encoded_key;
1078}
1079
1080cstr vws_accept_key(cstr key)
1081{
1082 // Concatenate the key and WebSocket GUID
1083 const char* websocket_guid = "258EAFA5-E914-47DA-95CA-C5AB0DC85B11";
1084 size_t key_length = strlen(key);
1085 size_t guid_length = strlen(websocket_guid);
1086 size_t input_length = key_length + guid_length;
1087 char* input = (char*)vws.malloc(input_length + 1);
1088
1089 strncpy(input, key, key_length);
1090 strncpy(input + key_length, websocket_guid, guid_length);
1091 input[input_length] = '\0';
1092
1093 // Compute the SHA-1 hash of the concatenated value
1094 unsigned char hash[SHA_DIGEST_LENGTH];
1095 SHA1((const unsigned char*)input, input_length, hash);
1096
1097 // Base64-encode the hash
1098 char* encoded_hash = vws_base64_encode(hash, sizeof(hash));
1099
1100 vws.free(input);
1101
1102 return encoded_hash;
1103}
1104
1105int verify_handshake(const char* key, const char* response)
1106{
1107 char* hash = vws_accept_key(key);
1108 int result = strcmp(hash, response);
1109 vws.free(hash);
1110
1111 return result == 0;
1112}
1113
1114ssize_t socket_wait_for_frame(vws_cnx* c)
1115{
1116 // Default success unless error
1117 vws.success();
1118
1119 if (vws_cnx_is_connected(c) == false)
1120 {
1121 return -1;
1122 }
1123
1124 ssize_t n;
1125 unsigned char buf[1024];
1126
1127 while (true)
1128 {
1129 memset(buf, 0, 1024);
1130 n = vws_socket_read((vws_socket*)c);
1131
1132 if (n <= 0)
1133 {
1134 // Error already set
1135 return n;
1136 }
1137
1138 if (vws_cnx_ingress(c) > 0)
1139 {
1140 break;
1141 }
1142 }
1143
1144 return n;
1145}
1146
1147ssize_t vws_cnx_ingress(vws_cnx* c)
1148{
1149 size_t total_consumed = 0;
1150
1151 // Process as many frames as possible
1152 while (true)
1153 {
1154 // If there is no more data in socket buffer
1155 if (c->base.buffer->size == 0)
1156 {
1157 break;
1158 }
1159
1160 // Attempt to parse complete frame
1161 size_t consumed = 0;
1162 vws_buffer* b = c->base.buffer;
1163 vws_frame* frame = vws_frame_new(NULL, 0, TEXT_FRAME);
1164
1165 if (vws.tracelevel >= VT_PROTOCOL)
1166 {
1167 vws_trace_lock();
1168 printf("\n+----------------------------------------------------+\n");
1169 printf("| Frame Received |\n");
1170 printf("+----------------------------------------------------+\n");
1171 vws_dump_websocket_frame(b->data, b->size);
1172 printf("------------------------------------------------------\n");
1173 vws_trace_unlock();
1174 }
1175
1176 fs_t rc = vws_deserialize(b->data, b->size, frame, &consumed);
1177
1178 if (rc == FRAME_ERROR)
1179 {
1180 vws.error(VE_WARN, "FRAME_ERROR");
1181 vws_frame_free(frame);
1182
1183 return 0;
1184 }
1185
1186 if (rc == FRAME_INCOMPLETE)
1187 {
1188 // No complete frame in socket buffer
1189 vws_frame_free(frame);
1190
1191 return 0;
1192 }
1193
1194 // Update
1195 total_consumed += consumed;
1196
1197 // We have a frame. Process it.
1198 c->process(c, frame);
1199
1200 // Drain the consumed frame data from buffer
1201 vws_buffer_drain(c->base.buffer, consumed);
1202 }
1203
1204 vws.success();
1205
1206 return total_consumed;
1207}
1208
1209vws_buffer* vws_generate_close_frame()
1210{
1211 size_t size = sizeof(int16_t);
1212 int16_t* data = vws.malloc(size);
1213
1214 // Convert to network byte order before assignement
1215 *data = htons(WS_CLOSE_NORMAL);
1216 vws_frame* f = vws_frame_new((ucstr)data, size, CLOSE_FRAME);
1217
1218 vws.free(data);
1219
1220 return vws_serialize(f);
1221}
1222
1223vws_buffer* vws_generate_pong_frame(ucstr ping_data, size_t s)
1224{
1225 // We create a new frame with the same data as the received ping frame
1226 vws_frame* f = vws_frame_new(ping_data, s, PONG_FRAME);
1227
1228 // Serialize the frame and return it
1229 return vws_serialize(f);
1230}
1231
1232vws_msg* vws_msg_pop(vws_cnx* c)
1233{
1234 if (has_complete_message(c) == false)
1235 {
1236 return NULL;
1237 }
1238
1239 // Create new message
1240 vws_msg* m = vws_msg_new();
1241
1242 // Set to sentinel value to detect first frame
1243 m->opcode = 100;
1244
1245 do
1246 {
1247 vws_frame* f = sc_queue_del_last(&c->queue);
1248
1249 // If this is first frame, opcode is sentinel value. We take the opcode
1250 // from the first frame only
1251 if (m->opcode == 100)
1252 {
1253 m->opcode = f->opcode;
1254 }
1255
1256 // Copy frame data into message buffer
1257 vws_buffer_append(m->data, f->data, f->size);
1258
1259 // Is this the completion frame?
1260 bool complete = (f->fin == 1);
1261
1262 // Free frame
1263 vws_frame_free(f);
1264
1265 if (complete)
1266 {
1267 break;
1268 }
1269 }
1270 while (true);
1271
1272 return m;
1273}
1274
1275bool has_complete_message(vws_cnx* c)
1276{
1277 vws_frame* f;
1278 sc_queue_foreach (&c->queue, f)
1279 {
1280 if (f->fin == 1)
1281 {
1282 return true;
1283 }
1284 }
1285
1286 return false;
1287}
1288
1289void dump_websocket_header(const ws_header* header)
1290{
1291 printf(" fin: %u\n", header->fin);
1292 printf(" opcode: %u\n", header->opcode);
1293 printf(" mask: %u (0x%08x)\n", header->mask, header->masking_key);
1294 printf(" payload: %lu bytes\n", header->payload_len);
1295 printf("\n");
1296}
1297
1298void vws_dump_websocket_frame(const uint8_t* frame, size_t size)
1299{
1300 if (size < 2)
1301 {
1302 printf("Invalid WebSocket frame\n");
1303 return;
1304 }
1305
1306 ws_header header;
1307 size_t header_size = 2;
1308 header.fin = (frame[0] & 0x80) >> 7;
1309 header.opcode = frame[0] & 0x0F;
1310 header.mask = (frame[1] & 0x80) >> 7;
1311 header.payload_len = frame[1] & 0x7F;
1312
1313 if (header.payload_len == 126)
1314 {
1315 if (size < 4)
1316 {
1317 printf("Invalid WebSocket frame\n");
1318 return;
1319 }
1320
1321 header_size += 2;
1322 header.payload_len = ((uint64_t)frame[2] << 8) | frame[3];
1323 }
1324 else if (header.payload_len == 127)
1325 {
1326 if (size < 10)
1327 {
1328 printf("Invalid WebSocket frame\n");
1329 return;
1330 }
1331
1332 header_size += 8;
1333 header.payload_len =
1334 ((uint64_t)frame[2] << 56) |
1335 ((uint64_t)frame[3] << 48) |
1336 ((uint64_t)frame[4] << 40) |
1337 ((uint64_t)frame[5] << 32) |
1338 ((uint64_t)frame[6] << 24) |
1339 ((uint64_t)frame[7] << 16) |
1340 ((uint64_t)frame[8] << 8) |
1341 frame[9];
1342 }
1343
1344 if (header.mask)
1345 {
1346 if (size < header_size + 4)
1347 {
1348 printf("Invalid WebSocket frame\n");
1349 return;
1350 }
1351 header.masking_key =
1352 ((uint32_t)frame[header_size] << 24) |
1353 ((uint32_t)frame[header_size + 1] << 16) |
1354 ((uint32_t)frame[header_size + 2] << 8) |
1355 frame[header_size + 3];
1356 header_size += 4;
1357 }
1358 else
1359 {
1360 header.masking_key = 0;
1361 }
1362
1363 printf(" header: %zu bytes\n", header_size);
1364 dump_websocket_header(&header);
1365
1366 if (size > header_size)
1367 {
1368 for (size_t i = header_size; i < size; ++i)
1369 {
1370 printf("%02x ", frame[i]);
1371 if ((i - header_size + 1) % 16 == 0)
1372 {
1373 printf("\n");
1374 }
1375 }
1376
1377 printf("\n");
1378 }
1379}
1380