diff --git a/mock/mock.go b/mock/mock.go index 49328337b..70e9485e1 100644 --- a/mock/mock.go +++ b/mock/mock.go @@ -18,8 +18,8 @@ import ( "github.com/stretchr/testify/assert" ) -// regex for GCCGO functions -var gccgoRE = regexp.MustCompile(`\.pN\d+_`) +// regex for GCCGO functions and "-fm" +var gccgoRE = regexp.MustCompile(`(\.pN\d+_|-fm)`) // TestingT is an interface wrapper around *testing.T type TestingT interface { @@ -32,6 +32,11 @@ type TestingT interface { Call */ +type CallSetup struct { + Arguments Arguments + Returns Returns +} + // Call represents a method call and is used for setting expectations, // as well as recording activity. type Call struct { @@ -102,6 +107,40 @@ func (c *Call) unlock() { c.Parent.mutex.Unlock() } +// Returns the function name from the function path +func getMethodNameFromPath(functionPath string) string { + if functionPath == "" { + panic("method name could not be empty") + } + + // Next four lines are required to use GCCGO function naming conventions. + // For Ex: github_com_docker_libkv_store_mock.WatchTree.pN39_github_com_docker_libkv_store_mock.Mock + // uses interface information unlike golang github.com/docker/libkv/store/mock.(*Mock).WatchTree + // With GCCGO we need to remove interface information starting from pN
. + // Also, GCCGO uses .-fm instead of . for function names in interface pointers paths. + if gccgoRE.MatchString(functionPath) { + functionPath = gccgoRE.Split(functionPath, -1)[0] + } + + parts := strings.Split(functionPath, ".") + functionName := parts[len(parts)-1] + + return functionName +} + +// Returns the function name from the function method interface +func getFuncInterfaceMethodName(method interface{}) string { + methodValue := reflect.ValueOf(method) + methodType := methodValue.Type() + if methodType.Kind() != reflect.Func { + panic("method must be a function") + } + + functionPath := runtime.FuncForPC(methodValue.Pointer()).Name() + + return getMethodNameFromPath(functionPath) +} + // Return specifies the return arguments for the expectation. // // Mock.On("DoSomething").Return(errors.New("failed")) @@ -208,6 +247,19 @@ func (c *Call) On(methodName string, arguments ...interface{}) *Call { return c.Parent.On(methodName, arguments...) } +// OnFunc chains a new expectation description onto the mocked interface. This +// allows syntax like. +// +// Mock. +// OnFunc(mockedService.MyMethod, 1).Return(nil). +// OnFunc(mockedService.MyOtherMethod, 'a', 'b', 'c').Return(errors.New("Some Error")) +// +//go:noinline +func (c *Call) OnFunc(method interface{}, arguments ...interface{}) *Call { + methodName := getFuncInterfaceMethodName(method) + return c.Parent.On(methodName, arguments...) +} + // Unset removes a mock handler from being called. // // test.On("func", mock.Anything).Unset() @@ -371,6 +423,16 @@ func (m *Mock) On(methodName string, arguments ...interface{}) *Call { return c } +// OnFunc starts a description of an expectation of the specified method +// being called. +// +// Mock.OnFunc(mockedService.MyMethod, arg1, arg2) +func (m *Mock) OnFunc(method interface{}, arguments ...interface{}) *Call { + methodName := getFuncInterfaceMethodName(method) + + return m.On(methodName, arguments...) +} + // /* // Recording and responding to activity // */ @@ -469,18 +531,41 @@ func (m *Mock) Called(arguments ...interface{}) Arguments { panic("Couldn't get the caller information") } functionPath := runtime.FuncForPC(pc).Name() - // Next four lines are required to use GCCGO function naming conventions. - // For Ex: github_com_docker_libkv_store_mock.WatchTree.pN39_github_com_docker_libkv_store_mock.Mock - // uses interface information unlike golang github.com/docker/libkv/store/mock.(*Mock).WatchTree - // With GCCGO we need to remove interface information starting from pN
. - if gccgoRE.MatchString(functionPath) { - functionPath = gccgoRE.Split(functionPath, -1)[0] - } - parts := strings.Split(functionPath, ".") - functionName := parts[len(parts)-1] + functionName := getMethodNameFromPath(functionPath) + return m.MethodCalled(functionName, arguments...) } +func (m *Mock) assignIfNotNil(returnValue interface{}, target interface{}) { + if returnValue != nil { + reflect.ValueOf(target).Elem().Set(reflect.ValueOf(returnValue)) + } +} + +// MockCall is a helper function to mock a method call with the given setup. +// It is useful when you want to mock a method call in a easier way in few lines using syntax like. +// +// var r0 *SomeType0 +// var r1 *SomeType1 +// var rN error +// i.MockCall(CallSetup{Arguments: Arguments{args...}, Returns: Returns{&r0, &r1, &rN}}) +// +// return r0, r1, rN +func (m *Mock) MockCall(setup CallSetup) { + pc, _, _, ok := runtime.Caller(1) + if !ok { + panic("Couldn't get the caller information") + } + functionPath := runtime.FuncForPC(pc).Name() + functionName := getMethodNameFromPath(functionPath) + + returns := m.MethodCalled(functionName, setup.Arguments...) + + for ind, target := range setup.Returns { + m.assignIfNotNil(returns[ind], target) + } +} + // MethodCalled tells the mock object that the given method has been called, and gets // an array of arguments to return. Panics if the call is unexpected (i.e. not preceded // by appropriate .On .Return() calls) @@ -772,6 +857,7 @@ func (m *Mock) calls() []Call { // Arguments holds an array of method arguments or return values. type Arguments []interface{} +type Returns []interface{} const ( // Anything is used in Diff and Assert when the argument being tested diff --git a/mock/mock_test.go b/mock/mock_test.go index 5aab204b9..fdfe82359 100644 --- a/mock/mock_test.go +++ b/mock/mock_test.go @@ -20,6 +20,7 @@ import ( // ExampleInterface represents an example interface. type ExampleInterface interface { TheExampleMethod(a, b, c int) (int, error) + TheExampleMethodFast(a, b, c int) (int, error) } // TestExampleImplementation is a test implementation of ExampleInterface @@ -32,6 +33,14 @@ func (i *TestExampleImplementation) TheExampleMethod(a, b, c int) (int, error) { return args.Int(0), errors.New("Whoops") } +func (i *TestExampleImplementation) TheExampleMethodFast(a, b, c int) (int, error) { + var r0 int + var r1 error + i.MockCall(CallSetup{Arguments: Arguments{a, b, c}, Returns: Returns{&r0, &r1}}) + + return r0, r1 +} + type options struct { num int str string @@ -69,6 +78,15 @@ func (i *TestExampleImplementation) TheExampleMethod3(et *ExampleType) error { return args.Error(0) } +func (i *TestExampleImplementation) TheExampleMethod3Fast(et *ExampleType) (string, int, error) { + var r0 string + var r1 int + var r2 error + i.MockCall(CallSetup{Arguments: Arguments{et}, Returns: Returns{&r0, &r1, &r2}}) + + return r0, r1, r2 +} + func (i *TestExampleImplementation) TheExampleMethod4(v ExampleInterface) error { args := i.Called(v) return args.Error(0) @@ -169,6 +187,16 @@ func Test_Mock_On(t *testing.T) { assert.Equal(t, "TheExampleMethod", c.Method) } +func Test_Mock_OnFunc(t *testing.T) { + + // make a test impl object + var mockedService = new(TestExampleImplementation) + + c := mockedService.OnFunc(mockedService.TheExampleMethod) + assert.Equal(t, []*Call{c}, mockedService.ExpectedCalls) + assert.Equal(t, "TheExampleMethod", c.Method) +} + func Test_Mock_Chained_On(t *testing.T) { // make a test impl object var mockedService = new(TestExampleImplementation) @@ -200,6 +228,68 @@ func Test_Mock_Chained_On(t *testing.T) { assert.Equal(t, expectedCalls, mockedService.ExpectedCalls) } +func Test_Mock_Chained_OnFunc(t *testing.T) { + // make a test impl object + var mockedService = new(TestExampleImplementation) + + // determine our current line number so we can assert the expected calls callerInfo properly + _, filename, line, _ := runtime.Caller(0) + mockedService. + OnFunc(mockedService.TheExampleMethod, 1, 2, 3). + Return(0). + OnFunc(mockedService.TheExampleMethod3, AnythingOfType("*mock.ExampleType")). + Return(nil) + + expectedCalls := []*Call{ + { + Parent: &mockedService.Mock, + Method: "TheExampleMethod", + Arguments: []interface{}{1, 2, 3}, + ReturnArguments: []interface{}{0}, + callerInfo: []string{fmt.Sprintf("%s:%d", filename, line+2)}, + }, + { + Parent: &mockedService.Mock, + Method: "TheExampleMethod3", + Arguments: []interface{}{AnythingOfType("*mock.ExampleType")}, + ReturnArguments: []interface{}{nil}, + callerInfo: []string{fmt.Sprintf("%s:%d", filename, line+4)}, + }, + } + assert.Equal(t, expectedCalls, mockedService.ExpectedCalls) +} + +func Test_Mock_Chained_OnFuncFast(t *testing.T) { + // make a test impl object + var mockedService = new(TestExampleImplementation) + + // determine our current line number so we can assert the expected calls callerInfo properly + _, filename, line, _ := runtime.Caller(0) + mockedService. + OnFunc(mockedService.TheExampleMethodFast, 1, 2, 3). + Return(0). + OnFunc(mockedService.TheExampleMethod3Fast, AnythingOfType("*mock.ExampleType")). + Return("Quick", 1, nil) + + expectedCalls := []*Call{ + { + Parent: &mockedService.Mock, + Method: "TheExampleMethodFast", + Arguments: []interface{}{1, 2, 3}, + ReturnArguments: []interface{}{0}, + callerInfo: []string{fmt.Sprintf("%s:%d", filename, line+2)}, + }, + { + Parent: &mockedService.Mock, + Method: "TheExampleMethod3Fast", + Arguments: []interface{}{AnythingOfType("*mock.ExampleType")}, + ReturnArguments: []interface{}{"Quick", 1, nil}, + callerInfo: []string{fmt.Sprintf("%s:%d", filename, line+4)}, + }, + } + assert.Equal(t, expectedCalls, mockedService.ExpectedCalls) +} + func Test_Mock_On_WithArgs(t *testing.T) { // make a test impl object @@ -212,6 +302,18 @@ func Test_Mock_On_WithArgs(t *testing.T) { assert.Equal(t, Arguments{1, 2, 3, 4}, c.Arguments) } +func Test_Mock_OnFunc_WithArgs(t *testing.T) { + + // make a test impl object + var mockedService = new(TestExampleImplementation) + + c := mockedService.OnFunc(mockedService.TheExampleMethod, 1, 2, 3, 4) + + assert.Equal(t, []*Call{c}, mockedService.ExpectedCalls) + assert.Equal(t, "TheExampleMethod", c.Method) + assert.Equal(t, Arguments{1, 2, 3, 4}, c.Arguments) +} + func Test_Mock_On_WithFuncArg(t *testing.T) { // make a test impl object @@ -233,6 +335,27 @@ func Test_Mock_On_WithFuncArg(t *testing.T) { }) } +func Test_Mock_OnFunc_WithFuncArg(t *testing.T) { + + // make a test impl object + var mockedService = new(TestExampleImplementation) + + c := mockedService. + OnFunc(mockedService.TheExampleMethodFunc, AnythingOfType("func(string) error")). + Return(nil) + + assert.Equal(t, []*Call{c}, mockedService.ExpectedCalls) + assert.Equal(t, "TheExampleMethodFunc", c.Method) + assert.Equal(t, 1, len(c.Arguments)) + assert.Equal(t, AnythingOfType("func(string) error"), c.Arguments[0]) + + fn := func(string) error { return nil } + + assert.NotPanics(t, func() { + mockedService.TheExampleMethodFunc(fn) + }) +} + func Test_Mock_On_WithIntArgMatcher(t *testing.T) { var mockedService TestExampleImplementation @@ -256,6 +379,29 @@ func Test_Mock_On_WithIntArgMatcher(t *testing.T) { }) } +func Test_Mock_OnFunc_WithIntArgMatcher(t *testing.T) { + var mockedService TestExampleImplementation + + mockedService.OnFunc(mockedService.TheExampleMethod, + MatchedBy(func(a int) bool { + return a == 1 + }), MatchedBy(func(b int) bool { + return b == 2 + }), MatchedBy(func(c int) bool { + return c == 3 + })).Return(0, nil) + + assert.Panics(t, func() { + mockedService.TheExampleMethod(1, 2, 4) + }) + assert.Panics(t, func() { + mockedService.TheExampleMethod(2, 2, 3) + }) + assert.NotPanics(t, func() { + mockedService.TheExampleMethod(1, 2, 3) + }) +} + func Test_Mock_On_WithArgMatcherThatPanics(t *testing.T) { var mockedService TestExampleImplementation @@ -282,6 +428,32 @@ func Test_Mock_On_WithArgMatcherThatPanics(t *testing.T) { }) } +func Test_Mock_OnFunc_WithArgMatcherThatPanics(t *testing.T) { + var mockedService TestExampleImplementation + + mockedService.OnFunc(mockedService.TheExampleMethod2, MatchedBy(func(_ interface{}) bool { + panic("try to lock mockedService") + })).Return() + + defer func() { + assertedExpectations := make(chan struct{}) + go func() { + tt := new(testing.T) + mockedService.AssertExpectations(tt) + close(assertedExpectations) + }() + select { + case <-assertedExpectations: + case <-time.After(time.Second): + t.Fatal("AssertExpectations() deadlocked, did the panic leave mockedService locked?") + } + }() + + assert.Panics(t, func() { + mockedService.TheExampleMethod2(false) + }) +} + func TestMock_WithTest(t *testing.T) { var ( mockedService TestExampleImplementation @@ -310,6 +482,34 @@ func TestMock_WithTest(t *testing.T) { assert.Equal(t, 1, mockedTest.failNowCount) } +func TestMock_OnFunc_WithTest(t *testing.T) { + var ( + mockedService TestExampleImplementation + mockedTest MockTestingT + ) + + mockedService.Test(&mockedTest) + mockedService.OnFunc(mockedService.TheExampleMethod, 1, 2, 3).Return(0, nil) + + // Test that on an expected call, the test was not failed + + mockedService.TheExampleMethod(1, 2, 3) + + // Assert that Errorf and FailNow were not called + assert.Equal(t, 0, mockedTest.errorfCount) + assert.Equal(t, 0, mockedTest.failNowCount) + + // Test that on unexpected call, the mocked test was called to fail the test + + assert.PanicsWithValue(t, mockTestingTFailNowCalled, func() { + mockedService.TheExampleMethod(1, 1, 1) + }) + + // Assert that Errorf and FailNow were called once + assert.Equal(t, 1, mockedTest.errorfCount) + assert.Equal(t, 1, mockedTest.failNowCount) +} + func Test_Mock_On_WithPtrArgMatcher(t *testing.T) { var mockedService TestExampleImplementation