1 | #if defined(__linux__) || defined(__bsd__) || defined(__sunos__) |
2 | #include <netinet/in.h> |
3 | #include <sys/socket.h> |
4 | #endif |
5 | |
6 | #if defined(__windows__) |
7 | #define _WIN32_WINNT 0x0601 // Windows 7 or later |
8 | #include <winsock2.h> |
9 | #include <ws2tcpip.h> |
10 | #endif |
11 | |
12 | #include <openssl/rand.h> |
13 | |
14 | #include "http_message.h" |
15 | #include "websocket.h" |
16 | #include "url.h" |
17 | |
18 | #define MAX_BUFFER_SIZE 1024 |
19 | |
20 | /** @brief Defines the various states of a WebSocket connection */ |
21 | typedef enum |
22 | { |
23 | /** The connection with the client is closed. */ |
24 | CNX_CLOSED = 0, |
25 | |
26 | /** The connection with the client is established and open. */ |
27 | CNX_CONNECTED = (1 << 1), |
28 | |
29 | /** The connection with the client is in the process of being closed. */ |
30 | CNX_CLOSING = (1 << 2), |
31 | |
32 | /** The connection with the client is in the initial SSL handshake phase. */ |
33 | CNX_SSL_INIT = (1 << 3), |
34 | |
35 | /** The connection is in server mode. */ |
36 | CNX_SERVER = (1 << 4) |
37 | |
38 | } cnx_flags_t; |
39 | |
40 | /** @brief Defines WebSocket close reason codes for CLOSE frames */ |
41 | typedef enum |
42 | { |
43 | /** Normal Closure. This means that the purpose for which the connection was |
44 | * established has been fulfilled. */ |
45 | WS_CLOSE_NORMAL = 1000, |
46 | |
47 | /** Going Away. A server is going down, a browser has navigated away from a |
48 | * page, etc. */ |
49 | WS_CLOSE_GOING_AWAY = 1001, |
50 | |
51 | /** Protocol Error. The endpoint is terminating the connection due to a |
52 | * protocol error. */ |
53 | WS_CLOSE_PROTOCOL_ERROR = 1002, |
54 | |
55 | /** Unsupported Data. The connection is being terminated because an endpoint |
56 | * received a type of data it cannot accept. */ |
57 | WS_CLOSE_UNSUPPORTED = 1003, |
58 | |
59 | /** Reserved. The specific meaning might be defined in the future. */ |
60 | WS_CLOSE_RESERVED = 1004, |
61 | |
62 | /** No Status Received. Reserved value. The connection is closed with no |
63 | * status code. */ |
64 | WS_CLOSE_NO_STATUS = 1005, |
65 | |
66 | /** Abnormal Closure. Reserved value. The connection is closed with no |
67 | * status code. */ |
68 | WS_CLOSE_ABNORMAL = 1006, |
69 | |
70 | /** Invalid frame payload data. The endpoint is terminating the connection |
71 | * because a message was received that contains inconsistent data. */ |
72 | WS_CLOSE_INVALID_PAYLOAD = 1007, |
73 | |
74 | /** Policy Violation. The endpoint is terminating the connection because it |
75 | * received a message that violates its policy. */ |
76 | WS_CLOSE_POLICY_VIOLATION = 1008, |
77 | |
78 | /** Message Too Big. The endpoint is terminating the connection because a |
79 | * data frame was received that is too large. */ |
80 | WS_CLOSE_TOO_BIG = 1009, |
81 | |
82 | /** Missing Extension. The client is terminating the connection because it |
83 | * wanted the server to negotiate one or more extension, but the server |
84 | * didn't. */ |
85 | WS_CLOSE_MISSING_EXTENSION = 1010, |
86 | |
87 | /** Internal Error. The server is terminating the connection because it |
88 | * encountered an unexpected condition that prevented it from fulfilling the |
89 | * request. */ |
90 | WS_CLOSE_INTERNAL_ERROR = 1011, |
91 | |
92 | /** Service Restart. The server is terminating the connection because it is |
93 | * restarting. */ |
94 | WS_CLOSE_SERVICE_RESTART = 1012, |
95 | |
96 | /** Try Again Later. The server is terminating the connection due to a |
97 | * temporary condition, e.g. it is overloaded and is casting off some of its |
98 | * clients. */ |
99 | WS_CLOSE_TRY_AGAIN_LATER = 1013, |
100 | |
101 | /** Bad Gateway. The server was acting as a gateway or proxy and received an |
102 | * invalid response from the upstream server. This is similar to 502 HTTP |
103 | * Status Code. */ |
104 | WS_CLOSE_BAD_GATEWAY = 1014, |
105 | |
106 | /** TLS handshake. Reserved value. The connection is closed due to a failure |
107 | * to perform a TLS handshake. */ |
108 | WS_CLOSE_TLS_HANDSHAKE = 1015 |
109 | |
110 | } ws_close_code_t; |
111 | |
112 | //------------------------------------------------------------------------------ |
113 | // Internal functions |
114 | //------------------------------------------------------------------------------ |
115 | |
116 | /** |
117 | * @defgroup ConnectionFunctions |
118 | * |
119 | * @brief Functions that manage WebSocket connections |
120 | * |
121 | */ |
122 | |
123 | /** |
124 | * @brief Attempts connection based on previously parsed URL |
125 | * |
126 | * @param c The websocket connection. |
127 | * @return Returns true if the connection is successful, false otherwise. |
128 | */ |
129 | static bool cnx_connect(); |
130 | |
131 | /** |
132 | * @brief Generates a new, random WebSocket key for the handshake process. |
133 | * |
134 | * @return A pointer to the generated WebSocket key. |
135 | * |
136 | * @ingroup ConnectionFunctions |
137 | */ |
138 | static char* generate_websocket_key(); |
139 | |
140 | /** |
141 | * @brief Extracts the WebSocket accept key from a server's handshake response. |
142 | * |
143 | * @param response The server's handshake response. |
144 | * @return A pointer to the extracted WebSocket accept key. |
145 | * |
146 | * @ingroup ConnectionFunctions |
147 | */ |
148 | static char* (const char* response); |
149 | |
150 | /** |
151 | * @brief Verifies the server's handshake response using the client's original |
152 | * key. |
153 | * |
154 | * @param key The client's original key. |
155 | * @param response The server's handshake response. |
156 | * @return 0 if the handshake is verified successfully, an error code otherwise. |
157 | * |
158 | * @ingroup ConnectionFunctions |
159 | */ |
160 | static int verify_handshake(const char* key, const char* response); |
161 | |
162 | |
163 | |
164 | |
165 | /** |
166 | * @defgroup SocketFunctions |
167 | * |
168 | * @brief Functions that manage sockets |
169 | * |
170 | */ |
171 | |
172 | /** |
173 | * @brief Performs WebSocket connection. This is integrated as callback into |
174 | * vws_connect(). |
175 | * |
176 | * @param s The WebSocket connection |
177 | * @return Returns true if handshake succeeded, false otherwise |
178 | * |
179 | * @ingroup ConnectionFunctions |
180 | */ |
181 | |
182 | static bool socket_handshake(vws_socket* s); |
183 | |
184 | /** |
185 | * @brief Waits for a complete frame to be available from a WebSocket |
186 | * connection. |
187 | * |
188 | * @param c The vws_cnx representing the WebSocket connection. |
189 | * @return The number of bytes read and processed, or an error code if an error |
190 | * occurred. |
191 | * |
192 | * @ingroup SocketFunctions |
193 | */ |
194 | static ssize_t socket_wait_for_frame(vws_cnx* c); |
195 | |
196 | /** |
197 | * @defgroup FrameFunctions |
198 | * |
199 | * @brief Functions that manage websocket frames |
200 | * |
201 | */ |
202 | |
203 | /** |
204 | * @brief Processes a single frame from a WebSocket connection. |
205 | * |
206 | * @param c The vws_cnx representing the WebSocket connection. |
207 | * @param frame The vws_frame to process. |
208 | * @return void |
209 | * |
210 | * @ingroup FrameFunctions |
211 | */ |
212 | static void process_frame(vws_cnx* c, vws_frame* frame); |
213 | |
214 | |
215 | |
216 | |
217 | /** |
218 | * @defgroup MessageFunctions |
219 | * |
220 | * @brief Functions that manage websocket messages |
221 | * |
222 | */ |
223 | |
224 | /** |
225 | * @brief Checks if a complete message is available in the connection's message |
226 | * queue. |
227 | * |
228 | * @param c The vws_cnx representing the WebSocket connection. |
229 | * @return True if a complete message is available, false otherwise. |
230 | * |
231 | * @ingroup MessageFunctions |
232 | */ |
233 | static bool has_complete_message(vws_cnx* c); |
234 | |
235 | |
236 | |
237 | |
238 | /** |
239 | * @defgroup TraceFunctions |
240 | * |
241 | * @brief Functions that provide tracing and debugging |
242 | * |
243 | */ |
244 | |
245 | /** |
246 | * @brief Structure representing a WebSocket header. |
247 | * |
248 | * @ingroup TraceFunctions |
249 | */ |
250 | typedef struct |
251 | { |
252 | uint8_t fin; /**< Indicates the final frame of a message */ |
253 | uint8_t opcode; /**< Identifies the frame type */ |
254 | uint8_t mask; /**< Indicates if the frame payload is masked */ |
255 | uint64_t payload_len; /**< Length of the payload data */ |
256 | uint32_t masking_key; /**< Key used for payload data masking */ |
257 | } ; |
258 | |
259 | /** |
260 | * @brief Dumps the contents of a WebSocket header for debugging purposes. |
261 | * |
262 | * @param header The WebSocket header to dump. |
263 | * |
264 | * @ingroup TraceFunctions |
265 | */ |
266 | static void dump_websocket_header(const ws_header* ); |
267 | |
268 | //------------------------------------------------------------------------------ |
269 | //> Connection API |
270 | //------------------------------------------------------------------------------ |
271 | |
272 | bool cnx_connect(vws_cnx* c) |
273 | { |
274 | if (c->url->host == NULL) |
275 | { |
276 | vws.error(VE_MEM, "Invalid or missing host" ); |
277 | return false; |
278 | } |
279 | |
280 | // Connect to the server |
281 | cstr default_port = strcmp(c->url->protocol, "wss" ) == 0 ? "443" : "80" ; |
282 | cstr port = c->url->port != NULL ? c->url->port : default_port; |
283 | |
284 | bool ssl = false; |
285 | if (strcmp(c->url->protocol, "wss" ) == 0) |
286 | { |
287 | ssl = true; |
288 | } |
289 | |
290 | return vws_socket_connect((vws_socket*)c, c->url->host, atoi(port), ssl); |
291 | } |
292 | |
293 | vws_cnx* vws_cnx_new() |
294 | { |
295 | vws_cnx* c = (vws_cnx*)vws.malloc(sizeof(vws_cnx)); |
296 | memset(c, 0, sizeof(vws_cnx)); |
297 | |
298 | // Call base constructor |
299 | vws_socket_ctor((vws_socket*)c); |
300 | |
301 | c->base.hs = socket_handshake; |
302 | c->flags = CNX_CLOSED; |
303 | c->url = NULL; |
304 | c->key = generate_websocket_key(); |
305 | c->process = process_frame; |
306 | c->disconnect = NULL; |
307 | c->data = NULL; |
308 | |
309 | sc_queue_init(&c->queue); |
310 | |
311 | return c; |
312 | } |
313 | |
314 | void vws_cnx_free(vws_cnx* c) |
315 | { |
316 | if (c == NULL) |
317 | { |
318 | return; |
319 | } |
320 | |
321 | vws_disconnect(c); |
322 | |
323 | // Free receive queue contents |
324 | vws_frame* f; |
325 | sc_queue_foreach (&c->queue, f) |
326 | { |
327 | vws_frame_free(f); |
328 | } |
329 | |
330 | // Free receive queue |
331 | sc_queue_term(&c->queue); |
332 | |
333 | // Free URL |
334 | if (c->url != NULL) |
335 | { |
336 | url_free((url_data_t*)c->url); |
337 | c->url = NULL; |
338 | } |
339 | |
340 | // Free websocket key |
341 | vws.free(c->key); |
342 | |
343 | // Call base constructor |
344 | vws_socket_dtor((vws_socket*)c); |
345 | } |
346 | |
347 | void vws_cnx_set_server_mode(vws_cnx* c) |
348 | { |
349 | vws_set_flag(&c->flags, CNX_SERVER); |
350 | } |
351 | |
352 | bool vws_connect(vws_cnx* c, cstr uri) |
353 | { |
354 | if (c == NULL) |
355 | { |
356 | // Return early if failed to create a connection. |
357 | vws.error(VE_RT, "Invalid connection pointer()" ); |
358 | return false; |
359 | } |
360 | |
361 | if (c->url != NULL) |
362 | { |
363 | url_free((url_data_t*)c->url); |
364 | } |
365 | |
366 | c->url = (vws_url_data*)url_parse(uri); |
367 | |
368 | return cnx_connect(c); |
369 | } |
370 | |
371 | bool vws_reconnect(vws_cnx* c) |
372 | { |
373 | if (vws_cnx_is_connected(c) == true) |
374 | { |
375 | return true; |
376 | } |
377 | |
378 | if (c->url != NULL) |
379 | { |
380 | cnx_connect(c); |
381 | } |
382 | |
383 | return false; |
384 | } |
385 | |
386 | bool vws_cnx_is_connected(vws_cnx* c) |
387 | { |
388 | if (vws_socket_is_connected((vws_socket*)c) == false) |
389 | { |
390 | vws.error(VE_SOCKET, "vws_cnx_is_connected()" ); |
391 | return false; |
392 | } |
393 | |
394 | return true; |
395 | } |
396 | |
397 | bool socket_handshake(vws_socket* s) |
398 | { |
399 | vws_cnx* c = (vws_cnx*)s; |
400 | |
401 | // Send the WebSocket handshake request |
402 | const char* rt = |
403 | "GET %s HTTP/1.1\r\n" |
404 | "Host: %s\r\n" |
405 | "Cache-Control: no-cache\r\n" |
406 | "Origin: %s\r\n" |
407 | "Upgrade: websocket\r\n" |
408 | "Connection: Upgrade\r\n" |
409 | "Sec-WebSocket-Key: %s\r\n" |
410 | "Sec-WebSocket-Version: 13\r\n" |
411 | "\r\n" ; |
412 | |
413 | char req[MAX_BUFFER_SIZE]; |
414 | snprintf(req, sizeof(req), rt, c->url->path, c->url->host, c->url->href, c->key); |
415 | |
416 | ssize_t n; |
417 | size_t total = 0; |
418 | size_t size = strlen(req); |
419 | |
420 | while (true) |
421 | { |
422 | n = vws_socket_write(s, (ucstr)req, size); |
423 | |
424 | if (vws_cnx_is_connected(c) == false) |
425 | { |
426 | return false; |
427 | } |
428 | |
429 | if (n > 0) |
430 | { |
431 | total += n; |
432 | |
433 | if (total == size) |
434 | { |
435 | break; |
436 | } |
437 | } |
438 | |
439 | if (n == 0) |
440 | { |
441 | if (vws.e.code == VE_TIMEOUT) |
442 | { |
443 | return false; |
444 | } |
445 | } |
446 | |
447 | if (n < 0) |
448 | { |
449 | return false; |
450 | } |
451 | } |
452 | |
453 | // Create HTTP response message |
454 | vws_http_msg* http = vws_http_msg_new(HTTP_RESPONSE); |
455 | |
456 | // Read until full HTTP response is received |
457 | while (true) |
458 | { |
459 | n = vws_socket_read(s); |
460 | |
461 | if (vws_cnx_is_connected(c) == false) |
462 | { |
463 | // Clear the socket buffer of anything that did arrive, |
464 | // otherwise it will possibly be in inconsistent state. |
465 | vws_buffer_clear(c->base.buffer); |
466 | |
467 | // Free HTTP response |
468 | vws_http_msg_free(http); |
469 | |
470 | return false; |
471 | } |
472 | |
473 | // If there was a timeout |
474 | if (n == 0) |
475 | { |
476 | // Fail |
477 | if (vws.e.code == VE_TIMEOUT) |
478 | { |
479 | // Clear the socket buffer of anything that did arrive, |
480 | // otherwise it will possibly be in inconsistent state. |
481 | vws_buffer_clear(c->base.buffer); |
482 | |
483 | // Free HTTP response |
484 | vws_http_msg_free(http); |
485 | |
486 | return false; |
487 | } |
488 | } |
489 | |
490 | if (n < 0) |
491 | { |
492 | // Clear the socket buffer of anything that did arrive, |
493 | // otherwise it will possibly be in inconsistent state. |
494 | vws_buffer_clear(c->base.buffer); |
495 | |
496 | // Free HTTP response |
497 | vws_http_msg_free(http); |
498 | |
499 | return false; |
500 | } |
501 | |
502 | if (s->buffer->size > 0) |
503 | { |
504 | cstr data = (cstr)s->buffer->data; |
505 | size_t size = s->buffer->size; |
506 | ssize_t n = vws_http_msg_parse(http, data, size); |
507 | |
508 | if (http->headers_complete == true) |
509 | { |
510 | // Drain HTTP request data from socket buffer |
511 | vws_buffer_drain(c->base.buffer, n); |
512 | |
513 | break; |
514 | } |
515 | } |
516 | } |
517 | |
518 | struct sc_map_str* = &http->headers; |
519 | cstr accept_key = vws_map_get(headers, "sec-websocket-accept" ); |
520 | |
521 | if (accept_key == NULL) |
522 | { |
523 | vws.error(VE_SYS, "connect failed: no accept key returned" ); |
524 | |
525 | return false; |
526 | } |
527 | |
528 | if (verify_handshake(c->key, accept_key) == false) |
529 | { |
530 | vws.error(VE_RT, "Handshake verification failed" ); |
531 | vws.free(accept_key); |
532 | return false; |
533 | } |
534 | |
535 | vws_http_msg_free(http); |
536 | |
537 | return true; |
538 | } |
539 | |
540 | void vws_disconnect(vws_cnx* c) |
541 | { |
542 | vws_socket* s = (vws_socket*)c; |
543 | |
544 | if (vws_cnx_is_connected(c) == false) |
545 | { |
546 | return; |
547 | } |
548 | |
549 | // If disconnect callback is registered |
550 | if (c->disconnect != NULL) |
551 | { |
552 | // Call it |
553 | c->disconnect(c); |
554 | } |
555 | |
556 | c->flags = CNX_CLOSED; |
557 | |
558 | vws_buffer* buffer = vws_generate_close_frame(); |
559 | |
560 | for (size_t i = 0; i < buffer->size;) |
561 | { |
562 | int n = vws_socket_write(s, buffer->data + i, buffer->size - i); |
563 | |
564 | if (n < 0) |
565 | { |
566 | break; |
567 | } |
568 | |
569 | i += n; |
570 | } |
571 | |
572 | vws_buffer_free(buffer); |
573 | |
574 | vws_socket_disconnect(s); |
575 | } |
576 | |
577 | //------------------------------------------------------------------------------ |
578 | //> Messaging API |
579 | //------------------------------------------------------------------------------ |
580 | |
581 | ssize_t vws_frame_send_text(vws_cnx* c, cstr data) |
582 | { |
583 | return vws_frame_send_data(c, (ucstr)data, strlen(data), 0x1); |
584 | } |
585 | |
586 | ssize_t vws_frame_send_binary(vws_cnx* c, ucstr data, size_t size) |
587 | { |
588 | return vws_frame_send_data(c, data, size, 0x2); |
589 | } |
590 | |
591 | ssize_t vws_frame_send_data(vws_cnx* c, ucstr data, size_t size, int oc) |
592 | { |
593 | return vws_frame_send(c, vws_frame_new(data, size, oc)); |
594 | } |
595 | |
596 | ssize_t vws_msg_send_text(vws_cnx* c, cstr data) |
597 | { |
598 | return vws_frame_send_data(c, (ucstr)data, strlen(data), 0x1); |
599 | } |
600 | |
601 | ssize_t vws_msg_send_binary(vws_cnx* c, ucstr data, size_t size) |
602 | { |
603 | return vws_frame_send_data(c, data, size, 0x2); |
604 | } |
605 | |
606 | ssize_t vws_msg_send_data(vws_cnx* c, ucstr data, size_t size, int oc) |
607 | { |
608 | return vws_frame_send(c, vws_frame_new(data, size, oc)); |
609 | } |
610 | |
611 | ssize_t vws_frame_send(vws_cnx* c, vws_frame* frame) |
612 | { |
613 | if (vws_cnx_is_connected(c) == false) |
614 | { |
615 | return -1; |
616 | } |
617 | |
618 | vws_buffer* binary = vws_serialize(frame); |
619 | |
620 | if (vws.tracelevel >= VT_PROTOCOL) |
621 | { |
622 | vws_trace_lock(); |
623 | printf("\n\n" ); |
624 | printf("+----------------------------------------------------+\n" ); |
625 | printf("| Frame Sent |\n" ); |
626 | printf("+----------------------------------------------------+\n" ); |
627 | |
628 | vws_dump_websocket_frame(binary->data, binary->size); |
629 | printf("------------------------------------------------------\n" ); |
630 | vws_trace_unlock(); |
631 | } |
632 | |
633 | ssize_t n = 0; |
634 | |
635 | if (binary->data != NULL) |
636 | { |
637 | n = vws_socket_write((vws_socket*)c, binary->data, binary->size); |
638 | vws_buffer_free(binary); |
639 | |
640 | if (vws_cnx_is_connected(c) == false) |
641 | { |
642 | return -1; |
643 | } |
644 | } |
645 | |
646 | vws.success(); |
647 | |
648 | return n; |
649 | } |
650 | |
651 | //------------------------------------------------------------------------------ |
652 | //> Message API |
653 | //------------------------------------------------------------------------------ |
654 | |
655 | vws_msg* vws_msg_new() |
656 | { |
657 | vws_msg* m = vws.malloc(sizeof(vws_msg)); |
658 | m->opcode = 0; |
659 | m->data = vws_buffer_new(); |
660 | |
661 | return m; |
662 | } |
663 | |
664 | void vws_msg_free(vws_msg* m) |
665 | { |
666 | if (m != NULL) |
667 | { |
668 | vws_buffer_free(m->data); |
669 | vws.free(m); |
670 | } |
671 | } |
672 | |
673 | vws_msg* vws_msg_recv(vws_cnx* c) |
674 | { |
675 | // Default success unless error |
676 | vws.success(); |
677 | |
678 | if (vws_cnx_is_connected(c) == false) |
679 | { |
680 | return NULL; |
681 | } |
682 | |
683 | while (true) |
684 | { |
685 | vws_msg* msg = vws_msg_pop(c); |
686 | |
687 | if (msg != NULL) |
688 | { |
689 | return msg; |
690 | } |
691 | |
692 | if (socket_wait_for_frame(c) <= 0) |
693 | { |
694 | break; |
695 | } |
696 | } |
697 | |
698 | return NULL; |
699 | } |
700 | |
701 | //------------------------------------------------------------------------------ |
702 | //> Frame API |
703 | //------------------------------------------------------------------------------ |
704 | |
705 | vws_frame* vws_frame_new(ucstr data, size_t s, unsigned char oc) |
706 | { |
707 | vws_frame* f = vws.malloc(sizeof(vws_frame)); |
708 | |
709 | // We must make our own copy of the data for deterministic memory management |
710 | |
711 | f->fin = 1; |
712 | f->opcode = oc; |
713 | f->mask = 1; |
714 | f->offset = 0; |
715 | f->size = s; |
716 | f->data = NULL; |
717 | |
718 | if (f->size > 0) |
719 | { |
720 | f->data = vws.malloc(f->size); |
721 | memcpy(f->data, data, f->size); |
722 | } |
723 | |
724 | return f; |
725 | } |
726 | |
727 | void vws_frame_free(vws_frame* f) |
728 | { |
729 | if (f != NULL) |
730 | { |
731 | if (f->data != NULL) |
732 | { |
733 | vws.free(f->data); |
734 | f->data = NULL; |
735 | } |
736 | |
737 | f->size = 0; |
738 | |
739 | vws.free(f); |
740 | } |
741 | } |
742 | |
743 | vws_frame* vws_frame_recv(vws_cnx* c) |
744 | { |
745 | // Default success unless error |
746 | vws.success(); |
747 | |
748 | if (vws_cnx_is_connected(c) == false) |
749 | { |
750 | return NULL; |
751 | } |
752 | |
753 | while (true) |
754 | { |
755 | if (sc_queue_size(&c->queue) > 0) |
756 | { |
757 | return sc_queue_del_last(&c->queue); |
758 | } |
759 | |
760 | if (socket_wait_for_frame(c) <= 0) |
761 | { |
762 | break; |
763 | } |
764 | } |
765 | |
766 | return NULL; |
767 | } |
768 | |
769 | vws_buffer* vws_serialize(vws_frame* f) |
770 | { |
771 | if (f == NULL) |
772 | { |
773 | vws.error(VE_RT, "empty frame" ); |
774 | |
775 | return NULL; |
776 | } |
777 | |
778 | //> Section 1: Size calculation |
779 | |
780 | // Calculate the frame size |
781 | uint64_t payload_length = f->size; |
782 | |
783 | // Set the mask bit and payload length. Maximum frame size with extended |
784 | // payload length and masking key |
785 | unsigned char [14]; |
786 | |
787 | // Minimum frame size |
788 | size_t = 2; |
789 | |
790 | // Set the FIN bit and opcode |
791 | header[0] = f->fin << 7 | f->opcode; |
792 | |
793 | if (payload_length <= 125) |
794 | { |
795 | header[1] = payload_length; |
796 | } |
797 | else if (payload_length <= 65535) |
798 | { |
799 | header[1] = 126; |
800 | header[2] = (payload_length >> 8) & 0xFF; |
801 | header[3] = payload_length & 0xFF; |
802 | |
803 | // Additional bytes for payload length |
804 | header_size += 2; |
805 | } |
806 | else |
807 | { |
808 | header[1] = 127; |
809 | header[2] = (payload_length >> 56) & 0xFF; |
810 | header[3] = (payload_length >> 48) & 0xFF; |
811 | header[4] = (payload_length >> 40) & 0xFF; |
812 | header[5] = (payload_length >> 32) & 0xFF; |
813 | header[6] = (payload_length >> 24) & 0xFF; |
814 | header[7] = (payload_length >> 16) & 0xFF; |
815 | header[8] = (payload_length >> 8) & 0xFF; |
816 | header[9] = payload_length & 0xFF; |
817 | |
818 | // Additional bytes for payload length |
819 | header_size += 8; |
820 | } |
821 | |
822 | //> Section 2: Frame allocation |
823 | |
824 | size_t frame_size = header_size + payload_length; |
825 | |
826 | if (f->mask) |
827 | { |
828 | // Set the masking bit |
829 | header[1] |= 0x80; |
830 | |
831 | // Additional bytes for masking key |
832 | frame_size += 4; |
833 | } |
834 | |
835 | // Allocate memory for the frame |
836 | unsigned char* frame_data = (unsigned char*)vws.malloc(frame_size); |
837 | |
838 | // Copy the header to the frame |
839 | memcpy(frame_data, header, header_size); |
840 | |
841 | //> Section 3: Masking |
842 | |
843 | if (f->mask) |
844 | { |
845 | // Generate a random masking key |
846 | |
847 | unsigned char masking_key[4]; |
848 | |
849 | if (RAND_bytes(masking_key, sizeof(masking_key)) != 1) |
850 | { |
851 | vws.error(VE_RT, "RAND_bytes() failed" ); |
852 | vws.free(frame_data); |
853 | vws_frame_free(f); |
854 | |
855 | return NULL; |
856 | } |
857 | |
858 | // Copy the masking key to the frame |
859 | memcpy(frame_data + header_size, masking_key, 4); |
860 | |
861 | // Apply masking to the payload data |
862 | size_t payload_start = header_size + 4; |
863 | for (size_t i = 0; i < payload_length; i++) |
864 | { |
865 | frame_data[payload_start + i] = f->data[i] ^ masking_key[i % 4]; |
866 | } |
867 | } |
868 | else |
869 | { |
870 | // Copy the payload data without masking |
871 | memcpy(frame_data + header_size, f->data, payload_length); |
872 | } |
873 | |
874 | //> Section 4: Finalizing |
875 | |
876 | // Free the frame |
877 | vws_frame_free(f); |
878 | |
879 | // Create the vws_buffer to hold the frame data |
880 | vws_buffer* buffer = vws_buffer_new(); |
881 | |
882 | // Have buffer take ownership of data |
883 | buffer->data = frame_data; |
884 | buffer->size = frame_size; |
885 | |
886 | vws.success(); |
887 | |
888 | return buffer; |
889 | } |
890 | |
891 | fs_t vws_deserialize(ucstr data, size_t size, vws_frame* f, size_t* consumed) |
892 | { |
893 | // Check if the data contains the minimum required frame header bytes |
894 | if (size < 2) |
895 | { |
896 | return FRAME_INCOMPLETE; |
897 | } |
898 | |
899 | // Read the first byte (FIN bit and opcode) |
900 | f->fin = (data[0] >> 7) & 0x01; |
901 | f->opcode = data[0] & 0x0F; |
902 | |
903 | // Read the second byte (mask bit and payload length) |
904 | f->mask = (data[1] >> 7) & 0x01; |
905 | f->size = data[1] & 0x7F; |
906 | |
907 | // Check if the payload length requires additional bytes |
908 | size_t size_bytes = 0; |
909 | if (f->size == 126) |
910 | { |
911 | size_bytes = 2; |
912 | } |
913 | else if (f->size == 127) |
914 | { |
915 | size_bytes = 8; |
916 | } |
917 | |
918 | // Check if the data contains complete frame header and payload |
919 | size_t required_bytes = 2 + size_bytes; |
920 | |
921 | if (size < required_bytes) |
922 | { |
923 | return FRAME_INCOMPLETE; |
924 | } |
925 | |
926 | // Read the payload |
927 | if (size_bytes > 0) |
928 | { |
929 | f->size = 0; |
930 | for (size_t i = 0; i < size_bytes; i++) |
931 | { |
932 | f->size = (f->size << 8) | data[2 + i]; |
933 | } |
934 | } |
935 | |
936 | // Check if the frame has masking key and payload data |
937 | if (f->mask) |
938 | { |
939 | // Check if the data contains the masking key and payload data |
940 | required_bytes += 4 + f->size; |
941 | |
942 | if (size < required_bytes) |
943 | { |
944 | return FRAME_INCOMPLETE; |
945 | } |
946 | |
947 | // Store the payload offset |
948 | f->offset = 2 + size_bytes + 4; |
949 | |
950 | // Allocate the frame data |
951 | f->data = vws.malloc(f->size); |
952 | |
953 | // Create a temp variable for the masking key |
954 | unsigned char mask[4]; |
955 | memcpy(mask, data + 2 + size_bytes, 4); |
956 | |
957 | // Read the payload data and apply the masking |
958 | for (size_t i = 0; i < f->size; i++) |
959 | { |
960 | f->data[i] = data[f->offset + i] ^ mask[i % 4]; |
961 | } |
962 | } |
963 | else |
964 | { |
965 | // Check if the data contains the payload data |
966 | |
967 | required_bytes += f->size; |
968 | |
969 | if (size < required_bytes) |
970 | { |
971 | return FRAME_INCOMPLETE; |
972 | } |
973 | |
974 | // Store the payload offset |
975 | f->offset = 2 + size_bytes; |
976 | |
977 | // Allocate the frame data |
978 | f->data = vws.malloc(f->size); |
979 | |
980 | // Copy the payload data |
981 | memcpy(f->data, data + f->offset, f->size); |
982 | } |
983 | |
984 | // Update the bytes consumed |
985 | *consumed = required_bytes; |
986 | |
987 | return FRAME_COMPLETE; |
988 | } |
989 | |
990 | //------------------------------------------------------------------------------ |
991 | // Utility functions |
992 | //------------------------------------------------------------------------------ |
993 | |
994 | void process_frame(vws_cnx* c, vws_frame* f) |
995 | { |
996 | switch (f->opcode) |
997 | { |
998 | case CLOSE_FRAME: |
999 | { |
1000 | // Set closing state |
1001 | vws_set_flag(&c->flags, CNX_CLOSING); |
1002 | |
1003 | // Build the response frame |
1004 | vws_buffer* buffer = vws_generate_close_frame(); |
1005 | |
1006 | // Send the response frame |
1007 | vws_socket_write((vws_socket*)c, buffer->data, buffer->size); |
1008 | |
1009 | // Clean up |
1010 | vws_buffer_free(buffer); |
1011 | vws_frame_free(f); |
1012 | |
1013 | break; |
1014 | } |
1015 | |
1016 | case TEXT_FRAME: |
1017 | case BINARY_FRAME: |
1018 | case CONTINUATION_FRAME: |
1019 | { |
1020 | // Add to queue |
1021 | sc_queue_add_first(&c->queue, f); |
1022 | |
1023 | break; |
1024 | } |
1025 | |
1026 | case PING_FRAME: |
1027 | { |
1028 | // Generate the PONG response |
1029 | vws_buffer* buffer = vws_generate_pong_frame(f->data, f->size); |
1030 | |
1031 | // Send the PONG response |
1032 | vws_socket_write((vws_socket*)c, buffer->data, buffer->size); |
1033 | |
1034 | // Clean up |
1035 | vws_buffer_free(buffer); |
1036 | vws_frame_free(f); |
1037 | |
1038 | break; |
1039 | } |
1040 | |
1041 | case PONG_FRAME: |
1042 | { |
1043 | // No need to send a response |
1044 | |
1045 | vws_frame_free(f); |
1046 | |
1047 | break; |
1048 | } |
1049 | |
1050 | default: |
1051 | { |
1052 | // Invalid frame type |
1053 | vws_frame_free(f); |
1054 | } |
1055 | } |
1056 | |
1057 | vws.success(); |
1058 | } |
1059 | |
1060 | char* generate_websocket_key() |
1061 | { |
1062 | // Generate a random 16-byte value |
1063 | unsigned char random_bytes[16]; |
1064 | if (RAND_bytes(random_bytes, sizeof(random_bytes)) != 1) |
1065 | { |
1066 | return NULL; |
1067 | } |
1068 | |
1069 | // Base64-encode the random bytes |
1070 | char* encoded_key = vws_base64_encode(random_bytes, sizeof(random_bytes)); |
1071 | |
1072 | if (encoded_key == NULL) |
1073 | { |
1074 | return NULL; |
1075 | } |
1076 | |
1077 | return encoded_key; |
1078 | } |
1079 | |
1080 | cstr vws_accept_key(cstr key) |
1081 | { |
1082 | // Concatenate the key and WebSocket GUID |
1083 | const char* websocket_guid = "258EAFA5-E914-47DA-95CA-C5AB0DC85B11" ; |
1084 | size_t key_length = strlen(key); |
1085 | size_t guid_length = strlen(websocket_guid); |
1086 | size_t input_length = key_length + guid_length; |
1087 | char* input = (char*)vws.malloc(input_length + 1); |
1088 | |
1089 | strncpy(input, key, key_length); |
1090 | strncpy(input + key_length, websocket_guid, guid_length); |
1091 | input[input_length] = '\0'; |
1092 | |
1093 | // Compute the SHA-1 hash of the concatenated value |
1094 | unsigned char hash[SHA_DIGEST_LENGTH]; |
1095 | SHA1((const unsigned char*)input, input_length, hash); |
1096 | |
1097 | // Base64-encode the hash |
1098 | char* encoded_hash = vws_base64_encode(hash, sizeof(hash)); |
1099 | |
1100 | vws.free(input); |
1101 | |
1102 | return encoded_hash; |
1103 | } |
1104 | |
1105 | int verify_handshake(const char* key, const char* response) |
1106 | { |
1107 | char* hash = vws_accept_key(key); |
1108 | int result = strcmp(hash, response); |
1109 | vws.free(hash); |
1110 | |
1111 | return result == 0; |
1112 | } |
1113 | |
1114 | ssize_t socket_wait_for_frame(vws_cnx* c) |
1115 | { |
1116 | // Default success unless error |
1117 | vws.success(); |
1118 | |
1119 | if (vws_cnx_is_connected(c) == false) |
1120 | { |
1121 | return -1; |
1122 | } |
1123 | |
1124 | ssize_t n; |
1125 | unsigned char buf[1024]; |
1126 | |
1127 | while (true) |
1128 | { |
1129 | memset(buf, 0, 1024); |
1130 | n = vws_socket_read((vws_socket*)c); |
1131 | |
1132 | if (n <= 0) |
1133 | { |
1134 | // Error already set |
1135 | return n; |
1136 | } |
1137 | |
1138 | if (vws_cnx_ingress(c) > 0) |
1139 | { |
1140 | break; |
1141 | } |
1142 | } |
1143 | |
1144 | return n; |
1145 | } |
1146 | |
1147 | ssize_t vws_cnx_ingress(vws_cnx* c) |
1148 | { |
1149 | size_t total_consumed = 0; |
1150 | |
1151 | // Process as many frames as possible |
1152 | while (true) |
1153 | { |
1154 | // If there is no more data in socket buffer |
1155 | if (c->base.buffer->size == 0) |
1156 | { |
1157 | break; |
1158 | } |
1159 | |
1160 | // Attempt to parse complete frame |
1161 | size_t consumed = 0; |
1162 | vws_buffer* b = c->base.buffer; |
1163 | vws_frame* frame = vws_frame_new(NULL, 0, TEXT_FRAME); |
1164 | |
1165 | if (vws.tracelevel >= VT_PROTOCOL) |
1166 | { |
1167 | vws_trace_lock(); |
1168 | printf("\n+----------------------------------------------------+\n" ); |
1169 | printf("| Frame Received |\n" ); |
1170 | printf("+----------------------------------------------------+\n" ); |
1171 | vws_dump_websocket_frame(b->data, b->size); |
1172 | printf("------------------------------------------------------\n" ); |
1173 | vws_trace_unlock(); |
1174 | } |
1175 | |
1176 | fs_t rc = vws_deserialize(b->data, b->size, frame, &consumed); |
1177 | |
1178 | if (rc == FRAME_ERROR) |
1179 | { |
1180 | vws.error(VE_WARN, "FRAME_ERROR" ); |
1181 | vws_frame_free(frame); |
1182 | |
1183 | return 0; |
1184 | } |
1185 | |
1186 | if (rc == FRAME_INCOMPLETE) |
1187 | { |
1188 | // No complete frame in socket buffer |
1189 | vws_frame_free(frame); |
1190 | |
1191 | return 0; |
1192 | } |
1193 | |
1194 | // Update |
1195 | total_consumed += consumed; |
1196 | |
1197 | // We have a frame. Process it. |
1198 | c->process(c, frame); |
1199 | |
1200 | // Drain the consumed frame data from buffer |
1201 | vws_buffer_drain(c->base.buffer, consumed); |
1202 | } |
1203 | |
1204 | vws.success(); |
1205 | |
1206 | return total_consumed; |
1207 | } |
1208 | |
1209 | vws_buffer* vws_generate_close_frame() |
1210 | { |
1211 | size_t size = sizeof(int16_t); |
1212 | int16_t* data = vws.malloc(size); |
1213 | |
1214 | // Convert to network byte order before assignement |
1215 | *data = htons(WS_CLOSE_NORMAL); |
1216 | vws_frame* f = vws_frame_new((ucstr)data, size, CLOSE_FRAME); |
1217 | |
1218 | vws.free(data); |
1219 | |
1220 | return vws_serialize(f); |
1221 | } |
1222 | |
1223 | vws_buffer* vws_generate_pong_frame(ucstr ping_data, size_t s) |
1224 | { |
1225 | // We create a new frame with the same data as the received ping frame |
1226 | vws_frame* f = vws_frame_new(ping_data, s, PONG_FRAME); |
1227 | |
1228 | // Serialize the frame and return it |
1229 | return vws_serialize(f); |
1230 | } |
1231 | |
1232 | vws_msg* vws_msg_pop(vws_cnx* c) |
1233 | { |
1234 | if (has_complete_message(c) == false) |
1235 | { |
1236 | return NULL; |
1237 | } |
1238 | |
1239 | // Create new message |
1240 | vws_msg* m = vws_msg_new(); |
1241 | |
1242 | // Set to sentinel value to detect first frame |
1243 | m->opcode = 100; |
1244 | |
1245 | do |
1246 | { |
1247 | vws_frame* f = sc_queue_del_last(&c->queue); |
1248 | |
1249 | // If this is first frame, opcode is sentinel value. We take the opcode |
1250 | // from the first frame only |
1251 | if (m->opcode == 100) |
1252 | { |
1253 | m->opcode = f->opcode; |
1254 | } |
1255 | |
1256 | // Copy frame data into message buffer |
1257 | vws_buffer_append(m->data, f->data, f->size); |
1258 | |
1259 | // Is this the completion frame? |
1260 | bool complete = (f->fin == 1); |
1261 | |
1262 | // Free frame |
1263 | vws_frame_free(f); |
1264 | |
1265 | if (complete) |
1266 | { |
1267 | break; |
1268 | } |
1269 | } |
1270 | while (true); |
1271 | |
1272 | return m; |
1273 | } |
1274 | |
1275 | bool has_complete_message(vws_cnx* c) |
1276 | { |
1277 | vws_frame* f; |
1278 | sc_queue_foreach (&c->queue, f) |
1279 | { |
1280 | if (f->fin == 1) |
1281 | { |
1282 | return true; |
1283 | } |
1284 | } |
1285 | |
1286 | return false; |
1287 | } |
1288 | |
1289 | void (const ws_header* ) |
1290 | { |
1291 | printf(" fin: %u\n" , header->fin); |
1292 | printf(" opcode: %u\n" , header->opcode); |
1293 | printf(" mask: %u (0x%08x)\n" , header->mask, header->masking_key); |
1294 | printf(" payload: %lu bytes\n" , header->payload_len); |
1295 | printf("\n" ); |
1296 | } |
1297 | |
1298 | void vws_dump_websocket_frame(const uint8_t* frame, size_t size) |
1299 | { |
1300 | if (size < 2) |
1301 | { |
1302 | printf("Invalid WebSocket frame\n" ); |
1303 | return; |
1304 | } |
1305 | |
1306 | ws_header ; |
1307 | size_t = 2; |
1308 | header.fin = (frame[0] & 0x80) >> 7; |
1309 | header.opcode = frame[0] & 0x0F; |
1310 | header.mask = (frame[1] & 0x80) >> 7; |
1311 | header.payload_len = frame[1] & 0x7F; |
1312 | |
1313 | if (header.payload_len == 126) |
1314 | { |
1315 | if (size < 4) |
1316 | { |
1317 | printf("Invalid WebSocket frame\n" ); |
1318 | return; |
1319 | } |
1320 | |
1321 | header_size += 2; |
1322 | header.payload_len = ((uint64_t)frame[2] << 8) | frame[3]; |
1323 | } |
1324 | else if (header.payload_len == 127) |
1325 | { |
1326 | if (size < 10) |
1327 | { |
1328 | printf("Invalid WebSocket frame\n" ); |
1329 | return; |
1330 | } |
1331 | |
1332 | header_size += 8; |
1333 | header.payload_len = |
1334 | ((uint64_t)frame[2] << 56) | |
1335 | ((uint64_t)frame[3] << 48) | |
1336 | ((uint64_t)frame[4] << 40) | |
1337 | ((uint64_t)frame[5] << 32) | |
1338 | ((uint64_t)frame[6] << 24) | |
1339 | ((uint64_t)frame[7] << 16) | |
1340 | ((uint64_t)frame[8] << 8) | |
1341 | frame[9]; |
1342 | } |
1343 | |
1344 | if (header.mask) |
1345 | { |
1346 | if (size < header_size + 4) |
1347 | { |
1348 | printf("Invalid WebSocket frame\n" ); |
1349 | return; |
1350 | } |
1351 | header.masking_key = |
1352 | ((uint32_t)frame[header_size] << 24) | |
1353 | ((uint32_t)frame[header_size + 1] << 16) | |
1354 | ((uint32_t)frame[header_size + 2] << 8) | |
1355 | frame[header_size + 3]; |
1356 | header_size += 4; |
1357 | } |
1358 | else |
1359 | { |
1360 | header.masking_key = 0; |
1361 | } |
1362 | |
1363 | printf(" header: %zu bytes\n" , header_size); |
1364 | dump_websocket_header(&header); |
1365 | |
1366 | if (size > header_size) |
1367 | { |
1368 | for (size_t i = header_size; i < size; ++i) |
1369 | { |
1370 | printf("%02x " , frame[i]); |
1371 | if ((i - header_size + 1) % 16 == 0) |
1372 | { |
1373 | printf("\n" ); |
1374 | } |
1375 | } |
1376 | |
1377 | printf("\n" ); |
1378 | } |
1379 | } |
1380 | |