summaryrefslogtreecommitdiff
path: root/vendor/src/github.com/godbus/dbus/transport_unix.go
diff options
context:
space:
mode:
Diffstat (limited to 'vendor/src/github.com/godbus/dbus/transport_unix.go')
-rw-r--r--vendor/src/github.com/godbus/dbus/transport_unix.go190
1 files changed, 190 insertions, 0 deletions
diff --git a/vendor/src/github.com/godbus/dbus/transport_unix.go b/vendor/src/github.com/godbus/dbus/transport_unix.go
new file mode 100644
index 0000000000..d16229be40
--- /dev/null
+++ b/vendor/src/github.com/godbus/dbus/transport_unix.go
@@ -0,0 +1,190 @@
+package dbus
+
+import (
+ "bytes"
+ "encoding/binary"
+ "errors"
+ "io"
+ "net"
+ "syscall"
+)
+
+type oobReader struct {
+ conn *net.UnixConn
+ oob []byte
+ buf [4096]byte
+}
+
+func (o *oobReader) Read(b []byte) (n int, err error) {
+ n, oobn, flags, _, err := o.conn.ReadMsgUnix(b, o.buf[:])
+ if err != nil {
+ return n, err
+ }
+ if flags&syscall.MSG_CTRUNC != 0 {
+ return n, errors.New("dbus: control data truncated (too many fds received)")
+ }
+ o.oob = append(o.oob, o.buf[:oobn]...)
+ return n, nil
+}
+
+type unixTransport struct {
+ *net.UnixConn
+ hasUnixFDs bool
+}
+
+func newUnixTransport(keys string) (transport, error) {
+ var err error
+
+ t := new(unixTransport)
+ abstract := getKey(keys, "abstract")
+ path := getKey(keys, "path")
+ switch {
+ case abstract == "" && path == "":
+ return nil, errors.New("dbus: invalid address (neither path nor abstract set)")
+ case abstract != "" && path == "":
+ t.UnixConn, err = net.DialUnix("unix", nil, &net.UnixAddr{Name: "@" + abstract, Net: "unix"})
+ if err != nil {
+ return nil, err
+ }
+ return t, nil
+ case abstract == "" && path != "":
+ t.UnixConn, err = net.DialUnix("unix", nil, &net.UnixAddr{Name: path, Net: "unix"})
+ if err != nil {
+ return nil, err
+ }
+ return t, nil
+ default:
+ return nil, errors.New("dbus: invalid address (both path and abstract set)")
+ }
+}
+
+func (t *unixTransport) EnableUnixFDs() {
+ t.hasUnixFDs = true
+}
+
+func (t *unixTransport) ReadMessage() (*Message, error) {
+ var (
+ blen, hlen uint32
+ csheader [16]byte
+ headers []header
+ order binary.ByteOrder
+ unixfds uint32
+ )
+ // To be sure that all bytes of out-of-band data are read, we use a special
+ // reader that uses ReadUnix on the underlying connection instead of Read
+ // and gathers the out-of-band data in a buffer.
+ rd := &oobReader{conn: t.UnixConn}
+ // read the first 16 bytes (the part of the header that has a constant size),
+ // from which we can figure out the length of the rest of the message
+ if _, err := io.ReadFull(rd, csheader[:]); err != nil {
+ return nil, err
+ }
+ switch csheader[0] {
+ case 'l':
+ order = binary.LittleEndian
+ case 'B':
+ order = binary.BigEndian
+ default:
+ return nil, InvalidMessageError("invalid byte order")
+ }
+ // csheader[4:8] -> length of message body, csheader[12:16] -> length of
+ // header fields (without alignment)
+ binary.Read(bytes.NewBuffer(csheader[4:8]), order, &blen)
+ binary.Read(bytes.NewBuffer(csheader[12:]), order, &hlen)
+ if hlen%8 != 0 {
+ hlen += 8 - (hlen % 8)
+ }
+
+ // decode headers and look for unix fds
+ headerdata := make([]byte, hlen+4)
+ copy(headerdata, csheader[12:])
+ if _, err := io.ReadFull(t, headerdata[4:]); err != nil {
+ return nil, err
+ }
+ dec := newDecoder(bytes.NewBuffer(headerdata), order)
+ dec.pos = 12
+ vs, err := dec.Decode(Signature{"a(yv)"})
+ if err != nil {
+ return nil, err
+ }
+ Store(vs, &headers)
+ for _, v := range headers {
+ if v.Field == byte(FieldUnixFDs) {
+ unixfds, _ = v.Variant.value.(uint32)
+ }
+ }
+ all := make([]byte, 16+hlen+blen)
+ copy(all, csheader[:])
+ copy(all[16:], headerdata[4:])
+ if _, err := io.ReadFull(rd, all[16+hlen:]); err != nil {
+ return nil, err
+ }
+ if unixfds != 0 {
+ if !t.hasUnixFDs {
+ return nil, errors.New("dbus: got unix fds on unsupported transport")
+ }
+ // read the fds from the OOB data
+ scms, err := syscall.ParseSocketControlMessage(rd.oob)
+ if err != nil {
+ return nil, err
+ }
+ if len(scms) != 1 {
+ return nil, errors.New("dbus: received more than one socket control message")
+ }
+ fds, err := syscall.ParseUnixRights(&scms[0])
+ if err != nil {
+ return nil, err
+ }
+ msg, err := DecodeMessage(bytes.NewBuffer(all))
+ if err != nil {
+ return nil, err
+ }
+ // substitute the values in the message body (which are indices for the
+ // array receiver via OOB) with the actual values
+ for i, v := range msg.Body {
+ if j, ok := v.(UnixFDIndex); ok {
+ if uint32(j) >= unixfds {
+ return nil, InvalidMessageError("invalid index for unix fd")
+ }
+ msg.Body[i] = UnixFD(fds[j])
+ }
+ }
+ return msg, nil
+ }
+ return DecodeMessage(bytes.NewBuffer(all))
+}
+
+func (t *unixTransport) SendMessage(msg *Message) error {
+ fds := make([]int, 0)
+ for i, v := range msg.Body {
+ if fd, ok := v.(UnixFD); ok {
+ msg.Body[i] = UnixFDIndex(len(fds))
+ fds = append(fds, int(fd))
+ }
+ }
+ if len(fds) != 0 {
+ if !t.hasUnixFDs {
+ return errors.New("dbus: unix fd passing not enabled")
+ }
+ msg.Headers[FieldUnixFDs] = MakeVariant(uint32(len(fds)))
+ oob := syscall.UnixRights(fds...)
+ buf := new(bytes.Buffer)
+ msg.EncodeTo(buf, binary.LittleEndian)
+ n, oobn, err := t.UnixConn.WriteMsgUnix(buf.Bytes(), oob, nil)
+ if err != nil {
+ return err
+ }
+ if n != buf.Len() || oobn != len(oob) {
+ return io.ErrShortWrite
+ }
+ } else {
+ if err := msg.EncodeTo(t, binary.LittleEndian); err != nil {
+ return nil
+ }
+ }
+ return nil
+}
+
+func (t *unixTransport) SupportsUnixFDs() bool {
+ return true
+}