summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorCheng Shao <terrorjack@type.dance>2022-12-01 12:56:23 +0000
committerCheng Shao <terrorjack@type.dance>2022-12-16 21:16:28 +0000
commit8a81d9d933089b6ed72478342a0070d7c8f82ff8 (patch)
tree441142f702b4655515757472f72cd46b5a9e08c6
parent1c6930bf59223b6a70ca6045e2bbd4a4fb297b76 (diff)
downloadhaskell-8a81d9d933089b6ed72478342a0070d7c8f82ff8.tar.gz
compiler: add optional tail-call support in wasm NCG
When the `-mtail-call` clang flag is passed at configure time, wasm tail-call extension is enabled, and the wasm NCG will emit `return_call`/`return_call_indirect` instructions to take advantage of it and avoid the `StgRun` trampoline overhead. Closes #22461.
-rw-r--r--compiler/GHC/CmmToAsm.hs7
-rw-r--r--compiler/GHC/CmmToAsm/Wasm.hs15
-rw-r--r--compiler/GHC/CmmToAsm/Wasm/Asm.hs57
-rw-r--r--compiler/GHC/CmmToAsm/Wasm/Types.hs27
-rw-r--r--compiler/GHC/Driver/CodeOutput.hs3
-rw-r--r--compiler/GHC/Wasm/ControlFlow.hs8
-rw-r--r--compiler/GHC/Wasm/ControlFlow/FromCmm.hs2
7 files changed, 87 insertions, 32 deletions
diff --git a/compiler/GHC/CmmToAsm.hs b/compiler/GHC/CmmToAsm.hs
index f4eb61c449..c0ac96fa79 100644
--- a/compiler/GHC/CmmToAsm.hs
+++ b/compiler/GHC/CmmToAsm.hs
@@ -136,6 +136,7 @@ import GHC.Types.Unique.Set
import GHC.Unit
import GHC.Data.Stream (Stream)
import qualified GHC.Data.Stream as Stream
+import GHC.Settings
import Data.List (sortBy)
import Data.List.NonEmpty (groupAllWith, head)
@@ -146,10 +147,10 @@ import System.IO
import System.Directory ( getCurrentDirectory )
--------------------
-nativeCodeGen :: forall a . Logger -> NCGConfig -> ModLocation -> Handle -> UniqSupply
+nativeCodeGen :: forall a . Logger -> ToolSettings -> NCGConfig -> ModLocation -> Handle -> UniqSupply
-> Stream IO RawCmmGroup a
-> IO a
-nativeCodeGen logger config modLoc h us cmms
+nativeCodeGen logger ts config modLoc h us cmms
= let platform = ncgPlatform config
nCG' :: ( OutputableP Platform statics, Outputable jumpDest, Instruction instr)
=> NcgImpl statics instr jumpDest -> IO a
@@ -169,7 +170,7 @@ nativeCodeGen logger config modLoc h us cmms
ArchLoongArch64->panic "nativeCodeGen: No NCG for LoongArch64"
ArchUnknown -> panic "nativeCodeGen: No NCG for unknown arch"
ArchJavaScript-> panic "nativeCodeGen: No NCG for JavaScript"
- ArchWasm32 -> Wasm32.ncgWasm platform us modLoc h cmms
+ ArchWasm32 -> Wasm32.ncgWasm platform ts us modLoc h cmms
-- | Data accumulated during code generation. Mostly about statistics,
-- but also collects debug data for DWARF generation.
diff --git a/compiler/GHC/CmmToAsm/Wasm.hs b/compiler/GHC/CmmToAsm/Wasm.hs
index 6ea3244db4..ed2d4eb2dd 100644
--- a/compiler/GHC/CmmToAsm/Wasm.hs
+++ b/compiler/GHC/CmmToAsm/Wasm.hs
@@ -14,22 +14,28 @@ import GHC.CmmToAsm.Wasm.Types
import GHC.Data.Stream (Stream, StreamS (..), runStream)
import GHC.Platform
import GHC.Prelude
+import GHC.Settings
import GHC.Types.Unique.Supply
import GHC.Unit
+import GHC.Utils.CliOption
import System.IO
ncgWasm ::
Platform ->
+ ToolSettings ->
UniqSupply ->
ModLocation ->
Handle ->
Stream IO RawCmmGroup a ->
IO a
-ncgWasm platform us loc h cmms = do
+ncgWasm platform ts us loc h cmms = do
(r, s) <- streamCmmGroups platform us cmms
hPutBuilder h $ "# " <> string7 (fromJust $ ml_hs_file loc) <> "\n\n"
- hPutBuilder h $ execWasmAsmM $ asmTellEverything TagI32 s
+ hPutBuilder h $ execWasmAsmM do_tail_call $ asmTellEverything TagI32 s
pure r
+ where
+ -- See Note [WasmTailCall]
+ do_tail_call = doTailCall ts
streamCmmGroups ::
Platform ->
@@ -43,3 +49,8 @@ streamCmmGroups platform us cmms =
go s (Done r) = pure (r, s)
go s (Effect m) = m >>= go s
go s (Yield cmm k) = go (wasmExecM (onCmmGroup cmm) s) k
+
+doTailCall :: ToolSettings -> Bool
+doTailCall ts = Option "-mtail-call" `elem` as_args
+ where
+ (_, as_args) = toolSettings_pgm_a ts
diff --git a/compiler/GHC/CmmToAsm/Wasm/Asm.hs b/compiler/GHC/CmmToAsm/Wasm/Asm.hs
index feb56371ce..5ccce28676 100644
--- a/compiler/GHC/CmmToAsm/Wasm/Asm.hs
+++ b/compiler/GHC/CmmToAsm/Wasm/Asm.hs
@@ -1,5 +1,6 @@
{-# LANGUAGE DerivingVia #-}
{-# LANGUAGE GADTs #-}
+{-# LANGUAGE MultiWayIf #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE RecordWildCards #-}
{-# LANGUAGE Strict #-}
@@ -32,13 +33,13 @@ import GHC.Utils.Outputable hiding ((<>))
import GHC.Utils.Panic (panic)
-- | Reads current indentation, appends result to state
-newtype WasmAsmM a = WasmAsmM (Builder -> State Builder a)
+newtype WasmAsmM a = WasmAsmM (Bool -> Builder -> State Builder a)
deriving
( Functor,
Applicative,
Monad
)
- via (ReaderT Builder (State Builder))
+ via (ReaderT Bool (ReaderT Builder (State Builder)))
instance Semigroup a => Semigroup (WasmAsmM a) where
(<>) = liftA2 (<>)
@@ -46,27 +47,33 @@ instance Semigroup a => Semigroup (WasmAsmM a) where
instance Monoid a => Monoid (WasmAsmM a) where
mempty = pure mempty
+-- | To tail call or not, that is the question
+doTailCall :: WasmAsmM Bool
+doTailCall = WasmAsmM $ \do_tail_call _ -> pure do_tail_call
+
-- | Default indent level is none
-execWasmAsmM :: WasmAsmM a -> Builder
-execWasmAsmM (WasmAsmM m) = execState (m mempty) mempty
+execWasmAsmM :: Bool -> WasmAsmM a -> Builder
+execWasmAsmM do_tail_call (WasmAsmM m) =
+ execState (m do_tail_call mempty) mempty
-- | Increase indent level by a tab
asmWithTab :: WasmAsmM a -> WasmAsmM a
-asmWithTab (WasmAsmM m) = WasmAsmM $ \t -> m $! char7 '\t' <> t
+asmWithTab (WasmAsmM m) =
+ WasmAsmM $ \do_tail_call t -> m do_tail_call $! char7 '\t' <> t
-- | Writes a single line starting with the current indent
asmTellLine :: Builder -> WasmAsmM ()
-asmTellLine b = WasmAsmM $ \t -> modify $ \acc -> acc <> t <> b <> char7 '\n'
+asmTellLine b = WasmAsmM $ \_ t -> modify $ \acc -> acc <> t <> b <> char7 '\n'
-- | Writes a single line break
asmTellLF :: WasmAsmM ()
-asmTellLF = WasmAsmM $ \_ -> modify $ \acc -> acc <> char7 '\n'
+asmTellLF = WasmAsmM $ \_ _ -> modify $ \acc -> acc <> char7 '\n'
-- | Writes a line starting with a single tab, ignoring current indent
-- level
asmTellTabLine :: Builder -> WasmAsmM ()
asmTellTabLine b =
- WasmAsmM $ \_ -> modify $ \acc -> acc <> char7 '\t' <> b <> char7 '\n'
+ WasmAsmM $ \_ _ -> modify $ \acc -> acc <> char7 '\t' <> b <> char7 '\n'
asmFromWasmType :: WasmTypeTag t -> Builder
asmFromWasmType ty = case ty of
@@ -386,7 +393,25 @@ asmTellWasmControl ty_word c = case c of
WasmBrTable (WasmExpr e) _ ts t -> do
asmTellWasmInstr ty_word e
asmTellLine $ "br_table {" <> builderCommas intDec (ts <> [t]) <> "}"
- WasmReturnTop _ -> asmTellLine "return"
+ -- See Note [WasmTailCall]
+ WasmTailCall (WasmExpr e) -> do
+ do_tail_call <- doTailCall
+ if
+ | do_tail_call,
+ WasmSymConst sym <- e ->
+ asmTellLine $ "return_call " <> asmFromSymName sym
+ | do_tail_call ->
+ do
+ asmTellWasmInstr ty_word e
+ asmTellLine $
+ "return_call_indirect "
+ <> asmFromFuncType
+ []
+ [SomeWasmType ty_word]
+ | otherwise ->
+ do
+ asmTellWasmInstr ty_word e
+ asmTellLine "return"
WasmActions (WasmStatements a) -> asmTellWasmInstr ty_word a
WasmSeq c0 c1 -> do
asmTellWasmControl ty_word c0
@@ -465,18 +490,20 @@ asmTellProducers = do
asmTellTargetFeatures :: WasmAsmM ()
asmTellTargetFeatures = do
+ do_tail_call <- doTailCall
asmTellSectionHeader ".custom_section.target_features"
asmTellVec
[ do
asmTellTabLine ".int8 0x2b"
asmTellBS feature
| feature <-
- [ "bulk-memory",
- "mutable-globals",
- "nontrapping-fptoint",
- "reference-types",
- "sign-ext"
- ]
+ ["tail-call" | do_tail_call]
+ <> [ "bulk-memory",
+ "mutable-globals",
+ "nontrapping-fptoint",
+ "reference-types",
+ "sign-ext"
+ ]
]
asmTellEverything :: WasmTypeTag w -> WasmCodeGenState w -> WasmAsmM ()
diff --git a/compiler/GHC/CmmToAsm/Wasm/Types.hs b/compiler/GHC/CmmToAsm/Wasm/Types.hs
index 06d2c246e6..fa3287f0ec 100644
--- a/compiler/GHC/CmmToAsm/Wasm/Types.hs
+++ b/compiler/GHC/CmmToAsm/Wasm/Types.hs
@@ -352,9 +352,30 @@ data WasmControl :: Type -> Type -> [WasmType] -> [WasmType] -> Type where
WasmControl s e dropped destination
-- invariant: the table interval is contained
-- within [0 .. pred (length targets)]
- WasmReturnTop ::
- WasmTypeTag t ->
- WasmControl s e (t : t1star) t2star -- as per type system
+
+ -- Note [WasmTailCall]
+ -- ~~~~~~~~~~~~~~~~~~~
+ -- This represents the exit point of each CmmGraph: tail calling the
+ -- destination in CmmCall. The STG stack may grow before the call,
+ -- but it's always a tail call in the sense that the C call stack is
+ -- guaranteed not to grow.
+ --
+ -- In the wasm backend, WasmTailCall is lowered to different
+ -- assembly code given whether the wasm tail-call extension is
+ -- enabled:
+ --
+ -- When tail-call is not enabled (which is the default as of today),
+ -- a WasmTailCall is lowered to code that pushes the callee function
+ -- pointer onto the value stack and returns immediately. The actual
+ -- call is done by the trampoline in StgRun.
+ --
+ -- When tail-call is indeed enabled via passing -mtail-call in
+ -- CONF_CC_OPTS_STAGE2 at configure time, a WasmTailCall is lowered
+ -- to return_call/return_call_indirect, thus tail calling into its
+ -- callee without returning to StgRun.
+ WasmTailCall ::
+ e ->
+ WasmControl s e t1star t2star -- as per type system
WasmActions ::
s ->
WasmControl s e stack stack -- basic block: one entry, one exit
diff --git a/compiler/GHC/Driver/CodeOutput.hs b/compiler/GHC/Driver/CodeOutput.hs
index 934d958120..c5c0534d20 100644
--- a/compiler/GHC/Driver/CodeOutput.hs
+++ b/compiler/GHC/Driver/CodeOutput.hs
@@ -201,7 +201,7 @@ outputAsm logger dflags this_mod location filenm cmm_stream = do
let ncg_config = initNCGConfig dflags this_mod
{-# SCC "OutputAsm" #-} doOutput filenm $
\h -> {-# SCC "NativeCodeGen" #-}
- nativeCodeGen logger ncg_config location h ncg_uniqs cmm_stream
+ nativeCodeGen logger (toolSettings dflags) ncg_config location h ncg_uniqs cmm_stream
{-
************************************************************************
@@ -397,4 +397,3 @@ ipInitCode do_info_table platform this_mod
ipe_buffer_decl =
text "extern IpeBufferListNode" <+> ipe_buffer_label <> text ";"
-
diff --git a/compiler/GHC/Wasm/ControlFlow.hs b/compiler/GHC/Wasm/ControlFlow.hs
index 97c703597e..365a003323 100644
--- a/compiler/GHC/Wasm/ControlFlow.hs
+++ b/compiler/GHC/Wasm/ControlFlow.hs
@@ -1,10 +1,10 @@
{-# LANGUAGE ScopedTypeVariables #-}
-{-# LANGUAGE DataKinds, GADTs, RankNTypes, TypeOperators, KindSignatures #-}
+{-# LANGUAGE DataKinds, GADTs, RankNTypes, KindSignatures #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE PatternSynonyms #-}
module GHC.Wasm.ControlFlow
- ( WasmControl(..), (<>), pattern WasmIf, wasmReturn
+ ( WasmControl(..), (<>), pattern WasmIf
, BrTableInterval(..), inclusiveInterval
, WasmType, WasmTypeTag(..)
@@ -47,7 +47,3 @@ pattern WasmIf :: WasmFunctionType pre post
pattern WasmIf ty e t f =
WasmPush TagI32 e `WasmSeq` WasmIfTop ty t f
-
--- More syntactic sugar.
-wasmReturn :: WasmTypeTag t -> e -> WasmControl s e (t ': t1star) t2star
-wasmReturn tag e = WasmPush tag e `WasmSeq` WasmReturnTop tag
diff --git a/compiler/GHC/Wasm/ControlFlow/FromCmm.hs b/compiler/GHC/Wasm/ControlFlow/FromCmm.hs
index 8235b59ed6..0667345162 100644
--- a/compiler/GHC/Wasm/ControlFlow/FromCmm.hs
+++ b/compiler/GHC/Wasm/ControlFlow/FromCmm.hs
@@ -198,7 +198,7 @@ structuredControl platform txExpr txBlock g =
<$> txExpr xlabel e
<*> doBranch fty xlabel t (IfThenElse maybeMarks `inside` context)
<*> doBranch fty xlabel f (IfThenElse maybeMarks `inside` context)
- TailCall e -> (WasmPush TagI32 <$> txExpr xlabel e) <<>> pure (WasmReturnTop TagI32)
+ TailCall e -> WasmTailCall <$> txExpr xlabel e
Switch e range targets default' ->
WasmBrTable <$> txExpr xlabel e
<$~> range