diff options
author | Roger Meier <roger@apache.org> | 2014-04-05 00:45:42 +0200 |
---|---|---|
committer | Roger Meier <roger@apache.org> | 2014-04-05 00:50:35 +0200 |
commit | 6cf0ffcec969e4a983171a5f411506b2ed0fd2c1 (patch) | |
tree | f618a140d60a6d99af32225e260b7b5cb28b6cd1 /lib/lua | |
parent | bdbf428365144dc8586276d42c071b44c389e4ff (diff) | |
download | thrift-6cf0ffcec969e4a983171a5f411506b2ed0fd2c1.tar.gz |
THRIFT-1681: Add Lua Support Patch: Dave Watson
Github Pull Request: This closes #92
Diffstat (limited to 'lib/lua')
-rw-r--r-- | lib/lua/Makefile.am | 58 | ||||
-rw-r--r-- | lib/lua/TBinaryProtocol.lua | 264 | ||||
-rw-r--r-- | lib/lua/TBufferedTransport.lua | 91 | ||||
-rw-r--r-- | lib/lua/TFramedTransport.lua | 119 | ||||
-rw-r--r-- | lib/lua/TMemoryBuffer.lua | 91 | ||||
-rw-r--r-- | lib/lua/TProtocol.lua | 162 | ||||
-rw-r--r-- | lib/lua/TServer.lua | 139 | ||||
-rw-r--r-- | lib/lua/TSocket.lua | 132 | ||||
-rw-r--r-- | lib/lua/TTransport.lua | 93 | ||||
-rw-r--r-- | lib/lua/Thrift.lua | 273 | ||||
-rw-r--r-- | lib/lua/src/longnumberutils.c | 47 | ||||
-rw-r--r-- | lib/lua/src/luabitwise.c | 83 | ||||
-rw-r--r-- | lib/lua/src/luabpack.c | 162 | ||||
-rw-r--r-- | lib/lua/src/lualongnumber.c | 228 | ||||
-rw-r--r-- | lib/lua/src/luasocket.c | 386 | ||||
-rw-r--r-- | lib/lua/src/socket.h | 78 | ||||
-rw-r--r-- | lib/lua/src/usocket.c | 362 |
17 files changed, 2768 insertions, 0 deletions
diff --git a/lib/lua/Makefile.am b/lib/lua/Makefile.am new file mode 100644 index 000000000..1c4296795 --- /dev/null +++ b/lib/lua/Makefile.am @@ -0,0 +1,58 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# + +SUBDIRS = . + +lib_LTLIBRARIES = libluasocket.la \ + libluabpack.la \ + libluabitwise.la \ + liblualongnumber.la + +libluasocket_la_SOURCES = src/luasocket.c \ + src/usocket.c + +libluasocket_la_CPPFLAGS = $(AM_CPPFLAGS) -I/usr/include/lua5.2 -DLUA_COMPAT_MODULE +libluasocket_la_LDFLAGS = $(AM_LDFLAGS) -llua5.2 -lm + +libluabpack_la_SOURCES = src/luabpack.c + +libluabpack_la_CPPFLAGS = $(AM_CPPFLAGS) -I/usr/include/lua5.2 -DLUA_COMPAT_MODULE +libluabpack_la_LDFLAGS = $(AM_LDFLAGS) -llua5.2 -lm +libluabpack_la_LIBADD = liblualongnumber.la + +libluabitwise_la_SOURCES = src/luabitwise.c + +libluabitwise_la_CPPFLAGS = $(AM_CPPFLAGS) -I/usr/include/lua5.2 -DLUA_COMPAT_MODULE +libluabitwise_la_LDFLAGS = $(AM_LDFLAGS) -llua5.2 -lm + +liblualongnumber_la_SOURCES = src/lualongnumber.c \ + src/longnumberutils.c + +liblualongnumber_la_CPPFLAGS = $(AM_CPPFLAGS) -I/usr/include/lua5.2 -DLUA_COMPAT_MODULE +liblualongnumber_la_LDFLAGS = $(AM_LDFLAGS) -llua5.2 -lm + +EXTRA_DIST = TBinaryProtocol.lua \ + TBufferedTransport.lua \ + TFramedTransport.lua \ + Thrift.lua \ + TMemoryBuffer.lua \ + TProtocol.lua \ + TServer.lua \ + TSocket.lua \ + TTransport.lua
\ No newline at end of file diff --git a/lib/lua/TBinaryProtocol.lua b/lib/lua/TBinaryProtocol.lua new file mode 100644 index 000000000..df13d61e0 --- /dev/null +++ b/lib/lua/TBinaryProtocol.lua @@ -0,0 +1,264 @@ +-- +-- Licensed to the Apache Software Foundation (ASF) under one +-- or more contributor license agreements. See the NOTICE file +-- distributed with this work for additional information +-- regarding copyright ownership. The ASF licenses this file +-- to you under the Apache License, Version 2.0 (the +-- "License"); you may not use this file except in compliance +-- with the License. You may obtain a copy of the License at +-- +-- http://www.apache.org/licenses/LICENSE-2.0 +-- +-- Unless required by applicable law or agreed to in writing, +-- software distributed under the License is distributed on an +-- "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +-- KIND, either express or implied. See the License for the +-- specific language governing permissions and limitations +-- under the License. +-- + +require 'TProtocol' +require 'libluabpack' +require 'libluabitwise' + +TBinaryProtocol = __TObject.new(TProtocolBase, { + __type = 'TBinaryProtocol', + VERSION_MASK = -65536, -- 0xffff0000 + VERSION_1 = -2147418112, -- 0x80010000 + TYPE_MASK = 0x000000ff, + strictRead = false, + strictWrite = true +}) + +function TBinaryProtocol:writeMessageBegin(name, ttype, seqid) + if self.stirctWrite then + self:writeI32(libluabitwise.bor(TBinaryProtocol.VERSION_1, ttype)) + self:writeString(name) + self:writeI32(seqid) + else + self:writeString(name) + self:writeByte(ttype) + self:writeI32(seqid) + end +end + +function TBinaryProtocol:writeMessageEnd() +end + +function TBinaryProtocol:writeStructBegin(name) +end + +function TBinaryProtocol:writeStructEnd() +end + +function TBinaryProtocol:writeFieldBegin(name, ttype, id) + self:writeByte(ttype) + self:writeI16(id) +end + +function TBinaryProtocol:writeFieldEnd() +end + +function TBinaryProtocol:writeFieldStop() + self:writeByte(TType.STOP); +end + +function TBinaryProtocol:writeMapBegin(ktype, vtype, size) + self:writeByte(ktype) + self:writeByte(vtype) + self:writeI32(size) +end + +function TBinaryProtocol:writeMapEnd() +end + +function TBinaryProtocol:writeListBegin(etype, size) + self:writeByte(etype) + self:writeI32(size) +end + +function TBinaryProtocol:writeListEnd() +end + +function TBinaryProtocol:writeSetBegin(etype, size) + self:writeByte(etype) + self:writeI32(size) +end + +function TBinaryProtocol:writeSetEnd() +end + +function TBinaryProtocol:writeBool(bool) + if bool then + self:writeByte(1) + else + self:writeByte(0) + end +end + +function TBinaryProtocol:writeByte(byte) + local buff = libluabpack.bpack('c', byte) + self.trans:write(buff) +end + +function TBinaryProtocol:writeI16(i16) + local buff = libluabpack.bpack('s', i16) + self.trans:write(buff) +end + +function TBinaryProtocol:writeI32(i32) + local buff = libluabpack.bpack('i', i32) + self.trans:write(buff) +end + +function TBinaryProtocol:writeI64(i64) + local buff = libluabpack.bpack('l', i64) + self.trans:write(buff) +end + +function TBinaryProtocol:writeDouble(dub) + local buff = libluabpack.bpack('d', dub) + self.trans:write(buff) +end + +function TBinaryProtocol:writeString(str) + -- Should be utf-8 + self:writeI32(string.len(str)) + self.trans:write(str) +end + +function TBinaryProtocol:readMessageBegin() + local sz, ttype, name, seqid = self:readI32() + if sz < 0 then + local version = libluabitwise.band(sz, TBinaryProtocol.VERSION_MASK) + if version ~= TBinaryProtocol.VERSION_1 then + terror(TProtocolException:new{ + message = 'Bad version in readMessageBegin: ' .. sz + }) + end + ttype = libluabitwise.band(sz, TBinaryProtocol.TYPE_MASK) + name = self:readString() + seqid = self:readI32() + else + if self.strictRead then + terror(TProtocolException:new{message = 'No protocol version header'}) + end + name = self.trans:readAll(sz) + ttype = self:readByte() + seqid = self:readI32() + end + return name, ttype, seqid +end + +function TBinaryProtocol:readMessageEnd() +end + +function TBinaryProtocol:readStructBegin() + return nil +end + +function TBinaryProtocol:readStructEnd() +end + +function TBinaryProtocol:readFieldBegin() + local ttype = self:readByte() + if ttype == TType.STOP then + return nil, ttype, 0 + end + local id = self:readI16() + return nil, ttype, id +end + +function TBinaryProtocol:readFieldEnd() +end + +function TBinaryProtocol:readMapBegin() + local ktype = self:readByte() + local vtype = self:readByte() + local size = self:readI32() + return ktype, vtype, size +end + +function TBinaryProtocol:readMapEnd() +end + +function TBinaryProtocol:readListBegin() + local etype = self:readByte() + local size = self:readI32() + return etype, size +end + +function TBinaryProtocol:readListEnd() +end + +function TBinaryProtocol:readSetBegin() + local etype = self:readByte() + local size = self:readI32() + return etype, size +end + +function TBinaryProtocol:readSetEnd() +end + +function TBinaryProtocol:readBool() + local byte = self:readByte() + if byte == 0 then + return false + end + return true +end + +function TBinaryProtocol:readByte() + local buff = self.trans:readAll(1) + local val = libluabpack.bunpack('c', buff) + return val +end + +function TBinaryProtocol:readI16() + local buff = self.trans:readAll(2) + local val = libluabpack.bunpack('s', buff) + return val +end + +function TBinaryProtocol:readI32() + local buff = self.trans:readAll(4) + local val = libluabpack.bunpack('i', buff) + return val +end + +function TBinaryProtocol:readI64() + local buff = self.trans:readAll(8) + local val = libluabpack.bunpack('l', buff) + return val +end + +function TBinaryProtocol:readDouble() + local buff = self.trans:readAll(8) + local val = libluabpack.bunpack('d', buff) + return val +end + +function TBinaryProtocol:readString() + local len = self:readI32() + local str = self.trans:readAll(len) + return str +end + +TBinaryProtocolFactory = TProtocolFactory:new{ + __type = 'TBinaryProtocolFactory', + strictRead = false +} + +function TBinaryProtocolFactory:getProtocol(trans) + -- TODO Enforce that this must be a transport class (ie not a bool) + if not trans then + terror(TProtocolException:new{ + message = 'Must supply a transport to ' .. ttype(self) + }) + end + return TBinaryProtocol:new{ + trans = trans, + strictRead = self.strictRead, + strictWrite = true + } +end diff --git a/lib/lua/TBufferedTransport.lua b/lib/lua/TBufferedTransport.lua new file mode 100644 index 000000000..2b0b94647 --- /dev/null +++ b/lib/lua/TBufferedTransport.lua @@ -0,0 +1,91 @@ +-- +-- Licensed to the Apache Software Foundation (ASF) under one +-- or more contributor license agreements. See the NOTICE file +-- distributed with this work for additional information +-- regarding copyright ownership. The ASF licenses this file +-- to you under the Apache License, Version 2.0 (the +-- "License"); you may not use this file except in compliance +-- with the License. You may obtain a copy of the License at +-- +-- http://www.apache.org/licenses/LICENSE-2.0 +-- +-- Unless required by applicable law or agreed to in writing, +-- software distributed under the License is distributed on an +-- "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +-- KIND, either express or implied. See the License for the +-- specific language governing permissions and limitations +-- under the License. +-- + +require 'TTransport' + +TBufferedTransport = TTransportBase:new{ + __type = 'TBufferedTransport', + rBufSize = 2048, + wBufSize = 2048, + wBuf = '', + rBuf = '' +} + +function TBufferedTransport:new(obj) + if ttype(obj) ~= 'table' then + error(ttype(self) .. 'must be initialized with a table') + end + + -- Ensure a transport is provided + if not obj.trans then + error('You must provide ' .. ttype(self) .. ' with a trans') + end + + return TTransportBase:new(obj) +end + +function TBufferedTransport:isOpen() + return self.trans:isOpen() +end + +function TBufferedTransport:open() + return self.trans:open() +end + +function TBufferedTransport:close() + return self.trans:close() +end + +function TBufferedTransport:read(len) + return self.trans:read(len) +end + +function TBufferedTransport:readAll(len) + return self.trans:readAll(len) +end + +function TBufferedTransport:write(buf) + self.wBuf = self.wBuf .. buf + if string.len(self.wBuf) >= self.wBufSize then + self.trans:write(self.wBuf) + self.wBuf = '' + end +end + +function TBufferedTransport:flush() + if string.len(self.wBuf) > 0 then + self.trans:write(self.wBuf) + self.wBuf = '' + end +end + +TBufferedTransportFactory = TTransportFactoryBase:new{ + __type = 'TBufferedTransportFactory' +} + +function TBufferedTransportFactory:getTransport(trans) + if not trans then + terror(TTransportException:new{ + message = 'Must supply a transport to ' .. ttype(self) + }) + end + return TBufferedTransport:new{ + trans = trans + } +end diff --git a/lib/lua/TFramedTransport.lua b/lib/lua/TFramedTransport.lua new file mode 100644 index 000000000..84ae3ecf2 --- /dev/null +++ b/lib/lua/TFramedTransport.lua @@ -0,0 +1,119 @@ +-- +-- Licensed to the Apache Software Foundation (ASF) under one +-- or more contributor license agreements. See the NOTICE file +-- distributed with this work for additional information +-- regarding copyright ownership. The ASF licenses this file +-- to you under the Apache License, Version 2.0 (the +-- "License"); you may not use this file except in compliance +-- with the License. You may obtain a copy of the License at +-- +-- http://www.apache.org/licenses/LICENSE-2.0 +-- +-- Unless required by applicable law or agreed to in writing, +-- software distributed under the License is distributed on an +-- "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +-- KIND, either express or implied. See the License for the +-- specific language governing permissions and limitations +-- under the License. +-- + +require 'TTransport' +require 'libluabpack' + +TFramedTransport = TTransportBase:new{ + __type = 'TFramedTransport', + doRead = true, + doWrite = true, + wBuf = '', + rBuf = '' +} + +function TFramedTransport:new(obj) + if ttype(obj) ~= 'table' then + error(ttype(self) .. 'must be initialized with a table') + end + + -- Ensure a transport is provided + if not obj.trans then + error('You must provide ' .. ttype(self) .. ' with a trans') + end + + return TTransportBase:new(obj) +end + +function TFramedTransport:isOpen() + return self.trans:isOpen() +end + +function TFramedTransport:open() + return self.trans:open() +end + +function TFramedTransport:close() + return self.trans:close() +end + +function TFramedTransport:read(len) + if string.len(self.rBuf) == 0 then + self:__readFrame() + end + + if self.doRead == false then + return self.trans:read(len) + end + + if len > string.len(self.rBuf) then + local val = self.rBuf + self.rBuf = '' + return val + end + + local val = string.sub(self.rBuf, 0, len) + self.rBuf = string.sub(self.rBuf, len) + return val +end + +function TFramedTransport:__readFrame() + local buf = self.trans:readAll(4) + local frame_len = libluabpack.bunpack('i', buf) + self.rBuf = self.trans:readAll(frame_len) +end + +function TFramedTransport:readAll(len) + return self.trans:readAll(len) +end + +function TFramedTransport:write(buf, len) + if self.doWrite == false then + return self.trans:write(buf, len) + end + + if len and len < string.len(buf) then + buf = string.sub(buf, 0, len) + end + self.wBuf = self.wBuf + buf +end + +function TFramedTransport:flush() + if self.doWrite == false then + return self.trans:flush() + end + + -- If the write fails we still want wBuf to be clear + local tmp = self.wBuf + self.wBuf = '' + self.trans:write(tmp) + self.trans:flush() +end + +TFramedTransportFactory = TTransportFactoryBase:new{ + __type = 'TFramedTransportFactory' +} +function TFramedTransportFactory:getTransport(trans) + if not trans then + terror(TProtocolException:new{ + message = 'Must supply a transport to ' .. ttype(self) + }) + end + return TFramedTransport:new{trans = trans} +end diff --git a/lib/lua/TMemoryBuffer.lua b/lib/lua/TMemoryBuffer.lua new file mode 100644 index 000000000..3d4368674 --- /dev/null +++ b/lib/lua/TMemoryBuffer.lua @@ -0,0 +1,91 @@ +-- +-- Licensed to the Apache Software Foundation (ASF) under one +-- or more contributor license agreements. See the NOTICE file +-- distributed with this work for additional information +-- regarding copyright ownership. The ASF licenses this file +-- to you under the Apache License, Version 2.0 (the +-- "License"); you may not use this file except in compliance +-- with the License. You may obtain a copy of the License at +-- +-- http://www.apache.org/licenses/LICENSE-2.0 +-- +-- Unless required by applicable law or agreed to in writing, +-- software distributed under the License is distributed on an +-- "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +-- KIND, either express or implied. See the License for the +-- specific language governing permissions and limitations +-- under the License. +-- + +require 'TTransport' + +TMemoryBuffer = TTransportBase:new{ + __type = 'TMemoryBuffer', + buffer = '', + bufferSize = 1024, + wPos = 0, + rPos = 0 +} +function TMemoryBuffer:isOpen() + return 1 +end +function TMemoryBuffer:open() end +function TMemoryBuffer:close() end + +function TMemoryBuffer:peak() + return self.rPos < self.wPos +end + +function TMemoryBuffer:getBuffer() + return self.buffer +end + +function TMemoryBuffer:resetBuffer(buf) + if buf then + self.buffer = buf + self.bufferSize = string.len(buf) + else + self.buffer = '' + self.bufferSize = 1024 + end + self.wPos = string.len(buf) + self.rPos = 0 +end + +function TMemoryBuffer:available() + return self.wPos - self.rPos +end + +function TMemoryBuffer:read(len) + local avail = self:available() + if avail == 0 then + return '' + end + + if avail < len then + len = avail + end + + local val = string.sub(self.buffer, self.rPos, len) + self.rPos = self.rPos + len + return val +end + +function TMemoryBuffer:readAll(len) + local avail = self:available() + + if avail < len then + local msg = string.format('Attempt to readAll(%d) found only %d available', + len, avail) + terror(TTransportException:new{message = msg}) + end + -- read should block so we don't need a loop here + return self:read(len) +end + +function TMemoryBuffer:write(buf) + self.buffer = self.buffer + buf + self.wPos = self.wPos + buf +end + +function TMemoryBuffer:flush() end diff --git a/lib/lua/TProtocol.lua b/lib/lua/TProtocol.lua new file mode 100644 index 000000000..9eb94f595 --- /dev/null +++ b/lib/lua/TProtocol.lua @@ -0,0 +1,162 @@ +-- +-- Licensed to the Apache Software Foundation (ASF) under one +-- or more contributor license agreements. See the NOTICE file +-- distributed with this work for additional information +-- regarding copyright ownership. The ASF licenses this file +-- to you under the Apache License, Version 2.0 (the +-- "License"); you may not use this file except in compliance +-- with the License. You may obtain a copy of the License at +-- +-- http://www.apache.org/licenses/LICENSE-2.0 +-- +-- Unless required by applicable law or agreed to in writing, +-- software distributed under the License is distributed on an +-- "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +-- KIND, either express or implied. See the License for the +-- specific language governing permissions and limitations +-- under the License. +-- + +require 'Thrift' + +TProtocolException = TException:new { + UNKNOWN = 0, + INVALID_DATA = 1, + NEGATIVE_SIZE = 2, + SIZE_LIMIT = 3, + BAD_VERSION = 4, + INVALID_PROTOCOL = 5, + MISSING_REQUIRED_FIELD = 6, + errorCode = 0, + __type = 'TProtocolException' +} +function TProtocolException:__errorCodeToString() + if self.errorCode == self.INVALID_DATA then + return 'Invalid data' + elseif self.errorCode == self.NEGATIVE_SIZE then + return 'Negative size' + elseif self.errorCode == self.SIZE_LIMIT then + return 'Size limit' + elseif self.errorCode == self.BAD_VERSION then + return 'Bad version' + elseif self.errorCode == self.INVALID_PROTOCOL then + return 'Invalid protocol' + elseif self.errorCode == self.MISSING_REQUIRED_FIELD then + return 'Missing required field' + else + return 'Default (unknown)' + end +end + +TProtocolBase = __TObject:new{ + __type = 'TProtocolBase', + trans +} + +function TProtocolBase:new(obj) + if ttype(obj) ~= 'table' then + error(ttype(self) .. 'must be initialized with a table') + end + + -- Ensure a transport is provided + if not obj.trans then + error('You must provide ' .. ttype(self) .. ' with a trans') + end + + return __TObject.new(self, obj) +end + +function TProtocolBase:writeMessageBegin(name, ttype, seqid) end +function TProtocolBase:writeMessageEnd() end +function TProtocolBase:writeStructBegin(name) end +function TProtocolBase:writeStructEnd() end +function TProtocolBase:writeFieldBegin(name, ttype, id) end +function TProtocolBase:writeFieldEnd() end +function TProtocolBase:writeFieldStop() end +function TProtocolBase:writeMapBegin(ktype, vtype, size) end +function TProtocolBase:writeMapEnd() end +function TProtocolBase:writeListBegin(ttype, size) end +function TProtocolBase:writeListEnd() end +function TProtocolBase:writeSetBegin(ttype, size) end +function TProtocolBase:writeSetEnd() end +function TProtocolBase:writeBool(bool) end +function TProtocolBase:writeByte(byte) end +function TProtocolBase:writeI16(i16) end +function TProtocolBase:writeI32(i32) end +function TProtocolBase:writeI64(i64) end +function TProtocolBase:writeDouble(dub) end +function TProtocolBase:writeString(str) end +function TProtocolBase:readMessageBegin() end +function TProtocolBase:readMessageEnd() end +function TProtocolBase:readStructBegin() end +function TProtocolBase:readStructEnd() end +function TProtocolBase:readFieldBegin() end +function TProtocolBase:readFieldEnd() end +function TProtocolBase:readMapBegin() end +function TProtocolBase:readMapEnd() end +function TProtocolBase:readListBegin() end +function TProtocolBase:readListEnd() end +function TProtocolBase:readSetBegin() end +function TProtocolBase:readSetEnd() end +function TProtocolBase:readBool() end +function TProtocolBase:readByte() end +function TProtocolBase:readI16() end +function TProtocolBase:readI32() end +function TProtocolBase:readI64() end +function TProtocolBase:readDouble() end +function TProtocolBase:readString() end + +function TProtocolBase:skip(ttype) + if type == TType.STOP then + return + elseif ttype == TType.BOOL then + self:readBool() + elseif ttype == TType.BYTE then + self:readByte() + elseif ttype == TType.I16 then + self:readI16() + elseif ttype == TType.I32 then + self:readI32() + elseif ttype == TType.I64 then + self:readI64() + elseif ttype == TType.DOUBLE then + self:readDouble() + elseif ttype == TType.STRING then + self:readString() + elseif ttype == TType.STRUCT then + local name = self:readStructBegin() + while true do + local name, ttype, id = self:readFieldBegin() + if ttype == TType.STOP then + break + end + self:skip(ttype) + self:readFieldEnd() + end + self:readStructEnd() + elseif ttype == TType.MAP then + local kttype, vttype, size = self:readMapBegin() + for i = 1, size, 1 do + self:skip(kttype) + self:skip(vttype) + end + self:readMapEnd() + elseif ttype == TType.SET then + local ettype, size = self:readSetBegin() + for i = 1, size, 1 do + self:skip(ettype) + end + self:readSetEnd() + elseif ttype == TType.LIST then + local ettype, size = self:readListBegin() + for i = 1, size, 1 do + self:skip(ettype) + end + self:readListEnd() + end +end + +TProtocolFactory = __TObject:new{ + __type = 'TProtocolFactory', +} +function TProtocolFactory:getProtocol(trans) end diff --git a/lib/lua/TServer.lua b/lib/lua/TServer.lua new file mode 100644 index 000000000..d6b9cd076 --- /dev/null +++ b/lib/lua/TServer.lua @@ -0,0 +1,139 @@ +-- +-- Licensed to the Apache Software Foundation (ASF) under one +-- or more contributor license agreements. See the NOTICE file +-- distributed with this work for additional information +-- regarding copyright ownership. The ASF licenses this file +-- to you under the Apache License, Version 2.0 (the +-- "License"); you may not use this file except in compliance +-- with the License. You may obtain a copy of the License at +-- +-- http://www.apache.org/licenses/LICENSE-2.0 +-- +-- Unless required by applicable law or agreed to in writing, +-- software distributed under the License is distributed on an +-- "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +-- KIND, either express or implied. See the License for the +-- specific language governing permissions and limitations +-- under the License. +-- + +require 'Thrift' +require 'TFramedTransport' +require 'TBinaryProtocol' + +-- TServer +TServer = __TObject:new{ + __type = 'TServer' +} + +-- 2 possible constructors +-- 1. {processor, serverTransport} +-- 2. {processor, serverTransport, transportFactory, protocolFactory} +function TServer:new(args) + if ttype(args) ~= 'table' then + error('TServer must be initialized with a table') + end + if args.processor == nil then + terror('You must provide ' .. ttype(self) .. ' with a processor') + end + if args.serverTransport == nil then + terror('You must provide ' .. ttype(self) .. ' with a serverTransport') + end + + -- Create the object + local obj = __TObject.new(self, args) + + if obj.transportFactory then + obj.inputTransportFactory = obj.transportFactory + obj.outputTransportFactory = obj.transportFactory + obj.transportFactory = nil + else + obj.inputTransportFactory = TFramedTransportFactory:new{} + obj.outputTransportFactory = obj.inputTransportFactory + end + + if obj.protocolFactory then + obj.inputProtocolFactory = obj.protocolFactory + obj.outputProtocolFactory = obj.protocolFactory + obj.protocolFactory = nil + else + obj.inputProtocolFactory = TBinaryProtocolFactory:new{} + obj.outputProtocolFactory = obj.inputProtocolFactory + end + + -- Set the __server variable in the handler so we can stop the server + obj.processor.handler.__server = self + + return obj +end + +function TServer:setServerEventHandler(handler) + self.serverEventHandler = handler +end + +function TServer:_clientBegin(content, iprot, oprot) + if self.serverEventHandler and + type(self.serverEventHandler.clientBegin) == 'function' then + self.serverEventHandler:clientBegin(iprot, oprot) + end +end + +function TServer:_preServe() + if self.serverEventHandler and + type(self.serverEventHandler.preServe) == 'function' then + self.serverEventHandler:preServe(self.serverTransport:getSocketInfo()) + end +end + +function TServer:_handleException(err) + if string.find(err, 'TTransportException') == nil then + print(err) + end +end + +function TServer:serve() end +function TServer:handle(client) + local itrans, otrans, iprot, oprot, ret, err = + self.inputTransportFactory:getTransport(client), + self.outputTransportFactory:getTransport(client), + self.inputProtocolFactory:getProtocol(client), + self.outputProtocolFactory:getProtocol(client) + + self:_clientBegin(iprot, oprot) + while true do + ret, err = pcall(self.processor.process, self.processor, iprot, oprot) + if ret == false and err then + if not string.find(err, "TTransportException") then + self:_handleException(err) + end + break + end + end + itrans:close() + otrans:close() +end + +function TServer:close() + self.serverTransport:close() +end + +-- TSimpleServer +-- Single threaded server that handles one transport (connection) +TSimpleServer = __TObject:new(TServer, { + __type = 'TSimpleServer', + __stop = false +}) + +function TSimpleServer:serve() + self.serverTransport:listen() + self:_preServe() + while not self.__stop do + client = self.serverTransport:accept() + self:handle(client) + end + self:close() +end + +function TSimpleServer:stop() + self.__stop = true +end diff --git a/lib/lua/TSocket.lua b/lib/lua/TSocket.lua new file mode 100644 index 000000000..d71fc1f98 --- /dev/null +++ b/lib/lua/TSocket.lua @@ -0,0 +1,132 @@ +---- Licensed to the Apache Software Foundation (ASF) under one +-- or more contributor license agreements. See the NOTICE file +-- distributed with this work for additional information +-- regarding copyright ownership. The ASF licenses this file +-- to you under the Apache License, Version 2.0 (the +-- "License"); you may not use this file except in compliance +-- with the License. You may obtain a copy of the License at +-- +-- http://www.apache.org/licenses/LICENSE-2.0 +-- +-- Unless required by applicable law or agreed to in writing, +-- software distributed under the License is distributed on an +-- "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +-- KIND, either express or implied. See the License for the +-- specific language governing permissions and limitations +-- under the License. +-- + +require 'TTransport' +require 'libluasocket' + +-- TSocketBase +TSocketBase = TTransportBase:new{ + __type = 'TSocketBase', + timeout = 1000, + host = 'localhost', + port = 9090, + handle +} + +function TSocketBase:close() + if self.handle then + self.handle:destroy() + self.handle = nil + end +end + +-- Returns a table with the fields host and port +function TSocketBase:getSocketInfo() + if self.handle then + return self.handle:getsockinfo() + end + terror(TTransportException:new{errorCode = TTransportException.NOT_OPEN}) +end + +function TSocketBase:setTimeout(timeout) + if timeout and ttype(timeout) == 'number' then + if self.handle then + self.handle:settimeout(timeout) + end + self.timeout = timeout + end +end + +-- TSocket +TSocket = TSocketBase:new{ + __type = 'TSocket', + host = 'localhost', + port = 9090 +} + +function TSocket:isOpen() + if self.handle then + return true + end + return false +end + +function TSocket:open() + if self.handle then + self:close() + end + + -- Create local handle + local sock, err = luasocket.create_and_connect( + self.host, self.port, self.timeout) + if err == nil then + self.handle = sock + end + + if err then + terror(TTransportException:new{ + message = 'Could not connect to ' .. self.host .. ':' .. self.port + .. ' (' .. err .. ')' + }) + end +end + +function TSocket:read(len) + local buf = self.handle:receive(self.handle, len) + if not buf or string.len(buf) ~= len then + terror(TTransportException:new{errorCode = TTransportException.UNKNOWN}) + end + return buf +end + +function TSocket:write(buf) + self.handle:send(self.handle, buf) +end + +function TSocket:flush() +end + +-- TServerSocket +TServerSocket = TSocketBase:new{ + __type = 'TServerSocket', + host = 'localhost', + port = 9090 +} + +function TServerSocket:listen() + if self.handle then + self:close() + end + + local sock, err = luasocket.create(self.host, self.port) + if not err then + self.handle = sock + else + terror(err) + end + self.handle:settimeout(self.timeout) + self.handle:listen() +end + +function TServerSocket:accept() + local client, err = self.handle:accept() + if err then + terror(err) + end + return TSocket:new({handle = client}) +end diff --git a/lib/lua/TTransport.lua b/lib/lua/TTransport.lua new file mode 100644 index 000000000..01c7e5979 --- /dev/null +++ b/lib/lua/TTransport.lua @@ -0,0 +1,93 @@ +-- +-- Licensed to the Apache Software Foundation (ASF) under one +-- or more contributor license agreements. See the NOTICE file +-- distributed with this work for additional information +-- regarding copyright ownership. The ASF licenses this file +-- to you under the Apache License, Version 2.0 (the +-- "License"); you may not use this file except in compliance +-- with the License. You may obtain a copy of the License at +-- +-- http://www.apache.org/licenses/LICENSE-2.0 +-- +-- Unless required by applicable law or agreed to in writing, +-- software distributed under the License is distributed on an +-- "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +-- KIND, either express or implied. See the License for the +-- specific language governing permissions and limitations +-- under the License. +-- + +require 'Thrift' + +TTransportException = TException:new { + UNKNOWN = 0, + NOT_OPEN = 1, + ALREADY_OPEN = 2, + TIMED_OUT = 3, + END_OF_FILE = 4, + INVALID_FRAME_SIZE = 5, + INVALID_TRANSFORM = 6, + INVALID_CLIENT_TYPE = 7, + errorCode = 0, + __type = 'TTransportException' +} + +function TTransportException:__errorCodeToString() + if self.errorCode == self.NOT_OPEN then + return 'Transport not open' + elseif self.errorCode == self.ALREADY_OPEN then + return 'Transport already open' + elseif self.errorCode == self.TIMED_OUT then + return 'Transport timed out' + elseif self.errorCode == self.END_OF_FILE then + return 'End of file' + elseif self.errorCode == self.INVALID_FRAME_SIZE then + return 'Invalid frame size' + elseif self.errorCode == self.INVALID_TRANSFORM then + return 'Invalid transform' + elseif self.errorCode == self.INVALID_CLIENT_TYPE then + return 'Invalid client type' + else + return 'Default (unknown)' + end +end + +TTransportBase = __TObject:new{ + __type = 'TTransportBase' +} + +function TTransportBase:isOpen() end +function TTransportBase:open() end +function TTransportBase:close() end +function TTransportBase:read(len) end +function TTransportBase:readAll(len) + local buf, have, chunk = '', 0 + while have < len do + chunk = self:read(len - have) + have = have + string.len(chunk) + buf = buf .. chunk + + if string.len(chunk) == 0 then + terror(TTransportException:new{ + errorCode = TTransportException.END_OF_FILE + }) + end + end + return buf +end +function TTransportBase:write(buf) end +function TTransportBase:flush() end + +TServerTransportBase = __TObject:new{ + __type = 'TServerTransportBase' +} +function TServerTransportBase:listen() end +function TServerTransportBase:accept() end +function TServerTransportBase:close() end + +TTransportFactoryBase = __TObject:new{ + __type = 'TTransportFactoryBase' +} +function TTransportFactoryBase:getTransport(trans) + return trans +end diff --git a/lib/lua/Thrift.lua b/lib/lua/Thrift.lua new file mode 100644 index 000000000..6ff8ecbc1 --- /dev/null +++ b/lib/lua/Thrift.lua @@ -0,0 +1,273 @@ +-- +-- Licensed to the Apache Software Foundation (ASF) under one +-- or more contributor license agreements. See the NOTICE file +-- distributed with this work for additional information +-- regarding copyright ownership. The ASF licenses this file +-- to you under the Apache License, Version 2.0 (the +-- "License"); you may not use this file except in compliance +-- with the License. You may obtain a copy of the License at +-- +-- http://www.apache.org/licenses/LICENSE-2.0 +-- +-- Unless required by applicable law or agreed to in writing, +-- software distributed under the License is distributed on an +-- "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +-- KIND, either express or implied. See the License for the +-- specific language governing permissions and limitations +-- under the License. +-- + +---- namespace thrift +--thrift = {} +--setmetatable(thrift, {__index = _G}) --> perf hit for accessing global methods +--setfenv(1, thrift) + +package.cpath = package.cpath .. ';bin/?.so' -- TODO FIX +function ttype(obj) + if type(obj) == 'table' and + obj.__type and + type(obj.__type) == 'string' then + return obj.__type + end + return type(obj) +end + +function terror(e) + if e and e.__tostring then + error(e:__tostring()) + return + end + error(e) +end + +version = 1.0 + +TType = { + STOP = 0, + VOID = 1, + BOOL = 2, + BYTE = 3, + I08 = 3, + DOUBLE = 4, + I16 = 6, + I32 = 8, + I64 = 10, + STRING = 11, + UTF7 = 11, + STRUCT = 12, + MAP = 13, + SET = 14, + LIST = 15, + UTF8 = 16, + UTF16 = 17 +} + +TMessageType = { + CALL = 1, + REPLY = 2, + EXCEPTION = 3, + ONEWAY = 4 +} + +-- Recursive __index function to achive inheritance +function __tobj_index(self, key) + local v = rawget(self, key) + if v ~= nil then + return v + end + + local p = rawget(self, '__parent') + if p then + return __tobj_index(p, key) + end + + return nil +end + +-- Basic Thrift-Lua Object +__TObject = { + __type = '__TObject', + __mt = { + __index = __tobj_index + } +} +function __TObject:new(init_obj) + local obj = {} + if ttype(obj) == 'table' then + obj = init_obj + end + + -- Use the __parent key and the __index function to achieve inheritance + obj.__parent = self + setmetatable(obj, __TObject.__mt) + return obj +end + +-- Return a string representation of any lua variable +function thrift_print_r(t) + local ret = '' + local ltype = type(t) + if (ltype == 'table') then + ret = ret .. '{ ' + for key,value in pairs(t) do + ret = ret .. tostring(key) .. '=' .. thrift_print_r(value) .. ' ' + end + ret = ret .. '}' + elseif ltype == 'string' then + ret = ret .. "'" .. tostring(t) .. "'" + else + ret = ret .. tostring(t) + end + return ret +end + +-- Basic Exception +TException = __TObject:new{ + message, + errorCode, + __type = 'TException' +} +function TException:__tostring() + if self.message then + return string.format('%s: %s', self.__type, self.message) + else + local message + if self.errorCode and self.__errorCodeToString then + message = string.format('%d: %s', self.errorCode, self:__errorCodeToString()) + else + message = thrift_print_r(self) + end + return string.format('%s:%s', self.__type, message) + end +end + +TApplicationException = TException:new{ + UNKNOWN = 0, + UNKNOWN_METHOD = 1, + INVALID_MESSAGE_TYPE = 2, + WRONG_METHOD_NAME = 3, + BAD_SEQUENCE_ID = 4, + MISSING_RESULT = 5, + INTERNAL_ERROR = 6, + PROTOCOL_ERROR = 7, + INVALID_TRANSFORM = 8, + INVALID_PROTOCOL = 9, + UNSUPPORTED_CLIENT_TYPE = 10, + errorCode = 0, + __type = 'TApplicationException' +} + +function TApplicationException:__errorCodeToString() + if self.errorCode == self.UNKNOWN_METHOD then + return 'Unknown method' + elseif self.errorCode == self.INVALID_MESSAGE_TYPE then + return 'Invalid message type' + elseif self.errorCode == self.WRONG_METHOD_NAME then + return 'Wrong method name' + elseif self.errorCode == self.BAD_SEQUENCE_ID then + return 'Bad sequence ID' + elseif self.errorCode == self.MISSING_RESULT then + return 'Missing result' + elseif self.errorCode == self.INTERNAL_ERROR then + return 'Internal error' + elseif self.errorCode == self.PROTOCOL_ERROR then + return 'Protocol error' + elseif self.errorCode == self.INVALID_TRANSFORM then + return 'Invalid transform' + elseif self.errorCode == self.INVALID_PROTOCOL then + return 'Invalid protocol' + elseif self.errorCode == self.UNSUPPORTED_CLIENT_TYPE then + return 'Unsupported client type' + else + return 'Default (unknown)' + end +end + +function TException:read(iprot) + iprot:readStructBegin() + while true do + local fname, ftype, fid = iprot:readFieldBegin() + if ftype == TType.STOP then + break + elseif fid == 1 then + if ftype == TType.STRING then + self.message = iprot:readString() + else + iprot:skip(ftype) + end + elseif fid == 2 then + if ftype == TType.I32 then + self.errorCode = iprot:readI32() + else + iprot:skip(ftype) + end + else + iprot:skip(ftype) + end + iprot:readFieldEnd() + end + iprot:readStructEnd() +end + +function TException:write(oprot) + oprot:writeStructBegin('TApplicationException') + if self.message then + oprot:writeFieldBegin('message', TType.STRING, 1) + oprot:writeString(self.message) + oprot:writeFieldEnd() + end + if self.errorCode then + oprot:writeFieldBegin('type', TType.I32, 2) + oprot:writeI32(self.errorCode) + oprot:writeFieldEnd() + end + oprot:writeFieldStop() + oprot:writeStructEnd() +end + +-- Basic Client (used in generated lua code) +__TClient = __TObject:new{ + __type = '__TClient', + _seqid = 0 +} +function __TClient:new(obj) + if ttype(obj) ~= 'table' then + error('TClient must be initialized with a table') + end + + -- Set iprot & oprot + if obj.protocol then + obj.iprot = obj.protocol + obj.oprot = obj.protocol + obj.protocol = nil + elseif not obj.iprot then + error('You must provide ' .. ttype(self) .. ' with an iprot') + end + if not obj.oprot then + obj.oprot = obj.iprot + end + + return __TObject.new(self, obj) +end + +function __TClient:close() + self.iprot.trans:close() + self.oprot.trans:close() +end + +-- Basic Processor (used in generated lua code) +__TProcessor = __TObject:new{ + __type = '__TProcessor' +} +function __TProcessor:new(obj) + if ttype(obj) ~= 'table' then + error('TProcessor must be initialized with a table') + end + + -- Ensure a handler is provided + if not obj.handler then + error('You must provide ' .. ttype(self) .. ' with a handler') + end + + return __TObject.new(self, obj) +end diff --git a/lib/lua/src/longnumberutils.c b/lib/lua/src/longnumberutils.c new file mode 100644 index 000000000..fbc678900 --- /dev/null +++ b/lib/lua/src/longnumberutils.c @@ -0,0 +1,47 @@ +// +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. +// + +#include <lua.h> +#include <lauxlib.h> +#include <stdlib.h> +#include <inttypes.h> + +const char * LONG_NUM_TYPE = "__thrift_longnumber"; +int64_t lualongnumber_checklong(lua_State *L, int index) { + switch (lua_type(L, index)) { + case LUA_TNUMBER: + return (int64_t)lua_tonumber(L, index); + case LUA_TSTRING: + return atoll(lua_tostring(L, index)); + default: + return *((int64_t *)luaL_checkudata(L, index, LONG_NUM_TYPE)); + } +} + +// Creates a new longnumber and pushes it onto the statck +int64_t * lualongnumber_pushlong(lua_State *L, int64_t *val) { + int64_t *data = (int64_t *)lua_newuserdata(L, sizeof(int64_t)); // longnum + luaL_getmetatable(L, LONG_NUM_TYPE); // longnum, mt + lua_setmetatable(L, -2); // longnum + if (val) { + *data = *val; + } + return data; +} + diff --git a/lib/lua/src/luabitwise.c b/lib/lua/src/luabitwise.c new file mode 100644 index 000000000..2e07e1724 --- /dev/null +++ b/lib/lua/src/luabitwise.c @@ -0,0 +1,83 @@ +// +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. +// + +#include <lua.h> +#include <lauxlib.h> + +static int l_not(lua_State *L) { + int a = luaL_checkinteger(L, 1); + a = ~a; + lua_pushnumber(L, a); + return 1; +} + +static int l_xor(lua_State *L) { + int a = luaL_checkinteger(L, 1); + int b = luaL_checkinteger(L, 2); + a ^= b; + lua_pushnumber(L, a); + return 1; +} + +static int l_and(lua_State *L) { + int a = luaL_checkinteger(L, 1); + int b = luaL_checkinteger(L, 2); + a &= b; + lua_pushnumber(L, a); + return 1; +} + +static int l_or(lua_State *L) { + int a = luaL_checkinteger(L, 1); + int b = luaL_checkinteger(L, 2); + a |= b; + lua_pushnumber(L, a); + return 1; +} + +static int l_shiftr(lua_State *L) { + int a = luaL_checkinteger(L, 1); + int b = luaL_checkinteger(L, 2); + a = a >> b; + lua_pushnumber(L, a); + return 1; +} + +static int l_shiftl(lua_State *L) { + int a = luaL_checkinteger(L, 1); + int b = luaL_checkinteger(L, 2); + a = a << b; + lua_pushnumber(L, a); + return 1; +} + +static const struct luaL_Reg funcs[] = { + {"band", l_and}, + {"bor", l_or}, + {"bxor", l_xor}, + {"bnot", l_not}, + {"shiftl", l_shiftl}, + {"shiftr", l_shiftr}, + {NULL, NULL} +}; + +int luaopen_libluabitwise(lua_State *L) { + luaL_register(L, "libluabitwise", funcs); + return 1; +} diff --git a/lib/lua/src/luabpack.c b/lib/lua/src/luabpack.c new file mode 100644 index 000000000..c936428cd --- /dev/null +++ b/lib/lua/src/luabpack.c @@ -0,0 +1,162 @@ +// +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. +// + +#include <lua.h> +#include <lauxlib.h> +#include <string.h> +#include <inttypes.h> +#include <netinet/in.h> + +extern int64_t lualongnumber_checklong(lua_State *L, int index); +extern int64_t lualongnumber_pushlong(lua_State *L, int64_t *val); + +// host order to network order (64-bit) +static int64_t T_htonll(uint64_t data) { + uint32_t d1 = htonl((uint32_t)data); + uint32_t d2 = htonl((uint32_t)(data >> 32)); + return ((uint64_t)d1 << 32) + (uint64_t)d2; +} + +// network order to host order (64-bit) +static int64_t T_ntohll(uint64_t data) { + uint32_t d1 = ntohl((uint32_t)data); + uint32_t d2 = ntohl((uint32_t)(data >> 32)); + return ((uint64_t)d1 << 32) + (uint64_t)d2; +} + +/** + * bpack(type, data) + * c - Signed Byte + * s - Signed Short + * i - Signed Int + * l - Signed Long + * d - Double + */ +static int l_bpack(lua_State *L) { + const char *code = luaL_checkstring(L, 1); + luaL_argcheck(L, code[1] == '\0', 0, "Format code must be one character."); + luaL_Buffer buf; + luaL_buffinit(L, &buf); + + switch (code[0]) { + case 'c': { + int8_t data = luaL_checknumber(L, 2); + luaL_addlstring(&buf, (void*)&data, sizeof(data)); + break; + } + case 's': { + int16_t data = luaL_checknumber(L, 2); + data = (int16_t)htons(data); + luaL_addlstring(&buf, (void*)&data, sizeof(data)); + break; + } + case 'i': { + int32_t data = luaL_checkinteger(L, 2); + data = (int32_t)htonl(data); + luaL_addlstring(&buf, (void*)&data, sizeof(data)); + break; + } + case 'l': { + int64_t data = lualongnumber_checklong(L, 2); + data = (int64_t)T_htonll(data); + luaL_addlstring(&buf, (void*)&data, sizeof(data)); + break; + } + case 'd': { + double data = luaL_checknumber(L, 2); + luaL_addlstring(&buf, (void*)&data, sizeof(data)); + break; + } + default: + luaL_argcheck(L, 0, 0, "Invalid format code."); + } + + luaL_pushresult(&buf); + return 1; +} + +/** + * bunpack(type, data) + * c - Signed Byte + * s - Signed Short + * i - Signed Int + * l - Signed Long + * d - Double + */ +static int l_bunpack(lua_State *L) { + const char *code = luaL_checkstring(L, 1); + luaL_argcheck(L, code[1] == '\0', 0, "Format code must be one character."); + const char *data = luaL_checkstring(L, 2); + size_t len = lua_rawlen(L, 2); + + switch (code[0]) { + case 'c': { + int8_t val; + luaL_argcheck(L, len == sizeof(val), 1, "Invalid input string size."); + memcpy(&val, data, sizeof(val)); + lua_pushnumber(L, val); + break; + } + case 's': { + int16_t val; + luaL_argcheck(L, len == sizeof(val), 1, "Invalid input string size."); + memcpy(&val, data, sizeof(val)); + val = (int16_t)ntohs(val); + lua_pushnumber(L, val); + break; + } + case 'i': { + int32_t val; + luaL_argcheck(L, len == sizeof(val), 1, "Invalid input string size."); + memcpy(&val, data, sizeof(val)); + val = (int32_t)ntohl(val); + lua_pushnumber(L, val); + break; + } + case 'l': { + int64_t val; + luaL_argcheck(L, len == sizeof(val), 1, "Invalid input string size."); + memcpy(&val, data, sizeof(val)); + val = (int64_t)T_ntohll(val); + lualongnumber_pushlong(L, &val); + break; + } + case 'd': { + double val; + luaL_argcheck(L, len == sizeof(val), 1, "Invalid input string size."); + memcpy(&val, data, sizeof(val)); + lua_pushnumber(L, val); + break; + } + default: + luaL_argcheck(L, 0, 0, "Invalid format code."); + } + return 1; +} + +static const struct luaL_Reg lua_bpack[] = { + {"bpack", l_bpack}, + {"bunpack", l_bunpack}, + {NULL, NULL} +}; + +int luaopen_libluabpack(lua_State *L) { + luaL_register(L, "libluabpack", lua_bpack); + return 1; +} diff --git a/lib/lua/src/lualongnumber.c b/lib/lua/src/lualongnumber.c new file mode 100644 index 000000000..9001e4a90 --- /dev/null +++ b/lib/lua/src/lualongnumber.c @@ -0,0 +1,228 @@ +// +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. +// + +#include <lua.h> +#include <lauxlib.h> +#include <stdlib.h> +#include <math.h> +#include <inttypes.h> +#include <string.h> + +extern const char * LONG_NUM_TYPE; +extern int64_t lualongnumber_checklong(lua_State *L, int index); +extern int64_t lualongnumber_pushlong(lua_State *L, int64_t *val); + +//////////////////////////////////////////////////////////////////////////////// + +static void l_serialize(char *buf, int len, int64_t val) { + snprintf(buf, len, "%"PRId64, val); +} + +static int64_t l_deserialize(const char *buf) { + int64_t data; + int rv; + // Support hex prefixed with '0x' + if (strstr(buf, "0x") == buf) { + rv = sscanf(buf, "%"PRIx64, &data); + } else { + rv = sscanf(buf, "%"PRId64, &data); + } + if (rv == 1) { + return data; + } + return 0; // Failed +} + +//////////////////////////////////////////////////////////////////////////////// + +static int l_new(lua_State *L) { + int64_t val; + const char *str = NULL; + if (lua_type(L, 1) == LUA_TSTRING) { + str = lua_tostring(L, 1); + val = l_deserialize(str); + } else if (lua_type(L, 1) == LUA_TNUMBER) { + val = (int64_t)lua_tonumber(L, 1); + str = (const char *)1; + } + lualongnumber_pushlong(L, (str ? &val : NULL)); + return 1; +} + +//////////////////////////////////////////////////////////////////////////////// + +// a + b +static int l_add(lua_State *L) { + int64_t a, b, c; + a = lualongnumber_checklong(L, 1); + b = lualongnumber_checklong(L, 2); + c = a + b; + lualongnumber_pushlong(L, &c); + return 1; +} + +// a / b +static int l_div(lua_State *L) { + int64_t a, b, c; + a = lualongnumber_checklong(L, 1); + b = lualongnumber_checklong(L, 2); + c = a / b; + lualongnumber_pushlong(L, &c); + return 1; +} + +// a == b (both a and b are lualongnumber's) +static int l_eq(lua_State *L) { + int64_t a, b; + a = lualongnumber_checklong(L, 1); + b = lualongnumber_checklong(L, 2); + lua_pushboolean(L, (a == b ? 1 : 0)); + return 1; +} + +// garbage collection +static int l_gc(lua_State *L) { + lua_pushnil(L); + lua_setmetatable(L, 1); + return 0; +} + +// a < b +static int l_lt(lua_State *L) { + int64_t a, b; + a = lualongnumber_checklong(L, 1); + b = lualongnumber_checklong(L, 2); + lua_pushboolean(L, (a < b ? 1 : 0)); + return 1; +} + +// a <= b +static int l_le(lua_State *L) { + int64_t a, b; + a = lualongnumber_checklong(L, 1); + b = lualongnumber_checklong(L, 2); + lua_pushboolean(L, (a <= b ? 1 : 0)); + return 1; +} + +// a % b +static int l_mod(lua_State *L) { + int64_t a, b, c; + a = lualongnumber_checklong(L, 1); + b = lualongnumber_checklong(L, 2); + c = a % b; + lualongnumber_pushlong(L, &c); + return 1; +} + +// a * b +static int l_mul(lua_State *L) { + int64_t a, b, c; + a = lualongnumber_checklong(L, 1); + b = lualongnumber_checklong(L, 2); + c = a * b; + lualongnumber_pushlong(L, &c); + return 1; +} + +// a ^ b +static int l_pow(lua_State *L) { + long double a, b; + int64_t c; + a = (long double)lualongnumber_checklong(L, 1); + b = (long double)lualongnumber_checklong(L, 2); + c = (int64_t)pow(a, b); + lualongnumber_pushlong(L, &c); + return 1; +} + +// a - b +static int l_sub(lua_State *L) { + int64_t a, b, c; + a = lualongnumber_checklong(L, 1); + b = lualongnumber_checklong(L, 2); + c = a - b; + lualongnumber_pushlong(L, &c); + return 1; +} + +// tostring() +static int l_tostring(lua_State *L) { + int64_t a; + char str[256]; + l_serialize(str, 256, lualongnumber_checklong(L, 1)); + lua_pushstring(L, str); + return 1; +} + +// -a +static int l_unm(lua_State *L) { + int64_t a, c; + a = lualongnumber_checklong(L, 1); + c = -a; + lualongnumber_pushlong(L, &c); + return 1; +} + +//////////////////////////////////////////////////////////////////////////////// + +static const luaL_Reg methods[] = { + {"__add", l_add}, + {"__div", l_div}, + {"__eq", l_eq}, + {"__gc", l_gc}, + {"__lt", l_lt}, + {"__le", l_le}, + {"__mod", l_mod}, + {"__mul", l_mul}, + {"__pow", l_pow}, + {"__sub", l_sub}, + {"__tostring", l_tostring}, + {"__unm", l_unm}, + {NULL, NULL}, +}; + +static const luaL_Reg funcs[] = { + {"new", l_new}, + {NULL, NULL} +}; + +//////////////////////////////////////////////////////////////////////////////// + +static void set_methods(lua_State *L, + const char *metatablename, + const struct luaL_Reg *methods) { + luaL_getmetatable(L, metatablename); // mt + // No need for a __index table since everything is __* + for (; methods->name; methods++) { + lua_pushstring(L, methods->name); // mt, "name" + lua_pushcfunction(L, methods->func); // mt, "name", func + lua_rawset(L, -3); // mt + } + lua_pop(L, 1); +} + +LUALIB_API int luaopen_liblualongnumber(lua_State *L) { + luaL_newmetatable(L, LONG_NUM_TYPE); + lua_pop(L, 1); + set_methods(L, LONG_NUM_TYPE, methods); + + luaL_register(L, "liblualongnumber", funcs); + return 1; +} diff --git a/lib/lua/src/luasocket.c b/lib/lua/src/luasocket.c new file mode 100644 index 000000000..c8a678ff5 --- /dev/null +++ b/lib/lua/src/luasocket.c @@ -0,0 +1,386 @@ +// +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. +// + +#include <lua.h> +#include <lauxlib.h> + +#include <unistd.h> +#include "string.h" +#include "socket.h" + +//////////////////////////////////////////////////////////////////////////////// + +static const char *SOCKET_ANY = "__thrift_socket_any"; +static const char *SOCKET_CONN = "__thrift_socket_connected"; + +static const char *SOCKET_GENERIC = "__thrift_socket_generic"; +static const char *SOCKET_CLIENT = "__thrift_socket_client"; +static const char *SOCKET_SERVER = "__thrift_socket_server"; + +static const char *DEFAULT_HOST = "localhost"; + +typedef struct __t_tcp { + t_socket sock; + int timeout; // Milliseconds +} t_tcp; +typedef t_tcp *p_tcp; + +//////////////////////////////////////////////////////////////////////////////// +// Util + +static void throw_argerror(lua_State *L, int index, const char *expected) { + char msg[256]; + sprintf(msg, "%s expected, got %s", expected, luaL_typename(L, index)); + luaL_argerror(L, index, msg); +} + +static void *checkgroup(lua_State *L, int index, const char *groupname) { + if (!lua_getmetatable(L, index)) { + throw_argerror(L, index, groupname); + } + + lua_pushstring(L, groupname); + lua_rawget(L, -2); + if (lua_isnil(L, -1)) { + lua_pop(L, 2); + throw_argerror(L, index, groupname); + } else { + lua_pop(L, 2); + return lua_touserdata(L, index); + } + return NULL; // Not reachable +} + +static void *checktype(lua_State *L, int index, const char *typename) { + if (strcmp(typename, SOCKET_ANY) == 0 || + strcmp(typename, SOCKET_CONN) == 0) { + return checkgroup(L, index, typename); + } else { + return luaL_checkudata(L, index, typename); + } +} + +static void settype(lua_State *L, int index, const char *typename) { + luaL_getmetatable(L, typename); + lua_setmetatable(L, index); +} + +#define LUA_SUCCESS_RETURN(L) \ + lua_pushnumber(L, 1); \ + return 1 + +#define LUA_CHECK_RETURN(L, err) \ + if (err) { \ + lua_pushnil(L); \ + lua_pushstring(L, err); \ + return 2; \ + } \ + LUA_SUCCESS_RETURN(L) + +//////////////////////////////////////////////////////////////////////////////// + +static int l_socket_create(lua_State *L); +static int l_socket_destroy(lua_State *L); +static int l_socket_settimeout(lua_State *L); +static int l_socket_getsockinfo(lua_State *L); + +static int l_socket_accept(lua_State *L); +static int l_socket_listen(lua_State *L); + +static int l_socket_create_and_connect(lua_State *L); +static int l_socket_connect(lua_State *L); +static int l_socket_send(lua_State *L); +static int l_socket_receive(lua_State *L); + +//////////////////////////////////////////////////////////////////////////////// + +static const struct luaL_Reg methods_generic[] = { + {"destroy", l_socket_destroy}, + {"settimeout", l_socket_settimeout}, + {"getsockinfo", l_socket_getsockinfo}, + {"listen", l_socket_listen}, + {"connect", l_socket_connect}, + {NULL, NULL} +}; + +static const struct luaL_Reg methods_server[] = { + {"destroy", l_socket_destroy}, + {"getsockinfo", l_socket_getsockinfo}, + {"accept", l_socket_accept}, + {"send", l_socket_send}, + {"receive", l_socket_receive}, + {NULL, NULL} +}; + +static const struct luaL_Reg methods_client[] = { + {"destroy", l_socket_destroy}, + {"settimeout", l_socket_settimeout}, + {"getsockinfo", l_socket_getsockinfo}, + {"send", l_socket_send}, + {"receive", l_socket_receive}, + {NULL, NULL} +}; + +static const struct luaL_Reg funcs_luasocket[] = { + {"create", l_socket_create}, + {"create_and_connect", l_socket_create_and_connect}, + {NULL, NULL} +}; + +//////////////////////////////////////////////////////////////////////////////// + +// Check/enforce inheritance +static void add_to_group(lua_State *L, + const char *metatablename, + const char *groupname) { + luaL_getmetatable(L, metatablename); // mt + lua_pushstring(L, groupname); // mt, "name" + lua_pushboolean(L, 1); // mt, "name", true + lua_rawset(L, -3); // mt + lua_pop(L, 1); +} + +static void set_methods(lua_State *L, + const char *metatablename, + const struct luaL_Reg *methods) { + luaL_getmetatable(L, metatablename); // mt + // Create the __index table + lua_pushstring(L, "__index"); // mt, "__index" + lua_newtable(L); // mt, "__index", t + for (; methods->name; methods++) { + lua_pushstring(L, methods->name); // mt, "__index", t, "name" + lua_pushcfunction(L, methods->func); // mt, "__index", t, "name", func + lua_rawset(L, -3); // mt, "__index", t + } + lua_rawset(L, -3); // mt + lua_pop(L, 1); +} + +int luaopen_libluasocket(lua_State *L) { + luaL_newmetatable(L, SOCKET_GENERIC); + luaL_newmetatable(L, SOCKET_CLIENT); + luaL_newmetatable(L, SOCKET_SERVER); + lua_pop(L, 3); + add_to_group(L, SOCKET_GENERIC, SOCKET_ANY); + add_to_group(L, SOCKET_CLIENT, SOCKET_ANY); + add_to_group(L, SOCKET_SERVER, SOCKET_ANY); + add_to_group(L, SOCKET_CLIENT, SOCKET_CONN); + add_to_group(L, SOCKET_SERVER, SOCKET_CONN); + set_methods(L, SOCKET_GENERIC, methods_generic); + set_methods(L, SOCKET_CLIENT, methods_client); + set_methods(L, SOCKET_SERVER, methods_server); + + luaL_register(L, "luasocket", funcs_luasocket); + return 1; +} + +//////////////////////////////////////////////////////////////////////////////// +// General + +// sock,err create(bind_host, bind_port) +// sock,err create(bind_host) -> any port +// sock,err create() -> any port on localhost +static int l_socket_create(lua_State *L) { + const char *err; + t_socket sock; + const char *addr = lua_tostring(L, 1); + if (!addr) { + addr = DEFAULT_HOST; + } + unsigned short port = lua_tonumber(L, 2); + err = tcp_create(&sock); + if (!err) { + err = tcp_bind(&sock, addr, port); // bind on create + if (err) { + tcp_destroy(&sock); + } else { + p_tcp tcp = (p_tcp) lua_newuserdata(L, sizeof(t_tcp)); + settype(L, -2, SOCKET_GENERIC); + socket_setnonblocking(&sock); + tcp->sock = sock; + tcp->timeout = 0; + return 1; // Return userdata + } + } + LUA_CHECK_RETURN(L, err); +} + +// destroy() +static int l_socket_destroy(lua_State *L) { + p_tcp tcp = (p_tcp) checktype(L, 1, SOCKET_ANY); + const char *err = tcp_destroy(&tcp->sock); + LUA_CHECK_RETURN(L, err); +} + +// send(socket, data) +static int l_socket_send(lua_State *L) { + p_tcp self = (p_tcp) checktype(L, 1, SOCKET_CONN); + p_tcp tcp = (p_tcp) checktype(L, 2, SOCKET_CONN); + size_t len; + const char *data = luaL_checklstring(L, 3, &len); + const char *err = + tcp_send(&tcp->sock, data, len, tcp->timeout); + LUA_CHECK_RETURN(L, err); +} + +#define LUA_READ_STEP 8192 +static int l_socket_receive(lua_State *L) { + p_tcp self = (p_tcp) checktype(L, 1, SOCKET_CONN); + p_tcp handle = (p_tcp) checktype(L, 2, SOCKET_CONN); + size_t len = luaL_checknumber(L, 3); + char buf[LUA_READ_STEP]; + const char *err = NULL; + int received; + size_t got = 0, step = 0; + luaL_Buffer b; + + luaL_buffinit(L, &b); + do { + step = (LUA_READ_STEP < len - got ? LUA_READ_STEP : len - got); + err = tcp_raw_receive(&handle->sock, buf, step, self->timeout, &received); + if (err == NULL) { + luaL_addlstring(&b, buf, received); + got += received; + } + } while (err == NULL && got < len); + + if (err) { + lua_pushnil(L); + lua_pushstring(L, err); + return 2; + } + luaL_pushresult(&b); + return 1; +} + +// settimeout(timeout) +static int l_socket_settimeout(lua_State *L) { + p_tcp self = (p_tcp) checktype(L, 1, SOCKET_ANY); + int timeout = luaL_checknumber(L, 2); + self->timeout = timeout; + LUA_SUCCESS_RETURN(L); +} + +// table getsockinfo() +static int l_socket_getsockinfo(lua_State *L) { + char buf[256]; + short port = 0; + p_tcp tcp = (p_tcp) checktype(L, 1, SOCKET_ANY); + if (socket_get_info(&tcp->sock, &port, buf, 256) == SUCCESS) { + lua_newtable(L); // t + lua_pushstring(L, "host"); // t, "host" + lua_pushstring(L, buf); // t, "host", buf + lua_rawset(L, -3); // t + lua_pushstring(L, "port"); // t, "port" + lua_pushnumber(L, port); // t, "port", port + lua_rawset(L, -3); // t + return 1; + } + return 0; +} + +//////////////////////////////////////////////////////////////////////////////// +// Server + +// accept() +static int l_socket_accept(lua_State *L) { + const char *err; + p_tcp self = (p_tcp) checktype(L, 1, SOCKET_SERVER); + t_socket sock; + err = tcp_accept(&self->sock, &sock, self->timeout); + if (!err) { // Success + // Create a reference to the client + p_tcp client = (p_tcp) lua_newuserdata(L, sizeof(t_tcp)); + settype(L, 2, SOCKET_CLIENT); + socket_setnonblocking(&sock); + client->sock = sock; + client->timeout = self->timeout; + return 1; + } + LUA_CHECK_RETURN(L, err); +} + +static int l_socket_listen(lua_State *L) { + const char* err; + p_tcp tcp = (p_tcp) checktype(L, 1, SOCKET_GENERIC); + int backlog = 10; + err = tcp_listen(&tcp->sock, backlog); + if (!err) { + // Set the current as a server + settype(L, 1, SOCKET_SERVER); // Now a server + } + LUA_CHECK_RETURN(L, err); +} + +//////////////////////////////////////////////////////////////////////////////// +// Client + +// create_and_connect(host, port, timeout) +extern double __gettime(); +static int l_socket_create_and_connect(lua_State *L) { + const char* err = NULL; + double end; + t_socket sock; + const char *host = luaL_checkstring(L, 1); + unsigned short port = luaL_checknumber(L, 2); + int timeout = luaL_checknumber(L, 3); + + // Create and connect loop for timeout milliseconds + end = __gettime() + timeout/1000; + do { + // Create the socket + err = tcp_create(&sock); + if (!err) { + // Bind to any port on localhost + err = tcp_bind(&sock, DEFAULT_HOST, 0); + if (err) { + tcp_destroy(&sock); + } else { + // Connect + err = tcp_connect(&sock, host, port, timeout); + if (err) { + tcp_destroy(&sock); + usleep(100000); // sleep for 100ms + } else { + p_tcp tcp = (p_tcp) lua_newuserdata(L, sizeof(t_tcp)); + settype(L, -2, SOCKET_CLIENT); + socket_setnonblocking(&sock); + tcp->sock = sock; + tcp->timeout = timeout; + return 1; // Return userdata + } + } + } + } while (err && __gettime() < end); + + LUA_CHECK_RETURN(L, err); +} + +// connect(host, port) +static int l_socket_connect(lua_State *L) { + const char *err; + p_tcp tcp = (p_tcp) checktype(L, 1, SOCKET_GENERIC); + const char *host = luaL_checkstring(L, 2); + unsigned short port = luaL_checknumber(L, 3); + err = tcp_connect(&tcp->sock, host, port, tcp->timeout); + if (!err) { + settype(L, 1, SOCKET_CLIENT); // Now a client + } + LUA_CHECK_RETURN(L, err); +} diff --git a/lib/lua/src/socket.h b/lib/lua/src/socket.h new file mode 100644 index 000000000..8019ffed8 --- /dev/null +++ b/lib/lua/src/socket.h @@ -0,0 +1,78 @@ +// +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. +// + +#ifndef LUA_THRIFT_SOCKET_H +#define LUA_THRIFT_SOCKET_H + +#include <sys/socket.h> + +#ifdef _WIN32 +// SOL +#else +typedef int t_socket; +typedef t_socket* p_socket; +#endif + +// Error Codes +enum { + SUCCESS = 0, + TIMEOUT = -1, + CLOSED = -2, +}; +typedef int T_ERRCODE; + +static const char * TIMEOUT_MSG = "Timeout"; +static const char * CLOSED_MSG = "Connection Closed"; + +typedef struct sockaddr t_sa; +typedef t_sa * p_sa; + +T_ERRCODE socket_create(p_socket sock, int domain, int type, int protocol); +T_ERRCODE socket_destroy(p_socket sock); +T_ERRCODE socket_bind(p_socket sock, p_sa addr, int addr_len); +T_ERRCODE socket_get_info(p_socket sock, short *port, char *buf, size_t len); +T_ERRCODE socket_send(p_socket sock, const char *data, size_t len, int timeout); +T_ERRCODE socket_recv(p_socket sock, char *data, size_t len, int timeout, + int *received); + +void socket_setblocking(p_socket sock); +void socket_setnonblocking(p_socket sock); + +T_ERRCODE socket_accept(p_socket sock, p_socket sibling, + p_sa addr, socklen_t *addr_len, int timeout); +T_ERRCODE socket_listen(p_socket sock, int backlog); + +T_ERRCODE socket_connect(p_socket sock, p_sa addr, int addr_len, int timeout); + +const char * tcp_create(p_socket sock); +const char * tcp_destroy(p_socket sock); +const char * tcp_bind(p_socket sock, const char *host, unsigned short port); +const char * tcp_send(p_socket sock, const char *data, size_t w_len, + int timeout); +const char * tcp_receive(p_socket sock, char *data, size_t r_len, int timeout); +const char * tcp_raw_receive(p_socket sock, char * data, size_t r_len, + int timeout, int *received); + +const char * tcp_listen(p_socket sock, int backlog); +const char * tcp_accept(p_socket sock, p_socket client, int timeout); + +const char * tcp_connect(p_socket sock, const char *host, unsigned short port, + int timeout); + +#endif diff --git a/lib/lua/src/usocket.c b/lib/lua/src/usocket.c new file mode 100644 index 000000000..be696e06e --- /dev/null +++ b/lib/lua/src/usocket.c @@ -0,0 +1,362 @@ +// +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. +// + +#include <sys/time.h> +#include <sys/types.h> +#include <arpa/inet.h> +#include <netdb.h> +#include <string.h> +#include <unistd.h> +#include <fcntl.h> +#include <errno.h> + +#include <stdio.h> // TODO REMOVE + +#include "socket.h" + +//////////////////////////////////////////////////////////////////////////////// +// Private + +// Num seconds since Jan 1 1970 (UTC) +#ifdef _WIN32 +// SOL +#else + double __gettime() { + struct timeval v; + gettimeofday(&v, (struct timezone*) NULL); + return v.tv_sec + v.tv_usec/1.0e6; + } +#endif + +#define WAIT_MODE_R 1 +#define WAIT_MODE_W 2 +#define WAIT_MODE_C (WAIT_MODE_R|WAIT_MODE_W) +T_ERRCODE socket_wait(p_socket sock, int mode, int timeout) { + int ret = 0; + fd_set rfds, wfds; + struct timeval tv; + double end, t; + if (!timeout) { + return TIMEOUT; + } + + end = __gettime() + timeout/1000; + do { + // Specify what I/O operations we care about + if (mode & WAIT_MODE_R) { + FD_ZERO(&rfds); + FD_SET(*sock, &rfds); + } + if (mode & WAIT_MODE_W) { + FD_ZERO(&wfds); + FD_SET(*sock, &wfds); + } + + // Check for timeout + t = end - __gettime(); + if (t < 0.0) { + break; + } + + // Wait + tv.tv_sec = (int)t; + tv.tv_usec = (int)((t - tv.tv_sec) * 1.0e6); + ret = select(*sock+1, &rfds, &wfds, NULL, &tv); + } while (ret == -1 && errno == EINTR); + if (ret == -1) { + return errno; + } + + // Check for timeout + if (ret == 0) { + return TIMEOUT; + } + + // Verify that we can actually read from the remote host + if (mode & WAIT_MODE_C && FD_ISSET(*sock, &rfds) && + recv(*sock, (char*) &rfds, 0, 0) != 0) { + return errno; + } + + return SUCCESS; +} + +//////////////////////////////////////////////////////////////////////////////// +// General + +T_ERRCODE socket_create(p_socket sock, int domain, int type, int protocol) { + *sock = socket(domain, type, protocol); + if (*sock > 0) { + return SUCCESS; + } else { + return errno; + } +} + +T_ERRCODE socket_destroy(p_socket sock) { + // TODO Figure out if I should be free-ing this + if (*sock > 0) { + socket_setblocking(sock); + close(*sock); + *sock = -1; + } + return SUCCESS; +} + +T_ERRCODE socket_bind(p_socket sock, p_sa addr, int addr_len) { + int ret = SUCCESS; + socket_setblocking(sock); + if (bind(*sock, addr, addr_len)) { + ret = errno; + } + socket_setnonblocking(sock); + return ret; +} + +T_ERRCODE socket_get_info(p_socket sock, short *port, char *buf, size_t len) { + struct sockaddr_in sa; + socklen_t addrlen; + memset(&sa, 0, sizeof(sa)); + int rc = getsockname(*sock, (struct sockaddr*)&sa, &addrlen); + if (!rc) { + char *addr = inet_ntoa(sa.sin_addr); + *port = ntohs(sa.sin_port); + if (strlen(addr) < len) { + len = strlen(addr); + } + memcpy(buf, addr, len); + return SUCCESS; + } + return rc; +} + +//////////////////////////////////////////////////////////////////////////////// +// Server + +T_ERRCODE socket_accept(p_socket sock, p_socket client, + p_sa addr, socklen_t *addrlen, int timeout) { + int err; + if (*sock < 0) { + return CLOSED; + } + do { + *client = accept(*sock, addr, addrlen); + if (*client > 0) { + return SUCCESS; + } + err = errno; + } while (err != EINTR); + if (err == EAGAIN || err == ECONNABORTED) { + return socket_wait(sock, WAIT_MODE_R, timeout); + } + return err; +} + +T_ERRCODE socket_listen(p_socket sock, int backlog) { + int ret = SUCCESS; + socket_setblocking(sock); + if (listen(*sock, backlog)) { + ret = errno; + } + socket_setnonblocking(sock); + return ret; +} + +//////////////////////////////////////////////////////////////////////////////// +// Client + +T_ERRCODE socket_connect(p_socket sock, p_sa addr, int addr_len, int timeout) { + int err; + if (*sock < 0) { + return CLOSED; + } + + do { + if (connect(*sock, addr, addr_len) == 0) { + return SUCCESS; + } + } while ((err = errno) == EINTR); + if (err != EINPROGRESS && err != EAGAIN) { + return err; + } + return socket_wait(sock, WAIT_MODE_C, timeout); +} + +T_ERRCODE socket_send( + p_socket sock, const char *data, size_t len, int timeout) { + int err, put = 0; + if (*sock < 0) { + return CLOSED; + } + do { + put = send(*sock, data, len, 0); + if (put > 0) { + return SUCCESS; + } + err = errno; + } while (err != EINTR); + + if (err == EAGAIN) { + return socket_wait(sock, WAIT_MODE_W, timeout); + } + return err; +} + +T_ERRCODE socket_recv( + p_socket sock, char *data, size_t len, int timeout, int *received) { + int err, got = 0; + if (*sock < 0) { + return CLOSED; + } + + int flags = fcntl(*sock, F_GETFL, 0); + do { + got = recv(*sock, data, len, 0); + if (got > 0) { + *received = got; + return SUCCESS; + } + err = errno; + + // Connection has been closed by peer + if (got == 0) { + return CLOSED; + } + } while (err != EINTR); + + if (err == EAGAIN) { + return socket_wait(sock, WAIT_MODE_R, timeout); + } + return err; +} + +//////////////////////////////////////////////////////////////////////////////// +// Util + +void socket_setnonblocking(p_socket sock) { + int flags = fcntl(*sock, F_GETFL, 0); + flags |= O_NONBLOCK; + fcntl(*sock, F_SETFL, flags); +} + +void socket_setblocking(p_socket sock) { + int flags = fcntl(*sock, F_GETFL, 0); + flags &= (~(O_NONBLOCK)); + fcntl(*sock, F_SETFL, flags); +} + +//////////////////////////////////////////////////////////////////////////////// +// TCP + +#define ERRORSTR_RETURN(err) \ + if (err == SUCCESS) { \ + return NULL; \ + } else if (err == TIMEOUT) { \ + return TIMEOUT_MSG; \ + } else if (err == CLOSED) { \ + return CLOSED_MSG; \ + } \ + return strerror(err) + +const char * tcp_create(p_socket sock) { + int err = socket_create(sock, AF_INET, SOCK_STREAM, 0); + ERRORSTR_RETURN(err); +} + +const char * tcp_destroy(p_socket sock) { + int err = socket_destroy(sock); + ERRORSTR_RETURN(err); +} + +const char * tcp_bind(p_socket sock, const char *host, unsigned short port) { + int err; + struct hostent *h; + struct sockaddr_in local; + memset(&local, 0, sizeof(local)); + local.sin_family = AF_INET; + local.sin_addr.s_addr = htonl(INADDR_ANY); + local.sin_port = htons(port); + if (strcmp(host, "*") && !inet_aton(host, &local.sin_addr)) { + h = gethostbyname(host); + if (!h) { + return hstrerror(h_errno); + } + memcpy(&local.sin_addr, + (struct in_addr *)h->h_addr_list[0], + sizeof(struct in_addr)); + } + err = socket_bind(sock, (p_sa) &local, sizeof(local)); + ERRORSTR_RETURN(err); +} + +const char * tcp_listen(p_socket sock, int backlog) { + int err = socket_listen(sock, backlog); + ERRORSTR_RETURN(err); +} + +const char * tcp_accept(p_socket sock, p_socket client, int timeout) { + int err = socket_accept(sock, client, NULL, NULL, timeout); + ERRORSTR_RETURN(err); +} + +const char * tcp_connect(p_socket sock, + const char *host, + unsigned short port, + int timeout) { + int err; + struct hostent *h; + struct sockaddr_in remote; + memset(&remote, 0, sizeof(remote)); + remote.sin_family = AF_INET; + remote.sin_port = htons(port); + if (strcmp(host, "*") && !inet_aton(host, &remote.sin_addr)) { + h = gethostbyname(host); + if (!h) { + return hstrerror(h_errno); + } + memcpy(&remote.sin_addr, + (struct in_addr *)h->h_addr_list[0], + sizeof(struct in_addr)); + } + err = socket_connect(sock, (p_sa) &remote, sizeof(remote), timeout); + ERRORSTR_RETURN(err); +} + +#define WRITE_STEP 8192 +const char * tcp_send( + p_socket sock, const char * data, size_t w_len, int timeout) { + int err; + size_t put = 0, step; + if (!w_len) { + return NULL; + } + + do { + step = (WRITE_STEP < w_len - put ? WRITE_STEP : w_len - put); + err = socket_send(sock, data + put, step, timeout); + put += step; + } while (err == SUCCESS && put < w_len); + ERRORSTR_RETURN(err); +} + +const char * tcp_raw_receive( + p_socket sock, char * data, size_t r_len, int timeout, int *received) { + int err = socket_recv(sock, data, r_len, timeout, received); + ERRORSTR_RETURN(err); +} |