diff options
author | Richard Eisenberg <rae@richarde.dev> | 2019-06-04 14:31:08 -0400 |
---|---|---|
committer | Ben Gamari <ben@smart-cactus.org> | 2019-06-25 14:37:38 -0400 |
commit | 7ffe0681d44d002af357dd97f81590c804abb324 (patch) | |
tree | ddc47e08f844f71bbdd5b192d9e2389106e6d1a3 | |
parent | f9da172cde69071121dc8698bc1488b454ab32b2 (diff) | |
download | haskell-7ffe0681d44d002af357dd97f81590c804abb324.tar.gz |
GHCi support for levity-polymorphic join points
Fixes #16509.
See Note [Levity-polymorphic join points] in ByteCodeGen,
which tells the full story.
This commit also adds some comments and cleans some code
in the byte-code generator, as I was exploring around trying
to understand it.
test case: ghci/scripts/T16509
(cherry picked from commit 392210bf8a27b3604f8642d76c39e391c2d4b5e0)
-rw-r--r-- | compiler/ghci/ByteCodeAsm.hs | 6 | ||||
-rw-r--r-- | compiler/ghci/ByteCodeGen.hs | 116 | ||||
-rw-r--r-- | compiler/ghci/ByteCodeInstr.hs | 10 | ||||
-rw-r--r-- | compiler/simplStg/RepType.hs | 2 | ||||
-rw-r--r-- | testsuite/tests/ghci/scripts/T16509.hs | 11 | ||||
-rw-r--r-- | testsuite/tests/ghci/scripts/T16509.script | 1 | ||||
-rwxr-xr-x | testsuite/tests/ghci/scripts/all.T | 1 |
7 files changed, 114 insertions, 33 deletions
diff --git a/compiler/ghci/ByteCodeAsm.hs b/compiler/ghci/ByteCodeAsm.hs index 0776e406d6..e3c18b93a2 100644 --- a/compiler/ghci/ByteCodeAsm.hs +++ b/compiler/ghci/ByteCodeAsm.hs @@ -156,7 +156,11 @@ assembleOneBCO hsc_env pbco = do return ubco' assembleBCO :: DynFlags -> ProtoBCO Name -> IO UnlinkedBCO -assembleBCO dflags (ProtoBCO nm instrs bitmap bsize arity _origin _malloced) = do +assembleBCO dflags (ProtoBCO { protoBCOName = nm + , protoBCOInstrs = instrs + , protoBCOBitmap = bitmap + , protoBCOBitmapSize = bsize + , protoBCOArity = arity }) = do -- pass 1: collect up the offsets of the local labels. let asm = mapM_ (assembleI dflags) instrs diff --git a/compiler/ghci/ByteCodeGen.hs b/compiler/ghci/ByteCodeGen.hs index 113690780b..0f5d6496dc 100644 --- a/compiler/ghci/ByteCodeGen.hs +++ b/compiler/ghci/ByteCodeGen.hs @@ -26,6 +26,7 @@ import Platform import Name import MkId import Id +import Var ( updateVarType ) import ForeignCall import HscTypes import CoreUtils @@ -61,7 +62,6 @@ import Data.Char import UniqSupply import Module -import Control.Arrow ( second ) import Control.Exception import Data.Array @@ -90,7 +90,7 @@ byteCodeGen hsc_env this_mod binds tycs mb_modBreaks (const ()) $ do -- Split top-level binds into strings and others. -- See Note [generating code for top-level string literal bindings]. - let (strings, flatBinds) = partitionEithers $ do + let (strings, flatBinds) = partitionEithers $ do -- list monad (bndr, rhs) <- flattenBinds binds return $ case exprIsTickedString_maybe rhs of Just str -> Left (bndr, str) @@ -181,29 +181,13 @@ coreExprToBCOs hsc_env this_mod expr where dflags = hsc_dflags hsc_env -- The regular freeVars function gives more information than is useful to --- us here. simpleFreeVars does the impedance matching. +-- us here. We need only the free variables, not everything in an FVAnn. +-- Historical note: At one point FVAnn was more sophisticated than just +-- a set. Now it isn't. So this function is much simpler. Keeping it around +-- so that if someone changes FVAnn, they will get a nice type error right +-- here. simpleFreeVars :: CoreExpr -> AnnExpr Id DVarSet -simpleFreeVars = go . freeVars - where - go :: AnnExpr Id FVAnn -> AnnExpr Id DVarSet - go (ann, e) = (freeVarsOfAnn ann, go' e) - - go' :: AnnExpr' Id FVAnn -> AnnExpr' Id DVarSet - go' (AnnVar id) = AnnVar id - go' (AnnLit lit) = AnnLit lit - go' (AnnLam bndr body) = AnnLam bndr (go body) - go' (AnnApp fun arg) = AnnApp (go fun) (go arg) - go' (AnnCase scrut bndr ty alts) = AnnCase (go scrut) bndr ty (map go_alt alts) - go' (AnnLet bind body) = AnnLet (go_bind bind) (go body) - go' (AnnCast expr (ann, co)) = AnnCast (go expr) (freeVarsOfAnn ann, co) - go' (AnnTick tick body) = AnnTick tick (go body) - go' (AnnType ty) = AnnType ty - go' (AnnCoercion co) = AnnCoercion co - - go_alt (con, args, expr) = (con, args, go expr) - - go_bind (AnnNonRec bndr rhs) = AnnNonRec bndr (go rhs) - go_bind (AnnRec pairs) = AnnRec (map (second go) pairs) +simpleFreeVars = freeVars -- ----------------------------------------------------------------------------- -- Compilation schema for the bytecode generator @@ -256,6 +240,7 @@ mkProtoBCO -> name -> BCInstrList -> Either [AnnAlt Id DVarSet] (AnnExpr Id DVarSet) + -- ^ original expression; for debugging only -> Int -> Word16 -> [StgWord] @@ -368,6 +353,9 @@ schemeR fvs (nm, rhs) -} = schemeR_wrk fvs nm rhs (collect rhs) +-- If an expression is a lambda (after apply bcView), return the +-- list of arguments to the lambda (in R-to-L order) and the +-- underlying expression collect :: AnnExpr Id DVarSet -> ([Var], AnnExpr' Id DVarSet) collect (_, e) = go [] e where @@ -382,8 +370,8 @@ collect (_, e) = go [] e schemeR_wrk :: [Id] -> Id - -> AnnExpr Id DVarSet - -> ([Var], AnnExpr' Var DVarSet) + -> AnnExpr Id DVarSet -- expression e, for debugging only + -> ([Var], AnnExpr' Var DVarSet) -- result of collect on e -> BcM (ProtoBCO Name) schemeR_wrk fvs nm original_body (args, body) = do @@ -508,8 +496,16 @@ schemeE d s p e@(AnnLit lit) = returnUnboxedAtom d s p e (typeArgRep (litera schemeE d s p e@(AnnCoercion {}) = returnUnboxedAtom d s p e V schemeE d s p e@(AnnVar v) + -- See Note [Levity-polymorphic join points], step 3. + | isLPJoinPoint v = schemeT d s p $ + AnnApp (bogus_fvs, AnnVar (protectLPJoinPointId v)) + (bogus_fvs, AnnVar voidPrimId) + -- schemeT will call splitApp, dropping the fvs. + | isUnliftedType (idType v) = returnUnboxedAtom d s p e (bcIdArgRep v) | otherwise = schemeT d s p e + where + bogus_fvs = pprPanic "schemeE bogus_fvs" (ppr v) schemeE d s p (AnnLet (AnnNonRec x (_,rhs)) (_,body)) | (AnnVar v, args_r_to_l) <- splitApp rhs, @@ -534,19 +530,22 @@ schemeE d s p (AnnLet binds (_,body)) = do fvss = map (fvsToEnv p' . fst) rhss + -- See Note [Levity-polymorphic join points], step 2. + (xs',rhss') = zipWithAndUnzip protectLPJoinPointBind xs rhss + -- Sizes of free vars size_w = trunc16W . idSizeW dflags sizes = map (\rhs_fvs -> sum (map size_w rhs_fvs)) fvss -- the arity of each rhs - arities = map (genericLength . fst . collect) rhss + arities = map (genericLength . fst . collect) rhss' -- This p', d' defn is safe because all the items being pushed -- are ptrs, so all have size 1 word. d' and p' reflect the stack -- after the closures have been allocated in the heap (but not -- filled in), and pointers to them parked on the stack. offsets = mkStackOffsets d (genericReplicate n_binds (wordSize dflags)) - p' = Map.insertList (zipE xs offsets) p + p' = Map.insertList (zipE xs' offsets) p d' = d + wordsToBytes dflags n_binds zipE = zipEqual "schemeE" @@ -587,7 +586,7 @@ schemeE d s p (AnnLet binds (_,body)) = do compile_binds = [ compile_bind d' fvs x rhs size arity (trunc16W n) | (fvs, x, rhs, size, arity, n) <- - zip6 fvss xs rhss sizes arities [n_binds, n_binds-1 .. 1] + zip6 fvss xs' rhss' sizes arities [n_binds, n_binds-1 .. 1] ] body_code <- schemeE d' s p' body thunk_codes <- sequence compile_binds @@ -681,6 +680,30 @@ schemeE _ _ _ expr = pprPanic "ByteCodeGen.schemeE: unhandled case" (pprCoreExpr (deAnnotate' expr)) +-- Is this Id a levity-polymorphic join point? +-- See Note [Levity-polymorphic join points], step 1 +isLPJoinPoint :: Id -> Bool +isLPJoinPoint x = isJoinId x && + isNothing (isLiftedType_maybe (idType x)) + +-- If necessary, modify this Id and body to protect levity-polymorphic join points. +-- See Note [Levity-polymorphic join points], step 2. +protectLPJoinPointBind :: Id -> AnnExpr Id DVarSet -> (Id, AnnExpr Id DVarSet) +protectLPJoinPointBind x rhs@(fvs, _) + | isLPJoinPoint x + = (protectLPJoinPointId x, (fvs, AnnLam voidArgId rhs)) + + | otherwise + = (x, rhs) + +-- Update an Id's type to take a Void# argument. +-- Precondition: the Id is a levity-polymorphic join point. +-- See Note [Levity-polymorphic join points] +protectLPJoinPointId :: Id -> Id +protectLPJoinPointId x + = ASSERT( isLPJoinPoint x ) + updateVarType (voidPrimTy `mkFunTy`) x + {- Ticked Expressions ------------------ @@ -689,6 +712,41 @@ schemeE _ _ _ expr the code. When we find such a thing, we pull out the useful information, and then compile the code as if it was just the expression E. +Note [Levity-polymorphic join points] +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +A join point variable is essentially a goto-label: it is, for example, +never used as an argument to another function, and it is called only +in tail position. See Note [Join points] and Note [Invariants on join points], +both in CoreSyn. Because join points do not compile to true, red-blooded +variables (with, e.g., registers allocated to them), they are allowed +to be levity-polymorphic. (See invariant #6 in Note [Invariants on join points] +in CoreSyn.) + +However, in this byte-code generator, join points *are* treated just as +ordinary variables. There is no check whether a binding is for a join point +or not; they are all treated uniformly. (Perhaps there is a missed optimization +opportunity here, but that is beyond the scope of my (Richard E's) Thursday.) + +We thus must have *some* strategy for dealing with levity-polymorphic join +points (LPJPs), because we cannot have a levity-polymorphic variable. +(Not having such a strategy led to #16509, which panicked in the isUnliftedType +check in the AnnVar case of schemeE.) Here is the strategy: + +1. Detect LPJPs. This is done in isLPJoinPoint. + +2. When binding an LPJP, add a `\ (_ :: Void#) ->` to its RHS, and modify the + type to tack on a `Void# ->`. (Void# is written voidPrimTy within GHC.) + Note that functions are never levity-polymorphic, so this transformation + changes an LPJP to a non-levity-polymorphic join point. This is done + in protectLPJoinPointBind, called from the AnnLet case of schemeE. + +3. At an occurrence of an LPJP, add an application to void# (called voidPrimId), + being careful to note the new type of the LPJP. This is done in the AnnVar + case of schemeE, with help from protectLPJoinPointId. + +It's a bit hacky, but it works well in practice and is local. I suspect the +Right Fix is to take advantage of join points as goto-labels. + -} -- Compile code to do a tail call. Specifically, push the fn, diff --git a/compiler/ghci/ByteCodeInstr.hs b/compiler/ghci/ByteCodeInstr.hs index 07dcd2222a..d405e1ade7 100644 --- a/compiler/ghci/ByteCodeInstr.hs +++ b/compiler/ghci/ByteCodeInstr.hs @@ -45,7 +45,7 @@ data ProtoBCO a protoBCOBitmap :: [StgWord], protoBCOBitmapSize :: Word16, protoBCOArity :: Int, - -- what the BCO came from + -- what the BCO came from, for debugging only protoBCOExpr :: Either [AnnAlt Id DVarSet] (AnnExpr Id DVarSet), -- malloc'd pointers protoBCOFFIs :: [FFIInfo] @@ -179,7 +179,13 @@ data BCInstr -- Printing bytecode instructions instance Outputable a => Outputable (ProtoBCO a) where - ppr (ProtoBCO name instrs bitmap bsize arity origin ffis) + ppr (ProtoBCO { protoBCOName = name + , protoBCOInstrs = instrs + , protoBCOBitmap = bitmap + , protoBCOBitmapSize = bsize + , protoBCOArity = arity + , protoBCOExpr = origin + , protoBCOFFIs = ffis }) = (text "ProtoBCO" <+> ppr name <> char '#' <> int arity <+> text (show ffis) <> colon) $$ nest 3 (case origin of diff --git a/compiler/simplStg/RepType.hs b/compiler/simplStg/RepType.hs index 4d437d3b7c..522eeb1ab3 100644 --- a/compiler/simplStg/RepType.hs +++ b/compiler/simplStg/RepType.hs @@ -64,7 +64,7 @@ isNvUnaryType ty = False -- INVARIANT: the result list is never empty. -typePrimRepArgs :: Type -> [PrimRep] +typePrimRepArgs :: HasDebugCallStack => Type -> [PrimRep] typePrimRepArgs ty | [] <- reps = [VoidRep] diff --git a/testsuite/tests/ghci/scripts/T16509.hs b/testsuite/tests/ghci/scripts/T16509.hs new file mode 100644 index 0000000000..6f35e3c792 --- /dev/null +++ b/testsuite/tests/ghci/scripts/T16509.hs @@ -0,0 +1,11 @@ +{-# LANGUAGE PatternSynonyms #-} +{-# LANGUAGE ViewPatterns #-} + +module PatternPanic where + +pattern TestPat :: (Int, Int) +pattern TestPat <- (isSameRef -> True, 0) + +isSameRef :: Int -> Bool +isSameRef e | 0 <- e = True +isSameRef _ = False diff --git a/testsuite/tests/ghci/scripts/T16509.script b/testsuite/tests/ghci/scripts/T16509.script new file mode 100644 index 0000000000..3e40de0b91 --- /dev/null +++ b/testsuite/tests/ghci/scripts/T16509.script @@ -0,0 +1 @@ +:l T16509 diff --git a/testsuite/tests/ghci/scripts/all.T b/testsuite/tests/ghci/scripts/all.T index 5162a3c220..b6772d4c37 100755 --- a/testsuite/tests/ghci/scripts/all.T +++ b/testsuite/tests/ghci/scripts/all.T @@ -295,5 +295,6 @@ test('T15941', normal, ghci_script, ['T15941.script']) test('T16030', normal, ghci_script, ['T16030.script']) test('T11606', normal, ghci_script, ['T11606.script']) test('T16089', normal, ghci_script, ['T16089.script']) +test('T16509', normal, ghci_script, ['T16509.script']) test('T16527', normal, ghci_script, ['T16527.script']) test('T16767', normal, ghci_script, ['T16767.script']) |