diff --git a/gomock/call.go b/gomock/call.go index 9f3ae9c8..dc2a479d 100644 --- a/gomock/call.go +++ b/gomock/call.go @@ -118,7 +118,8 @@ func (c *Call) Times(n int) *Call { } // SetArg declares an action that will set the nth argument's value, -// indirected through a pointer. +// indirected through a pointer. Or, in the case of a slice, SetArg +// will copy value's elements into the nth argument. func (c *Call) SetArg(n int, value interface{}) *Call { if c.setArgs == nil { c.setArgs = make(map[int]reflect.Value) @@ -142,8 +143,10 @@ func (c *Call) SetArg(n int, value interface{}) *Call { } case reflect.Interface: // nothing to do + case reflect.Slice: + // nothing to do default: - c.t.Fatalf("SetArg(%d, ...) referring to argument of non-pointer non-interface type %v [%s]", + c.t.Fatalf("SetArg(%d, ...) referring to argument of non-pointer non-interface non-slice type %v", n, at, c.origin) } c.setArgs[n] = reflect.ValueOf(value) @@ -243,7 +246,12 @@ func (c *Call) call(args []interface{}) (rets []interface{}, action func()) { action = func() { c.doFunc.Call(doArgs) } } for n, v := range c.setArgs { - reflect.ValueOf(args[n]).Elem().Set(v) + switch reflect.TypeOf(args[n]).Kind() { + case reflect.Slice: + setSlice(args[n], v) + default: + reflect.ValueOf(args[n]).Elem().Set(v) + } } rets = c.rets @@ -265,3 +273,10 @@ func InOrder(calls ...*Call) { calls[i].After(calls[i-1]) } } + +func setSlice(arg interface{}, v reflect.Value) { + va := reflect.ValueOf(arg) + for i := 0; i < v.Len(); i++ { + va.Index(i).Set(v.Index(i)) + } +} diff --git a/gomock/call_test.go b/gomock/call_test.go index 3ae7263c..71a20788 100644 --- a/gomock/call_test.go +++ b/gomock/call_test.go @@ -1,6 +1,9 @@ package gomock -import "testing" +import ( + "reflect" + "testing" +) type mockTestReporter struct { errorCalls int @@ -45,3 +48,32 @@ func TestCall_After(t *testing.T) { } }) } + +func TestCall_SetArg(t *testing.T) { + t.Run("SetArgSlice", func(t *testing.T) { + c := &Call{ + methodType: reflect.TypeOf(func([]byte) {}), + } + c.SetArg(0, []byte{1, 2, 3}) + + in := []byte{4, 5, 6} + c.call([]interface{}{in}) + + if in[0] != 1 || in[1] != 2 || in[2] != 3 { + t.Error("Expected SetArg() to modify input slice argument") + } + }) + + t.Run("SetArgPointer", func(t *testing.T) { + c := &Call{ + methodType: reflect.TypeOf(func(*int) {}), + } + c.SetArg(0, 42) + + in := 43 + c.call([]interface{}{&in}) + if in != 42 { + t.Error("Expected SetArg() to modify value pointed to by argument") + } + }) +}