summaryrefslogtreecommitdiff
path: root/buildscripts/gdb/mongo_lock.py
blob: 04b76b315b17f134869282b5979bd1bbf5fc5fef (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
from __future__ import print_function

import gdb
import gdb.printing
import re
import sys

if sys.version_info[0] >= 3:
    # GDB only permits converting a gdb.Value instance to its numerical address when using the
    # long() constructor in Python 2 and not when using the int() constructor. We define the
    # 'long' class as an alias for the 'int' class in Python 3 for compatibility.
    long = int


class Thread(object):
    def __init__(self, thread_id, lwpid):
        self.thread_id = thread_id
        self.lwpid = lwpid

    def __eq__(self, other):
        if isinstance(other, Thread):
            return self.thread_id == other.thread_id
        return NotImplemented

    def __ne__(self, other):
        return not self == other

    def __str__(self):
        return "Thread 0x{:012x} (LWP {})".format(self.thread_id, self.lwpid)

    def key(self):
        return "Thread 0x{:012x}".format(self.thread_id)


class Lock(object):
    def __init__(self, addr, resource):
        self.addr = addr
        self.resource = resource

    def __eq__(self, other):
        if isinstance(other, Lock):
            return self.addr == other.addr
        return NotImplemented

    def __ne__(self, other):
        return not self == other

    def __str__(self):
        return "Lock 0x{:012x} ({})".format(self.addr, self.resource)

    def key(self):
        return "Lock 0x{:012x}".format(self.addr)


class Graph(object):
    # The Graph is a dict with the following structure:
    #   {'node_key': {'node': {id: val}, 'next_nodes': [node_key_1, ...]}}
    # Example graph:
    #   {
    #    'Lock 1': {'node': {1: 'MongoDB lock'}, 'next_nodes': ['Thread 1']},
    #    'Lock 2': {'node': {2: 'MongoDB lock'}, 'next_nodes': ['Thread 2']},
    #    'Thread 1': {'node': {1: 123}, 'next_nodes': ['Lock 2']},
    #    'Thread 2': {'node': {2: 456}, 'next_nodes': ['Lock 1']}
    #   }
    def __init__(self):
        self.nodes = {}

    def is_empty(self):
        return not bool(self.nodes)

    def add_node(self, node):
        if not self.find_node(node):
            self.nodes[node.key()] = {'node': node, 'next_nodes': []}

    def find_node(self, node):
        if node.key() in self.nodes:
            return self.nodes[node.key()]
        return None

    def find_from_node(self, from_node):
        for node_key in self.nodes:
            node = self.nodes[node_key]
            for next_node in node['next_nodes']:
                if next_node == from_node['node'].key():
                    return node
        return None

    def remove_nodes_without_edge(self):
        # Rebuild graph by removing any nodes which do not have any incoming or outgoing edges.
        temp_nodes = {}
        for node_key in self.nodes:
            node = self.nodes[node_key]
            if node['next_nodes'] or self.find_from_node(node) is not None:
                temp_nodes[node_key] = self.nodes[node_key]
        self.nodes = temp_nodes

    def add_edge(self, from_node, to_node):
        f = self.find_node(from_node)
        if f is None:
            self.add_node(from_node)
            f = self.nodes[from_node.key()]

        t = self.find_node(to_node)
        if t is None:
            self.add_node(to_node)
            t = self.nodes[to_node.key()]

        for n in f['next_nodes']:
            if n == to_node.key():
                return
        self.nodes[from_node.key()]['next_nodes'].append(to_node.key())

    def print(self):
        for node_key in self.nodes:
            print("Node", self.nodes[node_key]['node'])
            for to in self.nodes[node_key]['next_nodes']:
                print(" ->", to)

    def to_graph(self, nodes=None, message=None):
        sb = []
        sb.append('# Legend:')
        sb.append('#    Thread 1 -> Lock 1 indicates Thread 1 is waiting on Lock 1')
        sb.append('#    Lock 2 -> Thread 2 indicates Lock 2 is held by Thread 2')
        if message is not None:
            sb.append(message)
        sb.append('digraph "mongod+lock-status" {')
        for node_key in self.nodes:
            for next_node_key in self.nodes[node_key]['next_nodes']:
                sb.append('    "{}" -> "{}";'.format(node_key, next_node_key))
        for node_key in self.nodes:
            color = ""
            if nodes and node_key in nodes:
                color = "color = red"
            sb.append('    "{}" [label="{}" {}]'.format(
                node_key, self.nodes[node_key]['node'], color))
        sb.append("}")
        return "\n".join(sb)

    def depth_first_search(self, node_key, nodes_visited, nodes_in_cycle=[]):
        """
        The nodes_visited is a set of nodes which indicates it has been visited.
        The node_in_cycle is a list of nodes in the potential cycle.
        Returns the list of nodes in the cycle or None.
        """
        nodes_visited.add(node_key)
        nodes_in_cycle.append(node_key)
        for node in self.nodes[node_key]['next_nodes']:
            if node in nodes_in_cycle:
                # The graph cycle starts at the index of node in nodes_in_cycle.
                return nodes_in_cycle[nodes_in_cycle.index(node):]
            if node in nodes_visited:
                dfs_nodes = self.depth_first_search(node, nodes_visited, nodes_in_cycle)
                if dfs_nodes:
                    return dfs_nodes

        # This node_key is not part of the graph cycle.
        nodes_in_cycle.pop()
        return None

    def detect_cycle(self):
        """
        If a cycle is detected, returns a list of nodes in the cycle or None.
        """
        nodes_visited = set()
        for node in self.nodes:
            if node not in nodes_visited:
                cycle_path = self.depth_first_search(node, nodes_visited)
                if cycle_path:
                    return cycle_path
        return None


def find_lwpid(thread_dict, search_thread_id):
    for (lwpid, thread_id) in thread_dict.items():
        if thread_id == search_thread_id:
            return lwpid
    return None


def find_func_block(block):
    while block:
        if block.function:
            return block
        block = block.superblock
    return None


def find_frame(function_name_pattern):
    frame = gdb.newest_frame()
    while frame:
        block = None
        try:
            block = frame.block()
        except RuntimeError as err:
            if err.args[0] != "Cannot locate block for frame.":
                raise

        block = find_func_block(block)
        if block and re.match(function_name_pattern, block.function.name):
            return frame
        try:
            frame = frame.older()
        except gdb.error as err:
            print("Ignoring GDB error '%s' in find_frame" % str(err))
            break
    return None


def find_mutex_holder(graph, thread_dict, show):
    frame = find_frame(r'std::mutex::lock\(\)')
    if frame is None:
        return

    frame.select()

    # Waiting for mutex locking!
    mutex_this, _ = gdb.lookup_symbol("this", frame.block())
    mutex_value = mutex_this.value(frame)
    # The mutex holder is a LWPID
    mutex_holder = int(mutex_value["_M_mutex"]["__data"]["__owner"])
    mutex_holder_id = thread_dict[mutex_holder]

    (_, mutex_waiter_lwpid, _) = gdb.selected_thread().ptid
    mutex_waiter_id = thread_dict[mutex_waiter_lwpid]
    if show:
        print("Mutex at {} held by thread 0x{:x} (LWP {}) "
              " waited on by thread 0x{:x} (LWP {})".format(mutex_value,
                                                            mutex_holder_id,
                                                            mutex_holder,
                                                            mutex_waiter_id,
                                                            mutex_waiter_lwpid))
    if graph:
        graph.add_edge(Thread(mutex_waiter_id, mutex_waiter_lwpid),
                       Lock(long(mutex_value), "Mutex"))
        graph.add_edge(Lock(long(mutex_value), "Mutex"), Thread(mutex_holder_id, mutex_holder))


def find_lock_manager_holders(graph, thread_dict, show):
    frame = find_frame(r'mongo::LockerImpl\<.*\>::')
    if not frame:
        return

    frame.select()

    (_, lwpid, _) = gdb.selected_thread().ptid

    locker_ptr_type = gdb.lookup_type("mongo::LockerImpl<false>").pointer()
    lock_head = gdb.parse_and_eval(
        "mongo::getGlobalLockManager()->_getBucket(resId)->findOrInsert(resId)")

    grantedList = lock_head.dereference()["grantedList"]
    lock_request_ptr = grantedList["_front"]
    while lock_request_ptr:
        lock_request = lock_request_ptr.dereference()
        locker_ptr = lock_request["locker"]
        locker_ptr = locker_ptr.cast(locker_ptr_type)
        locker = locker_ptr.dereference()
        lock_thread_id = int(locker["_threadId"]["_M_thread"])
        lock_thread_lwpid = find_lwpid(thread_dict, lock_thread_id)
        if show:
            print("MongoDB Lock at {} ({}) held by thread id 0x{:x} (LWP {})".format(
                lock_head, lock_request["mode"], lock_thread_id, lock_thread_lwpid) +
                " waited on by thread 0x{:x} (LWP {})".format(thread_dict[lwpid], lwpid))
        if graph:
            graph.add_edge(Thread(thread_dict[lwpid], lwpid), Lock(long(lock_head), "MongoDB lock"))
            graph.add_edge(Lock(long(lock_head), "MongoDB lock"),
                           Thread(lock_thread_id, lock_thread_lwpid))
        lock_request_ptr = lock_request["next"]


def get_locks(graph, thread_dict, show=False):
    for thread in gdb.selected_inferior().threads():
        try:
            if not thread.is_valid():
                continue
            thread.switch()
            find_mutex_holder(graph, thread_dict, show)
            find_lock_manager_holders(graph, thread_dict, show)
        except gdb.error as err:
            print("Ignoring GDB error '%s' in get_locks" % str(err))


def get_threads_info(graph=None):
    thread_dict = {}
    for thread in gdb.selected_inferior().threads():
        try:
            if not thread.is_valid():
                continue
            thread.switch()
            # PTID is a tuple: Process ID (PID), Lightweight Process ID (LWPID), Thread ID (TID)
            (_, lwpid, _) = thread.ptid
            thread_num = thread.num
            thread_id = get_thread_id()
            if not thread_id:
                print("Unable to retrieve thread_info for thread %d" % thread_num)
                continue
            thread_dict[lwpid] = thread_id
        except gdb.error as err:
            print("Ignoring GDB error '%s' in get_threads_info" % str(err))

    return thread_dict


class MongoDBShowLocks(gdb.Command):
    """Show MongoDB locks & pthread mutexes"""
    def __init__(self):
        register_mongo_command(self, "mongodb-show-locks", gdb.COMMAND_DATA)

    def invoke(self, arg, _from_tty):
        self.mongodb_show_locks()

    def mongodb_show_locks(self):
        """GDB in-process python supplement"""
        try:
            thread_dict = get_threads_info()
            get_locks(graph=None, thread_dict=thread_dict, show=True)
        except gdb.error as err:
            print("Ignoring GDB error '%s' in mongodb_show_locks" % str(err))

MongoDBShowLocks()


class MongoDBWaitsForGraph(gdb.Command):
    """Create MongoDB WaitsFor lock graph [graph_file]"""
    def __init__(self):
        register_mongo_command(self, "mongodb-waitsfor-graph", gdb.COMMAND_DATA)

    def invoke(self, arg, _from_tty):
        self.mongodb_waitsfor_graph(arg)

    def mongodb_waitsfor_graph(self, file=None):
        """GDB in-process python supplement"""

        graph = Graph()
        try:
            thread_dict = get_threads_info(graph=graph)
            get_locks(graph=graph, thread_dict=thread_dict, show=False)
            graph.remove_nodes_without_edge()
            if graph.is_empty():
                print("Not generating the digraph, since the lock graph is empty")
                return
            cycle_message = "# No cycle detected in the graph"
            cycle_nodes = graph.detect_cycle()
            if cycle_nodes:
                cycle_message = "# Cycle detected in the graph nodes %s" % cycle_nodes
            if file:
                print("Saving digraph to %s" % file)
                with open(file, 'w') as f:
                    f.write(graph.to_graph(nodes=cycle_nodes, message=cycle_message))
                print(cycle_message.split("# ")[1])
            else:
                print(graph.to_graph(nodes=cycle_nodes, message=cycle_message))

        except gdb.error as err:
            print("Ignoring GDB error '%s' in mongod_deadlock_graph" % str(err))


MongoDBWaitsForGraph()

print("MongoDB Lock analysis commands loaded")