forked from Leo-Guo/monkey
-
Notifications
You must be signed in to change notification settings - Fork 4
/
monkey.go
133 lines (109 loc) · 3.07 KB
/
monkey.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
package monkey // import "bou.ke/monkey"
import (
"fmt"
"reflect"
"sync"
"unsafe"
)
// patch is an applied patch
// needed to undo a patch
type patch struct {
originalBytes []byte
replacement *reflect.Value
}
var (
lock = sync.Mutex{}
patches = make(map[uintptr]patch)
)
type value struct {
_ uintptr
ptr unsafe.Pointer
}
func getPtr(v reflect.Value) unsafe.Pointer {
return (*value)(unsafe.Pointer(&v)).ptr
}
type PatchGuard struct {
target reflect.Value
replacement reflect.Value
}
func (g *PatchGuard) Unpatch() {
unpatchValue(g.target)
}
func (g *PatchGuard) Restore() {
patchValue(g.target, g.replacement)
}
// Patch replaces a function with another
func Patch(target, replacement interface{}) *PatchGuard {
t := reflect.ValueOf(target)
r := reflect.ValueOf(replacement)
patchValue(t, r)
return &PatchGuard{t, r}
}
// PatchInstanceMethod replaces an instance method methodName for the type target with replacement
// Replacement should expect the receiver (of type target) as the first argument
func PatchInstanceMethod(target reflect.Type, methodName string, replacement interface{}) *PatchGuard {
m, ok := target.MethodByName(methodName)
if !ok {
panic(fmt.Sprintf("unknown method %s", methodName))
}
r := reflect.ValueOf(replacement)
patchValue(m.Func, r)
return &PatchGuard{m.Func, r}
}
func patchValue(target, replacement reflect.Value) {
lock.Lock()
defer lock.Unlock()
if target.Kind() != reflect.Func {
panic("target has to be a Func")
}
if replacement.Kind() != reflect.Func {
panic("replacement has to be a Func")
}
if target.Type() != replacement.Type() {
panic(fmt.Sprintf("target and replacement have to have the same type %s != %s", target.Type(), replacement.Type()))
}
if patch, ok := patches[target.Pointer()]; ok {
unpatch(target.Pointer(), patch)
}
bytes := replaceFunction(target.Pointer(), (uintptr)(getPtr(replacement)))
patches[target.Pointer()] = patch{bytes, &replacement}
}
// Unpatch removes any monkey patches on target
// returns whether target was patched in the first place
func Unpatch(target interface{}) bool {
return unpatchValue(reflect.ValueOf(target))
}
// UnpatchInstanceMethod removes the patch on methodName of the target
// returns whether it was patched in the first place
func UnpatchInstanceMethod(target reflect.Type, methodName string) bool {
m, ok := target.MethodByName(methodName)
if !ok {
panic(fmt.Sprintf("unknown method %s", methodName))
}
return unpatchValue(m.Func)
}
// UnpatchAll removes all applied monkeypatches
func UnpatchAll() {
lock.Lock()
defer lock.Unlock()
for target, p := range patches {
unpatch(target, p)
delete(patches, target)
}
}
// Unpatch removes a monkeypatch from the specified function
// returns whether the function was patched in the first place
func unpatchValue(target reflect.Value) bool {
lock.Lock()
defer lock.Unlock()
patch, ok := patches[target.Pointer()]
if !ok {
return false
}
unpatch(target.Pointer(), patch)
delete(patches, target.Pointer())
return true
}
func unpatch(target uintptr, p patch) {
copyToLocation(target, p.originalBytes)
}