1 // Written by Christopher E. Miller
2 // See the included license.txt for copyright and license details.
3 
4 
5 
6 module dfl.socket;
7 
8 
9 version(WINE) {
10    version = DFL_NoSocket;
11 }
12 
13 
14 version(DFL_NoSocket) {
15 }
16 else {
17 
18    private import dfl.internal.dlib, dfl.internal.clib;
19 
20    private {
21       private import std.socket, core.bitop;
22       private import std.c.windows.winsock;
23 
24       alias InternetHost DInternetHost;
25       alias InternetAddress DInternetAddress;
26 
27       socket_t getSocketHandle(Socket sock) nothrow @nogc
28       {
29          return sock.handle;
30       }
31    }
32 
33    alias std.socket.Socket DflSocket; ///
34 
35    private import dfl.internal.winapi, dfl.application, dfl.base, dfl.internal.utf;
36 
37 
38    private {
39       enum
40       {
41          FD_READ =       0x01,
42          FD_WRITE =      0x02,
43          FD_OOB =        0x04,
44          FD_ACCEPT =     0x08,
45          FD_CONNECT =    0x10,
46          FD_CLOSE =      0x20,
47          FD_QOS =        0x40,
48          FD_GROUP_QOS =  0x80,
49       }
50 
51 
52       extern(Windows) int WSAAsyncSelect(socket_t s, HWND hWnd, UINT wMsg, int lEvent) nothrow @nogc;
53    }
54 
55 
56 
57    // Can be OR'ed.
58    enum EventType {
59       NONE = 0, ///
60 
61       READ =       FD_READ, /// ditto
62       WRITE =      FD_WRITE, /// ditto
63       OOB =        FD_OOB, /// ditto
64       ACCEPT =     FD_ACCEPT, /// ditto
65       CONNECT =    FD_CONNECT, /// ditto
66       CLOSE =      FD_CLOSE, /// ditto
67 
68       QOS =        FD_QOS,
69       GROUP_QOS =  FD_GROUP_QOS,
70    }
71 
72 
73 
74    // -err- will be 0 if no error.
75    // -type- will always contain only one flag.
76    alias void delegate(DflSocket sock, EventType type, int err) RegisterEventCallback;
77 
78 
79    // Calling this twice on the same socket cancels out previously
80    // registered events for the socket.
81    // Requires Application.run() or Application.doEvents() loop.
82    void registerEvent(DflSocket sock, EventType events, RegisterEventCallback callback) { // deprecated
83       assert(sock !is null, "registerEvent: socket cannot be null");
84       assert(callback !is null, "registerEvent: callback cannot be null");
85 
86       if(!hwNet) {
87          _init();
88       }
89 
90       sock.blocking = false; // So the getter will be correct.
91 
92       // SOCKET_ERROR
93       if(-1 == WSAAsyncSelect(getSocketHandle(sock), hwNet, WM_DFL_NETEVENT, cast(int)events)) {
94          throw new DflException("Unable to register socket events");
95       }
96 
97       EventInfo ei;
98 
99       ei.sock = sock;
100       ei.callback = callback;
101       allEvents[getSocketHandle(sock)] = ei;
102    }
103 
104 
105    void unregisterEvent(DflSocket sock) @trusted @nogc nothrow { // deprecated
106       WSAAsyncSelect(getSocketHandle(sock), hwNet, 0, 0);
107 
108       //delete allEvents[getSocketHandle(sock)];
109       allEvents.remove(getSocketHandle(sock));
110    }
111 
112 
113 
114    class AsyncSocket: DflSocket { // docmain
115 
116       this(AddressFamily af, SocketType type, ProtocolType protocol) {
117          super(af, type, protocol);
118          super.blocking = false;
119       }
120 
121       /// ditto
122       this(AddressFamily af, SocketType type) {
123          super(af, type);
124          super.blocking = false;
125       }
126 
127       /// ditto
128       this(AddressFamily af, SocketType type, Dstring protocolName) {
129          super(af, type, protocolName);
130          super.blocking = false;
131       }
132 
133       /// ditto
134       // For use with accept().
135       protected this() pure @safe nothrow {
136       }
137 
138 
139 
140       void event(EventType events, RegisterEventCallback callback) {
141          registerEvent(this, events, callback);
142       }
143 
144 
145       protected override AsyncSocket accepting() {
146          return new AsyncSocket;
147       }
148 
149 
150       override void close() {
151          unregisterEvent(this);
152          super.close();
153       }
154 
155 
156       override @property bool blocking() const { // getter
157          return false;
158       }
159 
160 
161       override @property void blocking(bool byes) { // setter
162          if(byes) {
163             assert(0);
164          }
165       }
166 
167    }
168 
169 
170 
171    class AsyncTcpSocket: AsyncSocket { // docmain
172 
173       this(AddressFamily family) {
174          super(family, SocketType.STREAM, ProtocolType.TCP);
175       }
176 
177       /// ditto
178       this() {
179          this(cast(AddressFamily)AddressFamily.INET);
180       }
181 
182       /// ditto
183       // Shortcut.
184       this(Address connectTo, EventType events, RegisterEventCallback eventCallback) {
185          this(connectTo.addressFamily());
186          event(events, eventCallback);
187          connect(connectTo);
188       }
189    }
190 
191 
192 
193    class AsyncUdpSocket: AsyncSocket { // docmain
194 
195       this(AddressFamily family) {
196          super(family, SocketType.DGRAM, ProtocolType.UDP);
197       }
198 
199       /// ditto
200       this() {
201          this(cast(AddressFamily)AddressFamily.INET);
202       }
203    }
204 
205 
206    /+
207    private class GetHostWaitHandle: WaitHandle {
208       this(HANDLE h) {
209          super.handle = h;
210       }
211 
212 
213     final:
214 
215       alias WaitHandle.handle handle; // Overload.
216 
217       override @property void handle(HANDLE h) { // setter
218          assert(0);
219       }
220 
221       override void close() {
222          WSACancelAsyncRequest(handle);
223          super.handle = INVALID_HANDLE;
224       }
225 
226 
227       private void _gotEvent() {
228          super.handle = INVALID_HANDLE;
229       }
230    }
231 
232 
233    private class GetHostAsyncResult, IAsyncResult {
234       this(HANDLE h, GetHostCallback callback) {
235          wh = new GetHostWaitHandle(h);
236          this.callback = callback;
237       }
238 
239 
240       @property WaitHandle asyncWaitHandle() { // getter
241          return wh;
242       }
243 
244 
245       @property bool completedSynchronously() { // getter
246          return false;
247       }
248 
249 
250       @property bool isCompleted() { // getter
251          return wh.handle != WaitHandle.INVALID_HANDLE;
252       }
253 
254 
255     private:
256       GetHostWaitHandle wh;
257       GetHostCallback callback;
258 
259 
260       void _gotEvent(LPARAM lparam) {
261          wh._gotEvent();
262 
263          callback(bla, HIWORD(lparam));
264       }
265    }
266    +/
267 
268 
269    private void _getHostErr() {
270       throw new DflException("Get host failure"); // Needs a better message.. ?
271    }
272 
273 
274    private class _InternetHost: DInternetHost {
275     private:
276       this(void* hostentBytes) {
277          super.validHostent(cast(hostent*)hostentBytes);
278          super.populate(cast(hostent*)hostentBytes);
279       }
280    }
281 
282 
283 
284    // If -err- is nonzero, it is a winsock error code and -inetHost- is null.
285    alias void delegate(DInternetHost inetHost, int err) GetHostCallback;
286 
287 
288 
289    class GetHost { // docmain
290 
291       void cancel() {
292          WSACancelAsyncRequest(h);
293          h = null;
294       }
295 
296 
297     private:
298       HANDLE h;
299       GetHostCallback callback;
300       DThrowable exception;
301       ubyte[/+MAXGETHOSTSTRUCT+/ 1024] hostentBytes;
302 
303 
304       void _gotEvent(LPARAM lparam) {
305          h = null;
306 
307          int err;
308          err = HIWORD(lparam);
309          if(err) {
310             callback(null, err);
311          } else {
312             callback(new _InternetHost(hostentBytes.ptr), 0);
313          }
314       }
315 
316 
317       this() {
318       }
319    }
320 
321 
322 
323    GetHost asyncGetHostByName(Dstring name, GetHostCallback callback) { // docmain
324       if(!hwNet) {
325          _init();
326       }
327 
328       HANDLE h;
329       GetHost result;
330 
331       result = new GetHost;
332       h = WSAAsyncGetHostByName(hwNet, WM_DFL_HOSTEVENT, unsafeStringz(name),
333                                 cast(char*)result.hostentBytes, result.hostentBytes.length);
334       if(!h) {
335          _getHostErr();
336       }
337 
338       result.h = h;
339       result.callback = callback;
340       allGetHosts[h] = result;
341 
342       return result;
343    }
344 
345 
346 
347    GetHost asyncGetHostByAddr(uint32_t addr, GetHostCallback callback) { // docmain
348       if(!hwNet) {
349          _init();
350       }
351 
352       HANDLE h;
353       GetHost result;
354 
355       result = new GetHost;
356       version(LittleEndian)
357       addr = bswap(addr);
358       h = WSAAsyncGetHostByAddr(hwNet, WM_DFL_HOSTEVENT, cast(char*)&addr, addr.sizeof,
359                                 AddressFamily.INET, cast(char*)result.hostentBytes, result.hostentBytes.length);
360       if(!h) {
361          _getHostErr();
362       }
363 
364       result.h = h;
365       result.callback = callback;
366       allGetHosts[h] = result;
367 
368       return result;
369    }
370 
371    /// ditto
372    // Shortcut.
373    GetHost asyncGetHostByAddr(Dstring addr, GetHostCallback callback) { // docmain
374       uint uiaddr;
375       uiaddr = DInternetAddress.parse(addr);
376       if(DInternetAddress.ADDR_NONE == uiaddr) {
377          _getHostErr();
378       }
379       return asyncGetHostByAddr(uiaddr, callback);
380    }
381 
382 
383 
384    class SocketQueue { // docmain
385 
386       this(DflSocket sock)
387       in {
388          assert(sock !is null);
389       }
390       body {
391          this.sock = sock;
392       }
393 
394 
395 
396       final @property DflSocket socket() { // getter
397          return sock;
398       }
399 
400 
401 
402       void reset() {
403          writebuf = null;
404          readbuf = null;
405       }
406 
407 
408       /+
409       // DMD 0.92 says error: function toString overrides but is not covariant with toString
410       override Dstring toString() {
411          return cast(Dstring)peek();
412       }
413       +/
414 
415 
416 
417       void[] peek() {
418          return readbuf[0 .. rpos];
419       }
420 
421       /// ditto
422       void[] peek(uint len) {
423          if(len >= rpos) {
424             return peek();
425          }
426 
427          return readbuf[0 .. len];
428       }
429 
430 
431 
432       void[] receive() {
433          ubyte[] result;
434 
435          result = readbuf[0 .. rpos];
436          readbuf = null;
437          rpos = 0;
438 
439          return result;
440       }
441 
442       /// ditto
443       void[] receive(uint len) {
444          if(len >= rpos) {
445             return receive();
446          }
447 
448          ubyte[] result;
449 
450          result = readbuf[0 .. len];
451          readbuf = readbuf[len .. readbuf.length];
452          rpos -= len;
453 
454          return result;
455       }
456 
457 
458 
459       void send(void[] buf) {
460          if(canwrite) {
461             assert(!writebuf.length);
462 
463             int st;
464             if(buf.length > 4096) {
465                st = 4096;
466             } else {
467                st = buf.length;
468             }
469 
470             st = sock.send(buf[0 .. st]);
471             if(st > 0) {
472                if(buf.length - st) {
473                   // dup so it can be appended to.
474                   writebuf = (cast(ubyte[])buf)[st .. buf.length].dup;
475                }
476             } else {
477                // dup so it can be appended to.
478                writebuf = (cast(ubyte[])buf).dup;
479             }
480 
481             //canwrite = false;
482          } else {
483             writebuf ~= cast(ubyte[])buf;
484          }
485       }
486 
487 
488 
489       // Number of bytes in send queue.
490       @property uint sendBytes() { // getter
491          return writebuf.length;
492       }
493 
494 
495 
496       // Number of bytes in recv queue.
497       @property uint receiveBytes() { // getter
498          return rpos;
499       }
500 
501 
502 
503       // Same signature as RegisterEventCallback for simplicity.
504       void event(DflSocket _sock, EventType type, int err)
505       in {
506          assert(_sock is sock);
507       }
508       body {
509          switch(type) {
510             case EventType.READ:
511                readEvent();
512                break;
513 
514             case EventType.WRITE:
515                writeEvent();
516                break;
517 
518             default:
519          }
520       }
521 
522 
523 
524       // Call on a read event so that incoming data may be buffered.
525       void readEvent() {
526          if(readbuf.length - rpos < 1024) {
527             readbuf.length = readbuf.length + 2048;
528          }
529 
530          int rd = sock.receive(readbuf[rpos .. readbuf.length]);
531          if(rd > 0) {
532             rpos += cast(uint)rd;
533          }
534       }
535 
536 
537 
538       // Call on a write event so that buffered outgoing data may be sent.
539       void writeEvent() {
540          if(writebuf.length) {
541             ubyte[] buf;
542 
543             if(writebuf.length > 4096) {
544                buf = writebuf[0 .. 4096];
545             } else {
546                buf = writebuf;
547             }
548 
549             int st = sock.send(buf);
550             if(st > 0) {
551                writebuf = writebuf[st .. writebuf.length];
552             }
553          } else {
554             //canwrite = true;
555          }
556       }
557 
558 
559       deprecated {
560          alias receiveBytes recvBytes;
561          alias receive recv;
562       }
563 
564 
565     private:
566       ubyte[] writebuf;
567       ubyte[] readbuf;
568       uint rpos;
569       DflSocket sock;
570       //bool canwrite = false;
571 
572 
573       @property bool canwrite() { // getter
574          return writebuf.length == 0;
575       }
576    }
577 
578 
579 private:
580 
581    struct EventInfo {
582       DflSocket sock;
583       RegisterEventCallback callback;
584       DThrowable exception;
585    }
586 
587 
588    enum UINT WM_DFL_NETEVENT = WM_USER + 104;
589    enum UINT WM_DFL_HOSTEVENT = WM_USER + 105;
590    enum NETEVENT_CLASSNAME = "DFL_NetEvent";
591 
592    EventInfo[socket_t] allEvents;
593    GetHost[HANDLE] allGetHosts;
594    HWND hwNet;
595 
596 
597    extern(Windows) LRESULT netWndProc(HWND hwnd, UINT msg, WPARAM wparam, LPARAM lparam) nothrow {
598       switch(msg) {
599          case WM_DFL_NETEVENT:
600             if(cast(socket_t)wparam in allEvents) {
601                EventInfo ei = allEvents[cast(socket_t)wparam];
602                try {
603                   ei.callback(ei.sock, cast(EventType)LOWORD(lparam), HIWORD(lparam));
604                } catch (DThrowable e) {
605                   ei.exception = e;
606                }
607             }
608             break;
609 
610          case WM_DFL_HOSTEVENT:
611             if(cast(HANDLE)wparam in allGetHosts) {
612                GetHost gh;
613                gh = allGetHosts[cast(HANDLE)wparam];
614                assert(gh !is null);
615                //delete allGetHosts[cast(HANDLE)wparam];
616                allGetHosts.remove(cast(HANDLE)wparam);
617                try {
618                   gh._gotEvent(lparam);
619                } catch (DThrowable e) {
620                   gh.exception = e;
621                }
622             }
623             break;
624 
625          default:
626       }
627 
628       return 1;
629    }
630 
631 
632    void _init() {
633       WNDCLASSEXA wce;
634       wce.cbSize = wce.sizeof;
635       wce.lpszClassName = NETEVENT_CLASSNAME.ptr;
636       wce.lpfnWndProc = &netWndProc;
637       wce.hInstance = GetModuleHandleA(null);
638 
639       if(!RegisterClassExA(&wce)) {
640          debug(APP_PRINT)
641          cprintf("RegisterClassEx() failed for network event class.\n");
642 
643       init_err:
644          throw new DflException("Unable to initialize asynchronous socket library");
645       }
646 
647       hwNet = CreateWindowExA(0, NETEVENT_CLASSNAME.ptr, "", 0, 0, 0, 0, 0, HWND_MESSAGE, null, wce.hInstance, null);
648       if(!hwNet) {
649          // Guess it doesn't support HWND_MESSAGE, so just try null parent.
650 
651          hwNet = CreateWindowExA(0, NETEVENT_CLASSNAME.ptr, "", 0, 0, 0, 0, 0, null, null, wce.hInstance, null);
652          if(!hwNet) {
653             debug(APP_PRINT)
654             cprintf("CreateWindowEx() failed for network event window.\n");
655 
656             goto init_err;
657          }
658       }
659    }
660 
661 }
662