summaryrefslogtreecommitdiff
path: root/pyreverse/diagrams.py
blob: 7f281d44727bca6ddf45169c62f38d89a1a938bd (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
# Copyright (c) 2004 LOGILAB S.A. (Paris, FRANCE).
# http://www.logilab.fr/ -- mailto:contact@logilab.fr
#
# This program is free software; you can redistribute it and/or modify it under
# the terms of the GNU General Public License as published by the Free Software
# Foundation; either version 2 of the License, or (at your option) any later
# version.
#
# This program is distributed in the hope that it will be useful, but WITHOUT
# ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS
# FOR A PARTICULAR PURPOSE. See the GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License along with
# this program; if not, write to the Free Software Foundation, Inc.,
# 59 Temple Place - Suite 330, Boston, MA  02111-1307, USA.
"""diagram objects
"""

__revision__ = "$Id: diagrams.py,v 1.6 2006-03-14 09:56:08 syt Exp $"

from pyreverse.utils import is_interface
from logilab import astng

def set_counter(value):
    Figure._UID_COUNT = value
    
class Figure:
    _UID_COUNT = 0
    def __init__(self):
        Figure._UID_COUNT += 1
        self.fig_id = Figure._UID_COUNT
        
class Relationship(Figure):
    """a relation ship from an object in the diagram to another
    """
    def __init__(self, from_object, to_object, r_type, name=None):
        Figure.__init__(self)
        self.from_object = from_object
        self.to_object = to_object
        self.type = r_type
        self.name = name
        
    
class DiagramEntity(Figure):
    """a diagram object, ie a label associated to an astng node
    """
    def __init__(self, title='No name', node=None):
        Figure.__init__(self)
        self.title = title
        self.node = node

class ClassDiagram(Figure):
    """a class diagram objet
    """
    TYPE = 'class'
    def __init__(self, title='No name'):
        Figure.__init__(self)
        self.title = title
        self.objects = []
        self.relationships = {}
        self._nodes = {}
        
    def add_relationship(self, from_object, to_object, r_type, name=None):
        """create a relation ship
        """
        rel = Relationship(from_object, to_object, r_type, name)
        self.relationships.setdefault(r_type, []).append(rel)

    def get_relationship(self, from_object, r_type):
        """return a relation ship or None
        """
        for rel in self.relationships.get(r_type, ()):
            if rel.from_object is from_object:
                return rel
        raise KeyError(r_type)
    
    def add_object(self, title, node):
        """create a diagram object
        """
        assert not self._nodes.has_key(node)
        ent = DiagramEntity(title, node)
        self._nodes[node] = ent
        self.objects.append(ent)

    def nodes(self):
        """return the list of underlying nodes
        """
        return self._nodes.keys()

    def has_node(self, node):
        """return true if the given node is included in the diagram
        """
        return self._nodes.has_key(node)
        
    def object_from_node(self, node):
        """return the diagram object mapped to node
        """
        return self._nodes[node]
            
    def classes(self):
        """return all class nodes in the diagram"""
        return [o for o in self.objects if isinstance(o.node, astng.Class)]

    def classe(self, name):
        """return a klass by its name, raise KeyError if not found
        """
        for klass in self.classes():
            if klass.node.name == name:
                return klass
        raise KeyError(name)
    
    def extract_relationships(self):
        """extract relation ships between nodes in the diagram
        """
        for obj in self.classes():
            node = obj.node
            # shape
            if is_interface(node):
                obj.shape = 'interface'
            else:
                obj.shape = 'class'
            # inheritance link
            for par_node in node.baseobjects:
                try:
                    par_obj = self.object_from_node(par_node)
                    self.add_relationship(obj, par_obj, 'specialization')
                except KeyError:
                    continue
            # implements link
            for impl_node in node.implements:
                try:
                    impl_obj = self.object_from_node(impl_node)
                    self.add_relationship(obj, impl_obj, 'implements')
                except KeyError:
                    continue
            # associations link
            for name, value in node.instance_attrs_type.items():
                try:
                    ass_obj = self.object_from_node(value)
                    self.add_relationship(obj, ass_obj, 'association', name)
                except KeyError:
                    continue
        
class PackageDiagram(ClassDiagram):
    TYPE = 'package'
    
    def modules(self):
        """return all module nodes in the diagram"""
        return [o for o in self.objects if isinstance(o.node, astng.Module)]

    def module(self, name):
        """return a module by its name, raise KeyError if not found
        """
        for mod in self.modules():
            if mod.node.name == name:
                return mod
        raise KeyError(name)
    
    def extract_relationships(self):
        """extract relation ships between nodes in the diagram
        """
        ClassDiagram.extract_relationships(self)
        for obj in self.classes():
            node = obj.node
            # ownership
            try:
                mod = self.object_from_node(node.root())
                self.add_relationship(obj, mod, 'ownership')
            except KeyError:
                continue
        for obj in self.modules():
            obj.shape = 'package'
            # dependancies
            for dep in obj.node.depends:
                try:
                    dep = self.module(dep)
                except KeyError:
                    continue
                self.add_relationship(obj, dep, 'depends')