summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorNobuaki Sukegawa <nsuke@apache.org>2016-01-24 04:01:27 +0900
committerJens Geyer <jensg@apache.org>2017-12-03 17:45:33 +0100
commitfc0ff81ee7d4aa95a041c826dd5a83239ef98780 (patch)
tree51cca76f2a8f1d23de9f09d585f9fa6923ae51ee
parent1310dc1eb4014457a667a7287d1fa113432c7a54 (diff)
downloadthrift-fc0ff81ee7d4aa95a041c826dd5a83239ef98780.tar.gz
THRIFT-3580 THeader for Haskell
Client: hs This closes #820 This closes #1423
-rw-r--r--compiler/cpp/src/thrift/generate/t_hs_generator.cc64
-rw-r--r--lib/hs/src/Thrift.hs4
-rw-r--r--lib/hs/src/Thrift/Protocol.hs41
-rw-r--r--lib/hs/src/Thrift/Protocol/Binary.hs59
-rw-r--r--lib/hs/src/Thrift/Protocol/Compact.hs62
-rw-r--r--lib/hs/src/Thrift/Protocol/Header.hs141
-rw-r--r--lib/hs/src/Thrift/Protocol/JSON.hs58
-rw-r--r--lib/hs/src/Thrift/Server.hs6
-rw-r--r--lib/hs/src/Thrift/Transport/Handle.hs14
-rw-r--r--lib/hs/src/Thrift/Transport/Header.hs354
-rw-r--r--lib/hs/thrift.cabal2
-rw-r--r--test/hs/TestClient.hs6
-rw-r--r--test/hs/TestServer.hs15
-rw-r--r--test/known_failures_Linux.json4
-rw-r--r--test/tests.json1
15 files changed, 690 insertions, 141 deletions
diff --git a/compiler/cpp/src/thrift/generate/t_hs_generator.cc b/compiler/cpp/src/thrift/generate/t_hs_generator.cc
index 30eb8fa9a..d0a8cb2d6 100644
--- a/compiler/cpp/src/thrift/generate/t_hs_generator.cc
+++ b/compiler/cpp/src/thrift/generate/t_hs_generator.cc
@@ -711,13 +711,13 @@ void t_hs_generator::generate_hs_struct_reader(ofstream& out, t_struct* tstruct)
string tmap = type_name(tstruct, "typemap_");
indent(out) << "to_" << sname << " _ = P.error \"not a struct\"" << endl;
- indent(out) << "read_" << sname << " :: (T.Transport t, T.Protocol p) => p t -> P.IO " << sname
+ indent(out) << "read_" << sname << " :: T.Protocol p => p -> P.IO " << sname
<< endl;
indent(out) << "read_" << sname << " iprot = to_" << sname;
out << " <$> T.readVal iprot (T.T_STRUCT " << tmap << ")" << endl;
indent(out) << "decode_" << sname
- << " :: (T.Protocol p, T.Transport t) => p t -> LBS.ByteString -> " << sname << endl;
+ << " :: T.StatelessProtocol p => p -> LBS.ByteString -> " << sname << endl;
indent(out) << "decode_" << sname << " iprot bs = to_" << sname << " $ ";
out << "T.deserializeVal iprot (T.T_STRUCT " << tmap << ") bs" << endl;
}
@@ -818,13 +818,13 @@ void t_hs_generator::generate_hs_struct_writer(ofstream& out, t_struct* tstruct)
indent_down();
// write
- indent(out) << "write_" << name << " :: (T.Protocol p, T.Transport t) => p t -> " << name
+ indent(out) << "write_" << name << " :: T.Protocol p => p -> " << name
<< " -> P.IO ()" << endl;
indent(out) << "write_" << name << " oprot record = T.writeVal oprot $ from_";
out << name << " record" << endl;
// encode
- indent(out) << "encode_" << name << " :: (T.Protocol p, T.Transport t) => p t -> " << name
+ indent(out) << "encode_" << name << " :: T.StatelessProtocol p => p -> " << name
<< " -> LBS.ByteString" << endl;
indent(out) << "encode_" << name << " oprot record = T.serializeVal oprot $ ";
out << "from_" << name << " record" << endl;
@@ -1085,8 +1085,9 @@ void t_hs_generator::generate_service_client(t_service* tservice) {
// Serialize the request header
string fname = (*f_iter)->get_name();
string msgType = (*f_iter)->is_oneway() ? "T.M_ONEWAY" : "T.M_CALL";
- indent(f_client_) << "T.writeMessageBegin op (\"" << fname << "\", " << msgType << ", seqn)"
+ indent(f_client_) << "T.writeMessage op (\"" << fname << "\", " << msgType << ", seqn) $"
<< endl;
+ indent_up();
indent(f_client_) << "write_" << argsname << " op (" << argsname << "{";
bool first = true;
@@ -1102,10 +1103,7 @@ void t_hs_generator::generate_service_client(t_service* tservice) {
first = false;
}
f_client_ << "})" << endl;
- indent(f_client_) << "T.writeMessageEnd op" << endl;
-
- // Write to the stream
- indent(f_client_) << "T.tFlush (T.getTransport op)" << endl;
+ indent_down();
indent_down();
if (!(*f_iter)->is_oneway()) {
@@ -1119,12 +1117,12 @@ void t_hs_generator::generate_service_client(t_service* tservice) {
indent(f_client_) << funname << " ip = do" << endl;
indent_up();
- indent(f_client_) << "(fname, mtype, rseqid) <- T.readMessageBegin ip" << endl;
+ indent(f_client_) << "T.readMessage ip $ \\(fname, mtype, rseqid) -> do" << endl;
+ indent_up();
indent(f_client_) << "M.when (mtype == T.M_EXCEPTION) $ do { exn <- T.readAppExn ip ; "
- "T.readMessageEnd ip ; X.throw exn }" << endl;
+ "X.throw exn }" << endl;
indent(f_client_) << "res <- read_" << resultname << " ip" << endl;
- indent(f_client_) << "T.readMessageEnd ip" << endl;
t_struct* xs = (*f_iter)->get_xceptions();
const vector<t_field*>& xceptions = xs->get_members();
@@ -1142,6 +1140,7 @@ void t_hs_generator::generate_service_client(t_service* tservice) {
// Close function
indent_down();
+ indent_down();
}
}
@@ -1180,11 +1179,11 @@ void t_hs_generator::generate_service_server(t_service* tservice) {
f_service_ << "do" << endl;
indent_up();
indent(f_service_) << "_ <- T.readVal iprot (T.T_STRUCT Map.empty)" << endl;
- indent(f_service_) << "T.writeMessageBegin oprot (name,T.M_EXCEPTION,seqid)" << endl;
+ indent(f_service_) << "T.writeMessage oprot (name,T.M_EXCEPTION,seqid) $" << endl;
+ indent_up();
indent(f_service_) << "T.writeAppExn oprot (T.AppExn T.AE_UNKNOWN_METHOD (\"Unknown function "
"\" ++ LT.unpack name))" << endl;
- indent(f_service_) << "T.writeMessageEnd oprot" << endl;
- indent(f_service_) << "T.tFlush (T.getTransport oprot)" << endl;
+ indent_down();
indent_down();
}
@@ -1194,9 +1193,8 @@ void t_hs_generator::generate_service_server(t_service* tservice) {
indent(f_service_) << "process handler (iprot, oprot) = do" << endl;
indent_up();
- indent(f_service_) << "(name, typ, seqid) <- T.readMessageBegin iprot" << endl;
- indent(f_service_) << "proc_ handler (iprot,oprot) (name,typ,seqid)" << endl;
- indent(f_service_) << "T.readMessageEnd iprot" << endl;
+ indent(f_service_) << "T.readMessage iprot (" << endl;
+ indent(f_service_) << " proc_ handler (iprot,oprot))" << endl;
indent(f_service_) << "P.return P.True" << endl;
indent_down();
}
@@ -1286,11 +1284,11 @@ void t_hs_generator::generate_process_function(t_service* tservice, t_function*
if (tfunction->is_oneway()) {
indent(f_service_) << "P.return ()";
} else {
- indent(f_service_) << "T.writeMessageBegin oprot (\"" << tfunction->get_name()
- << "\", T.M_REPLY, seqid)" << endl;
- indent(f_service_) << "write_" << resultname << " oprot res" << endl;
- indent(f_service_) << "T.writeMessageEnd oprot" << endl;
- indent(f_service_) << "T.tFlush (T.getTransport oprot)";
+ indent(f_service_) << "T.writeMessage oprot (\"" << tfunction->get_name()
+ << "\", T.M_REPLY, seqid) $" << endl;
+ indent_up();
+ indent(f_service_) << "write_" << resultname << " oprot res";
+ indent_down();
}
if (n > 0) {
f_service_ << ")";
@@ -1307,11 +1305,11 @@ void t_hs_generator::generate_process_function(t_service* tservice, t_function*
indent(f_service_) << "let res = default_" << resultname << "{"
<< field_name(resultname, (*x_iter)->get_name()) << " = P.Just e}"
<< endl;
- indent(f_service_) << "T.writeMessageBegin oprot (\"" << tfunction->get_name()
- << "\", T.M_REPLY, seqid)" << endl;
- indent(f_service_) << "write_" << resultname << " oprot res" << endl;
- indent(f_service_) << "T.writeMessageEnd oprot" << endl;
- indent(f_service_) << "T.tFlush (T.getTransport oprot)";
+ indent(f_service_) << "T.writeMessage oprot (\"" << tfunction->get_name()
+ << "\", T.M_REPLY, seqid) $" << endl;
+ indent_up();
+ indent(f_service_) << "write_" << resultname << " oprot res";
+ indent_down();
} else {
indent(f_service_) << "P.return ()";
}
@@ -1324,11 +1322,11 @@ void t_hs_generator::generate_process_function(t_service* tservice, t_function*
indent_up();
if (!tfunction->is_oneway()) {
- indent(f_service_) << "T.writeMessageBegin oprot (\"" << tfunction->get_name()
- << "\", T.M_EXCEPTION, seqid)" << endl;
- indent(f_service_) << "T.writeAppExn oprot (T.AppExn T.AE_UNKNOWN \"\")" << endl;
- indent(f_service_) << "T.writeMessageEnd oprot" << endl;
- indent(f_service_) << "T.tFlush (T.getTransport oprot)";
+ indent(f_service_) << "T.writeMessage oprot (\"" << tfunction->get_name()
+ << "\", T.M_EXCEPTION, seqid) $" << endl;
+ indent_up();
+ indent(f_service_) << "T.writeAppExn oprot (T.AppExn T.AE_UNKNOWN \"\")";
+ indent_down();
} else {
indent(f_service_) << "P.return ()";
}
diff --git a/lib/hs/src/Thrift.hs b/lib/hs/src/Thrift.hs
index 58a304b6e..658020991 100644
--- a/lib/hs/src/Thrift.hs
+++ b/lib/hs/src/Thrift.hs
@@ -90,13 +90,13 @@ data AppExn = AppExn { ae_type :: AppExnType, ae_message :: String }
deriving ( Show, Typeable )
instance Exception AppExn
-writeAppExn :: (Protocol p, Transport t) => p t -> AppExn -> IO ()
+writeAppExn :: Protocol p => p -> AppExn -> IO ()
writeAppExn pt ae = writeVal pt $ TStruct $ Map.fromList
[ (1, ("message", TString $ encodeUtf8 $ pack $ ae_message ae))
, (2, ("type", TI32 $ fromIntegral $ fromEnum (ae_type ae)))
]
-readAppExn :: (Protocol p, Transport t) => p t -> IO AppExn
+readAppExn :: Protocol p => p -> IO AppExn
readAppExn pt = do
let typemap = Map.fromList [(1,("message",T_STRING)),(2,("type",T_I32))]
TStruct fields <- readVal pt $ T_STRUCT typemap
diff --git a/lib/hs/src/Thrift/Protocol.hs b/lib/hs/src/Thrift/Protocol.hs
index ed779a27d..67a9175cb 100644
--- a/lib/hs/src/Thrift/Protocol.hs
+++ b/lib/hs/src/Thrift/Protocol.hs
@@ -22,12 +22,11 @@
module Thrift.Protocol
( Protocol(..)
+ , StatelessProtocol(..)
, ProtocolExn(..)
, ProtocolExnType(..)
, getTypeOf
, runParser
- , versionMask
- , version1
, bsToDouble
, bsToDoubleLE
) where
@@ -35,7 +34,6 @@ module Thrift.Protocol
import Control.Exception
import Data.Attoparsec.ByteString
import Data.Bits
-import Data.ByteString.Lazy (ByteString, toStrict)
import Data.ByteString.Unsafe
import Data.Functor ((<$>))
import Data.Int
@@ -44,37 +42,26 @@ import Data.Text.Lazy (Text)
import Data.Typeable (Typeable)
import Data.Word
import Foreign.Ptr (castPtr)
-import Foreign.Storable (Storable, peek, poke)
+import Foreign.Storable (peek, poke)
import System.IO.Unsafe
import qualified Data.ByteString as BS
import qualified Data.HashMap.Strict as Map
+import qualified Data.ByteString.Lazy as LBS
-import Thrift.Types
import Thrift.Transport
-
-versionMask :: Int32
-versionMask = fromIntegral (0xffff0000 :: Word32)
-
-version1 :: Int32
-version1 = fromIntegral (0x80010000 :: Word32)
+import Thrift.Types
class Protocol a where
- getTransport :: Transport t => a t -> t
-
- writeMessageBegin :: Transport t => a t -> (Text, MessageType, Int32) -> IO ()
- writeMessageEnd :: Transport t => a t -> IO ()
- writeMessageEnd _ = return ()
-
- readMessageBegin :: Transport t => a t -> IO (Text, MessageType, Int32)
- readMessageEnd :: Transport t => a t -> IO ()
- readMessageEnd _ = return ()
+ readByte :: a -> IO LBS.ByteString
+ readVal :: a -> ThriftType -> IO ThriftVal
+ readMessage :: a -> ((Text, MessageType, Int32) -> IO b) -> IO b
- serializeVal :: Transport t => a t -> ThriftVal -> ByteString
- deserializeVal :: Transport t => a t -> ThriftType -> ByteString -> ThriftVal
+ writeVal :: a -> ThriftVal -> IO ()
+ writeMessage :: a -> (Text, MessageType, Int32) -> IO () -> IO ()
- writeVal :: Transport t => a t -> ThriftVal -> IO ()
- writeVal p = tWrite (getTransport p) . serializeVal p
- readVal :: Transport t => a t -> ThriftType -> IO ThriftVal
+class Protocol a => StatelessProtocol a where
+ serializeVal :: a -> ThriftVal -> LBS.ByteString
+ deserializeVal :: a -> ThriftType -> LBS.ByteString -> ThriftVal
data ProtocolExnType
= PE_UNKNOWN
@@ -105,10 +92,10 @@ getTypeOf v = case v of
TBinary{} -> T_BINARY
TDouble{} -> T_DOUBLE
-runParser :: (Protocol p, Transport t, Show a) => p t -> Parser a -> IO a
+runParser :: (Protocol p, Show a) => p -> Parser a -> IO a
runParser prot p = refill >>= getResult . parse p
where
- refill = handle handleEOF $ toStrict <$> tReadAll (getTransport prot) 1
+ refill = handle handleEOF $ LBS.toStrict <$> readByte prot
getResult (Done _ a) = return a
getResult (Partial k) = refill >>= getResult . k
getResult f = throw $ ProtocolExn PE_INVALID_DATA (show f)
diff --git a/lib/hs/src/Thrift/Protocol/Binary.hs b/lib/hs/src/Thrift/Protocol/Binary.hs
index 2d35305dc..7b0acd9d4 100644
--- a/lib/hs/src/Thrift/Protocol/Binary.hs
+++ b/lib/hs/src/Thrift/Protocol/Binary.hs
@@ -25,6 +25,8 @@
module Thrift.Protocol.Binary
( module Thrift.Protocol
, BinaryProtocol(..)
+ , versionMask
+ , version1
) where
import Control.Exception ( throw )
@@ -35,6 +37,7 @@ import Data.Functor
import Data.Int
import Data.Monoid
import Data.Text.Lazy.Encoding ( decodeUtf8, encodeUtf8 )
+import Data.Word
import Thrift.Protocol
import Thrift.Transport
@@ -47,37 +50,55 @@ import qualified Data.ByteString.Lazy as LBS
import qualified Data.HashMap.Strict as Map
import qualified Data.Text.Lazy as LT
-data BinaryProtocol a = BinaryProtocol a
+versionMask :: Int32
+versionMask = fromIntegral (0xffff0000 :: Word32)
+
+version1 :: Int32
+version1 = fromIntegral (0x80010000 :: Word32)
+
+data BinaryProtocol a = Transport a => BinaryProtocol a
+
+getTransport :: Transport t => BinaryProtocol t -> t
+getTransport (BinaryProtocol t) = t
-- NOTE: Reading and Writing functions rely on Builders and Data.Binary to
-- encode and decode data. Data.Binary assumes that the binary values it is
-- encoding to and decoding from are in BIG ENDIAN format, and converts the
-- endianness as necessary to match the local machine.
-instance Protocol BinaryProtocol where
- getTransport (BinaryProtocol t) = t
-
- writeMessageBegin p (n, t, s) = tWrite (getTransport p) $ toLazyByteString $
- buildBinaryValue (TI32 (version1 .|. fromIntegral (fromEnum t))) <>
- buildBinaryValue (TString $ encodeUtf8 n) <>
- buildBinaryValue (TI32 s)
-
- readMessageBegin p = runParser p $ do
- TI32 ver <- parseBinaryValue T_I32
- if ver .&. versionMask /= version1
- then throw $ ProtocolExn PE_BAD_VERSION "Missing version identifier"
- else do
- TString s <- parseBinaryValue T_STRING
- TI32 sz <- parseBinaryValue T_I32
- return (decodeUtf8 s, toEnum $ fromIntegral $ ver .&. 0xFF, sz)
+instance Transport t => Protocol (BinaryProtocol t) where
+ readByte p = tReadAll (getTransport p) 1
+ -- flushTransport p = tFlush (getTransport p)
+ writeMessage p (n, t, s) f = do
+ tWrite (getTransport p) messageBegin
+ f
+ tFlush $ getTransport p
+ where
+ messageBegin = toLazyByteString $
+ buildBinaryValue (TI32 (version1 .|. fromIntegral (fromEnum t))) <>
+ buildBinaryValue (TString $ encodeUtf8 n) <>
+ buildBinaryValue (TI32 s)
+
+ readMessage p = (readMessageBegin p >>=)
+ where
+ readMessageBegin p = runParser p $ do
+ TI32 ver <- parseBinaryValue T_I32
+ if ver .&. versionMask /= version1
+ then throw $ ProtocolExn PE_BAD_VERSION "Missing version identifier"
+ else do
+ TString s <- parseBinaryValue T_STRING
+ TI32 sz <- parseBinaryValue T_I32
+ return (decodeUtf8 s, toEnum $ fromIntegral $ ver .&. 0xFF, sz)
+
+ writeVal p = tWrite (getTransport p) . toLazyByteString . buildBinaryValue
+ readVal p = runParser p . parseBinaryValue
+instance Transport t => StatelessProtocol (BinaryProtocol t) where
serializeVal _ = toLazyByteString . buildBinaryValue
deserializeVal _ ty bs =
case LP.eitherResult $ LP.parse (parseBinaryValue ty) bs of
Left s -> error s
Right val -> val
- readVal p = runParser p . parseBinaryValue
-
-- | Writing Functions
buildBinaryValue :: ThriftVal -> Builder
buildBinaryValue (TStruct fields) = buildBinaryStruct fields <> buildType T_STOP
diff --git a/lib/hs/src/Thrift/Protocol/Compact.hs b/lib/hs/src/Thrift/Protocol/Compact.hs
index 07113df21..f23970a82 100644
--- a/lib/hs/src/Thrift/Protocol/Compact.hs
+++ b/lib/hs/src/Thrift/Protocol/Compact.hs
@@ -25,10 +25,11 @@
module Thrift.Protocol.Compact
( module Thrift.Protocol
, CompactProtocol(..)
+ , parseVarint
+ , buildVarint
) where
import Control.Applicative
-import Control.Exception ( throw )
import Control.Monad
import Data.Attoparsec.ByteString as P
import Data.Attoparsec.ByteString.Lazy as LP
@@ -40,7 +41,7 @@ import Data.Monoid
import Data.Word
import Data.Text.Lazy.Encoding ( decodeUtf8, encodeUtf8 )
-import Thrift.Protocol hiding (versionMask)
+import Thrift.Protocol
import Thrift.Transport
import Thrift.Types
@@ -64,38 +65,47 @@ typeBits = 0x07 -- 0000 0111
typeShiftAmount :: Int
typeShiftAmount = 5
+getTransport :: Transport t => CompactProtocol t -> t
+getTransport (CompactProtocol t) = t
-instance Protocol CompactProtocol where
- getTransport (CompactProtocol t) = t
+instance Transport t => Protocol (CompactProtocol t) where
+ readByte p = tReadAll (getTransport p) 1
+ writeMessage p (n, t, s) f = do
+ tWrite (getTransport p) messageBegin
+ f
+ tFlush $ getTransport p
+ where
+ messageBegin = toLazyByteString $
+ B.word8 protocolID <>
+ B.word8 ((version .&. versionMask) .|.
+ (((fromIntegral $ fromEnum t) `shiftL`
+ typeShiftAmount) .&. typeMask)) <>
+ buildVarint (i32ToZigZag s) <>
+ buildCompactValue (TString $ encodeUtf8 n)
- writeMessageBegin p (n, t, s) = tWrite (getTransport p) $ toLazyByteString $
- B.word8 protocolID <>
- B.word8 ((version .&. versionMask) .|.
- (((fromIntegral $ fromEnum t) `shiftL`
- typeShiftAmount) .&. typeMask)) <>
- buildVarint (i32ToZigZag s) <>
- buildCompactValue (TString $ encodeUtf8 n)
-
- readMessageBegin p = runParser p $ do
- pid <- fromIntegral <$> P.anyWord8
- when (pid /= protocolID) $ error "Bad Protocol ID"
- w <- fromIntegral <$> P.anyWord8
- let ver = w .&. versionMask
- when (ver /= version) $ error "Bad Protocol version"
- let typ = (w `shiftR` typeShiftAmount) .&. typeBits
- seqId <- parseVarint zigZagToI32
- TString name <- parseCompactValue T_STRING
- return (decodeUtf8 name, toEnum $ fromIntegral $ typ, seqId)
+ readMessage p f = readMessageBegin >>= f
+ where
+ readMessageBegin = runParser p $ do
+ pid <- fromIntegral <$> P.anyWord8
+ when (pid /= protocolID) $ error "Bad Protocol ID"
+ w <- fromIntegral <$> P.anyWord8
+ let ver = w .&. versionMask
+ when (ver /= version) $ error "Bad Protocol version"
+ let typ = (w `shiftR` typeShiftAmount) .&. typeBits
+ seqId <- parseVarint zigZagToI32
+ TString name <- parseCompactValue T_STRING
+ return (decodeUtf8 name, toEnum $ fromIntegral $ typ, seqId)
+ writeVal p = tWrite (getTransport p) . toLazyByteString . buildCompactValue
+ readVal p ty = runParser p $ parseCompactValue ty
+
+instance Transport t => StatelessProtocol (CompactProtocol t) where
serializeVal _ = toLazyByteString . buildCompactValue
deserializeVal _ ty bs =
case LP.eitherResult $ LP.parse (parseCompactValue ty) bs of
Left s -> error s
Right val -> val
- readVal p ty = runParser p $ parseCompactValue ty
-
-
-- | Writing Functions
buildCompactValue :: ThriftVal -> Builder
buildCompactValue (TStruct fields) = buildCompactStruct fields
@@ -283,7 +293,7 @@ typeOf v = case v of
TSet{} -> 0x0A
TMap{} -> 0x0B
TStruct{} -> 0x0C
-
+
typeFrom :: Word8 -> ThriftType
typeFrom w = case w of
0x01 -> T_BOOL
diff --git a/lib/hs/src/Thrift/Protocol/Header.hs b/lib/hs/src/Thrift/Protocol/Header.hs
new file mode 100644
index 000000000..5f42db45d
--- /dev/null
+++ b/lib/hs/src/Thrift/Protocol/Header.hs
@@ -0,0 +1,141 @@
+--
+-- Licensed to the Apache Software Foundation (ASF) under one
+-- or more contributor license agreements. See the NOTICE file
+-- distributed with this work for additional information
+-- regarding copyright ownership. The ASF licenses this file
+-- to you under the Apache License, Version 2.0 (the
+-- "License"); you may not use this file except in compliance
+-- with the License. You may obtain a copy of the License at
+--
+-- http://www.apache.org/licenses/LICENSE-2.0
+--
+-- Unless required by applicable law or agreed to in writing,
+-- software distributed under the License is distributed on an
+-- "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+-- KIND, either express or implied. See the License for the
+-- specific language governing permissions and limitations
+-- under the License.
+--
+
+
+module Thrift.Protocol.Header
+ ( module Thrift.Protocol
+ , HeaderProtocol(..)
+ , getProtocolType
+ , setProtocolType
+ , getHeaders
+ , getWriteHeaders
+ , setHeader
+ , setHeaders
+ , createHeaderProtocol
+ , createHeaderProtocol1
+ ) where
+
+import Thrift.Protocol
+import Thrift.Protocol.Binary
+import Thrift.Protocol.JSON
+import Thrift.Protocol.Compact
+import Thrift.Transport
+import Thrift.Transport.Header
+import Data.IORef
+import qualified Data.Map as Map
+
+data ProtocolWrap = forall a. (Protocol a) => ProtocolWrap(a)
+
+instance Protocol ProtocolWrap where
+ readByte (ProtocolWrap p) = readByte p
+ readVal (ProtocolWrap p) = readVal p
+ readMessage (ProtocolWrap p) = readMessage p
+ writeVal (ProtocolWrap p) = writeVal p
+ writeMessage (ProtocolWrap p) = writeMessage p
+
+data HeaderProtocol i o = (Transport i, Transport o) => HeaderProtocol {
+ trans :: HeaderTransport i o,
+ wrappedProto :: IORef ProtocolWrap
+ }
+
+createProtocolWrap :: Transport t => ProtocolType -> t -> ProtocolWrap
+createProtocolWrap typ t =
+ case typ of
+ TBinary -> ProtocolWrap $ BinaryProtocol t
+ TCompact -> ProtocolWrap $ CompactProtocol t
+ TJSON -> ProtocolWrap $ JSONProtocol t
+
+createHeaderProtocol :: (Transport i, Transport o) => i -> o -> IO(HeaderProtocol i o)
+createHeaderProtocol i o = do
+ t <- openHeaderTransport i o
+ pid <- readIORef $ protocolType t
+ proto <- newIORef $ createProtocolWrap pid t
+ return $ HeaderProtocol { trans = t, wrappedProto = proto }
+
+createHeaderProtocol1 :: Transport t => t -> IO(HeaderProtocol t t)
+createHeaderProtocol1 t = createHeaderProtocol t t
+
+resetProtocol :: (Transport i, Transport o) => HeaderProtocol i o -> IO ()
+resetProtocol p = do
+ pid <- readIORef $ protocolType $ trans p
+ writeIORef (wrappedProto p) $ createProtocolWrap pid $ trans p
+
+getWrapped = readIORef . wrappedProto
+
+setTransport :: (Transport i, Transport o) => HeaderProtocol i o -> HeaderTransport i o -> HeaderProtocol i o
+setTransport p t = p { trans = t }
+
+updateTransport :: (Transport i, Transport o) => HeaderProtocol i o -> (HeaderTransport i o -> HeaderTransport i o)-> HeaderProtocol i o
+updateTransport p f = setTransport p (f $ trans p)
+
+type Headers = Map.Map String String
+
+-- TODO: we want to set headers without recreating client...
+setHeader :: (Transport i, Transport o) => HeaderProtocol i o -> String -> String -> HeaderProtocol i o
+setHeader p k v = updateTransport p $ \t -> t { writeHeaders = Map.insert k v $ writeHeaders t }
+
+setHeaders :: (Transport i, Transport o) => HeaderProtocol i o -> Headers -> HeaderProtocol i o
+setHeaders p h = updateTransport p $ \t -> t { writeHeaders = h }
+
+-- TODO: make it public once we have first transform implementation for Haskell
+setTransforms :: (Transport i, Transport o) => HeaderProtocol i o -> [TransformType] -> HeaderProtocol i o
+setTransforms p trs = updateTransport p $ \t -> t { writeTransforms = trs }
+
+setTransform :: (Transport i, Transport o) => HeaderProtocol i o -> TransformType -> HeaderProtocol i o
+setTransform p tr = updateTransport p $ \t -> t { writeTransforms = tr:(writeTransforms t) }
+
+getWriteHeaders :: (Transport i, Transport o) => HeaderProtocol i o -> Headers
+getWriteHeaders = writeHeaders . trans
+
+getHeaders :: (Transport i, Transport o) => HeaderProtocol i o -> IO [(String, String)]
+getHeaders = readIORef . headers . trans
+
+getProtocolType :: (Transport i, Transport o) => HeaderProtocol i o -> IO ProtocolType
+getProtocolType p = readIORef $ protocolType $ trans p
+
+setProtocolType :: (Transport i, Transport o) => HeaderProtocol i o -> ProtocolType -> IO ()
+setProtocolType p typ = do
+ typ0 <- getProtocolType p
+ if typ == typ0
+ then return ()
+ else do
+ tSetProtocol (trans p) typ
+ resetProtocol p
+
+instance (Transport i, Transport o) => Protocol (HeaderProtocol i o) where
+ readByte p = tReadAll (trans p) 1
+
+ readVal p tp = do
+ proto <- getWrapped p
+ readVal proto tp
+
+ readMessage p f = do
+ tResetProtocol (trans p)
+ resetProtocol p
+ proto <- getWrapped p
+ readMessage proto f
+
+ writeVal p v = do
+ proto <- getWrapped p
+ writeVal proto v
+
+ writeMessage p x f = do
+ proto <- getWrapped p
+ writeMessage proto x f
+
diff --git a/lib/hs/src/Thrift/Protocol/JSON.hs b/lib/hs/src/Thrift/Protocol/JSON.hs
index 7f619e8cb..839eddc84 100644
--- a/lib/hs/src/Thrift/Protocol/JSON.hs
+++ b/lib/hs/src/Thrift/Protocol/JSON.hs
@@ -29,12 +29,12 @@ module Thrift.Protocol.JSON
) where
import Control.Applicative
+import Control.Exception (bracket)
import Control.Monad
import Data.Attoparsec.ByteString as P
import Data.Attoparsec.ByteString.Char8 as PC
import Data.Attoparsec.ByteString.Lazy as LP
import Data.ByteString.Base64.Lazy as B64C
-import Data.ByteString.Base64 as B64
import Data.ByteString.Lazy.Builder as B
import Data.ByteString.Internal (c2w, w2c)
import Data.Functor
@@ -58,38 +58,48 @@ import qualified Data.Text.Lazy as LT
-- encoded as a JSON 'ByteString'
data JSONProtocol t = JSONProtocol t
-- ^ Construct a 'JSONProtocol' with a 'Transport'
+getTransport :: Transport t => JSONProtocol t -> t
+getTransport (JSONProtocol t) = t
-instance Protocol JSONProtocol where
- getTransport (JSONProtocol t) = t
+instance Transport t => Protocol (JSONProtocol t) where
+ readByte p = tReadAll (getTransport p) 1
- writeMessageBegin (JSONProtocol t) (s, ty, sq) = tWrite t $ toLazyByteString $
- B.char8 '[' <> buildShowable (1 :: Int32) <>
- B.string8 ",\"" <> escape (encodeUtf8 s) <> B.char8 '\"' <>
- B.char8 ',' <> buildShowable (fromEnum ty) <>
- B.char8 ',' <> buildShowable sq <>
- B.char8 ','
- writeMessageEnd (JSONProtocol t) = tWrite t "]"
- readMessageBegin p = runParser p $ skipSpace *> do
- _ver :: Int32 <- lexeme (PC.char8 '[') *> lexeme (signed decimal)
- bs <- lexeme (PC.char8 ',') *> lexeme escapedString
- case decodeUtf8' bs of
- Left _ -> fail "readMessage: invalid text encoding"
- Right str -> do
- ty <- toEnum <$> (lexeme (PC.char8 ',') *> lexeme (signed decimal))
- seqNum <- lexeme (PC.char8 ',') *> lexeme (signed decimal)
- _ <- PC.char8 ','
- return (str, ty, seqNum)
- readMessageEnd p = void $ runParser p (PC.char8 ']')
+ writeMessage (JSONProtocol t) (s, ty, sq) = bracket readMessageBegin readMessageEnd . const
+ where
+ readMessageBegin = tWrite t $ toLazyByteString $
+ B.char8 '[' <> buildShowable (1 :: Int32) <>
+ B.string8 ",\"" <> escape (encodeUtf8 s) <> B.char8 '\"' <>
+ B.char8 ',' <> buildShowable (fromEnum ty) <>
+ B.char8 ',' <> buildShowable sq <>
+ B.char8 ','
+ readMessageEnd _ = do
+ tWrite t "]"
+ tFlush t
+ readMessage p = bracket readMessageBegin readMessageEnd
+ where
+ readMessageBegin = runParser p $ skipSpace *> do
+ _ver :: Int32 <- lexeme (PC.char8 '[') *> lexeme (signed decimal)
+ bs <- lexeme (PC.char8 ',') *> lexeme escapedString
+ case decodeUtf8' bs of
+ Left _ -> fail "readMessage: invalid text encoding"
+ Right str -> do
+ ty <- toEnum <$> (lexeme (PC.char8 ',') *> lexeme (signed decimal))
+ seqNum <- lexeme (PC.char8 ',') *> lexeme (signed decimal)
+ _ <- PC.char8 ','
+ return (str, ty, seqNum)
+ readMessageEnd _ = void $ runParser p (PC.char8 ']')
+
+ writeVal p = tWrite (getTransport p) . toLazyByteString . buildJSONValue
+ readVal p ty = runParser p $ skipSpace *> parseJSONValue ty
+
+instance Transport t => StatelessProtocol (JSONProtocol t) where
serializeVal _ = toLazyByteString . buildJSONValue
deserializeVal _ ty bs =
case LP.eitherResult $ LP.parse (parseJSONValue ty) bs of
Left s -> error s
Right val -> val
- readVal p ty = runParser p $ skipSpace *> parseJSONValue ty
-
-
-- Writing Functions
buildJSONValue :: ThriftVal -> Builder
diff --git a/lib/hs/src/Thrift/Server.hs b/lib/hs/src/Thrift/Server.hs
index ed74ceba6..543f33850 100644
--- a/lib/hs/src/Thrift/Server.hs
+++ b/lib/hs/src/Thrift/Server.hs
@@ -38,10 +38,10 @@ import Thrift.Protocol.Binary
-- | A threaded sever that is capable of using any Transport or Protocol
-- instances.
-runThreadedServer :: (Transport t, Protocol i, Protocol o)
- => (Socket -> IO (i t, o t))
+runThreadedServer :: (Protocol i, Protocol o)
+ => (Socket -> IO (i, o))
-> h
- -> (h -> (i t, o t) -> IO Bool)
+ -> (h -> (i, o) -> IO Bool)
-> PortID
-> IO a
runThreadedServer accepter hand proc_ port = do
diff --git a/lib/hs/src/Thrift/Transport/Handle.hs b/lib/hs/src/Thrift/Transport/Handle.hs
index b7d16e4fb..ff6295b67 100644
--- a/lib/hs/src/Thrift/Transport/Handle.hs
+++ b/lib/hs/src/Thrift/Transport/Handle.hs
@@ -44,7 +44,13 @@ import Data.Monoid
instance Transport Handle where
tIsOpen = hIsOpen
tClose = hClose
- tRead h n = LBS.hGet h n `Control.Exception.catch` handleEOF mempty
+ tRead h n = read `Control.Exception.catch` handleEOF mempty
+ where
+ read = do
+ hLookAhead h
+ LBS.hGetNonBlocking h n
+ tReadAll _ 0 = return mempty
+ tReadAll h n = LBS.hGet h n `Control.Exception.catch` throwTransportExn
tPeek h = (Just . c2w <$> hLookAhead h) `Control.Exception.catch` handleEOF Nothing
tWrite = LBS.hPut
tFlush = hFlush
@@ -61,8 +67,12 @@ instance HandleSource FilePath where
instance HandleSource (HostName, PortID) where
hOpen = uncurry connectTo
+throwTransportExn :: IOError -> IO a
+throwTransportExn e = if isEOFError e
+ then throw $ TransportExn "Cannot read. Remote side has closed." TE_UNKNOWN
+ else throw $ TransportExn "Handle tReadAll: Could not read" TE_UNKNOWN
handleEOF :: a -> IOError -> IO a
handleEOF a e = if isEOFError e
then return a
- else throw $ TransportExn "TChannelTransport: Could not read" TE_UNKNOWN
+ else throw $ TransportExn "Handle: Could not read" TE_UNKNOWN
diff --git a/lib/hs/src/Thrift/Transport/Header.hs b/lib/hs/src/Thrift/Transport/Header.hs
new file mode 100644
index 000000000..2dacad25f
--- /dev/null
+++ b/lib/hs/src/Thrift/Transport/Header.hs
@@ -0,0 +1,354 @@
+--
+-- Licensed to the Apache Software Foundation (ASF) under one
+-- or more contributor license agreements. See the NOTICE file
+-- distributed with this work for additional information
+-- regarding copyright ownership. The ASF licenses this file
+-- to you under the Apache License, Version 2.0 (the
+-- "License"); you may not use this file except in compliance
+-- with the License. You may obtain a copy of the License at
+--
+-- http://www.apache.org/licenses/LICENSE-2.0
+--
+-- Unless required by applicable law or agreed to in writing,
+-- software distributed under the License is distributed on an
+-- "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+-- KIND, either express or implied. See the License for the
+-- specific language governing permissions and limitations
+-- under the License.
+--
+
+module Thrift.Transport.Header
+ ( module Thrift.Transport
+ , HeaderTransport(..)
+ , openHeaderTransport
+ , ProtocolType(..)
+ , TransformType(..)
+ , ClientType(..)
+ , tResetProtocol
+ , tSetProtocol
+ ) where
+
+import Thrift.Transport
+import Thrift.Protocol.Compact
+import Control.Applicative
+import Control.Exception ( throw )
+import Control.Monad
+import Data.Bits
+import Data.IORef
+import Data.Int
+import Data.Monoid
+import Data.Word
+
+import qualified Data.Attoparsec.ByteString as P
+import qualified Data.Binary as Binary
+import qualified Data.ByteString as BS
+import qualified Data.ByteString.Char8 as C
+import qualified Data.ByteString.Lazy as LBS
+import qualified Data.ByteString.Lazy.Builder as B
+import qualified Data.Map as Map
+
+data ProtocolType = TBinary | TCompact | TJSON deriving (Enum, Eq)
+data ClientType = HeaderClient | Framed | Unframed deriving (Enum, Eq)
+
+infoIdKeyValue = 1
+
+type Headers = Map.Map String String
+
+data TransformType = ZlibTransform deriving (Enum, Eq)
+
+fromTransportType :: TransformType -> Int16
+fromTransportType ZlibTransform = 1
+
+toTransportType :: Int16 -> TransformType
+toTransportType 1 = ZlibTransform
+toTransportType _ = throw $ TransportExn "HeaderTransport: Unknown transform ID" TE_UNKNOWN
+
+data HeaderTransport i o = (Transport i, Transport o) => HeaderTransport
+ { readBuffer :: IORef LBS.ByteString
+ , writeBuffer :: IORef B.Builder
+ , inTrans :: i
+ , outTrans :: o
+ , clientType :: IORef ClientType
+ , protocolType :: IORef ProtocolType
+ , headers :: IORef [(String, String)]
+ , writeHeaders :: Headers
+ , transforms :: IORef [TransformType]
+ , writeTransforms :: [TransformType]
+ }
+
+openHeaderTransport :: (Transport i, Transport o) => i -> o -> IO (HeaderTransport i o)
+openHeaderTransport i o = do
+ pid <- newIORef TCompact
+ rBuf <- newIORef LBS.empty
+ wBuf <- newIORef mempty
+ cType <- newIORef HeaderClient
+ h <- newIORef []
+ trans <- newIORef []
+ return HeaderTransport
+ { readBuffer = rBuf
+ , writeBuffer = wBuf
+ , inTrans = i
+ , outTrans = o
+ , clientType = cType
+ , protocolType = pid
+ , headers = h
+ , writeHeaders = Map.empty
+ , transforms = trans
+ , writeTransforms = []
+ }
+
+isFramed t = (/= Unframed) <$> readIORef (clientType t)
+
+readFrame :: (Transport i, Transport o) => HeaderTransport i o -> IO Bool
+readFrame t = do
+ let input = inTrans t
+ let rBuf = readBuffer t
+ let cType = clientType t
+ lsz <- tRead input 4
+ let sz = LBS.toStrict lsz
+ case P.parseOnly P.endOfInput sz of
+ Right _ -> do return False
+ Left _ -> do
+ case parseBinaryMagic sz of
+ Right _ -> do
+ writeIORef rBuf $ lsz
+ writeIORef cType Unframed
+ writeIORef (protocolType t) TBinary
+ return True
+ Left _ -> do
+ case parseCompactMagic sz of
+ Right _ -> do
+ writeIORef rBuf $ lsz
+ writeIORef cType Unframed
+ writeIORef (protocolType t) TCompact
+ return True
+ Left _ -> do
+ let len = Binary.decode lsz :: Int32
+ lbuf <- tReadAll input $ fromIntegral len
+ let buf = LBS.toStrict lbuf
+ case parseBinaryMagic buf of
+ Right _ -> do
+ writeIORef cType Framed
+ writeIORef (protocolType t) TBinary
+ writeIORef rBuf lbuf
+ return True
+ Left _ -> do
+ case parseCompactMagic buf of
+ Right _ -> do
+ writeIORef cType Framed
+ writeIORef (protocolType t) TCompact
+ writeIORef rBuf lbuf
+ return True
+ Left _ -> do
+ case parseHeaderMagic buf of
+ Right flags -> do
+ let (flags, seqNum, header, body) = extractHeader buf
+ writeIORef cType HeaderClient
+ handleHeader t header
+ payload <- untransform t body
+ writeIORef rBuf $ LBS.fromStrict $ payload
+ return True
+ Left _ ->
+ throw $ TransportExn "HeaderTransport: unkonwn client type" TE_UNKNOWN
+
+parseBinaryMagic = P.parseOnly $ P.word8 0x80 *> P.word8 0x01 *> P.word8 0x00 *> P.anyWord8
+parseCompactMagic = P.parseOnly $ P.word8 0x82 *> P.satisfy (\b -> b .&. 0x1f == 0x01)
+parseHeaderMagic = P.parseOnly $ P.word8 0x0f *> P.word8 0xff *> (P.count 2 P.anyWord8)
+
+parseI32 :: P.Parser Int32
+parseI32 = Binary.decode . LBS.fromStrict <$> P.take 4
+parseI16 :: P.Parser Int16
+parseI16 = Binary.decode . LBS.fromStrict <$> P.take 2
+
+extractHeader :: BS.ByteString -> (Int16, Int32, BS.ByteString, BS.ByteString)
+extractHeader bs =
+ case P.parse extractHeader_ bs of
+ P.Done remain (flags, seqNum, header) -> (flags, seqNum, header, remain)
+ _ -> throw $ TransportExn "HeaderTransport: Invalid header" TE_UNKNOWN
+ where
+ extractHeader_ = do
+ magic <- P.word8 0x0f *> P.word8 0xff
+ flags <- parseI16
+ seqNum <- parseI32
+ (headerSize :: Int) <- (* 4) . fromIntegral <$> parseI16
+ header <- P.take headerSize
+ return (flags, seqNum, header)
+
+handleHeader t header =
+ case P.parseOnly parseHeader header of
+ Right (pType, trans, info) -> do
+ writeIORef (protocolType t) pType
+ writeIORef (transforms t) trans
+ writeIORef (headers t) info
+ _ -> throw $ TransportExn "HeaderTransport: Invalid header" TE_UNKNOWN
+
+
+iw16 :: Int16 -> Word16
+iw16 = fromIntegral
+iw32 :: Int32 -> Word32
+iw32 = fromIntegral
+wi16 :: Word16 -> Int16
+wi16 = fromIntegral
+wi32 :: Word32 -> Int32
+wi32 = fromIntegral
+
+parseHeader :: P.Parser (ProtocolType, [TransformType], [(String, String)])
+parseHeader = do
+ protocolType <- toProtocolType <$> parseVarint wi16
+ numTrans <- fromIntegral <$> parseVarint wi16
+ trans <- replicateM numTrans parseTransform
+ info <- parseInfo
+ return (protocolType, trans, info)
+
+toProtocolType :: Int16 -> ProtocolType
+toProtocolType 0 = TBinary
+toProtocolType 1 = TJSON
+toProtocolType 2 = TCompact
+
+fromProtocolType :: ProtocolType -> Int16
+fromProtocolType TBinary = 0
+fromProtocolType TJSON = 1
+fromProtocolType TCompact = 2
+
+parseTransform :: P.Parser TransformType
+parseTransform = toTransportType <$> parseVarint wi16
+
+parseInfo :: P.Parser [(String, String)]
+parseInfo = do
+ n <- P.eitherP P.endOfInput (parseVarint wi32)
+ case n of
+ Left _ -> return []
+ Right n0 ->
+ replicateM (fromIntegral n0) $ do
+ klen <- parseVarint wi16
+ k <- P.take $ fromIntegral klen
+ vlen <- parseVarint wi16
+ v <- P.take $ fromIntegral vlen
+ return (C.unpack k, C.unpack v)
+
+parseString :: P.Parser BS.ByteString
+parseString = parseVarint wi32 >>= (P.take . fromIntegral)
+
+buildHeader :: HeaderTransport i o -> IO B.Builder
+buildHeader t = do
+ pType <- readIORef $ protocolType t
+ let pId = buildVarint $ iw16 $ fromProtocolType pType
+ let headerContent = pId <> (buildTransforms t) <> (buildInfo t)
+ let len = fromIntegral $ LBS.length $ B.toLazyByteString headerContent
+ -- TODO: length limit check
+ let padding = mconcat $ replicate (mod len 4) $ B.word8 0
+ let codedLen = B.int16BE (fromIntegral $ (quot (len - 1) 4) + 1)
+ let flags = 0
+ let seqNum = 0
+ return $ B.int16BE 0x0fff <> B.int16BE flags <> B.int32BE seqNum <> codedLen <> headerContent <> padding
+
+buildTransforms :: HeaderTransport i o -> B.Builder
+-- TODO: check length limit
+buildTransforms t =
+ let trans = writeTransforms t in
+ (buildVarint $ iw16 $ fromIntegral $ length trans) <>
+ (mconcat $ map (buildVarint . iw16 . fromTransportType) trans)
+
+buildInfo :: HeaderTransport i o -> B.Builder
+buildInfo t =
+ let h = Map.assocs $ writeHeaders t in
+ -- TODO: check length limit
+ case length h of
+ 0 -> mempty
+ len -> (buildVarint $ iw16 $ fromIntegral $ len) <> (mconcat $ map buildInfoEntry h)
+ where
+ buildInfoEntry (k, v) = buildVarStr k <> buildVarStr v
+ -- TODO: check length limit
+ buildVarStr s = (buildVarint $ iw16 $ fromIntegral $ length s) <> B.string8 s
+
+tResetProtocol :: (Transport i, Transport o) => HeaderTransport i o -> IO Bool
+tResetProtocol t = do
+ rBuf <- readIORef $ readBuffer t
+ writeIORef (clientType t) HeaderClient
+ readFrame t
+
+tSetProtocol :: (Transport i, Transport o) => HeaderTransport i o -> ProtocolType -> IO ()
+tSetProtocol t = writeIORef (protocolType t)
+
+transform :: HeaderTransport i o -> LBS.ByteString -> LBS.ByteString
+transform t bs =
+ foldr applyTransform bs $ writeTransforms t
+ where
+ -- applyTransform bs ZlibTransform =
+ -- throw $ TransportExn "HeaderTransport: not implemented: ZlibTransform " TE_UNKNOWN
+ applyTransform bs _ =
+ throw $ TransportExn "HeaderTransport: Unknown transform" TE_UNKNOWN
+
+untransform :: HeaderTransport i o -> BS.ByteString -> IO BS.ByteString
+untransform t bs = do
+ trans <- readIORef $ transforms t
+ return $ foldl unapplyTransform bs trans
+ where
+ -- unapplyTransform bs ZlibTransform =
+ -- throw $ TransportExn "HeaderTransport: not implemented: ZlibTransform " TE_UNKNOWN
+ unapplyTransform bs _ =
+ throw $ TransportExn "HeaderTransport: Unknown transform" TE_UNKNOWN
+
+instance (Transport i, Transport o) => Transport (HeaderTransport i o) where
+ tIsOpen t = do
+ tIsOpen (inTrans t)
+ tIsOpen (outTrans t)
+
+ tClose t = do
+ tClose(outTrans t)
+ tClose(inTrans t)
+
+ tRead t len = do
+ rBuf <- readIORef $ readBuffer t
+ if not $ LBS.null rBuf
+ then do
+ let (consumed, remain) = LBS.splitAt (fromIntegral len) rBuf
+ writeIORef (readBuffer t) remain
+ return consumed
+ else do
+ framed <- isFramed t
+ if not framed
+ then tRead (inTrans t) len
+ else do
+ ok <- readFrame t
+ if ok
+ then tRead t len
+ else return LBS.empty
+
+ tPeek t = do
+ rBuf <- readIORef (readBuffer t)
+ if not $ LBS.null rBuf
+ then return $ Just $ LBS.head rBuf
+ else do
+ framed <- isFramed t
+ if not framed
+ then tPeek (inTrans t)
+ else do
+ ok <- readFrame t
+ if ok
+ then tPeek t
+ else return Nothing
+
+ tWrite t buf = do
+ let wBuf = writeBuffer t
+ framed <- isFramed t
+ if framed
+ then modifyIORef wBuf (<> B.lazyByteString buf)
+ else
+ -- TODO: what should we do when switched to unframed in the middle ?
+ tWrite(outTrans t) buf
+
+ tFlush t = do
+ cType <- readIORef $ clientType t
+ case cType of
+ Unframed -> tFlush $ outTrans t
+ Framed -> flushBuffer t id mempty
+ HeaderClient -> buildHeader t >>= flushBuffer t (transform t)
+ where
+ flushBuffer t f header = do
+ wBuf <- readIORef $ writeBuffer t
+ writeIORef (writeBuffer t) mempty
+ let payload = B.toLazyByteString (header <> wBuf)
+ tWrite (outTrans t) $ Binary.encode (fromIntegral $ LBS.length payload :: Int32)
+ tWrite (outTrans t) $ f payload
+ tFlush (outTrans t)
diff --git a/lib/hs/thrift.cabal b/lib/hs/thrift.cabal
index 9754ab2ee..583067953 100644
--- a/lib/hs/thrift.cabal
+++ b/lib/hs/thrift.cabal
@@ -49,6 +49,7 @@ Library
Thrift,
Thrift.Arbitraries
Thrift.Protocol,
+ Thrift.Protocol.Header,
Thrift.Protocol.Binary,
Thrift.Protocol.Compact,
Thrift.Protocol.JSON,
@@ -57,6 +58,7 @@ Library
Thrift.Transport.Empty,
Thrift.Transport.Framed,
Thrift.Transport.Handle,
+ Thrift.Transport.Header,
Thrift.Transport.HttpClient,
Thrift.Transport.IOBuffer,
Thrift.Transport.Memory,
diff --git a/test/hs/TestClient.hs b/test/hs/TestClient.hs
index d1ebb3cd0..93fb591b3 100644
--- a/test/hs/TestClient.hs
+++ b/test/hs/TestClient.hs
@@ -46,6 +46,7 @@ import Thrift.Transport.HttpClient
import Thrift.Protocol
import Thrift.Protocol.Binary
import Thrift.Protocol.Compact
+import Thrift.Protocol.Header
import Thrift.Protocol.JSON
data Options = Options
@@ -85,12 +86,14 @@ getTransport t host port = do return (NoTransport $ "Unsupported transport: " ++
data ProtocolType = Binary
| Compact
| JSON
+ | Header
deriving (Show, Eq)
getProtocol :: String -> ProtocolType
getProtocol "binary" = Binary
getProtocol "compact" = Compact
getProtocol "json" = JSON
+getProtocol "header" = Header
getProtocol p = error $ "Unsupported Protocol: " ++ p
defaultOptions :: Options
@@ -104,7 +107,7 @@ defaultOptions = Options
, testLoops = 1
}
-runClient :: (Protocol p, Transport t) => p t -> IO ()
+runClient :: Protocol p => p -> IO ()
runClient p = do
let prot = (p,p)
putStrLn "Starting Tests"
@@ -266,6 +269,7 @@ main = do
Binary -> runClient $ BinaryProtocol t
Compact -> runClient $ CompactProtocol t
JSON -> runClient $ JSONProtocol t
+ Header -> createHeaderProtocol t t >>= runClient
runTest loops p t = do
let client = makeClient p t
replicateM_ loops client
diff --git a/test/hs/TestServer.hs b/test/hs/TestServer.hs
index 4a88649b8..b7731ab1c 100644
--- a/test/hs/TestServer.hs
+++ b/test/hs/TestServer.hs
@@ -48,6 +48,7 @@ import Thrift.Transport.Framed
import Thrift.Transport.Handle
import Thrift.Protocol.Binary
import Thrift.Protocol.Compact
+import Thrift.Protocol.Header
import Thrift.Protocol.JSON
data Options = Options
@@ -90,11 +91,13 @@ getTransport t = NoTransport $ "Unsupported transport: " ++ t
data ProtocolType = Binary
| Compact
| JSON
+ | Header
getProtocol :: String -> ProtocolType
getProtocol "binary" = Binary
getProtocol "compact" = Compact
getProtocol "json" = JSON
+getProtocol "header" = Header
getProtocol p = error $"Unsupported Protocol: " ++ p
defaultOptions :: Options
@@ -261,13 +264,19 @@ main = do
t <- f socket
return (p t, p t)
+ headerAcceptor f socket = do
+ t <- f socket
+ p <- createHeaderProtocol1 t
+ return (p, p)
+
doRunServer p f = do
runThreadedServer (acceptor p f) TestHandler ThriftTest.process . PortNumber . fromIntegral
runServer p f port = case p of
- Binary -> do doRunServer BinaryProtocol f port
- Compact -> do doRunServer CompactProtocol f port
- JSON -> do doRunServer JSONProtocol f port
+ Binary -> doRunServer BinaryProtocol f port
+ Compact -> doRunServer CompactProtocol f port
+ JSON -> doRunServer JSONProtocol f port
+ Header -> runThreadedServer (headerAcceptor f) TestHandler ThriftTest.process (PortNumber $ fromIntegral port)
parseFlags :: [String] -> Options -> Maybe Options
parseFlags (flag : flags) opts = do
diff --git a/test/known_failures_Linux.json b/test/known_failures_Linux.json
index c96198808..754535f12 100644
--- a/test/known_failures_Linux.json
+++ b/test/known_failures_Linux.json
@@ -229,6 +229,8 @@
"go-java_json_http-ip",
"go-java_json_http-ip-ssl",
"go-nodejs_json_framed-ip",
+ "hs-csharp_binary_framed-ip",
+ "hs-csharp_compact_framed-ip",
"hs-dart_binary_framed-ip",
"hs-dart_compact_framed-ip",
"hs-dart_json_framed-ip",
@@ -331,4 +333,4 @@
"rs-dart_compact_framed-ip",
"rs-dart_multi-binary_framed-ip",
"rs-dart_multic-compact_framed-ip"
-] \ No newline at end of file
+]
diff --git a/test/tests.json b/test/tests.json
index 35d0a6cc1..c4e07eefb 100644
--- a/test/tests.json
+++ b/test/tests.json
@@ -216,6 +216,7 @@
"ip"
],
"protocols": [
+ "header",
"compact",
"binary",
"json"