summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--lib/ansible/plugins/__init__.py48
1 files changed, 31 insertions, 17 deletions
diff --git a/lib/ansible/plugins/__init__.py b/lib/ansible/plugins/__init__.py
index 83956753dc..9ac54a8d06 100644
--- a/lib/ansible/plugins/__init__.py
+++ b/lib/ansible/plugins/__init__.py
@@ -316,6 +316,7 @@ class PluginLoader:
def get(self, name, *args, **kwargs):
''' instantiates a plugin of the given name using arguments '''
+ class_only = kwargs.pop('class_only', False)
if name in self.aliases:
name = self.aliases[name]
path = self.find_plugin(name)
@@ -325,23 +326,28 @@ class PluginLoader:
if path not in self._module_cache:
self._module_cache[path] = self._load_module_source('.'.join([self.package, name]), path)
- if kwargs.get('class_only', False):
- obj = getattr(self._module_cache[path], self.class_name)
- else:
- obj = getattr(self._module_cache[path], self.class_name)(*args, **kwargs)
- if self.base_class:
- # The import path is hardcoded and should be the right place,
- # so we are not expecting an ImportError.
- module = __import__(self.package, fromlist=[self.base_class])
- # Check whether this obj has the required base class.
- if not issubclass(obj.__class__, getattr(module, self.base_class, None)):
- return None
+ obj = getattr(self._module_cache[path], self.class_name)
+ if self.base_class:
+ # The import path is hardcoded and should be the right place,
+ # so we are not expecting an ImportError.
+ module = __import__(self.package, fromlist=[self.base_class])
+ # Check whether this obj has the required base class.
+ try:
+ plugin_class = getattr(module, self.base_class)
+ except AttributeError:
+ return None
+ if not issubclass(obj, plugin_class):
+ return None
+
+ if not class_only:
+ obj = obj(*args, **kwargs)
return obj
def all(self, *args, **kwargs):
''' instantiates all plugins with the same arguments '''
+ class_only = kwargs.pop('class_only', False)
for i in self._get_paths():
matches = glob.glob(os.path.join(i, "*.py"))
matches.sort()
@@ -353,14 +359,22 @@ class PluginLoader:
if path not in self._module_cache:
self._module_cache[path] = self._load_module_source(name, path)
- if kwargs.get('class_only', False):
- obj = getattr(self._module_cache[path], self.class_name)
- else:
- obj = getattr(self._module_cache[path], self.class_name)(*args, **kwargs)
-
- if self.base_class and self.base_class not in [base.__name__ for base in obj.__class__.__bases__]:
+ obj = getattr(self._module_cache[path], self.class_name)
+ if self.base_class:
+ # The import path is hardcoded and should be the right place,
+ # so we are not expecting an ImportError.
+ module = __import__(self.package, fromlist=[self.base_class])
+ # Check whether this obj has the required base class.
+ try:
+ plugin_class = getattr(module, self.base_class)
+ except AttributeError:
+ continue
+ if not issubclass(obj, plugin_class):
continue
+ if not class_only:
+ obj = obj(*args, **kwargs)
+
# set extra info on the module, in case we want it later
setattr(obj, '_original_path', path)
yield obj