1 | #if defined(__linux__) || defined(__bsd__) || defined(__sunos__) |
2 | #include <unistd.h> |
3 | #include <fcntl.h> |
4 | #include <netdb.h> |
5 | #include <poll.h> |
6 | #include <sys/types.h> |
7 | #include <sys/socket.h> |
8 | #endif |
9 | |
10 | #if defined(__windows__) |
11 | #define _WIN32_WINNT 0x0601 // Windows 7 or later |
12 | #include <winsock2.h> |
13 | #include <ws2tcpip.h> |
14 | #endif |
15 | |
16 | #include <assert.h> |
17 | #include <string.h> |
18 | |
19 | #include <openssl/rand.h> |
20 | |
21 | #include "socket.h" |
22 | |
23 | //------------------------------------------------------------------------------ |
24 | // Internal functions |
25 | //------------------------------------------------------------------------------ |
26 | |
27 | /** @brief Defines the various states of a WebSocket connection */ |
28 | typedef enum |
29 | { |
30 | /** The connection with the client is in the initial SSL handshake phase. */ |
31 | CNX_SSL_INIT = (1 << 3), |
32 | |
33 | } socket_flags_t; |
34 | |
35 | /** |
36 | * @brief Connects to a host at a specific port and returns the connection |
37 | * status. |
38 | * |
39 | * @param host The host to connect to. |
40 | * @param port The port to connect to. |
41 | * @return The connection status, 0 if successful, an error code otherwise. |
42 | * |
43 | * @ingroup ConnectionFunctions |
44 | */ |
45 | static int connect_to_host(const char* host, const char* port); |
46 | |
47 | /** |
48 | * @brief Sets a timeout on a socket read/write operations. |
49 | * |
50 | * @param fd The socket file descriptor. |
51 | * @param sec The timeout value in seconds. |
52 | * @return True if successful, false otherwise. |
53 | * |
54 | * @ingroup SocketFunctions |
55 | */ |
56 | static bool socket_set_timeout(int fd, int sec); |
57 | |
58 | /** |
59 | * @brief Calls handler for unexpected socket closure. |
60 | */ |
61 | static void socket_abnormal_close(vws_socket* c); |
62 | |
63 | //------------------------------------------------------------------------------ |
64 | //> Socket API |
65 | //------------------------------------------------------------------------------ |
66 | |
67 | vws_socket* vws_socket_new() |
68 | { |
69 | vws_socket* c = (vws_socket*)vws.malloc(sizeof(vws_socket)); |
70 | memset(c, 0, sizeof(vws_socket)); |
71 | |
72 | return vws_socket_ctor(c); |
73 | } |
74 | |
75 | vws_socket* vws_socket_ctor(vws_socket* s) |
76 | { |
77 | s->sockfd = -1; |
78 | s->buffer = vws_buffer_new(); |
79 | s->ssl = NULL; |
80 | s->timeout = 10000; |
81 | s->data = NULL; |
82 | s->hs = NULL; |
83 | s->disconnect = NULL; |
84 | s->flush = true; |
85 | |
86 | return s; |
87 | } |
88 | |
89 | void vws_socket_free(vws_socket* c) |
90 | { |
91 | if (c == NULL) |
92 | { |
93 | return; |
94 | } |
95 | |
96 | vws_socket_dtor(c); |
97 | } |
98 | |
99 | void vws_socket_dtor(vws_socket* s) |
100 | { |
101 | vws_socket_disconnect(s); |
102 | |
103 | // Free receive buffer |
104 | vws_buffer_free(s->buffer); |
105 | |
106 | if (s->sockfd >= 0) |
107 | { |
108 | close(s->sockfd); |
109 | } |
110 | |
111 | // Free connection |
112 | vws.free(s); |
113 | } |
114 | |
115 | //------------------------------------------------------------------------------ |
116 | // Utility functions |
117 | //------------------------------------------------------------------------------ |
118 | |
119 | void socket_abnormal_close(vws_socket* c) |
120 | { |
121 | // If disconnect callback is registered |
122 | if (c->disconnect != NULL) |
123 | { |
124 | // Call it |
125 | c->disconnect(c); |
126 | } |
127 | |
128 | vws_socket_close(c); |
129 | } |
130 | |
131 | bool vws_socket_set_timeout(vws_socket* s, int sec) |
132 | { |
133 | if (socket_set_timeout(s->sockfd, sec) == false) |
134 | { |
135 | return false; |
136 | } |
137 | |
138 | // Set socket attribute, this will apply to poll(). |
139 | s->timeout = sec; |
140 | |
141 | return true; |
142 | } |
143 | |
144 | bool socket_set_timeout(int fd, int sec) |
145 | { |
146 | #if defined(__linux__) || defined(__bsd__) || defined(__sunos__) |
147 | |
148 | if (fd < 0) |
149 | { |
150 | vws.error(VE_RT, "Invalid socket descriptor" ); |
151 | return false; |
152 | } |
153 | |
154 | if (sec == -1) |
155 | { |
156 | sec = 0; |
157 | } |
158 | |
159 | struct timeval tm; |
160 | tm.tv_sec = sec; |
161 | tm.tv_usec = 0; |
162 | |
163 | // Set the send timeout |
164 | if (setsockopt(fd, SOL_SOCKET, SO_SNDTIMEO, (char *)&tm, sizeof(tm)) < 0) |
165 | { |
166 | vws.error(VE_SYS, "setsockopt failed" ); |
167 | |
168 | return false; |
169 | } |
170 | |
171 | // Set the receive timeout |
172 | if (setsockopt(fd, SOL_SOCKET, SO_RCVTIMEO, (char *)&tm, sizeof(tm)) < 0) |
173 | { |
174 | vws.error(VE_SYS, "setsockopt failed" ); |
175 | |
176 | return false; |
177 | } |
178 | |
179 | #elif defined(__windows__) |
180 | |
181 | if (fd == INVALID_SOCKET) |
182 | { |
183 | vws.error(VE_RT, "Invalid socket descriptor" ); |
184 | return false; |
185 | } |
186 | |
187 | // Convert from sec to ms for Windows |
188 | DWORD tm = sec * 1000; |
189 | |
190 | if (sec == -1) |
191 | { |
192 | // Maximum value (136.17 years) |
193 | sec = 4294967295; |
194 | } |
195 | |
196 | // Set the send timeout |
197 | if (setsockopt(fd, SOL_SOCKET, SO_SNDTIMEO, (cstr)&tm, sizeof(tm)) < 0) |
198 | { |
199 | vws.error(VE_SYS, "setsockopt failed" ); |
200 | |
201 | return false; |
202 | } |
203 | |
204 | // Set the receive timeout |
205 | if (setsockopt(fd, SOL_SOCKET, SO_RCVTIMEO, (cstr)&tm, sizeof(tm)) < 0) |
206 | { |
207 | vws.error(VE_SYS, "setsockopt failed" ); |
208 | |
209 | return false; |
210 | } |
211 | |
212 | #else |
213 | #error Platform not supported |
214 | #endif |
215 | |
216 | vws.success(); |
217 | |
218 | return true; |
219 | } |
220 | |
221 | bool vws_socket_set_nonblocking(int sockfd) |
222 | { |
223 | #if defined(__linux__) || defined(__bsd__) || defined(__sunos__) |
224 | |
225 | int flags = fcntl(sockfd, F_GETFL, 0); |
226 | |
227 | if (flags == -1) |
228 | { |
229 | vws.error(VE_SYS, "fcntl(sockfd, F_GETFL, 0) failed" ); |
230 | |
231 | return false; |
232 | } |
233 | |
234 | if (fcntl(sockfd, F_SETFL, flags | O_NONBLOCK) == -1) |
235 | { |
236 | vws.error(VE_SYS, "fcntl(sockfd, F_SETFL, flags | O_NONBLOCK) failed" ); |
237 | |
238 | return false; |
239 | } |
240 | |
241 | #elif defined(__windows__) |
242 | |
243 | unsigned long arg = 1; |
244 | if (ioctlsocket(sockfd, FIONBIO, &arg) == SOCKET_ERROR) |
245 | { |
246 | vws.error(VE_SYS, "ioctlsocket(sockfd, FIONBIO, &arg)" ); |
247 | |
248 | return false; |
249 | } |
250 | |
251 | #else |
252 | #error Platform not supported |
253 | #endif |
254 | |
255 | vws.success(); |
256 | |
257 | return true; |
258 | } |
259 | |
260 | bool vws_socket_is_connected(vws_socket* c) |
261 | { |
262 | if (c == NULL) |
263 | { |
264 | return false; |
265 | } |
266 | |
267 | return c->sockfd > -1; |
268 | } |
269 | |
270 | bool vws_socket_connect(vws_socket* c, cstr host, int port, bool ssl) |
271 | { |
272 | if (c == NULL) |
273 | { |
274 | // Return early if failed to create a connection. |
275 | vws.error(VE_RT, "Invalid connection pointer()" ); |
276 | return false; |
277 | } |
278 | |
279 | // Clear socket buffer in case it was previously used in other connection. |
280 | vws_buffer_clear(c->buffer); |
281 | |
282 | if (ssl == true) |
283 | { |
284 | if (vws_is_flag(&vws.state, CNX_SSL_INIT) == false) |
285 | { |
286 | SSL_library_init(); |
287 | RAND_poll(); |
288 | SSL_load_error_strings(); |
289 | |
290 | vws_ssl_ctx = SSL_CTX_new(TLS_method()); |
291 | |
292 | if (vws_ssl_ctx == NULL) |
293 | { |
294 | vws.error(VE_SYS, "Failed to create new SSL context" ); |
295 | return false; |
296 | } |
297 | |
298 | SSL_CTX_set_options(vws_ssl_ctx, SSL_OP_NO_SSLv2 | SSL_OP_NO_SSLv3); |
299 | } |
300 | |
301 | c->ssl = SSL_new(vws_ssl_ctx); |
302 | |
303 | if (c->ssl == NULL) |
304 | { |
305 | vws.error(VE_SYS, "Failed to create new SSL object" ); |
306 | vws_socket_close(c); |
307 | return false; |
308 | } |
309 | |
310 | vws_set_flag(&vws.state, CNX_SSL_INIT); |
311 | } |
312 | |
313 | char port_str[20]; |
314 | sprintf(port_str, "%d" , port); |
315 | c->sockfd = connect_to_host(host, port_str); |
316 | |
317 | if (c->sockfd < 0) |
318 | { |
319 | vws.error(VE_SYS, "Connection failed" ); |
320 | vws_socket_close(c); |
321 | return false; |
322 | } |
323 | |
324 | // Set default timeout |
325 | if (socket_set_timeout(c->sockfd, c->timeout/1000) == false) |
326 | { |
327 | // Error already set |
328 | vws_socket_close(c); |
329 | return false; |
330 | } |
331 | |
332 | if (c->ssl != NULL) |
333 | { |
334 | SSL_set_fd(c->ssl, c->sockfd); |
335 | |
336 | if (SSL_connect(c->ssl) <= 0) |
337 | { |
338 | vws.error(VE_SYS, "SSL connection failed" ); |
339 | vws_socket_close(c); |
340 | return false; |
341 | } |
342 | } |
343 | |
344 | #if defined(__bsd__) |
345 | |
346 | // Disable SIGPIPE |
347 | int val = 1; |
348 | setsockopt(c->sockfd, SOL_SOCKET, SO_NOSIGPIPE, &val, sizeof(val)); |
349 | |
350 | #endif |
351 | |
352 | // Go into non-blocking mode as we are using poll() for socket_read() and |
353 | // socket_write(). |
354 | if (vws_socket_set_nonblocking(c->sockfd) == false) |
355 | { |
356 | // Error already set |
357 | vws_socket_close(c); |
358 | return false; |
359 | } |
360 | |
361 | // Check if handshake handler is registered |
362 | if (c->hs != NULL) |
363 | { |
364 | if (c->hs(c) == false) |
365 | { |
366 | vws.error(VE_SYS, "Handshake failed" ); |
367 | vws_socket_close(c); |
368 | return false; |
369 | } |
370 | } |
371 | |
372 | vws.success(); |
373 | |
374 | return true; |
375 | } |
376 | |
377 | void vws_socket_disconnect(vws_socket* c) |
378 | { |
379 | if (vws_socket_is_connected(c) == false) |
380 | { |
381 | return; |
382 | } |
383 | |
384 | vws_socket_close(c); |
385 | |
386 | vws.success(); |
387 | } |
388 | |
389 | ssize_t vws_socket_read(vws_socket* c) |
390 | { |
391 | // Default success unless error |
392 | vws.success(); |
393 | |
394 | if (vws_socket_is_connected(c) == false) |
395 | { |
396 | vws.error(VE_SOCKET, "vws_socket_read()" ); |
397 | return -1; |
398 | } |
399 | |
400 | // Validate input parameters |
401 | if (c == NULL) |
402 | { |
403 | vws.error(VE_WARN, "Invalid parameters" ); |
404 | return -1; |
405 | } |
406 | |
407 | struct pollfd fds; |
408 | int poll_events = POLLIN; |
409 | |
410 | openssl_reread: |
411 | |
412 | #if defined(__linux__) || defined(__bsd__) || defined(__sunos__) |
413 | |
414 | fds.fd = c->sockfd; |
415 | fds.events = poll_events; |
416 | |
417 | int rc = poll(&fds, 1, c->timeout); |
418 | |
419 | if (fds.revents & (POLLERR | POLLHUP | POLLNVAL)) |
420 | { |
421 | vws.error(VE_SOCKET, "Socket error during poll()" ); |
422 | socket_abnormal_close(c); |
423 | return -1; |
424 | } |
425 | |
426 | #elif defined(__windows__) |
427 | |
428 | WSAPOLLFD fds; |
429 | fds.fd = c->sockfd; |
430 | fds.events = POLLIN; |
431 | |
432 | int rc = WSAPoll(&fds, 1, c->timeout); |
433 | |
434 | if (rc == SOCKET_ERROR) |
435 | { |
436 | vws.error(VE_SOCKET, "Socket error during WSAPoll()" ); |
437 | socket_abnormal_close(c); |
438 | |
439 | return -1; |
440 | } |
441 | |
442 | #else |
443 | #error Platform not supported |
444 | #endif |
445 | |
446 | if (rc == -1) |
447 | { |
448 | vws.error(VE_RT, "poll() failed" ); |
449 | return -1; |
450 | } |
451 | |
452 | if (rc == 0) |
453 | { |
454 | vws.error(VE_TIMEOUT, "poll()" ); |
455 | return 0; |
456 | } |
457 | |
458 | ssize_t n = 0; |
459 | ucstr data = &vws.sslbuf[0]; |
460 | ssize_t size = sizeof(vws.sslbuf); |
461 | |
462 | if (fds.revents & poll_events) |
463 | { |
464 | if (c->ssl != NULL) |
465 | { |
466 | // We need running total bc we may make multiple SSL_read() calls. |
467 | int total = 0; |
468 | |
469 | // Drain all data from SSL buffer |
470 | while ((n = SSL_read(c->ssl, data, size)) > 0) |
471 | { |
472 | // Process received data stored in buf |
473 | total += n; |
474 | vws_buffer_append(c->buffer, data, n); |
475 | |
476 | if (n < size) |
477 | { |
478 | // All available data has been read, break the loop. |
479 | break; |
480 | } |
481 | } |
482 | |
483 | // Check for error conditions |
484 | if (n <= 0) |
485 | { |
486 | int err = SSL_get_error(c->ssl, n); |
487 | |
488 | if (err == SSL_ERROR_WANT_WRITE || err == SSL_ERROR_WANT_READ) |
489 | { |
490 | // SSL needs to do something on socket in order to continue. |
491 | |
492 | // If it wants to read data |
493 | if (err == SSL_ERROR_WANT_READ) |
494 | { |
495 | // We are done. We have emptied the read buffer. |
496 | return total; |
497 | } |
498 | |
499 | // If it wants to write data |
500 | if (err == SSL_ERROR_WANT_WRITE) |
501 | { |
502 | // It's doing some internal negotiation and we need to |
503 | // help it along by running poll() for writes. Then we |
504 | // will return to SSL_read() in which SSL will send out |
505 | // the data it needs to. |
506 | poll_events = POLLOUT; |
507 | goto openssl_reread; |
508 | } |
509 | } |
510 | else if (err == SSL_ERROR_SYSCALL) |
511 | { |
512 | #if defined(__windows__) |
513 | |
514 | int err = WSAGetLastError(); |
515 | |
516 | if (err == WSAEWOULDBLOCK || err == WSAEINPROGRESS) |
517 | { |
518 | vws.error(VE_TIMEOUT, "SSL_read()" ); |
519 | return 0; |
520 | } |
521 | |
522 | #else |
523 | |
524 | if ((errno == EWOULDBLOCK) || (errno == EAGAIN)) |
525 | { |
526 | vws.error(VE_TIMEOUT, "SSL_read()" ); |
527 | return 0; |
528 | } |
529 | |
530 | #endif |
531 | } |
532 | |
533 | // Get the latest OpenSSL error |
534 | char buf[256]; |
535 | unsigned long ssl_err = ERR_get_error(); |
536 | ERR_error_string_n(ssl_err, buf, sizeof(buf)); |
537 | vws.error(VE_SOCKET, "SSL_read() failed: %s" , buf); |
538 | |
539 | // Close socket |
540 | socket_abnormal_close(c); |
541 | |
542 | return -1; |
543 | } |
544 | } |
545 | else |
546 | { |
547 | // Non-SSL socket is readable, perform recv() operation |
548 | #if defined(__linux__) || defined(__sunos__) |
549 | n = recv(c->sockfd, data, size, MSG_NOSIGNAL); |
550 | #else |
551 | n = recv(c->sockfd, data, size, 0); |
552 | #endif |
553 | |
554 | if (n == 0) |
555 | { |
556 | vws.error(VE_SOCKET, "disconnect" ); |
557 | |
558 | // Close socket |
559 | socket_abnormal_close(c); |
560 | |
561 | return -1; |
562 | } |
563 | |
564 | if (n <= -1) |
565 | { |
566 | #if defined(__windows__) |
567 | int err = WSAGetLastError(); |
568 | |
569 | if (err == WSAEWOULDBLOCK || err == WSAEINPROGRESS) |
570 | { |
571 | vws.error(VE_TIMEOUT, "recv()" ); |
572 | return 0; |
573 | } |
574 | |
575 | #else |
576 | |
577 | if ((errno == EWOULDBLOCK) || (errno == EAGAIN)) |
578 | { |
579 | vws.error(VE_TIMEOUT, "recv()" ); |
580 | return 0; |
581 | } |
582 | else |
583 | { |
584 | // Error |
585 | char buf[256]; |
586 | strerror_r(errno, buf, sizeof(buf)); |
587 | vws.error(VE_WARN, "recv() failed: %s" , buf); |
588 | |
589 | // Close socket |
590 | socket_abnormal_close(c); |
591 | |
592 | return -1; |
593 | } |
594 | #endif |
595 | } |
596 | |
597 | // Should always be true if we get here |
598 | if (n > 0) |
599 | { |
600 | vws_buffer_append(c->buffer, data, n); |
601 | } |
602 | } |
603 | } |
604 | |
605 | return n; |
606 | } |
607 | |
608 | ssize_t vws_socket_write(vws_socket* c, const ucstr data, size_t size) |
609 | { |
610 | // Default success unless error |
611 | vws.success(); |
612 | |
613 | if (vws_socket_is_connected(c) == false) |
614 | { |
615 | vws.error(VE_SOCKET, "vws_socket_write()" ); |
616 | return -1; |
617 | } |
618 | |
619 | // Validate input parameters |
620 | if (c == NULL || data == NULL || size == 0) |
621 | { |
622 | vws.error(VE_WARN, "Invalid parameters" ); |
623 | return -1; |
624 | } |
625 | |
626 | // But default we will keep looping until we have sent all the data |
627 | size_t sent = 0; |
628 | int poll_events = POLLOUT; |
629 | int iterations = 0; |
630 | while (sent < size) |
631 | { |
632 | // If we attempted at least one poll()/send() |
633 | if (iterations++ > 0) |
634 | { |
635 | // And we are not set to flush mode, then we will return here, |
636 | // sending back how much data we sent. The caller will need to |
637 | // adjust the buffer accordingly. |
638 | if (c->flush == false) |
639 | { |
640 | break; |
641 | } |
642 | } |
643 | |
644 | #if defined(__linux__) || defined(__bsd__) || defined(__sunos__) |
645 | |
646 | struct pollfd fds; |
647 | fds.fd = c->sockfd; |
648 | fds.events = poll_events; |
649 | |
650 | int rc = poll(&fds, 1, c->timeout); |
651 | |
652 | if (fds.revents & (POLLERR | POLLHUP | POLLNVAL)) |
653 | { |
654 | vws.error(VE_SOCKET, "Socket error during poll()" ); |
655 | socket_abnormal_close(c); |
656 | return -1; |
657 | } |
658 | |
659 | #elif defined(__windows__) |
660 | |
661 | WSAPOLLFD fds; |
662 | fds.fd = c->sockfd; |
663 | fds.events = poll_events; |
664 | |
665 | int rc = WSAPoll(&fds, 1, c->timeout); |
666 | |
667 | if (rc == SOCKET_ERROR) |
668 | { |
669 | vws.error(VE_SOCKET, "Socket error during WSAPoll()" ); |
670 | socket_abnormal_close(c); |
671 | return -1; |
672 | } |
673 | |
674 | #else |
675 | #error Platform not supported |
676 | #endif |
677 | |
678 | if (rc == -1) |
679 | { |
680 | vws.error(VE_SYS, "poll() failed" ); |
681 | return -1; |
682 | } |
683 | |
684 | // There was a timeout. Restart loop. Sends are all or nothing: we keep |
685 | // pushing until either all the data goes or the connection |
686 | // drops. Anything else is inconsistent state. |
687 | if (rc == 0) |
688 | { |
689 | // Keep going. |
690 | continue; |
691 | } |
692 | |
693 | ssize_t n = 0; |
694 | if (fds.revents & poll_events) |
695 | { |
696 | if (c->ssl != NULL) |
697 | { |
698 | // SSL socket is writable, perform SSL_write() operation |
699 | n = SSL_write(c->ssl, data + sent, size - sent); |
700 | |
701 | if (n <= 0) |
702 | { |
703 | int err = SSL_get_error(c->ssl, n); |
704 | |
705 | if (err == SSL_ERROR_WANT_WRITE || err == SSL_ERROR_WANT_READ) |
706 | { |
707 | // The SSL socket is still open but would block on write |
708 | |
709 | if (err == SSL_ERROR_WANT_READ) |
710 | { |
711 | poll_events = POLLIN; |
712 | } |
713 | else |
714 | { |
715 | poll_events = POLLOUT; |
716 | } |
717 | |
718 | // Keep going |
719 | continue; |
720 | } |
721 | else if (err == SSL_ERROR_SYSCALL) |
722 | { |
723 | #if defined(__windows__) |
724 | |
725 | int err = WSAGetLastError(); |
726 | |
727 | if (err == WSAEWOULDBLOCK || err == WSAEINPROGRESS) |
728 | { |
729 | // Timeout. Keep going. |
730 | continue; |
731 | } |
732 | |
733 | #else |
734 | |
735 | if ((errno == EWOULDBLOCK) || (errno == EAGAIN)) |
736 | { |
737 | // Timeout. Keep going. |
738 | continue; |
739 | } |
740 | |
741 | #endif |
742 | } |
743 | |
744 | // Get the latest OpenSSL error |
745 | char buf[256]; |
746 | unsigned long ssl_err = ERR_get_error(); |
747 | ERR_error_string_n(ssl_err, buf, sizeof(buf)); |
748 | vws.error(VE_SOCKET, "SSL_write() failed: %s" , buf); |
749 | |
750 | // Close socket |
751 | socket_abnormal_close(c); |
752 | |
753 | return -1; |
754 | } |
755 | |
756 | // Reset |
757 | poll_events = POLLOUT; |
758 | } |
759 | else |
760 | { |
761 | // Non-SSL socket is writable, perform send() operation |
762 | #if defined(__linux__) || defined(__sunos__) |
763 | |
764 | n = send(c->sockfd, data + sent, size - sent, MSG_NOSIGNAL); |
765 | |
766 | #else |
767 | |
768 | n = send(c->sockfd, data + sent, size - sent, 0); |
769 | |
770 | #endif |
771 | |
772 | if (n <= -1) |
773 | { |
774 | #if defined(__windows__) |
775 | int err = WSAGetLastError(); |
776 | if (err == WSAEWOULDBLOCK || err == WSAEINPROGRESS) |
777 | { |
778 | // The socket is still open but would block on send |
779 | continue; |
780 | } |
781 | #else |
782 | if (errno == EWOULDBLOCK || errno == EAGAIN) |
783 | { |
784 | // The socket is still open but would block on send |
785 | continue; |
786 | } |
787 | #endif |
788 | |
789 | // An error occurred, and the socket might be closed |
790 | vws.error(VE_SYS, "send() error" ); |
791 | |
792 | // Close socket |
793 | socket_abnormal_close(c); |
794 | |
795 | return -1; |
796 | } |
797 | } |
798 | |
799 | if (n > 0) |
800 | { |
801 | sent += n; |
802 | } |
803 | } |
804 | } |
805 | |
806 | return sent; |
807 | } |
808 | |
809 | void vws_socket_close(vws_socket* c) |
810 | { |
811 | if (c->ssl != NULL) |
812 | { |
813 | // Unidirectional shutdown |
814 | int rc = SSL_shutdown(c->ssl); |
815 | |
816 | if (rc < 0) |
817 | { |
818 | // Get the latest OpenSSL error |
819 | char buf[256]; |
820 | unsigned long ssl_err = ERR_get_error(); |
821 | ERR_error_string_n(ssl_err, buf, sizeof(buf)); |
822 | vws.error(VE_WARN, "SSL_shutdown failed: %s" , buf); |
823 | } |
824 | |
825 | SSL_free(c->ssl); |
826 | c->ssl = NULL; |
827 | } |
828 | |
829 | if (c->sockfd >= 0) |
830 | { |
831 | #if defined(__windows__) |
832 | if (closesocket(c->sockfd) == SOCKET_ERROR) |
833 | #else |
834 | if (close(c->sockfd) == -1) |
835 | #endif |
836 | { |
837 | char buf[256]; |
838 | strerror_r(errno, buf, sizeof(buf)); |
839 | vws.error(VE_WARN, "Socket close failed: %s" , buf); |
840 | } |
841 | #if defined(__windows__) |
842 | WSACleanup(); |
843 | #endif |
844 | |
845 | c->sockfd = -1; |
846 | } |
847 | } |
848 | |
849 | int connect_to_host(const char* host, const char* port) |
850 | { |
851 | int sockfd = -1; |
852 | |
853 | #if defined(__linux__) || defined(__bsd__) || defined(__sunos__) |
854 | |
855 | // Resolve the host |
856 | struct addrinfo hints, *res, *res0; |
857 | int error; |
858 | |
859 | memset(&hints, 0, sizeof(hints)); |
860 | hints.ai_family = PF_UNSPEC; // Accept any family (IPv4 or IPv6) |
861 | hints.ai_socktype = SOCK_STREAM; |
862 | |
863 | error = getaddrinfo(host, port, &hints, &res0); |
864 | |
865 | if (error) |
866 | { |
867 | if (vws.tracelevel > 0) |
868 | { |
869 | cstr msg = gai_strerror(error); |
870 | vws.trace(VL_ERROR, "getaddrinfo failed: %s: %s" , host, msg); |
871 | } |
872 | |
873 | vws.error(VE_SYS, "getaddrinfo() failed" ); |
874 | |
875 | return -1; |
876 | } |
877 | |
878 | for (res = res0; res; res = res->ai_next) |
879 | { |
880 | sockfd = socket(res->ai_family, res->ai_socktype, res->ai_protocol); |
881 | |
882 | if (sockfd == -1) |
883 | { |
884 | vws.error(VE_SYS, "Failed to create socket" ); |
885 | continue; |
886 | } |
887 | |
888 | if (connect(sockfd, res->ai_addr, res->ai_addrlen) == -1) |
889 | { |
890 | close(sockfd); |
891 | sockfd = -1; |
892 | |
893 | vws.error(VE_SYS, "Failed to connect" ); |
894 | continue; |
895 | } |
896 | |
897 | break; // If we get here, we must have connected successfully |
898 | } |
899 | |
900 | freeaddrinfo(res0); // Free the addrinfo structure for this host |
901 | |
902 | #elif defined(__windows__) |
903 | |
904 | // Windows specific implementation |
905 | // Please refer to Windows Socket programming guide |
906 | |
907 | WSADATA wsaData; |
908 | struct addrinfo *result = NULL, *ptr = NULL, hints; |
909 | sockfd = INVALID_SOCKET; |
910 | |
911 | // Initialize Winsock |
912 | if (WSAStartup(MAKEWORD(2,2), &wsaData) != 0) |
913 | { |
914 | vws.error(VE_SYS, "WSAStartup failed" ); |
915 | return -1; |
916 | } |
917 | |
918 | ZeroMemory(&hints, sizeof(hints)); |
919 | hints.ai_family = AF_UNSPEC; |
920 | hints.ai_socktype = SOCK_STREAM; |
921 | hints.ai_protocol = IPPROTO_TCP; |
922 | |
923 | // Resolve the server address and port |
924 | if (getaddrinfo(host, port, &hints, &result) != 0) |
925 | { |
926 | vws.error(VE_SYS, "getaddrinfo failed\n" ); |
927 | return -1; |
928 | } |
929 | |
930 | // Attempt to connect to an address until one succeeds |
931 | for (ptr = result; ptr != NULL; ptr =ptr->ai_next) |
932 | { |
933 | // Create a SOCKET for connecting to server |
934 | sockfd = socket(ptr->ai_family, ptr->ai_socktype, ptr->ai_protocol); |
935 | |
936 | if (sockfd == INVALID_SOCKET) |
937 | { |
938 | char buf[256]; |
939 | int e = WSAGetLastError(); |
940 | snprintf(buf, sizeof(buf), "socket failed with error: %ld" , e); |
941 | vws.error(VE_RT, buf); |
942 | |
943 | WSACleanup(); |
944 | return -1; |
945 | } |
946 | |
947 | // Connect to server. |
948 | if (connect(sockfd, ptr->ai_addr, (int)ptr->ai_addrlen) == SOCKET_ERROR) |
949 | { |
950 | closesocket(sockfd); |
951 | sockfd = INVALID_SOCKET; |
952 | continue; |
953 | } |
954 | |
955 | break; |
956 | } |
957 | |
958 | freeaddrinfo(result); |
959 | |
960 | if (sockfd == INVALID_SOCKET) |
961 | { |
962 | vws.error(VE_SYS, "Unable to connect to host" ); |
963 | WSACleanup(); |
964 | return -1; |
965 | } |
966 | |
967 | #else |
968 | #error Platform not supported |
969 | #endif |
970 | |
971 | vws.success(); |
972 | |
973 | return sockfd; |
974 | } |
975 | |