diff options
author | R. Tyler Ballance <tyler@slide.com> | 2009-07-12 17:00:26 -0700 |
---|---|---|
committer | R. Tyler Ballance <tyler@slide.com> | 2009-07-12 17:00:26 -0700 |
commit | 6dec6a56a749e1e45f40d602eb3c47f447f4b1dd (patch) | |
tree | 02986a7dc96ab38b9184f27449cf4e6ac9f09148 | |
parent | f21c4c2aa7f724df6e2f101c887ee4e518db6a90 (diff) | |
download | python-cheetah-6dec6a56a749e1e45f40d602eb3c47f447f4b1dd.tar.gz |
Add preliminary support for multiple inheritance via the #extends directive
This is covered in mantis #26
-rw-r--r-- | src/Compiler.py | 64 | ||||
-rw-r--r-- | src/Parser.py | 39 | ||||
-rw-r--r-- | src/Tests/Template.py | 4 |
3 files changed, 74 insertions, 33 deletions
diff --git a/src/Compiler.py b/src/Compiler.py index 4b8e44e..39c7f51 100644 --- a/src/Compiler.py +++ b/src/Compiler.py @@ -1739,38 +1739,40 @@ class ModuleCompiler(SettingsManager, GenUtils): # - We also assume that the final . separates the classname from the # module name. This might break if people do something really fancy # with their dots and namespaces. - chunks = baseClassName.split('.') - if len(chunks)==1: - self._getActiveClassCompiler().setBaseClass(baseClassName) - if baseClassName not in self.importedVarNames(): - modName = baseClassName - # we assume the class name to be the module name - # and that it's not a builtin: - importStatement = "from %s import %s" % (modName, baseClassName) - self.addImportStatement(importStatement) - self.addImportedVarNames( [baseClassName,] ) - else: - needToAddImport = True - modName = chunks[0] - #print chunks, ':', self.importedVarNames() - for chunk in chunks[1:-1]: - if modName in self.importedVarNames(): - needToAddImport = False - finalBaseClassName = baseClassName.replace(modName+'.', '') - self._getActiveClassCompiler().setBaseClass(finalBaseClassName) - break - else: - modName += '.'+chunk - if needToAddImport: - modName, finalClassName = '.'.join(chunks[:-1]), chunks[-1] - #if finalClassName != chunks[:-1][-1]: - if finalClassName != chunks[-2]: + baseclasses = baseClassName.split(',') + for klass in baseclasses: + chunks = klass.split('.') + if len(chunks)==1: + self._getActiveClassCompiler().setBaseClass(klass) + if klass not in self.importedVarNames(): + modName = klass # we assume the class name to be the module name - modName = '.'.join(chunks) - self._getActiveClassCompiler().setBaseClass(finalClassName) - importStatement = "from %s import %s" % (modName, finalClassName) - self.addImportStatement(importStatement) - self.addImportedVarNames( [finalClassName,] ) + # and that it's not a builtin: + importStatement = "from %s import %s" % (modName, klass) + self.addImportStatement(importStatement) + self.addImportedVarNames((klass,)) + else: + needToAddImport = True + modName = chunks[0] + #print chunks, ':', self.importedVarNames() + for chunk in chunks[1:-1]: + if modName in self.importedVarNames(): + needToAddImport = False + finalBaseClassName = klass.replace(modName+'.', '') + self._getActiveClassCompiler().setBaseClass(finalBaseClassName) + break + else: + modName += '.'+chunk + if needToAddImport: + modName, finalClassName = '.'.join(chunks[:-1]), chunks[-1] + #if finalClassName != chunks[:-1][-1]: + if finalClassName != chunks[-2]: + # we assume the class name to be the module name + modName = '.'.join(chunks) + self._getActiveClassCompiler().setBaseClass(finalClassName) + importStatement = "from %s import %s" % (modName, finalClassName) + self.addImportStatement(importStatement) + self.addImportedVarNames( [finalClassName,] ) def setCompilerSetting(self, key, valueExpr): self.setSetting(key, eval(valueExpr) ) diff --git a/src/Parser.py b/src/Parser.py index 3e6e7fe..7436e9c 100644 --- a/src/Parser.py +++ b/src/Parser.py @@ -596,6 +596,42 @@ class _LowLevelParser(SourceReader): if not match: raise ParseError(self, msg='Invalid multi-line comment end token') return self.readTo(match.end()) + + def getCommaSeparatedSymbols(self): + """ + Loosely based on getDottedName to pull out comma separated + named chunks + """ + srcLen = len(self) + pieces = [] + nameChunks = [] + + if not self.peek() in identchars: + raise ParseError(self) + + while self.pos() < srcLen: + c = self.peek() + if c in namechars: + nameChunk = self.getIdentifier() + nameChunks.append(nameChunk) + elif c == '.': + if self.pos()+1 <srcLen and self.peek(1) in identchars: + nameChunks.append(self.getc()) + else: + break + elif c == ',': + self.getc() + pieces.append(''.join(nameChunks)) + nameChunks = [] + elif c in (' ', '\t'): + self.getc() + else: + break + + if nameChunks: + pieces.append(''.join(nameChunks)) + + return pieces def getDottedName(self): srcLen = len(self) @@ -2037,7 +2073,8 @@ class _HighLevelParser(_LowLevelParser): if self.setting('allowExpressionsInExtendsDirective'): baseName = self.getExpression() else: - baseName = self.getDottedName() + baseName = self.getCommaSeparatedSymbols() + baseName = ', '.join(baseName) baseName = self._applyExpressionFilters(baseName, 'extends', startPos=startPos) self._compiler.setBaseClass(baseName) # in compiler diff --git a/src/Tests/Template.py b/src/Tests/Template.py index 06f9768..085180d 100644 --- a/src/Tests/Template.py +++ b/src/Tests/Template.py @@ -338,7 +338,9 @@ class MultipleInheritanceSupport(TemplateTest): #return [4,5] + $boink() #end def ''' - template = Template.compile(template) + template = Template.compile(template, + moduleGlobals={'Useless' : Useless}, + compilerSettings={'autoImportForExtendsDirective' : False}) template = template() result = template.foo() print result |