diff --git a/syscalls_windows.go b/syscalls_windows.go index c93ca86..38d4bcd 100644 --- a/syscalls_windows.go +++ b/syscalls_windows.go @@ -7,7 +7,9 @@ import ( "net" "os" "sync" + "sync/atomic" "syscall" + "time" "unsafe" "golang.org/x/sys/windows" @@ -286,11 +288,47 @@ func openTap(config Config) (ifce *Interface, err error) { return nil, errIfceNameNotFound } +// https://github.com/WireGuard/wireguard-go/blob/master/tun/tun_windows.go +const ( + rateMeasurementGranularity = uint64((time.Second / 2) / time.Nanosecond) + spinloopRateThreshold = 800000000 / 8 // 800mbps + spinloopDuration = uint64(time.Millisecond / 80 / time.Nanosecond) // ~1gbit/s +) + +//go:linkname procyield runtime.procyield +func procyield(cycles uint32) + +//go:linkname nanotime runtime.nanotime +func nanotime() int64 + +type rateJuggler struct { + current uint64 + nextByteCount uint64 + nextStartTime int64 + changing int32 +} + +func (rate *rateJuggler) update(packetLen uint64) { + now := nanotime() + total := atomic.AddUint64(&rate.nextByteCount, packetLen) + period := uint64(now - atomic.LoadInt64(&rate.nextStartTime)) + if period >= rateMeasurementGranularity { + if !atomic.CompareAndSwapInt32(&rate.changing, 0, 1) { + return + } + atomic.StoreInt64(&rate.nextStartTime, now) + atomic.StoreUint64(&rate.current, total*uint64(time.Second/time.Nanosecond)/period) + atomic.StoreUint64(&rate.nextByteCount, 0) + atomic.StoreInt32(&rate.changing, 0) + } +} + type wintunRWC struct { ad wintun.Adapter s wintun.Session readwait windows.Handle readbuf []byte + rate rateJuggler isclosed bool } @@ -301,6 +339,7 @@ func (w *wintunRWC) Close() error { } func (w *wintunRWC) Write(b []byte) (int, error) { + w.rate.update(uint64(len(b))) packet, err := w.s.AllocateSendPacket(len(b)) switch err { case nil: @@ -336,20 +375,30 @@ RETRY: if w.isclosed { return 0, errors.New("wintun is closed") } - packet, err := w.s.ReceivePacket() - switch err { - case nil: - n += copy(b, packet) - if len(packet) > len(b) { - w.readbuf = make([]byte, len(packet)-len(b)) - copy(w.readbuf, packet[len(b):]) + start := nanotime() + shouldSpin := atomic.LoadUint64(&w.rate.current) >= spinloopRateThreshold && uint64(start-atomic.LoadInt64(&w.rate.nextStartTime)) <= rateMeasurementGranularity*2 + for { + packet, err := w.s.ReceivePacket() + switch err { + case nil: + packetSize := len(packet) + n += copy(b, packet) + if len(packet) > len(b) { + w.readbuf = make([]byte, len(packet)-len(b)) + copy(w.readbuf, packet[len(b):]) + } + w.s.ReleaseReceivePacket(packet) + w.rate.update(uint64(packetSize)) + case windows.ERROR_NO_MORE_ITEMS: + if !shouldSpin || uint64(nanotime()-start) >= spinloopDuration { + windows.WaitForSingleObject(w.readwait, windows.INFINITE) + goto RETRY + } + procyield(1) + continue } - w.s.ReleaseReceivePacket(packet) - case windows.ERROR_NO_MORE_ITEMS: - windows.WaitForSingleObject(w.readwait, windows.INFINITE) - goto RETRY + return n, err } - return n, err } // openDev find and open an interface.