Skip to content

Commit 0b7e31d

Browse files
committed
reflect: permit DeepCopyOption
1 parent 1f2c918 commit 0b7e31d

File tree

2 files changed

+104
-29
lines changed

2 files changed

+104
-29
lines changed

README.md

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,9 @@ Package reflect implements the proposal [go.dev/issue/51520](https://go.dev/issu
4040
// memory representations of the source value but may result in unexpected
4141
// consequences in follow-up usage, the caller should clear these values
4242
// depending on their usage context.
43-
func DeepCopy[T any](src T) (dst T)
43+
//
44+
// To change these predefined behaviors, use provided DeepCopyOption.
45+
func DeepCopy[T any](src T, opts ...DeepCopyOption) (dst T)
4446
```
4547

4648
_Warning_: Not largely tested. Use it with care.

deepcopy.go

Lines changed: 101 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -14,10 +14,52 @@ import (
1414
"reflect"
1515
"strings"
1616
"unsafe"
17-
18-
_ "unsafe" // for go:linkname
1917
)
2018

19+
// DeepCopyOption represents an option to customize deep copied results.
20+
type DeepCopyOption func(opt *copyConfig)
21+
22+
type copyConfig struct {
23+
disallowCopyUnexported bool
24+
disallowCopyCircular bool
25+
disallowCopyBidirectionalChan bool
26+
disallowCopyTypes []reflect.Type
27+
}
28+
29+
// DisallowCopyUnexported returns a DeepCopyOption that disables the behavior
30+
// of copying unexported fields.
31+
func DisallowCopyUnexported() DeepCopyOption {
32+
return func(opt *copyConfig) {
33+
opt.disallowCopyUnexported = true
34+
}
35+
}
36+
37+
// DisallowCopyCircular returns a DeepCopyOption that disables the behavior
38+
// of copying circular structures.
39+
func DisallowCopyCircular() DeepCopyOption {
40+
return func(opt *copyConfig) {
41+
opt.disallowCopyCircular = true
42+
}
43+
}
44+
45+
// DisallowCopyBidirectionalChan returns a DeepCopyOption that disables
46+
// the behavior of producing new channel when a bidirectional channel is copied.
47+
func DisallowCopyBidirectionalChan() DeepCopyOption {
48+
return func(opt *copyConfig) {
49+
opt.disallowCopyBidirectionalChan = true
50+
}
51+
}
52+
53+
// DisallowTypes returns a DeepCopyOption that disallows copying any types
54+
// that are in given values.
55+
func DisallowTypes(val ...any) DeepCopyOption {
56+
return func(opt *copyConfig) {
57+
for i := range val {
58+
opt.disallowCopyTypes = append(opt.disallowCopyTypes, reflect.TypeOf(val[i]))
59+
}
60+
}
61+
}
62+
2163
// DeepCopy copies src to dst recursively.
2264
//
2365
// Two values of identical type are deeply copied if one of the following
@@ -55,17 +97,33 @@ import (
5597
// memory representations of the source value but may result in unexpected
5698
// consequences in follow-up usage, the caller should clear these values
5799
// depending on their usage context.
58-
func DeepCopy[T any](src T) (dst T) {
100+
//
101+
// To change these predefined behaviors, use provided DeepCopyOption.
102+
func DeepCopy[T any](src T, opts ...DeepCopyOption) (dst T) {
59103
ptrs := map[uintptr]any{}
60-
ret := copyAny(src, ptrs)
104+
conf := &copyConfig{}
105+
for _, opt := range opts {
106+
opt(conf)
107+
}
108+
109+
ret := copyAny(src, ptrs, conf)
61110
if v, ok := ret.(T); ok {
62111
dst = v
63112
return
64113
}
65114
panic(fmt.Sprintf("reflect: internal error: copied value is not typed in %T, got %T", src, ret))
66115
}
67116

68-
func copyAny(src any, ptrs map[uintptr]any) (dst any) {
117+
func copyAny(src any, ptrs map[uintptr]any, copyConf *copyConfig) (dst any) {
118+
119+
if len(copyConf.disallowCopyTypes) != 0 {
120+
for i := range copyConf.disallowCopyTypes {
121+
if reflect.TypeOf(src) == copyConf.disallowCopyTypes[i] {
122+
panic(fmt.Sprintf("reflect: deep copying type %T is disallowed", src))
123+
}
124+
}
125+
}
126+
69127
v := reflect.ValueOf(src)
70128
if !v.IsValid() {
71129
return src
@@ -77,30 +135,30 @@ func copyAny(src any, ptrs map[uintptr]any) (dst any) {
77135
reflect.Int64, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32,
78136
reflect.Uint64, reflect.Uintptr, reflect.Float32, reflect.Float64,
79137
reflect.Complex64, reflect.Complex128, reflect.Func:
80-
dst = copyPremitive(src, ptrs)
138+
dst = copyPremitive(src, ptrs, copyConf)
81139
case reflect.String:
82140
dst = strings.Clone(src.(string))
83141
case reflect.Slice:
84-
dst = copySlice(src, ptrs)
142+
dst = copySlice(src, ptrs, copyConf)
85143
case reflect.Array:
86-
dst = copyArray(src, ptrs)
144+
dst = copyArray(src, ptrs, copyConf)
87145
case reflect.Map:
88-
dst = copyMap(src, ptrs)
146+
dst = copyMap(src, ptrs, copyConf)
89147
case reflect.Ptr, reflect.UnsafePointer:
90-
dst = copyPointer(src, ptrs)
148+
dst = copyPointer(src, ptrs, copyConf)
91149
case reflect.Struct:
92-
dst = copyStruct(src, ptrs)
150+
dst = copyStruct(src, ptrs, copyConf)
93151
case reflect.Interface:
94-
dst = copyAny(src, ptrs)
152+
dst = copyAny(src, ptrs, copyConf)
95153
case reflect.Chan:
96-
dst = copyChan(src, ptrs)
154+
dst = copyChan(src, ptrs, copyConf)
97155
default:
98156
panic(fmt.Sprintf("reflect: internal error: unknown type %v", v.Kind()))
99157
}
100158
return
101159
}
102160

103-
func copyPremitive(src any, ptr map[uintptr]any) (dst any) {
161+
func copyPremitive(src any, ptr map[uintptr]any, copyConf *copyConfig) (dst any) {
104162
kind := reflect.ValueOf(src).Kind()
105163
switch kind {
106164
case reflect.Array, reflect.Chan, reflect.Interface, reflect.Map, reflect.Ptr, reflect.Slice, reflect.Struct, reflect.UnsafePointer:
@@ -110,7 +168,7 @@ func copyPremitive(src any, ptr map[uintptr]any) (dst any) {
110168
return
111169
}
112170

113-
func copySlice(x any, ptrs map[uintptr]any) any {
171+
func copySlice(x any, ptrs map[uintptr]any, copyConf *copyConfig) any {
114172
v := reflect.ValueOf(x)
115173
kind := v.Kind()
116174
if kind != reflect.Slice {
@@ -121,15 +179,15 @@ func copySlice(x any, ptrs map[uintptr]any) any {
121179
t := reflect.TypeOf(x)
122180
dc := reflect.MakeSlice(t, size, size)
123181
for i := 0; i < size; i++ {
124-
iv := reflect.ValueOf(copyAny(v.Index(i).Interface(), ptrs))
182+
iv := reflect.ValueOf(copyAny(v.Index(i).Interface(), ptrs, copyConf))
125183
if iv.IsValid() {
126184
dc.Index(i).Set(iv)
127185
}
128186
}
129187
return dc.Interface()
130188
}
131189

132-
func copyArray(x any, ptrs map[uintptr]any) any {
190+
func copyArray(x any, ptrs map[uintptr]any, copyConf *copyConfig) any {
133191
v := reflect.ValueOf(x)
134192
if v.Kind() != reflect.Array {
135193
panic(fmt.Errorf("reflect: internal error: must be an Array; got %v", v.Kind()))
@@ -138,13 +196,13 @@ func copyArray(x any, ptrs map[uintptr]any) any {
138196
size := t.Len()
139197
dc := reflect.New(reflect.ArrayOf(size, t.Elem())).Elem()
140198
for i := 0; i < size; i++ {
141-
item := copyAny(v.Index(i).Interface(), ptrs)
199+
item := copyAny(v.Index(i).Interface(), ptrs, copyConf)
142200
dc.Index(i).Set(reflect.ValueOf(item))
143201
}
144202
return dc.Interface()
145203
}
146204

147-
func copyMap(x any, ptrs map[uintptr]any) any {
205+
func copyMap(x any, ptrs map[uintptr]any, copyConf *copyConfig) any {
148206
v := reflect.ValueOf(x)
149207
if v.Kind() != reflect.Map {
150208
panic(fmt.Errorf("reflect: internal error: must be a Map; got %v", v.Kind()))
@@ -153,27 +211,30 @@ func copyMap(x any, ptrs map[uintptr]any) any {
153211
dc := reflect.MakeMapWithSize(t, v.Len())
154212
iter := v.MapRange()
155213
for iter.Next() {
156-
item := copyAny(iter.Value().Interface(), ptrs)
157-
k := copyAny(iter.Key().Interface(), ptrs)
214+
item := copyAny(iter.Value().Interface(), ptrs, copyConf)
215+
k := copyAny(iter.Key().Interface(), ptrs, copyConf)
158216
dc.SetMapIndex(reflect.ValueOf(k), reflect.ValueOf(item))
159217
}
160218
return dc.Interface()
161219
}
162220

163-
func copyPointer(x any, ptrs map[uintptr]any) any {
221+
func copyPointer(x any, ptrs map[uintptr]any, copyConf *copyConfig) any {
164222
v := reflect.ValueOf(x)
165223
if v.Kind() != reflect.Pointer {
166224
panic(fmt.Errorf("reflect: internal error: must be a Pointer or Ptr; got %v", v.Kind()))
167225
}
168226
addr := uintptr(v.UnsafePointer())
169227
if dc, ok := ptrs[addr]; ok {
228+
if copyConf.disallowCopyCircular {
229+
panic("reflect: deep copy dircular value is disallowed")
230+
}
170231
return dc
171232
}
172233
t := reflect.TypeOf(x)
173234
dc := reflect.New(t.Elem())
174235
ptrs[addr] = dc.Interface()
175236
if !v.IsNil() {
176-
item := copyAny(v.Elem().Interface(), ptrs)
237+
item := copyAny(v.Elem().Interface(), ptrs, copyConf)
177238
iv := reflect.ValueOf(item)
178239
if iv.IsValid() {
179240
dc.Elem().Set(reflect.ValueOf(item))
@@ -182,21 +243,30 @@ func copyPointer(x any, ptrs map[uintptr]any) any {
182243
return dc.Interface()
183244
}
184245

185-
func copyStruct(x any, ptrs map[uintptr]any) any {
246+
func copyStruct(x any, ptrs map[uintptr]any, copyConf *copyConfig) any {
186247
v := reflect.ValueOf(x)
187248
if v.Kind() != reflect.Struct {
188249
panic(fmt.Errorf("reflect: internal error: must be a Struct; got %v", v.Kind()))
189250
}
190251
t := reflect.TypeOf(x)
191252
dc := reflect.New(t)
192253
for i := 0; i < t.NumField(); i++ {
193-
item := copyAny(valueInterfaceUnsafe(v.Field(i)), ptrs)
194-
setField(dc.Elem().Field(i), reflect.ValueOf(item))
254+
if copyConf.disallowCopyUnexported {
255+
f := t.Field(i)
256+
if f.PkgPath != "" {
257+
continue
258+
}
259+
item := copyAny(v.Field(i).Interface(), ptrs, copyConf)
260+
dc.Elem().Field(i).Set(reflect.ValueOf(item))
261+
} else {
262+
item := copyAny(valueInterfaceUnsafe(v.Field(i)), ptrs, copyConf)
263+
setField(dc.Elem().Field(i), reflect.ValueOf(item))
264+
}
195265
}
196266
return dc.Elem().Interface()
197267
}
198268

199-
func copyChan(x any, ptrs map[uintptr]any) any {
269+
func copyChan(x any, ptrs map[uintptr]any, copyConf *copyConfig) any {
200270
v := reflect.ValueOf(x)
201271
if v.Kind() != reflect.Chan {
202272
panic(fmt.Errorf("reflect: internal error: must be a Chan; got %v", v.Kind()))
@@ -206,7 +276,10 @@ func copyChan(x any, ptrs map[uintptr]any) any {
206276
var dc any
207277
switch dir {
208278
case reflect.BothDir:
209-
dc = reflect.MakeChan(t, v.Cap()).Interface()
279+
if !copyConf.disallowCopyBidirectionalChan {
280+
dc = reflect.MakeChan(t, v.Cap()).Interface()
281+
}
282+
fallthrough
210283
case reflect.SendDir, reflect.RecvDir:
211284
dc = x
212285
}

0 commit comments

Comments
 (0)