diff --git a/mock/mock.go b/mock/mock.go
index 49328337b..70e9485e1 100644
--- a/mock/mock.go
+++ b/mock/mock.go
@@ -18,8 +18,8 @@ import (
"github.com/stretchr/testify/assert"
)
-// regex for GCCGO functions
-var gccgoRE = regexp.MustCompile(`\.pN\d+_`)
+// regex for GCCGO functions and "-fm"
+var gccgoRE = regexp.MustCompile(`(\.pN\d+_|-fm)`)
// TestingT is an interface wrapper around *testing.T
type TestingT interface {
@@ -32,6 +32,11 @@ type TestingT interface {
Call
*/
+type CallSetup struct {
+ Arguments Arguments
+ Returns Returns
+}
+
// Call represents a method call and is used for setting expectations,
// as well as recording activity.
type Call struct {
@@ -102,6 +107,40 @@ func (c *Call) unlock() {
c.Parent.mutex.Unlock()
}
+// Returns the function name from the function path
+func getMethodNameFromPath(functionPath string) string {
+ if functionPath == "" {
+ panic("method name could not be empty")
+ }
+
+ // Next four lines are required to use GCCGO function naming conventions.
+ // For Ex: github_com_docker_libkv_store_mock.WatchTree.pN39_github_com_docker_libkv_store_mock.Mock
+ // uses interface information unlike golang github.com/docker/libkv/store/mock.(*Mock).WatchTree
+ // With GCCGO we need to remove interface information starting from pN
.
+ // Also, GCCGO uses .-fm instead of . for function names in interface pointers paths.
+ if gccgoRE.MatchString(functionPath) {
+ functionPath = gccgoRE.Split(functionPath, -1)[0]
+ }
+
+ parts := strings.Split(functionPath, ".")
+ functionName := parts[len(parts)-1]
+
+ return functionName
+}
+
+// Returns the function name from the function method interface
+func getFuncInterfaceMethodName(method interface{}) string {
+ methodValue := reflect.ValueOf(method)
+ methodType := methodValue.Type()
+ if methodType.Kind() != reflect.Func {
+ panic("method must be a function")
+ }
+
+ functionPath := runtime.FuncForPC(methodValue.Pointer()).Name()
+
+ return getMethodNameFromPath(functionPath)
+}
+
// Return specifies the return arguments for the expectation.
//
// Mock.On("DoSomething").Return(errors.New("failed"))
@@ -208,6 +247,19 @@ func (c *Call) On(methodName string, arguments ...interface{}) *Call {
return c.Parent.On(methodName, arguments...)
}
+// OnFunc chains a new expectation description onto the mocked interface. This
+// allows syntax like.
+//
+// Mock.
+// OnFunc(mockedService.MyMethod, 1).Return(nil).
+// OnFunc(mockedService.MyOtherMethod, 'a', 'b', 'c').Return(errors.New("Some Error"))
+//
+//go:noinline
+func (c *Call) OnFunc(method interface{}, arguments ...interface{}) *Call {
+ methodName := getFuncInterfaceMethodName(method)
+ return c.Parent.On(methodName, arguments...)
+}
+
// Unset removes a mock handler from being called.
//
// test.On("func", mock.Anything).Unset()
@@ -371,6 +423,16 @@ func (m *Mock) On(methodName string, arguments ...interface{}) *Call {
return c
}
+// OnFunc starts a description of an expectation of the specified method
+// being called.
+//
+// Mock.OnFunc(mockedService.MyMethod, arg1, arg2)
+func (m *Mock) OnFunc(method interface{}, arguments ...interface{}) *Call {
+ methodName := getFuncInterfaceMethodName(method)
+
+ return m.On(methodName, arguments...)
+}
+
// /*
// Recording and responding to activity
// */
@@ -469,18 +531,41 @@ func (m *Mock) Called(arguments ...interface{}) Arguments {
panic("Couldn't get the caller information")
}
functionPath := runtime.FuncForPC(pc).Name()
- // Next four lines are required to use GCCGO function naming conventions.
- // For Ex: github_com_docker_libkv_store_mock.WatchTree.pN39_github_com_docker_libkv_store_mock.Mock
- // uses interface information unlike golang github.com/docker/libkv/store/mock.(*Mock).WatchTree
- // With GCCGO we need to remove interface information starting from pN.
- if gccgoRE.MatchString(functionPath) {
- functionPath = gccgoRE.Split(functionPath, -1)[0]
- }
- parts := strings.Split(functionPath, ".")
- functionName := parts[len(parts)-1]
+ functionName := getMethodNameFromPath(functionPath)
+
return m.MethodCalled(functionName, arguments...)
}
+func (m *Mock) assignIfNotNil(returnValue interface{}, target interface{}) {
+ if returnValue != nil {
+ reflect.ValueOf(target).Elem().Set(reflect.ValueOf(returnValue))
+ }
+}
+
+// MockCall is a helper function to mock a method call with the given setup.
+// It is useful when you want to mock a method call in a easier way in few lines using syntax like.
+//
+// var r0 *SomeType0
+// var r1 *SomeType1
+// var rN error
+// i.MockCall(CallSetup{Arguments: Arguments{args...}, Returns: Returns{&r0, &r1, &rN}})
+//
+// return r0, r1, rN
+func (m *Mock) MockCall(setup CallSetup) {
+ pc, _, _, ok := runtime.Caller(1)
+ if !ok {
+ panic("Couldn't get the caller information")
+ }
+ functionPath := runtime.FuncForPC(pc).Name()
+ functionName := getMethodNameFromPath(functionPath)
+
+ returns := m.MethodCalled(functionName, setup.Arguments...)
+
+ for ind, target := range setup.Returns {
+ m.assignIfNotNil(returns[ind], target)
+ }
+}
+
// MethodCalled tells the mock object that the given method has been called, and gets
// an array of arguments to return. Panics if the call is unexpected (i.e. not preceded
// by appropriate .On .Return() calls)
@@ -772,6 +857,7 @@ func (m *Mock) calls() []Call {
// Arguments holds an array of method arguments or return values.
type Arguments []interface{}
+type Returns []interface{}
const (
// Anything is used in Diff and Assert when the argument being tested
diff --git a/mock/mock_test.go b/mock/mock_test.go
index 5aab204b9..fdfe82359 100644
--- a/mock/mock_test.go
+++ b/mock/mock_test.go
@@ -20,6 +20,7 @@ import (
// ExampleInterface represents an example interface.
type ExampleInterface interface {
TheExampleMethod(a, b, c int) (int, error)
+ TheExampleMethodFast(a, b, c int) (int, error)
}
// TestExampleImplementation is a test implementation of ExampleInterface
@@ -32,6 +33,14 @@ func (i *TestExampleImplementation) TheExampleMethod(a, b, c int) (int, error) {
return args.Int(0), errors.New("Whoops")
}
+func (i *TestExampleImplementation) TheExampleMethodFast(a, b, c int) (int, error) {
+ var r0 int
+ var r1 error
+ i.MockCall(CallSetup{Arguments: Arguments{a, b, c}, Returns: Returns{&r0, &r1}})
+
+ return r0, r1
+}
+
type options struct {
num int
str string
@@ -69,6 +78,15 @@ func (i *TestExampleImplementation) TheExampleMethod3(et *ExampleType) error {
return args.Error(0)
}
+func (i *TestExampleImplementation) TheExampleMethod3Fast(et *ExampleType) (string, int, error) {
+ var r0 string
+ var r1 int
+ var r2 error
+ i.MockCall(CallSetup{Arguments: Arguments{et}, Returns: Returns{&r0, &r1, &r2}})
+
+ return r0, r1, r2
+}
+
func (i *TestExampleImplementation) TheExampleMethod4(v ExampleInterface) error {
args := i.Called(v)
return args.Error(0)
@@ -169,6 +187,16 @@ func Test_Mock_On(t *testing.T) {
assert.Equal(t, "TheExampleMethod", c.Method)
}
+func Test_Mock_OnFunc(t *testing.T) {
+
+ // make a test impl object
+ var mockedService = new(TestExampleImplementation)
+
+ c := mockedService.OnFunc(mockedService.TheExampleMethod)
+ assert.Equal(t, []*Call{c}, mockedService.ExpectedCalls)
+ assert.Equal(t, "TheExampleMethod", c.Method)
+}
+
func Test_Mock_Chained_On(t *testing.T) {
// make a test impl object
var mockedService = new(TestExampleImplementation)
@@ -200,6 +228,68 @@ func Test_Mock_Chained_On(t *testing.T) {
assert.Equal(t, expectedCalls, mockedService.ExpectedCalls)
}
+func Test_Mock_Chained_OnFunc(t *testing.T) {
+ // make a test impl object
+ var mockedService = new(TestExampleImplementation)
+
+ // determine our current line number so we can assert the expected calls callerInfo properly
+ _, filename, line, _ := runtime.Caller(0)
+ mockedService.
+ OnFunc(mockedService.TheExampleMethod, 1, 2, 3).
+ Return(0).
+ OnFunc(mockedService.TheExampleMethod3, AnythingOfType("*mock.ExampleType")).
+ Return(nil)
+
+ expectedCalls := []*Call{
+ {
+ Parent: &mockedService.Mock,
+ Method: "TheExampleMethod",
+ Arguments: []interface{}{1, 2, 3},
+ ReturnArguments: []interface{}{0},
+ callerInfo: []string{fmt.Sprintf("%s:%d", filename, line+2)},
+ },
+ {
+ Parent: &mockedService.Mock,
+ Method: "TheExampleMethod3",
+ Arguments: []interface{}{AnythingOfType("*mock.ExampleType")},
+ ReturnArguments: []interface{}{nil},
+ callerInfo: []string{fmt.Sprintf("%s:%d", filename, line+4)},
+ },
+ }
+ assert.Equal(t, expectedCalls, mockedService.ExpectedCalls)
+}
+
+func Test_Mock_Chained_OnFuncFast(t *testing.T) {
+ // make a test impl object
+ var mockedService = new(TestExampleImplementation)
+
+ // determine our current line number so we can assert the expected calls callerInfo properly
+ _, filename, line, _ := runtime.Caller(0)
+ mockedService.
+ OnFunc(mockedService.TheExampleMethodFast, 1, 2, 3).
+ Return(0).
+ OnFunc(mockedService.TheExampleMethod3Fast, AnythingOfType("*mock.ExampleType")).
+ Return("Quick", 1, nil)
+
+ expectedCalls := []*Call{
+ {
+ Parent: &mockedService.Mock,
+ Method: "TheExampleMethodFast",
+ Arguments: []interface{}{1, 2, 3},
+ ReturnArguments: []interface{}{0},
+ callerInfo: []string{fmt.Sprintf("%s:%d", filename, line+2)},
+ },
+ {
+ Parent: &mockedService.Mock,
+ Method: "TheExampleMethod3Fast",
+ Arguments: []interface{}{AnythingOfType("*mock.ExampleType")},
+ ReturnArguments: []interface{}{"Quick", 1, nil},
+ callerInfo: []string{fmt.Sprintf("%s:%d", filename, line+4)},
+ },
+ }
+ assert.Equal(t, expectedCalls, mockedService.ExpectedCalls)
+}
+
func Test_Mock_On_WithArgs(t *testing.T) {
// make a test impl object
@@ -212,6 +302,18 @@ func Test_Mock_On_WithArgs(t *testing.T) {
assert.Equal(t, Arguments{1, 2, 3, 4}, c.Arguments)
}
+func Test_Mock_OnFunc_WithArgs(t *testing.T) {
+
+ // make a test impl object
+ var mockedService = new(TestExampleImplementation)
+
+ c := mockedService.OnFunc(mockedService.TheExampleMethod, 1, 2, 3, 4)
+
+ assert.Equal(t, []*Call{c}, mockedService.ExpectedCalls)
+ assert.Equal(t, "TheExampleMethod", c.Method)
+ assert.Equal(t, Arguments{1, 2, 3, 4}, c.Arguments)
+}
+
func Test_Mock_On_WithFuncArg(t *testing.T) {
// make a test impl object
@@ -233,6 +335,27 @@ func Test_Mock_On_WithFuncArg(t *testing.T) {
})
}
+func Test_Mock_OnFunc_WithFuncArg(t *testing.T) {
+
+ // make a test impl object
+ var mockedService = new(TestExampleImplementation)
+
+ c := mockedService.
+ OnFunc(mockedService.TheExampleMethodFunc, AnythingOfType("func(string) error")).
+ Return(nil)
+
+ assert.Equal(t, []*Call{c}, mockedService.ExpectedCalls)
+ assert.Equal(t, "TheExampleMethodFunc", c.Method)
+ assert.Equal(t, 1, len(c.Arguments))
+ assert.Equal(t, AnythingOfType("func(string) error"), c.Arguments[0])
+
+ fn := func(string) error { return nil }
+
+ assert.NotPanics(t, func() {
+ mockedService.TheExampleMethodFunc(fn)
+ })
+}
+
func Test_Mock_On_WithIntArgMatcher(t *testing.T) {
var mockedService TestExampleImplementation
@@ -256,6 +379,29 @@ func Test_Mock_On_WithIntArgMatcher(t *testing.T) {
})
}
+func Test_Mock_OnFunc_WithIntArgMatcher(t *testing.T) {
+ var mockedService TestExampleImplementation
+
+ mockedService.OnFunc(mockedService.TheExampleMethod,
+ MatchedBy(func(a int) bool {
+ return a == 1
+ }), MatchedBy(func(b int) bool {
+ return b == 2
+ }), MatchedBy(func(c int) bool {
+ return c == 3
+ })).Return(0, nil)
+
+ assert.Panics(t, func() {
+ mockedService.TheExampleMethod(1, 2, 4)
+ })
+ assert.Panics(t, func() {
+ mockedService.TheExampleMethod(2, 2, 3)
+ })
+ assert.NotPanics(t, func() {
+ mockedService.TheExampleMethod(1, 2, 3)
+ })
+}
+
func Test_Mock_On_WithArgMatcherThatPanics(t *testing.T) {
var mockedService TestExampleImplementation
@@ -282,6 +428,32 @@ func Test_Mock_On_WithArgMatcherThatPanics(t *testing.T) {
})
}
+func Test_Mock_OnFunc_WithArgMatcherThatPanics(t *testing.T) {
+ var mockedService TestExampleImplementation
+
+ mockedService.OnFunc(mockedService.TheExampleMethod2, MatchedBy(func(_ interface{}) bool {
+ panic("try to lock mockedService")
+ })).Return()
+
+ defer func() {
+ assertedExpectations := make(chan struct{})
+ go func() {
+ tt := new(testing.T)
+ mockedService.AssertExpectations(tt)
+ close(assertedExpectations)
+ }()
+ select {
+ case <-assertedExpectations:
+ case <-time.After(time.Second):
+ t.Fatal("AssertExpectations() deadlocked, did the panic leave mockedService locked?")
+ }
+ }()
+
+ assert.Panics(t, func() {
+ mockedService.TheExampleMethod2(false)
+ })
+}
+
func TestMock_WithTest(t *testing.T) {
var (
mockedService TestExampleImplementation
@@ -310,6 +482,34 @@ func TestMock_WithTest(t *testing.T) {
assert.Equal(t, 1, mockedTest.failNowCount)
}
+func TestMock_OnFunc_WithTest(t *testing.T) {
+ var (
+ mockedService TestExampleImplementation
+ mockedTest MockTestingT
+ )
+
+ mockedService.Test(&mockedTest)
+ mockedService.OnFunc(mockedService.TheExampleMethod, 1, 2, 3).Return(0, nil)
+
+ // Test that on an expected call, the test was not failed
+
+ mockedService.TheExampleMethod(1, 2, 3)
+
+ // Assert that Errorf and FailNow were not called
+ assert.Equal(t, 0, mockedTest.errorfCount)
+ assert.Equal(t, 0, mockedTest.failNowCount)
+
+ // Test that on unexpected call, the mocked test was called to fail the test
+
+ assert.PanicsWithValue(t, mockTestingTFailNowCalled, func() {
+ mockedService.TheExampleMethod(1, 1, 1)
+ })
+
+ // Assert that Errorf and FailNow were called once
+ assert.Equal(t, 1, mockedTest.errorfCount)
+ assert.Equal(t, 1, mockedTest.failNowCount)
+}
+
func Test_Mock_On_WithPtrArgMatcher(t *testing.T) {
var mockedService TestExampleImplementation