summaryrefslogtreecommitdiff
path: root/ghc/compiler/simplStg/StgSATMonad.lhs
blob: f0cb84d4d10836614b29e774451454af05506365 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
%
% (c) The GRASP/AQUA Project, Glasgow University, 1992-1995
%
%************************************************************************
%*									*
\section[SATMonad]{The Static Argument Transformation pass Monad}
%*									*
%************************************************************************

\begin{code}
#include "HsVersions.h"

module StgSATMonad (
	getArgLists, saTransform, 

	Id, UniType, SplitUniqSupply, PlainStgExpr(..)
    ) where

import AbsUniType	( mkTyVarTy, mkSigmaTy, TyVarTemplate,
			  extractTyVarsFromTy, splitType, splitTyArgs,
			  glueTyArgs, instantiateTy, TauType(..),
			  Class, ThetaType(..), SigmaType(..),
			  InstTyEnv(..)
			)
import IdEnv
import Id		( mkSysLocal, getIdUniType, eqId )
import Maybes		( Maybe(..) )
import StgSyn
import SATMonad         ( SATEnv(..), SATInfo(..), Arg(..), updSAEnv, insSAEnv,
                          SatM(..), initSAT, thenSAT, thenSAT_,
                          emptyEnvSAT, returnSAT, mapSAT, isStatic, dropStatics,
                          getSATInfo, newSATName )
import SrcLoc		( SrcLoc, mkUnknownSrcLoc )
import SplitUniq
import Unique
import UniqSet		( UniqSet(..), emptyUniqSet )
import Util

\end{code}

%************************************************************************
%*									*
\subsection{Utility Functions}
%*									*
%************************************************************************

\begin{code}
newSATNames :: [Id] -> SatM [Id]
newSATNames [] = returnSAT []
newSATNames (id:ids) = newSATName id (getIdUniType id)	`thenSAT` \ id' ->
                       newSATNames ids			`thenSAT` \ ids' ->
                       returnSAT (id:ids)

getArgLists :: PlainStgRhs -> ([Arg UniType],[Arg Id])
getArgLists (StgRhsCon _ _ _) 
  = ([],[])
getArgLists (StgRhsClosure _ _ _ _ args _)
  = ([], [Static v | v <- args])

\end{code}

\begin{code}
saTransform :: Id -> PlainStgRhs -> SatM PlainStgBinding
saTransform binder rhs
  = getSATInfo binder `thenSAT` \ r ->
    case r of
      Just (_,args) | any isStatic args 
      -- [Andre] test: do it only if we have more than one static argument.
      --Just (_,args) | length (filter isStatic args) > 1
	-> newSATName binder (new_ty args)	`thenSAT` \ binder' ->
           let non_static_args = get_nsa args (snd (getArgLists rhs))
           in
	   newSATNames non_static_args		`thenSAT` \ non_static_args' ->
	   mkNewRhs binder binder' args rhs non_static_args' non_static_args
						`thenSAT` \ new_rhs ->
	   trace ("SAT(STG) "++ show (length (filter isStatic args))) (
           returnSAT (StgNonRec binder new_rhs)
           )
      _ -> returnSAT (StgRec [(binder, rhs)])

  where
    get_nsa :: [Arg a] -> [Arg a] -> [a]
    get_nsa []			_		= []
    get_nsa _			[]		= []
    get_nsa (NotStatic:args)	(Static v:as)	= v:get_nsa args as
    get_nsa (_:args)		(_:as)		=   get_nsa args as

    mkNewRhs binder binder' args rhs@(StgRhsClosure cc bi fvs upd rhsargs body) non_static_args' non_static_args
      = let
	  local_body = StgApp (StgVarAtom binder')
			 [StgVarAtom a | a <- non_static_args] emptyUniqSet

	  rec_body = StgRhsClosure cc bi fvs upd non_static_args'
	               (doStgSubst binder args subst_env body)

	  subst_env = mkIdEnv 
                        ((binder,binder'):zip non_static_args non_static_args')
	in
	returnSAT (
	    StgRhsClosure cc bi fvs upd rhsargs 
	      (StgLet (StgRec [(binder',rec_body)]) {-in-} local_body)
	)

    new_ty args
      = instantiateTy [] (mkSigmaTy [] dict_tys' tau_ty')
      where
	-- get type info for the local function:
	(tv_tmpl, dict_tys, tau_ty) = (splitType . getIdUniType) binder
	(reg_arg_tys, res_type)	    = splitTyArgs tau_ty

	-- now, we drop the ones that are
	-- static, that is, the ones we will not pass to the local function
	l   	     = length dict_tys
	dict_tys'    = dropStatics (take l args) dict_tys
	reg_arg_tys' = dropStatics (drop l args) reg_arg_tys
	tau_ty'	     = glueTyArgs reg_arg_tys' res_type
\end{code}

NOTE: This does not keep live variable/free variable information!!

\begin{code}
doStgSubst binder orig_args subst_env body
  = substExpr body
  where 
    substExpr (StgConApp con args lvs) 
      = StgConApp con (map substAtom args) emptyUniqSet
    substExpr (StgPrimApp op args lvs)
      = StgPrimApp op (map substAtom args) emptyUniqSet
    substExpr expr@(StgApp (StgLitAtom _) [] _) 
      = expr
    substExpr (StgApp atom@(StgVarAtom v)  args lvs)
      | v `eqId` binder
      = StgApp (StgVarAtom (lookupNoFailIdEnv subst_env v))
               (remove_static_args orig_args args) emptyUniqSet
      | otherwise
      = StgApp (substAtom atom) (map substAtom args) lvs
    substExpr (StgCase scrut lv1 lv2 uniq alts)
      = StgCase (substExpr scrut) emptyUniqSet emptyUniqSet uniq (subst_alts alts)
      where
        subst_alts (StgAlgAlts ty alg_alts deflt)
          = StgAlgAlts ty (map subst_alg_alt alg_alts) (subst_deflt deflt)
        subst_alts (StgPrimAlts ty prim_alts deflt)
          = StgPrimAlts ty (map subst_prim_alt prim_alts) (subst_deflt deflt)
        subst_alg_alt (con, args, use_mask, rhs)
          = (con, args, use_mask, substExpr rhs)
        subst_prim_alt (lit, rhs)
          = (lit, substExpr rhs)
        subst_deflt StgNoDefault 
          = StgNoDefault
        subst_deflt (StgBindDefault var used rhs)
          = StgBindDefault var used (substExpr rhs)
    substExpr (StgLetNoEscape fv1 fv2 b body)
      = StgLetNoEscape emptyUniqSet emptyUniqSet (substBinding b) (substExpr body)
    substExpr (StgLet b body)
      = StgLet (substBinding b) (substExpr body)
    substExpr (StgSCC ty cc expr)
      = StgSCC ty cc (substExpr expr)
    substRhs (StgRhsCon cc v args) 
      = StgRhsCon cc v (map substAtom args)
    substRhs (StgRhsClosure cc bi fvs upd args body)
      = StgRhsClosure cc bi [] upd args (substExpr body)
    
    substBinding (StgNonRec binder rhs)
      = StgNonRec binder (substRhs rhs)
    substBinding (StgRec pairs)
      = StgRec (zip binders (map substRhs rhss))
      where
        (binders,rhss) = unzip pairs
    
    substAtom atom@(StgLitAtom lit) = atom
    substAtom atom@(StgVarAtom v) 
      = case lookupIdEnv subst_env v of
          Just v' -> StgVarAtom v'
          Nothing -> atom
    
    remove_static_args _ [] 
      = []
    remove_static_args (Static _:origs) (_:as) 
      = remove_static_args origs as
    remove_static_args (NotStatic:origs) (a:as) 
      = substAtom a:remove_static_args origs as
\end{code}