Skip to content

Commit 3a0dfef

Browse files
bradfitzzx2c4
authored andcommitted
all: use Go 1.19 and its atomic types
Signed-off-by: Brad Fitzpatrick <[email protected]> Signed-off-by: Jason A. Donenfeld <[email protected]>
1 parent d1d0842 commit 3a0dfef

20 files changed

+156
-246
lines changed

conn/bind_windows.go

Lines changed: 16 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,7 @@ type afWinRingBind struct {
7474
type WinRingBind struct {
7575
v4, v6 afWinRingBind
7676
mu sync.RWMutex
77-
isOpen uint32
77+
isOpen atomic.Uint32 // 0, 1, or 2
7878
}
7979

8080
func NewDefaultBind() Bind { return NewWinRingBind() }
@@ -212,7 +212,7 @@ func (bind *afWinRingBind) CloseAndZero() {
212212
}
213213

214214
func (bind *WinRingBind) closeAndZero() {
215-
atomic.StoreUint32(&bind.isOpen, 0)
215+
bind.isOpen.Store(0)
216216
bind.v4.CloseAndZero()
217217
bind.v6.CloseAndZero()
218218
}
@@ -276,7 +276,7 @@ func (bind *WinRingBind) Open(port uint16) (recvFns []ReceiveFunc, selectedPort
276276
bind.closeAndZero()
277277
}
278278
}()
279-
if atomic.LoadUint32(&bind.isOpen) != 0 {
279+
if bind.isOpen.Load() != 0 {
280280
return nil, 0, ErrBindAlreadyOpen
281281
}
282282
var sa windows.Sockaddr
@@ -299,17 +299,17 @@ func (bind *WinRingBind) Open(port uint16) (recvFns []ReceiveFunc, selectedPort
299299
return nil, 0, err
300300
}
301301
}
302-
atomic.StoreUint32(&bind.isOpen, 1)
302+
bind.isOpen.Store(1)
303303
return []ReceiveFunc{bind.receiveIPv4, bind.receiveIPv6}, selectedPort, err
304304
}
305305

306306
func (bind *WinRingBind) Close() error {
307307
bind.mu.RLock()
308-
if atomic.LoadUint32(&bind.isOpen) != 1 {
308+
if bind.isOpen.Load() != 1 {
309309
bind.mu.RUnlock()
310310
return nil
311311
}
312-
atomic.StoreUint32(&bind.isOpen, 2)
312+
bind.isOpen.Store(2)
313313
windows.PostQueuedCompletionStatus(bind.v4.rx.iocp, 0, 0, nil)
314314
windows.PostQueuedCompletionStatus(bind.v4.tx.iocp, 0, 0, nil)
315315
windows.PostQueuedCompletionStatus(bind.v6.rx.iocp, 0, 0, nil)
@@ -345,8 +345,8 @@ func (bind *afWinRingBind) InsertReceiveRequest() error {
345345
//go:linkname procyield runtime.procyield
346346
func procyield(cycles uint32)
347347

348-
func (bind *afWinRingBind) Receive(buf []byte, isOpen *uint32) (int, Endpoint, error) {
349-
if atomic.LoadUint32(isOpen) != 1 {
348+
func (bind *afWinRingBind) Receive(buf []byte, isOpen *atomic.Uint32) (int, Endpoint, error) {
349+
if isOpen.Load() != 1 {
350350
return 0, nil, net.ErrClosed
351351
}
352352
bind.rx.mu.Lock()
@@ -359,7 +359,7 @@ retry:
359359
count = 0
360360
for tries := 0; count == 0 && tries < receiveSpins; tries++ {
361361
if tries > 0 {
362-
if atomic.LoadUint32(isOpen) != 1 {
362+
if isOpen.Load() != 1 {
363363
return 0, nil, net.ErrClosed
364364
}
365365
procyield(1)
@@ -378,7 +378,7 @@ retry:
378378
if err != nil {
379379
return 0, nil, err
380380
}
381-
if atomic.LoadUint32(isOpen) != 1 {
381+
if isOpen.Load() != 1 {
382382
return 0, nil, net.ErrClosed
383383
}
384384
count = winrio.DequeueCompletion(bind.rx.cq, results[:])
@@ -395,7 +395,7 @@ retry:
395395
// huge packets. Just try again when this happens. The infinite loop this could cause is still limited to
396396
// attacker bandwidth, just like the rest of the receive path.
397397
if windows.Errno(results[0].Status) == windows.WSAEMSGSIZE {
398-
if atomic.LoadUint32(isOpen) != 1 {
398+
if isOpen.Load() != 1 {
399399
return 0, nil, net.ErrClosed
400400
}
401401
goto retry
@@ -421,8 +421,8 @@ func (bind *WinRingBind) receiveIPv6(buf []byte) (int, Endpoint, error) {
421421
return bind.v6.Receive(buf, &bind.isOpen)
422422
}
423423

424-
func (bind *afWinRingBind) Send(buf []byte, nend *WinRingEndpoint, isOpen *uint32) error {
425-
if atomic.LoadUint32(isOpen) != 1 {
424+
func (bind *afWinRingBind) Send(buf []byte, nend *WinRingEndpoint, isOpen *atomic.Uint32) error {
425+
if isOpen.Load() != 1 {
426426
return net.ErrClosed
427427
}
428428
if len(buf) > bytesPerPacket {
@@ -444,7 +444,7 @@ func (bind *afWinRingBind) Send(buf []byte, nend *WinRingEndpoint, isOpen *uint3
444444
if err != nil {
445445
return err
446446
}
447-
if atomic.LoadUint32(isOpen) != 1 {
447+
if isOpen.Load() != 1 {
448448
return net.ErrClosed
449449
}
450450
count = winrio.DequeueCompletion(bind.tx.cq, results[:])
@@ -538,7 +538,7 @@ func (bind *StdNetBind) BindSocketToInterface6(interfaceIndex uint32, blackhole
538538
func (bind *WinRingBind) BindSocketToInterface4(interfaceIndex uint32, blackhole bool) error {
539539
bind.mu.RLock()
540540
defer bind.mu.RUnlock()
541-
if atomic.LoadUint32(&bind.isOpen) != 1 {
541+
if bind.isOpen.Load() != 1 {
542542
return net.ErrClosed
543543
}
544544
err := bindSocketToInterface4(bind.v4.sock, interfaceIndex)
@@ -552,7 +552,7 @@ func (bind *WinRingBind) BindSocketToInterface4(interfaceIndex uint32, blackhole
552552
func (bind *WinRingBind) BindSocketToInterface6(interfaceIndex uint32, blackhole bool) error {
553553
bind.mu.RLock()
554554
defer bind.mu.RUnlock()
555-
if atomic.LoadUint32(&bind.isOpen) != 1 {
555+
if bind.isOpen.Load() != 1 {
556556
return net.ErrClosed
557557
}
558558
err := bindSocketToInterface6(bind.v6.sock, interfaceIndex)

device/alignment_test.go

Lines changed: 0 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -18,30 +18,6 @@ func checkAlignment(t *testing.T, name string, offset uintptr) {
1818
}
1919
}
2020

21-
// TestPeerAlignment checks that atomically-accessed fields are
22-
// aligned to 64-bit boundaries, as required by the atomic package.
23-
//
24-
// Unfortunately, violating this rule on 32-bit platforms results in a
25-
// hard segfault at runtime.
26-
func TestPeerAlignment(t *testing.T) {
27-
var p Peer
28-
29-
typ := reflect.TypeOf(&p).Elem()
30-
t.Logf("Peer type size: %d, with fields:", typ.Size())
31-
for i := 0; i < typ.NumField(); i++ {
32-
field := typ.Field(i)
33-
t.Logf("\t%30s\toffset=%3v\t(type size=%3d, align=%d)",
34-
field.Name,
35-
field.Offset,
36-
field.Type.Size(),
37-
field.Type.Align(),
38-
)
39-
}
40-
41-
checkAlignment(t, "Peer.stats", unsafe.Offsetof(p.stats))
42-
checkAlignment(t, "Peer.isRunning", unsafe.Offsetof(p.isRunning))
43-
}
44-
4521
// TestDeviceAlignment checks that atomically-accessed fields are
4622
// aligned to 64-bit boundaries, as required by the atomic package.
4723
//

device/device.go

Lines changed: 15 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ type Device struct {
3030
// will become the actual state; Up can fail.
3131
// The device can also change state multiple times between time of check and time of use.
3232
// Unsynchronized uses of state must therefore be advisory/best-effort only.
33-
state uint32 // actually a deviceState, but typed uint32 for convenience
33+
state atomic.Uint32 // actually a deviceState, but typed uint32 for convenience
3434
// stopping blocks until all inputs to Device have been closed.
3535
stopping sync.WaitGroup
3636
// mu protects state changes.
@@ -60,7 +60,7 @@ type Device struct {
6060

6161
// Keep this 8-byte aligned
6262
rate struct {
63-
underLoadUntil int64
63+
underLoadUntil atomic.Int64
6464
limiter ratelimiter.Ratelimiter
6565
}
6666

@@ -82,7 +82,7 @@ type Device struct {
8282

8383
tun struct {
8484
device tun.Device
85-
mtu int32
85+
mtu atomic.Int32
8686
}
8787

8888
ipcMutex sync.RWMutex
@@ -94,10 +94,9 @@ type Device struct {
9494
// There are three states: down, up, closed.
9595
// Transitions:
9696
//
97-
// down -----+
98-
// ↑↓ ↓
99-
// up -> closed
100-
//
97+
// down -----+
98+
// ↑↓ ↓
99+
// up -> closed
101100
type deviceState uint32
102101

103102
//go:generate go run golang.org/x/tools/cmd/stringer -type deviceState -trimprefix=deviceState
@@ -110,7 +109,7 @@ const (
110109
// deviceState returns device.state.state as a deviceState
111110
// See those docs for how to interpret this value.
112111
func (device *Device) deviceState() deviceState {
113-
return deviceState(atomic.LoadUint32(&device.state.state))
112+
return deviceState(device.state.state.Load())
114113
}
115114

116115
// isClosed reports whether the device is closed (or is closing).
@@ -149,14 +148,14 @@ func (device *Device) changeState(want deviceState) (err error) {
149148
case old:
150149
return nil
151150
case deviceStateUp:
152-
atomic.StoreUint32(&device.state.state, uint32(deviceStateUp))
151+
device.state.state.Store(uint32(deviceStateUp))
153152
err = device.upLocked()
154153
if err == nil {
155154
break
156155
}
157156
fallthrough // up failed; bring the device all the way back down
158157
case deviceStateDown:
159-
atomic.StoreUint32(&device.state.state, uint32(deviceStateDown))
158+
device.state.state.Store(uint32(deviceStateDown))
160159
errDown := device.downLocked()
161160
if err == nil {
162161
err = errDown
@@ -182,7 +181,7 @@ func (device *Device) upLocked() error {
182181
device.peers.RLock()
183182
for _, peer := range device.peers.keyMap {
184183
peer.Start()
185-
if atomic.LoadUint32(&peer.persistentKeepaliveInterval) > 0 {
184+
if peer.persistentKeepaliveInterval.Load() > 0 {
186185
peer.SendKeepalive()
187186
}
188187
}
@@ -219,11 +218,11 @@ func (device *Device) IsUnderLoad() bool {
219218
now := time.Now()
220219
underLoad := len(device.queue.handshake.c) >= QueueHandshakeSize/8
221220
if underLoad {
222-
atomic.StoreInt64(&device.rate.underLoadUntil, now.Add(UnderLoadAfterTime).UnixNano())
221+
device.rate.underLoadUntil.Store(now.Add(UnderLoadAfterTime).UnixNano())
223222
return true
224223
}
225224
// check if recently under load
226-
return atomic.LoadInt64(&device.rate.underLoadUntil) > now.UnixNano()
225+
return device.rate.underLoadUntil.Load() > now.UnixNano()
227226
}
228227

229228
func (device *Device) SetPrivateKey(sk NoisePrivateKey) error {
@@ -283,7 +282,7 @@ func (device *Device) SetPrivateKey(sk NoisePrivateKey) error {
283282

284283
func NewDevice(tunDevice tun.Device, bind conn.Bind, logger *Logger) *Device {
285284
device := new(Device)
286-
device.state.state = uint32(deviceStateDown)
285+
device.state.state.Store(uint32(deviceStateDown))
287286
device.closed = make(chan struct{})
288287
device.log = logger
289288
device.net.bind = bind
@@ -293,7 +292,7 @@ func NewDevice(tunDevice tun.Device, bind conn.Bind, logger *Logger) *Device {
293292
device.log.Errorf("Trouble determining MTU, assuming default: %v", err)
294293
mtu = DefaultMTU
295294
}
296-
device.tun.mtu = int32(mtu)
295+
device.tun.mtu.Store(int32(mtu))
297296
device.peers.keyMap = make(map[NoisePublicKey]*Peer)
298297
device.rate.limiter.Init()
299298
device.indexTable.Init()
@@ -359,7 +358,7 @@ func (device *Device) Close() {
359358
if device.isClosed() {
360359
return
361360
}
362-
atomic.StoreUint32(&device.state.state, uint32(deviceStateClosed))
361+
device.state.state.Store(uint32(deviceStateClosed))
363362
device.log.Verbosef("Device closing")
364363

365364
device.tun.device.Close()

device/device_test.go

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -333,7 +333,7 @@ func BenchmarkThroughput(b *testing.B) {
333333

334334
// Measure how long it takes to receive b.N packets,
335335
// starting when we receive the first packet.
336-
var recv uint64
336+
var recv atomic.Uint64
337337
var elapsed time.Duration
338338
var wg sync.WaitGroup
339339
wg.Add(1)
@@ -342,7 +342,7 @@ func BenchmarkThroughput(b *testing.B) {
342342
var start time.Time
343343
for {
344344
<-pair[0].tun.Inbound
345-
new := atomic.AddUint64(&recv, 1)
345+
new := recv.Add(1)
346346
if new == 1 {
347347
start = time.Now()
348348
}
@@ -358,7 +358,7 @@ func BenchmarkThroughput(b *testing.B) {
358358
ping := tuntest.Ping(pair[0].ip, pair[1].ip)
359359
pingc := pair[1].tun.Outbound
360360
var sent uint64
361-
for atomic.LoadUint64(&recv) != uint64(b.N) {
361+
for recv.Load() != uint64(b.N) {
362362
sent++
363363
pingc <- ping
364364
}

device/keypair.go

Lines changed: 2 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,6 @@ import (
1010
"sync"
1111
"sync/atomic"
1212
"time"
13-
"unsafe"
1413

1514
"golang.zx2c4.com/wireguard/replay"
1615
)
@@ -23,7 +22,7 @@ import (
2322
*/
2423

2524
type Keypair struct {
26-
sendNonce uint64 // accessed atomically
25+
sendNonce atomic.Uint64
2726
send cipher.AEAD
2827
receive cipher.AEAD
2928
replayFilter replay.Filter
@@ -37,15 +36,7 @@ type Keypairs struct {
3736
sync.RWMutex
3837
current *Keypair
3938
previous *Keypair
40-
next *Keypair
41-
}
42-
43-
func (kp *Keypairs) storeNext(next *Keypair) {
44-
atomic.StorePointer((*unsafe.Pointer)((unsafe.Pointer)(&kp.next)), (unsafe.Pointer)(next))
45-
}
46-
47-
func (kp *Keypairs) loadNext() *Keypair {
48-
return (*Keypair)(atomic.LoadPointer((*unsafe.Pointer)((unsafe.Pointer)(&kp.next))))
39+
next atomic.Pointer[Keypair]
4940
}
5041

5142
func (kp *Keypairs) Current() *Keypair {

device/misc.go

Lines changed: 0 additions & 41 deletions
This file was deleted.

0 commit comments

Comments
 (0)