# Licensed under the GPL: https://www.gnu.org/licenses/old-licenses/gpl-2.0.html # For details: https://github.com/pylint-dev/pylint/blob/main/LICENSE # Copyright (c) https://github.com/pylint-dev/pylint/blob/main/CONTRIBUTORS.txt """Diagram objects.""" from __future__ import annotations from collections.abc import Iterable from typing import Any import astroid from astroid import nodes, util from pylint.checkers.utils import decorated_with_property from pylint.pyreverse.utils import FilterMixIn class Figure: """Base class for counter handling.""" def __init__(self) -> None: self.fig_id: str = "" class Relationship(Figure): """A relationship from an object in the diagram to another.""" def __init__( self, from_object: DiagramEntity, to_object: DiagramEntity, relation_type: str, name: str | None = None, ): super().__init__() self.from_object = from_object self.to_object = to_object self.type = relation_type self.name = name class DiagramEntity(Figure): """A diagram object, i.e. a label associated to an astroid node.""" default_shape = "" def __init__( self, title: str = "No name", node: nodes.NodeNG | None = None ) -> None: super().__init__() self.title = title self.node: nodes.NodeNG = node if node else nodes.NodeNG() self.shape = self.default_shape class PackageEntity(DiagramEntity): """A diagram object representing a package.""" default_shape = "package" class ClassEntity(DiagramEntity): """A diagram object representing a class.""" default_shape = "class" def __init__(self, title: str, node: nodes.ClassDef) -> None: super().__init__(title=title, node=node) self.attrs: list[str] = [] self.methods: list[nodes.FunctionDef] = [] class ClassDiagram(Figure, FilterMixIn): """Main class diagram handling.""" TYPE = "class" def __init__(self, title: str, mode: str) -> None: FilterMixIn.__init__(self, mode) Figure.__init__(self) self.title = title # TODO: Specify 'Any' after refactor of `DiagramEntity` self.objects: list[Any] = [] self.relationships: dict[str, list[Relationship]] = {} self._nodes: dict[nodes.NodeNG, DiagramEntity] = {} def get_relationships(self, role: str) -> Iterable[Relationship]: # sorted to get predictable (hence testable) results return sorted( self.relationships.get(role, ()), key=lambda x: (x.from_object.fig_id, x.to_object.fig_id), ) def add_relationship( self, from_object: DiagramEntity, to_object: DiagramEntity, relation_type: str, name: str | None = None, ) -> None: """Create a relationship.""" rel = Relationship(from_object, to_object, relation_type, name) self.relationships.setdefault(relation_type, []).append(rel) def get_relationship( self, from_object: DiagramEntity, relation_type: str ) -> Relationship: """Return a relationship or None.""" for rel in self.relationships.get(relation_type, ()): if rel.from_object is from_object: return rel raise KeyError(relation_type) def get_attrs(self, node: nodes.ClassDef) -> list[str]: """Return visible attributes, possibly with class name.""" attrs = [] properties = [ (n, m) for n, m in node.items() if isinstance(m, nodes.FunctionDef) and decorated_with_property(m) ] for node_name, associated_nodes in ( list(node.instance_attrs_type.items()) + list(node.locals_type.items()) + properties ): if not self.show_attr(node_name): continue names = self.class_names(associated_nodes) if names: node_name = f"{node_name} : {', '.join(names)}" attrs.append(node_name) return sorted(attrs) def get_methods(self, node: nodes.ClassDef) -> list[nodes.FunctionDef]: """Return visible methods.""" methods = [ m for m in node.values() if isinstance(m, nodes.FunctionDef) and not isinstance(m, astroid.objects.Property) and not decorated_with_property(m) and self.show_attr(m.name) ] return sorted(methods, key=lambda n: n.name) # type: ignore[no-any-return] def add_object(self, title: str, node: nodes.ClassDef) -> None: """Create a diagram object.""" assert node not in self._nodes ent = ClassEntity(title, node) self._nodes[node] = ent self.objects.append(ent) def class_names(self, nodes_lst: Iterable[nodes.NodeNG]) -> list[str]: """Return class names if needed in diagram.""" names = [] for node in nodes_lst: if isinstance(node, astroid.Instance): node = node._proxied if ( isinstance( node, (nodes.ClassDef, nodes.Name, nodes.Subscript, nodes.BinOp) ) and hasattr(node, "name") and not self.has_node(node) ): if node.name not in names: node_name = node.name names.append(node_name) return names def has_node(self, node: nodes.NodeNG) -> bool: """Return true if the given node is included in the diagram.""" return node in self._nodes def object_from_node(self, node: nodes.NodeNG) -> DiagramEntity: """Return the diagram object mapped to node.""" return self._nodes[node] def classes(self) -> list[ClassEntity]: """Return all class nodes in the diagram.""" return [o for o in self.objects if isinstance(o, ClassEntity)] def classe(self, name: str) -> ClassEntity: """Return a class by its name, raise KeyError if not found.""" for klass in self.classes(): if klass.node.name == name: return klass raise KeyError(name) def extract_relationships(self) -> None: """Extract relationships between nodes in the diagram.""" for obj in self.classes(): node = obj.node obj.attrs = self.get_attrs(node) obj.methods = self.get_methods(node) obj.shape = "class" # inheritance link for par_node in node.ancestors(recurs=False): try: par_obj = self.object_from_node(par_node) self.add_relationship(obj, par_obj, "specialization") except KeyError: continue # associations & aggregations links for name, values in list(node.aggregations_type.items()): for value in values: self.assign_association_relationship( value, obj, name, "aggregation" ) for name, values in list(node.associations_type.items()) + list( node.locals_type.items() ): for value in values: self.assign_association_relationship( value, obj, name, "association" ) def assign_association_relationship( self, value: astroid.NodeNG, obj: ClassEntity, name: str, type_relationship: str ) -> None: if isinstance(value, util.UninferableBase): return if isinstance(value, astroid.Instance): value = value._proxied try: associated_obj = self.object_from_node(value) self.add_relationship(associated_obj, obj, type_relationship, name) except KeyError: return class PackageDiagram(ClassDiagram): """Package diagram handling.""" TYPE = "package" def modules(self) -> list[PackageEntity]: """Return all module nodes in the diagram.""" return [o for o in self.objects if isinstance(o, PackageEntity)] def module(self, name: str) -> PackageEntity: """Return a module by its name, raise KeyError if not found.""" for mod in self.modules(): if mod.node.name == name: return mod raise KeyError(name) def add_object(self, title: str, node: nodes.Module) -> None: """Create a diagram object.""" assert node not in self._nodes ent = PackageEntity(title, node) self._nodes[node] = ent self.objects.append(ent) def get_module(self, name: str, node: nodes.Module) -> PackageEntity: """Return a module by its name, looking also for relative imports; raise KeyError if not found. """ for mod in self.modules(): mod_name = mod.node.name if mod_name == name: return mod # search for fullname of relative import modules package = node.root().name if mod_name == f"{package}.{name}": return mod if mod_name == f"{package.rsplit('.', 1)[0]}.{name}": return mod raise KeyError(name) def add_from_depend(self, node: nodes.ImportFrom, from_module: str) -> None: """Add dependencies created by from-imports.""" mod_name = node.root().name obj = self.module(mod_name) if from_module not in obj.node.depends: obj.node.depends.append(from_module) def extract_relationships(self) -> None: """Extract relationships between nodes in the diagram.""" super().extract_relationships() for class_obj in self.classes(): # ownership try: mod = self.object_from_node(class_obj.node.root()) self.add_relationship(class_obj, mod, "ownership") except KeyError: continue for package_obj in self.modules(): package_obj.shape = "package" # dependencies for dep_name in package_obj.node.depends: try: dep = self.get_module(dep_name, package_obj.node) except KeyError: continue self.add_relationship(package_obj, dep, "depends")