summaryrefslogtreecommitdiff
path: root/compiler/main/StaticPtrTable.hs
diff options
context:
space:
mode:
Diffstat (limited to 'compiler/main/StaticPtrTable.hs')
-rw-r--r--compiler/main/StaticPtrTable.hs175
1 files changed, 121 insertions, 54 deletions
diff --git a/compiler/main/StaticPtrTable.hs b/compiler/main/StaticPtrTable.hs
index 9ec970f453..694c874711 100644
--- a/compiler/main/StaticPtrTable.hs
+++ b/compiler/main/StaticPtrTable.hs
@@ -46,79 +46,146 @@
--
{-# LANGUAGE ViewPatterns #-}
-module StaticPtrTable (sptModuleInitCode) where
+module StaticPtrTable (sptCreateStaticBinds) where
--- See SimplCore Note [Grand plan for static forms]
+-- See SimplCore Note [Grand plan for static forms] for an overview.
import CLabel
import CoreSyn
+import CoreUtils (collectMakeStaticArgs)
import DataCon
+import DynFlags
+import HscTypes
import Id
-import Literal
import Module
+import Name
import Outputable
+import Platform
import PrelNames
+import Type
+import Control.Monad.Trans.Class (lift)
+import Control.Monad.Trans.State
+import Data.List
import Data.Maybe
import GHC.Fingerprint
--- | @sptModuleInitCode module binds@ is a C stub to insert the static entries
--- found in @binds@ of @module@ into the static pointer table.
+-- | Replaces all bindings of the form
--
--- A bind is considered a static entry if it is an application of the
--- data constructor @StaticPtr@.
+-- > b = /\ ... -> makeStatic info value
--
-sptModuleInitCode :: Module -> CoreProgram -> SDoc
-sptModuleInitCode this_mod binds =
- sptInitCode $ catMaybes
- $ map (\(b, e) -> ((,) b) <$> staticPtrFp e)
- $ flattenBinds binds
+-- with
+--
+-- > b = /\ ... -> StaticPtr key info value
+--
+-- where a distinct key is generated for each binding.
+--
+-- It also yields the C stub that inserts these bindings into the static
+-- pointer table.
+sptCreateStaticBinds :: HscEnv -> Module -> CoreProgram
+ -> IO (SDoc, CoreProgram)
+sptCreateStaticBinds hsc_env this_mod binds = do
+ (fps, binds') <- evalStateT (go [] [] binds) 0
+ return (sptModuleInitCode this_mod fps, binds')
where
- staticPtrFp :: CoreExpr -> Maybe Fingerprint
- staticPtrFp (collectTyBinders -> (_, e))
- | (Var v, _ : Lit lit0 : Lit lit1 : _) <- collectArgs e
- , Just con <- isDataConId_maybe v
- , dataConName con == staticPtrDataConName
- , Just w0 <- fromPlatformWord64Rep lit0
- , Just w1 <- fromPlatformWord64Rep lit1
- = Just $ Fingerprint (fromInteger w0) (fromInteger w1)
- staticPtrFp _ = Nothing
+ go fps bs xs = case xs of
+ [] -> return (reverse fps, reverse bs)
+ bnd : xs' -> do
+ (fps', bnd') <- replaceStaticBind bnd
+ go (reverse fps' ++ fps) (bnd' : bs) xs'
+
+ -- Generates keys and replaces 'makeStatic' with 'StaticPtr'.
+ --
+ -- The 'Int' state is used to produce a different key for each binding.
+ replaceStaticBind :: CoreBind
+ -> StateT Int IO ([(Id, Fingerprint)], CoreBind)
+ replaceStaticBind (NonRec b e) = do (mfp, (b', e')) <- replaceStatic b e
+ return (maybeToList mfp, NonRec b' e')
+ replaceStaticBind (Rec rbs) = do
+ (mfps, rbs') <- unzip <$> mapM (uncurry replaceStatic) rbs
+ return (catMaybes mfps, Rec rbs')
+
+ replaceStatic :: Id -> CoreExpr
+ -> StateT Int IO (Maybe (Id, Fingerprint), (Id, CoreExpr))
+ replaceStatic b e@(collectTyBinders -> (tvs, e0)) =
+ case collectMakeStaticArgs e0 of
+ Nothing -> return (Nothing, (b, e))
+ Just (_, t, info, arg) -> do
+ (fp, e') <- mkStaticBind t info arg
+ return (Just (b, fp), (b, foldr Lam e' tvs))
+
+ mkStaticBind :: Type -> CoreExpr -> CoreExpr
+ -> StateT Int IO (Fingerprint, CoreExpr)
+ mkStaticBind t info e = do
+ i <- get
+ put (i + 1)
+ let fp@(Fingerprint w0 w1) = mkStaticPtrFingerprint i
+ dflags = hsc_dflags hsc_env
- fromPlatformWord64Rep (MachWord w) = Just w
- fromPlatformWord64Rep (MachWord64 w) = Just w
- fromPlatformWord64Rep _ = Nothing
+ staticPtrDataCon <- lift $ lookupDataCon staticPtrDataConName
+ return (fp, mkConApp staticPtrDataCon
+ [ Type t
+ , mkWord64LitWordRep dflags w0
+ , mkWord64LitWordRep dflags w1
+ , info
+ , e ])
- sptInitCode :: [(Id, Fingerprint)] -> SDoc
- sptInitCode [] = Outputable.empty
- sptInitCode entries = vcat
- [ text "static void hs_spt_init_" <> ppr this_mod
- <> text "(void) __attribute__((constructor));"
- , text "static void hs_spt_init_" <> ppr this_mod <> text "(void)"
- , braces $ vcat $
- [ text "static StgWord64 k" <> int i <> text "[2] = "
- <> pprFingerprint fp <> semi
- $$ text "extern StgPtr "
- <> (ppr $ mkClosureLabel (idName n) (idCafInfo n)) <> semi
- $$ text "hs_spt_insert" <> parens
- (hcat $ punctuate comma
- [ char 'k' <> int i
- , char '&' <> ppr (mkClosureLabel (idName n) (idCafInfo n))
- ]
- )
- <> semi
- | (i, (n, fp)) <- zip [0..] entries
- ]
- , text "static void hs_spt_fini_" <> ppr this_mod
- <> text "(void) __attribute__((destructor));"
- , text "static void hs_spt_fini_" <> ppr this_mod <> text "(void)"
- , braces $ vcat $
- [ text "StgWord64 k" <> int i <> text "[2] = "
- <> pprFingerprint fp <> semi
- $$ text "hs_spt_remove" <> parens (char 'k' <> int i) <> semi
- | (i, (_, fp)) <- zip [0..] entries
- ]
- ]
+ mkStaticPtrFingerprint :: Int -> Fingerprint
+ mkStaticPtrFingerprint n = fingerprintString $ intercalate ":"
+ [ unitIdString $ moduleUnitId this_mod
+ , moduleNameString $ moduleName this_mod
+ , show n
+ ]
+ -- Choose either 'Word64#' or 'Word#' to represent the arguments of the
+ -- 'Fingerprint' data constructor.
+ mkWord64LitWordRep dflags
+ | platformWordSize (targetPlatform dflags) < 8 = mkWord64LitWord64
+ | otherwise = mkWordLit dflags . toInteger
+
+ lookupDataCon :: Name -> IO DataCon
+ lookupDataCon n = lookupTypeHscEnv hsc_env n >>=
+ maybe (getError n) (return . tyThingDataCon)
+
+ getError n = pprPanic "sptCreateStaticBinds.get: not found" $
+ text "Couldn't find" <+> ppr n
+
+-- | @sptModuleInitCode module fps@ is a C stub to insert the static entries
+-- of @module@ into the static pointer table.
+--
+-- @fps@ is a list associating each binding corresponding to a static entry with
+-- its fingerprint.
+sptModuleInitCode :: Module -> [(Id, Fingerprint)] -> SDoc
+sptModuleInitCode _ [] = Outputable.empty
+sptModuleInitCode this_mod entries = vcat
+ [ text "static void hs_spt_init_" <> ppr this_mod
+ <> text "(void) __attribute__((constructor));"
+ , text "static void hs_spt_init_" <> ppr this_mod <> text "(void)"
+ , braces $ vcat $
+ [ text "static StgWord64 k" <> int i <> text "[2] = "
+ <> pprFingerprint fp <> semi
+ $$ text "extern StgPtr "
+ <> (ppr $ mkClosureLabel (idName n) (idCafInfo n)) <> semi
+ $$ text "hs_spt_insert" <> parens
+ (hcat $ punctuate comma
+ [ char 'k' <> int i
+ , char '&' <> ppr (mkClosureLabel (idName n) (idCafInfo n))
+ ]
+ )
+ <> semi
+ | (i, (n, fp)) <- zip [0..] entries
+ ]
+ , text "static void hs_spt_fini_" <> ppr this_mod
+ <> text "(void) __attribute__((destructor));"
+ , text "static void hs_spt_fini_" <> ppr this_mod <> text "(void)"
+ , braces $ vcat $
+ [ text "StgWord64 k" <> int i <> text "[2] = "
+ <> pprFingerprint fp <> semi
+ $$ text "hs_spt_remove" <> parens (char 'k' <> int i) <> semi
+ | (i, (_, fp)) <- zip [0..] entries
+ ]
+ ]
+ where
pprFingerprint :: Fingerprint -> SDoc
pprFingerprint (Fingerprint w1 w2) =
braces $ hcat $ punctuate comma