summaryrefslogtreecommitdiff
path: root/cheetah/ImportHooks.py
blob: 0ae51419ffe045a8e5388dc390bef2c624ccd3c7 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
#!/usr/bin/env python

"""
Provides some import hooks to allow Cheetah's .tmpl files to be imported
directly like Python .py modules.

To use these:
  import Cheetah.ImportHooks
  Cheetah.ImportHooks.install()
"""

import sys
import os.path
import types
import __builtin__
import imp
from threading import RLock
import string
import traceback
import types

from Cheetah import ImportManager
from Cheetah.ImportManager import DirOwner
from Cheetah.Compiler import Compiler
from Cheetah.convertTmplPathToModuleName import convertTmplPathToModuleName

_installed = False

##################################################
## HELPER FUNCS

_cacheDir = []
def setCacheDir(cacheDir):
    global _cacheDir
    _cacheDir.append(cacheDir)
    
##################################################
## CLASSES

class CheetahDirOwner(DirOwner):
    _lock = RLock()
    _acquireLock = _lock.acquire
    _releaseLock = _lock.release

    templateFileExtensions = ('.tmpl',)

    def getmod(self, name):
        self._acquireLock()
        try:        
            mod = DirOwner.getmod(self, name)
            if mod:
                return mod

            for ext in self.templateFileExtensions:
                tmplPath =  os.path.join(self.path, name + ext)
                if os.path.exists(tmplPath):
                    try:
                        return self._compile(name, tmplPath)
                    except:
                        # @@TR: log the error
                        exc_txt = traceback.format_exc()
                        exc_txt ='  '+('  \n'.join(exc_txt.splitlines()))
                        raise ImportError(
                            'Error while compiling Cheetah module'
                        ' %(name)s, original traceback follows:\n%(exc_txt)s'%locals())
            ##
            return None

        finally:
            self._releaseLock()          

    def _compile(self, name, tmplPath):
        ## @@ consider adding an ImportError raiser here
        code = str(Compiler(file=tmplPath, moduleName=name,
                            mainClassName=name))
        if _cacheDir:
            __file__ = os.path.join(_cacheDir[0],
                                    convertTmplPathToModuleName(tmplPath)) + '.py'
            try:
                open(__file__, 'w').write(code)
            except OSError:
                ## @@ TR: need to add some error code here
                traceback.print_exc(file=sys.stderr)
                __file__ = tmplPath
        else:
            __file__ = tmplPath
        co = compile(code+'\n', __file__, 'exec')

        mod = types.ModuleType(name)
        mod.__file__ = co.co_filename
        if _cacheDir:
            mod.__orig_file__ = tmplPath # @@TR: this is used in the WebKit
                                         # filemonitoring code
        mod.__co__ = co
        return mod
        

##################################################
## FUNCTIONS

def install(templateFileExtensions=('.tmpl',)):
    """Install the Cheetah Import Hooks"""

    global _installed
    if not _installed:
        CheetahDirOwner.templateFileExtensions = templateFileExtensions
        import __builtin__
        if isinstance(__builtin__.__import__, types.BuiltinFunctionType):
            global __oldimport__
            __oldimport__ = __builtin__.__import__
            ImportManager._globalOwnerTypes.insert(0, CheetahDirOwner)
            #ImportManager._globalOwnerTypes.append(CheetahDirOwner)            
            global _manager
            _manager=ImportManager.ImportManager()
            _manager.setThreaded()
            _manager.install()
        
def uninstall():
    """Uninstall the Cheetah Import Hooks"""    
    global _installed
    if not _installed:
        import __builtin__
        if isinstance(__builtin__.__import__, types.MethodType):
            __builtin__.__import__ = __oldimport__
            global _manager
            del _manager

if __name__ == '__main__':
    install()