summaryrefslogtreecommitdiff
path: root/src/os
diff options
context:
space:
mode:
authorAnanth Bhaskararaman <antsub@gmail.com>2023-03-22 05:37:43 +0000
committerGopher Robot <gobot@golang.org>2023-03-23 17:32:05 +0000
commit1596d71255f3c19a27600e751592079dae46bf40 (patch)
tree2318c54a9adeb3ae7105d422b0422b2e113011c6 /src/os
parent0aa14fca8c639c9ceba264dbf0d82bd53306aeaa (diff)
downloadgo-git-1596d71255f3c19a27600e751592079dae46bf40.tar.gz
os/user: lookup Linux users and groups via systemd userdb
Fetch usernames and groups via systemd userdb if available. Otherwise fall back to parsing /etc/passwd, etc. Fixes #38810 Co-authored-by: Michael Stapelberg <stapelberg@google.com> Change-Id: Iff6ffc54feec6b6cec241b89e362c2285c8c0454 GitHub-Last-Rev: 1a627cc9a18063f5d274bb96113947cd4d952e5a GitHub-Pull-Request: golang/go#57458 Reviewed-on: https://go-review.googlesource.com/c/go/+/459455 TryBot-Result: Gopher Robot <gobot@golang.org> Run-TryBot: Ian Lance Taylor <iant@google.com> Reviewed-by: Ian Lance Taylor <iant@google.com> Reviewed-by: Heschi Kreinick <heschi@google.com> Auto-Submit: Ian Lance Taylor <iant@google.com>
Diffstat (limited to 'src/os')
-rw-r--r--src/os/user/listgroups_unix.go9
-rw-r--r--src/os/user/lookup_unix.go30
-rw-r--r--src/os/user/user.go4
-rw-r--r--src/os/user/userdbclient.go22
-rw-r--r--src/os/user/userdbclient_linux.go772
-rw-r--r--src/os/user/userdbclient_linux_test.go504
-rw-r--r--src/os/user/userdbclient_stub.go29
7 files changed, 1370 insertions, 0 deletions
diff --git a/src/os/user/listgroups_unix.go b/src/os/user/listgroups_unix.go
index ef366fa280..b620ad3652 100644
--- a/src/os/user/listgroups_unix.go
+++ b/src/os/user/listgroups_unix.go
@@ -9,11 +9,13 @@ package user
import (
"bufio"
"bytes"
+ "context"
"errors"
"fmt"
"io"
"os"
"strconv"
+ "time"
)
func listGroupsFromReader(u *User, r io.Reader) ([]string, error) {
@@ -99,6 +101,13 @@ func listGroupsFromReader(u *User, r io.Reader) ([]string, error) {
}
func listGroups(u *User) ([]string, error) {
+ if defaultUserdbClient.isUsable() {
+ ctx, cancel := context.WithTimeout(context.Background(), 2*time.Minute)
+ defer cancel()
+ if ids, ok, err := defaultUserdbClient.lookupGroupIds(ctx, u.Username); ok {
+ return ids, err
+ }
+ }
f, err := os.Open(groupFile)
if err != nil {
return nil, err
diff --git a/src/os/user/lookup_unix.go b/src/os/user/lookup_unix.go
index 608d9b2140..0ee2ad35ef 100644
--- a/src/os/user/lookup_unix.go
+++ b/src/os/user/lookup_unix.go
@@ -9,11 +9,13 @@ package user
import (
"bufio"
"bytes"
+ "context"
"errors"
"io"
"os"
"strconv"
"strings"
+ "time"
)
// lineFunc returns a value, an error, or (nil, nil) to skip the row.
@@ -198,6 +200,13 @@ func findUsername(name string, r io.Reader) (*User, error) {
}
func lookupGroup(groupname string) (*Group, error) {
+ if defaultUserdbClient.isUsable() {
+ ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
+ defer cancel()
+ if g, ok, err := defaultUserdbClient.lookupGroup(ctx, groupname); ok {
+ return g, err
+ }
+ }
f, err := os.Open(groupFile)
if err != nil {
return nil, err
@@ -207,6 +216,13 @@ func lookupGroup(groupname string) (*Group, error) {
}
func lookupGroupId(id string) (*Group, error) {
+ if defaultUserdbClient.isUsable() {
+ ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
+ defer cancel()
+ if g, ok, err := defaultUserdbClient.lookupGroupId(ctx, id); ok {
+ return g, err
+ }
+ }
f, err := os.Open(groupFile)
if err != nil {
return nil, err
@@ -216,6 +232,13 @@ func lookupGroupId(id string) (*Group, error) {
}
func lookupUser(username string) (*User, error) {
+ if defaultUserdbClient.isUsable() {
+ ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
+ defer cancel()
+ if u, ok, err := defaultUserdbClient.lookupUser(ctx, username); ok {
+ return u, err
+ }
+ }
f, err := os.Open(userFile)
if err != nil {
return nil, err
@@ -225,6 +248,13 @@ func lookupUser(username string) (*User, error) {
}
func lookupUserId(uid string) (*User, error) {
+ if defaultUserdbClient.isUsable() {
+ ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
+ defer cancel()
+ if u, ok, err := defaultUserdbClient.lookupUserId(ctx, uid); ok {
+ return u, err
+ }
+ }
f, err := os.Open(userFile)
if err != nil {
return nil, err
diff --git a/src/os/user/user.go b/src/os/user/user.go
index 0307d2ad6a..4cf5b7c515 100644
--- a/src/os/user/user.go
+++ b/src/os/user/user.go
@@ -11,6 +11,10 @@ One is written in pure Go and parses /etc/passwd and /etc/group. The other
is cgo-based and relies on the standard C library (libc) routines such as
getpwuid_r, getgrnam_r, and getgrouplist.
+For Linux, the pure Go implementation queries the systemd-userdb service first.
+If the service is not available, it falls back to parsing /etc/passwd and
+/etc/group.
+
When cgo is available, and the required routines are implemented in libc
for a particular platform, cgo-based (libc-backed) code is used.
This can be overridden by using osusergo build tag, which enforces
diff --git a/src/os/user/userdbclient.go b/src/os/user/userdbclient.go
new file mode 100644
index 0000000000..b0f3895ed4
--- /dev/null
+++ b/src/os/user/userdbclient.go
@@ -0,0 +1,22 @@
+// Copyright 2023 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package user
+
+// userdbClient queries the io.systemd.UserDatabase VARLINK interface provided by
+// systemd-userdbd.service(8) on Linux for obtaining full user/group details
+// even when cgo is not available.
+// VARLINK protocol: https://varlink.org
+// Systemd userdb VARLINK interface https://systemd.io/USER_GROUP_API
+// dir contains multiple varlink service sockets implementing the userdb interface.
+type userdbClient struct {
+ dir string
+}
+
+// IsUsable checks if the client can be used to make queries.
+func (cl userdbClient) isUsable() bool {
+ return len(cl.dir) != 0
+}
+
+var defaultUserdbClient userdbClient
diff --git a/src/os/user/userdbclient_linux.go b/src/os/user/userdbclient_linux.go
new file mode 100644
index 0000000000..e585b7f3c3
--- /dev/null
+++ b/src/os/user/userdbclient_linux.go
@@ -0,0 +1,772 @@
+// Copyright 2023 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+//go:build linux
+
+package user
+
+import (
+ "bytes"
+ "context"
+ "errors"
+ "fmt"
+ "io"
+ "io/fs"
+ "os"
+ "strconv"
+ "strings"
+ "sync"
+ "syscall"
+ "unicode/utf16"
+ "unicode/utf8"
+)
+
+const (
+ // Well known multiplexer service.
+ svcMultiplexer = "io.systemd.Multiplexer"
+
+ userdbNamespace = "io.systemd.UserDatabase"
+
+ // io.systemd.UserDatabase VARLINK interface methods.
+ mGetGroupRecord = userdbNamespace + ".GetGroupRecord"
+ mGetUserRecord = userdbNamespace + ".GetUserRecord"
+ mGetMemberships = userdbNamespace + ".GetMemberships"
+
+ // io.systemd.UserDatabase VARLINK interface errors.
+ errNoRecordFound = userdbNamespace + ".NoRecordFound"
+ errServiceNotAvailable = userdbNamespace + ".ServiceNotAvailable"
+)
+
+func init() {
+ defaultUserdbClient.dir = "/run/systemd/userdb"
+}
+
+// userdbCall represents a VARLINK service call sent to systemd-userdb.
+// method is the VARLINK method to call.
+// parameters are the VARLINK parameters to pass.
+// more indicates if more responses are expected.
+// fastest indicates if only the fastest response should be returned.
+type userdbCall struct {
+ method string
+ parameters callParameters
+ more bool
+ fastest bool
+}
+
+func (u userdbCall) marshalJSON(service string) ([]byte, error) {
+ params, err := u.parameters.marshalJSON(service)
+ if err != nil {
+ return nil, err
+ }
+ var data bytes.Buffer
+ data.WriteString(`{"method":"`)
+ data.WriteString(u.method)
+ data.WriteString(`","parameters":`)
+ data.Write(params)
+ if u.more {
+ data.WriteString(`,"more":true`)
+ }
+ data.WriteString(`}`)
+ return data.Bytes(), nil
+}
+
+type callParameters struct {
+ uid *int64
+ userName string
+ gid *int64
+ groupName string
+}
+
+func (c callParameters) marshalJSON(service string) ([]byte, error) {
+ var data bytes.Buffer
+ data.WriteString(`{"service":"`)
+ data.WriteString(service)
+ data.WriteString(`"`)
+ if c.uid != nil {
+ data.WriteString(`,"uid":`)
+ data.WriteString(strconv.FormatInt(*c.uid, 10))
+ }
+ if c.userName != "" {
+ data.WriteString(`,"userName":"`)
+ data.WriteString(c.userName)
+ data.WriteString(`"`)
+ }
+ if c.gid != nil {
+ data.WriteString(`,"gid":`)
+ data.WriteString(strconv.FormatInt(*c.gid, 10))
+ }
+ if c.groupName != "" {
+ data.WriteString(`,"groupName":"`)
+ data.WriteString(c.groupName)
+ data.WriteString(`"`)
+ }
+ data.WriteString(`}`)
+ return data.Bytes(), nil
+}
+
+type userdbReply struct {
+ continues bool
+ errorStr string
+}
+
+func (u *userdbReply) unmarshalJSON(data []byte) error {
+ var (
+ kContinues = []byte(`"continues"`)
+ kError = []byte(`"error"`)
+ )
+ if i := bytes.Index(data, kContinues); i != -1 {
+ continues, err := parseJSONBoolean(data[i+len(kContinues):])
+ if err != nil {
+ return err
+ }
+ u.continues = continues
+ }
+ if i := bytes.Index(data, kError); i != -1 {
+ errStr, err := parseJSONString(data[i+len(kError):])
+ if err != nil {
+ return err
+ }
+ u.errorStr = errStr
+ }
+ return nil
+}
+
+// response is the parsed reply from a method call to systemd-userdb.
+// data is one or more VARLINK response parameters separated by 0.
+// handled indicates if the call was handled by systemd-userdb.
+// err is any error encountered.
+type response struct {
+ data []byte
+ handled bool
+ err error
+}
+
+// querySocket calls the io.systemd.UserDatabase VARLINK interface at sock with request.
+// Multiple replies can be read by setting more to true in the request.
+// Reply parameters are accumulated separated by 0, if there are many.
+// Replies with io.systemd.UserDatabase.NoRecordFound errors are skipped.
+// Other UserDatabase errors are returned as is.
+// If the socket does not exist, or if the io.systemd.UserDatabase.ServiceNotAvailable
+// error is seen in a response, the query is considered unhandled.
+func querySocket(ctx context.Context, sock string, request []byte) response {
+ sockFd, err := syscall.Socket(syscall.AF_UNIX, syscall.SOCK_STREAM, 0)
+ if err != nil {
+ return response{err: err}
+ }
+ defer syscall.Close(sockFd)
+ if err := syscall.Connect(sockFd, &syscall.SockaddrUnix{Name: sock}); err != nil {
+ if errors.Is(err, os.ErrNotExist) {
+ return response{err: err}
+ }
+ return response{handled: true, err: err}
+ }
+
+ // Null terminate request.
+ if request[len(request)-1] != 0 {
+ request = append(request, 0)
+ }
+
+ // Write request to socket.
+ written := 0
+ for written < len(request) {
+ if ctx.Err() != nil {
+ return response{handled: true, err: ctx.Err()}
+ }
+ if n, err := syscall.Write(sockFd, request[written:]); err != nil {
+ return response{handled: true, err: err}
+ } else {
+ written += n
+ }
+ }
+
+ // Read response.
+ var resp bytes.Buffer
+ for {
+ if ctx.Err() != nil {
+ return response{handled: true, err: ctx.Err()}
+ }
+ buf := make([]byte, 4096)
+ if n, err := syscall.Read(sockFd, buf); err != nil {
+ return response{handled: true, err: err}
+ } else if n > 0 {
+ resp.Write(buf[:n])
+ if buf[n-1] == 0 {
+ break
+ }
+ } else {
+ // EOF
+ break
+ }
+ }
+
+ if resp.Len() == 0 {
+ return response{handled: true}
+ }
+
+ buf := resp.Bytes()
+ // Remove trailing 0.
+ buf = buf[:len(buf)-1]
+ // Split into VARLINK messages.
+ msgs := bytes.Split(buf, []byte{0})
+
+ // Parse VARLINK messages.
+ for _, m := range msgs {
+ var resp userdbReply
+ if err := resp.unmarshalJSON(m); err != nil {
+ return response{handled: true, err: err}
+ }
+ // Handle VARLINK message errors.
+ switch e := resp.errorStr; e {
+ case "":
+ case errNoRecordFound: // Ignore not found error.
+ continue
+ case errServiceNotAvailable:
+ return response{}
+ default:
+ return response{handled: true, err: errors.New(e)}
+ }
+ if !resp.continues {
+ break
+ }
+ }
+ return response{data: buf, handled: true, err: ctx.Err()}
+}
+
+// queryMany calls the io.systemd.UserDatabase VARLINK interface on many services at once.
+// ss is a slice of userdb services to call. Each service must have a socket in cl.dir.
+// c is sent to all services in ss. If c.fastest is true, only the fastest reply is read.
+// Otherwise all replies are aggregated. um is called with aggregated reply parameters.
+// queryMany returns the first error encountered. The first result is false if no userdb
+// socket is available or if all requests time out.
+func (cl userdbClient) queryMany(ctx context.Context, ss []string, c *userdbCall, um jsonUnmarshaler) (bool, error) {
+ responseCh := make(chan response, len(ss))
+
+ ctx, cancel := context.WithCancel(ctx)
+ defer cancel()
+
+ // Query all services in parallel.
+ var workers sync.WaitGroup
+ for _, svc := range ss {
+ data, err := c.marshalJSON(svc)
+ if err != nil {
+ return true, err
+ }
+ // Spawn worker to query service.
+ workers.Add(1)
+ go func(sock string, data []byte) {
+ defer workers.Done()
+ responseCh <- querySocket(ctx, sock, data)
+ }(cl.dir+"/"+svc, data)
+ }
+
+ go func() {
+ // Clean up workers.
+ workers.Wait()
+ close(responseCh)
+ }()
+
+ var result bytes.Buffer
+ var notOk int
+RecvResponses:
+ for {
+ select {
+ case resp, ok := <-responseCh:
+ if !ok {
+ // Responses channel is closed so stop reading.
+ break RecvResponses
+ }
+ if resp.err != nil {
+ // querySocket only returns unrecoverable errors,
+ // so return the first one received.
+ return true, resp.err
+ }
+ if !resp.handled {
+ notOk++
+ continue
+ }
+
+ first := result.Len() == 0
+ result.Write(resp.data)
+ if first && c.fastest {
+ // Return the fastest response.
+ break RecvResponses
+ }
+ case <-ctx.Done():
+ // If requests time out, userdb is unavailable.
+ return ctx.Err() != context.DeadlineExceeded, nil
+ }
+ }
+ // If all sockets are not ok, userdb is unavailable.
+ if notOk == len(ss) {
+ return false, nil
+ }
+ return true, um.unmarshalJSON(result.Bytes())
+}
+
+// services enumerates userdb service sockets in dir.
+// If ok is false, io.systemd.UserDatabase service does not exist.
+func (cl userdbClient) services() (s []string, ok bool, err error) {
+ var entries []fs.DirEntry
+ if entries, err = os.ReadDir(cl.dir); err != nil {
+ ok = !os.IsNotExist(err)
+ return
+ }
+ ok = true
+ for _, ent := range entries {
+ s = append(s, ent.Name())
+ }
+ return
+}
+
+// query looks up users/groups on the io.systemd.UserDatabase VARLINK interface.
+// If the multiplexer service is available, the call is sent only to it.
+// Otherwise, the call is sent simultaneously to all UserDatabase services in cl.dir.
+// The fastest reply is read and parsed. All other requests are cancelled.
+// If the service is unavailable, the first result is false.
+// The service is considered unavailable if the requests time-out as well.
+func (cl userdbClient) query(ctx context.Context, call *userdbCall, um jsonUnmarshaler) (bool, error) {
+ services := []string{svcMultiplexer}
+ if _, err := os.Stat(cl.dir + "/" + svcMultiplexer); err != nil {
+ // No mux service so call all available services.
+ var ok bool
+ if services, ok, err = cl.services(); !ok || err != nil {
+ return ok, err
+ }
+ }
+ call.fastest = true
+ if ok, err := cl.queryMany(ctx, services, call, um); !ok || err != nil {
+ return ok, err
+ }
+ return true, nil
+}
+
+type jsonUnmarshaler interface {
+ unmarshalJSON([]byte) error
+}
+
+func isSpace(c byte) bool {
+ return c == ' ' || c == '\t' || c == '\r' || c == '\n'
+}
+
+// findElementStart returns a slice of r that starts at the next JSON element.
+// It skips over valid JSON space characters and checks for the colon separator.
+func findElementStart(r []byte) ([]byte, error) {
+ var idx int
+ var b byte
+ colon := byte(':')
+ var seenColon bool
+ for idx, b = range r {
+ if isSpace(b) {
+ continue
+ }
+ if !seenColon && b == colon {
+ seenColon = true
+ continue
+ }
+ // Spotted colon and b is not a space, so value starts here.
+ if seenColon {
+ break
+ }
+ return nil, errors.New("expected colon, got invalid character: " + string(b))
+ }
+ if !seenColon {
+ return nil, errors.New("expected colon, got end of input")
+ }
+ return r[idx:], nil
+}
+
+// parseJSONString reads a JSON string from r.
+func parseJSONString(r []byte) (string, error) {
+ r, err := findElementStart(r)
+ if err != nil {
+ return "", err
+ }
+ // Smallest valid string is `""`.
+ if l := len(r); l < 2 {
+ return "", errors.New("unexpected end of input")
+ } else if l == 2 {
+ if bytes.Equal(r, []byte(`""`)) {
+ return "", nil
+ }
+ return "", errors.New("invalid string")
+ }
+
+ if c := r[0]; c != '"' {
+ return "", errors.New(`expected " got ` + string(c))
+ }
+ // Advance over opening quote.
+ r = r[1:]
+
+ var value strings.Builder
+ var inEsc bool
+ var inUEsc bool
+ var strEnds bool
+ reader := bytes.NewReader(r)
+ for {
+ if value.Len() > 4096 {
+ return "", errors.New("string too large")
+ }
+
+ // Parse unicode escape sequences.
+ if inUEsc {
+ maybeRune := make([]byte, 4)
+ n, err := reader.Read(maybeRune)
+ if err != nil || n != 4 {
+ return "", fmt.Errorf("invalid unicode escape sequence \\u%s", string(maybeRune))
+ }
+ prn, err := strconv.ParseUint(string(maybeRune), 16, 32)
+ if err != nil {
+ return "", fmt.Errorf("invalid unicode escape sequence \\u%s", string(maybeRune))
+ }
+ rn := rune(prn)
+ if !utf16.IsSurrogate(rn) {
+ value.WriteRune(rn)
+ inUEsc = false
+ continue
+ }
+ // rn maybe a high surrogate; read the low surrogate.
+ maybeRune = make([]byte, 6)
+ n, err = reader.Read(maybeRune)
+ if err != nil || n != 6 || maybeRune[0] != '\\' || maybeRune[1] != 'u' {
+ // Not a valid UTF-16 surrogate pair.
+ if _, err := reader.Seek(int64(-n), io.SeekCurrent); err != nil {
+ return "", err
+ }
+ // Invalid low surrogate; write the replacement character.
+ value.WriteRune(utf8.RuneError)
+ } else {
+ rn1, err := strconv.ParseUint(string(maybeRune[2:]), 16, 32)
+ if err != nil {
+ return "", fmt.Errorf("invalid unicode escape sequence %s", string(maybeRune))
+ }
+ // Check if rn and rn1 are valid UTF-16 surrogate pairs.
+ if dec := utf16.DecodeRune(rn, rune(rn1)); dec != utf8.RuneError {
+ n = utf8.EncodeRune(maybeRune, dec)
+ // Write the decoded rune.
+ value.Write(maybeRune[:n])
+ }
+ }
+ inUEsc = false
+ continue
+ }
+
+ if inEsc {
+ b, err := reader.ReadByte()
+ if err != nil {
+ return "", err
+ }
+ switch b {
+ case 'b':
+ value.WriteByte('\b')
+ case 'f':
+ value.WriteByte('\f')
+ case 'n':
+ value.WriteByte('\n')
+ case 'r':
+ value.WriteByte('\r')
+ case 't':
+ value.WriteByte('\t')
+ case 'u':
+ inUEsc = true
+ case '/':
+ value.WriteByte('/')
+ case '\\':
+ value.WriteByte('\\')
+ case '"':
+ value.WriteByte('"')
+ default:
+ return "", errors.New("unexpected character in escape sequence " + string(b))
+ }
+ inEsc = false
+ continue
+ } else {
+ rn, _, err := reader.ReadRune()
+ if err != nil {
+ if err == io.EOF {
+ break
+ }
+ return "", err
+ }
+ if rn == '\\' {
+ inEsc = true
+ continue
+ }
+ if rn == '"' {
+ // String ends on un-escaped quote.
+ strEnds = true
+ break
+ }
+ value.WriteRune(rn)
+ }
+ }
+ if !strEnds {
+ return "", errors.New("unexpected end of input")
+ }
+ return value.String(), nil
+}
+
+// parseJSONInt64 reads a 64 bit integer from r.
+func parseJSONInt64(r []byte) (int64, error) {
+ r, err := findElementStart(r)
+ if err != nil {
+ return 0, err
+ }
+ var num strings.Builder
+ for _, b := range r {
+ // int64 max is 19 digits long.
+ if num.Len() == 20 {
+ return 0, errors.New("number too large")
+ }
+ if strings.ContainsRune("0123456789", rune(b)) {
+ num.WriteByte(b)
+ } else {
+ break
+ }
+ }
+ n, err := strconv.ParseInt(num.String(), 10, 64)
+ return int64(n), err
+}
+
+// parseJSONBoolean reads a boolean from r.
+func parseJSONBoolean(r []byte) (bool, error) {
+ r, err := findElementStart(r)
+ if err != nil {
+ return false, err
+ }
+ if bytes.HasPrefix(r, []byte("true")) {
+ return true, nil
+ }
+ if bytes.HasPrefix(r, []byte("false")) {
+ return false, nil
+ }
+ return false, errors.New("unable to parse boolean value")
+}
+
+type groupRecord struct {
+ groupName string
+ gid int64
+}
+
+func (g *groupRecord) unmarshalJSON(data []byte) error {
+ var (
+ kGroupName = []byte(`"groupName"`)
+ kGid = []byte(`"gid"`)
+ )
+ if i := bytes.Index(data, kGroupName); i != -1 {
+ groupname, err := parseJSONString(data[i+len(kGroupName):])
+ if err != nil {
+ return err
+ }
+ g.groupName = groupname
+ }
+ if i := bytes.Index(data, kGid); i != -1 {
+ gid, err := parseJSONInt64(data[i+len(kGid):])
+ if err != nil {
+ return err
+ }
+ g.gid = gid
+ }
+ return nil
+}
+
+// queryGroupDb queries the userdb interface for a gid, groupname, or both.
+func (cl userdbClient) queryGroupDb(ctx context.Context, gid *int64, groupname string) (*Group, bool, error) {
+ group := groupRecord{}
+ request := userdbCall{
+ method: mGetGroupRecord,
+ parameters: callParameters{gid: gid, groupName: groupname},
+ }
+ if ok, err := cl.query(ctx, &request, &group); !ok || err != nil {
+ return nil, ok, fmt.Errorf("error querying systemd-userdb group record: %s", err)
+ }
+ return &Group{
+ Name: group.groupName,
+ Gid: strconv.FormatInt(group.gid, 10),
+ }, true, nil
+}
+
+type userRecord struct {
+ userName string
+ realName string
+ uid int64
+ gid int64
+ homeDirectory string
+}
+
+func (u *userRecord) unmarshalJSON(data []byte) error {
+ var (
+ kUserName = []byte(`"userName"`)
+ kRealName = []byte(`"realName"`)
+ kUid = []byte(`"uid"`)
+ kGid = []byte(`"gid"`)
+ kHomeDirectory = []byte(`"homeDirectory"`)
+ )
+ if i := bytes.Index(data, kUserName); i != -1 {
+ username, err := parseJSONString(data[i+len(kUserName):])
+ if err != nil {
+ return err
+ }
+ u.userName = username
+ }
+ if i := bytes.Index(data, kRealName); i != -1 {
+ realname, err := parseJSONString(data[i+len(kRealName):])
+ if err != nil {
+ return err
+ }
+ u.realName = realname
+ }
+ if i := bytes.Index(data, kUid); i != -1 {
+ uid, err := parseJSONInt64(data[i+len(kUid):])
+ if err != nil {
+ return err
+ }
+ u.uid = uid
+ }
+ if i := bytes.Index(data, kGid); i != -1 {
+ gid, err := parseJSONInt64(data[i+len(kGid):])
+ if err != nil {
+ return err
+ }
+ u.gid = gid
+ }
+ if i := bytes.Index(data, kHomeDirectory); i != -1 {
+ homedir, err := parseJSONString(data[i+len(kHomeDirectory):])
+ if err != nil {
+ return err
+ }
+ u.homeDirectory = homedir
+ }
+ return nil
+}
+
+// queryUserDb queries the userdb interface for a uid, username, or both.
+func (cl userdbClient) queryUserDb(ctx context.Context, uid *int64, username string) (*User, bool, error) {
+ user := userRecord{}
+ request := userdbCall{
+ method: mGetUserRecord,
+ parameters: callParameters{
+ uid: uid,
+ userName: username,
+ },
+ }
+ if ok, err := cl.query(ctx, &request, &user); !ok || err != nil {
+ return nil, ok, fmt.Errorf("error querying systemd-userdb user record: %s", err)
+ }
+ return &User{
+ Uid: strconv.FormatInt(user.uid, 10),
+ Gid: strconv.FormatInt(user.gid, 10),
+ Username: user.userName,
+ Name: user.realName,
+ HomeDir: user.homeDirectory,
+ }, true, nil
+}
+
+func (cl userdbClient) lookupGroup(ctx context.Context, groupname string) (*Group, bool, error) {
+ return cl.queryGroupDb(ctx, nil, groupname)
+}
+
+func (cl userdbClient) lookupGroupId(ctx context.Context, id string) (*Group, bool, error) {
+ gid, err := strconv.ParseInt(id, 10, 64)
+ if err != nil {
+ return nil, true, err
+ }
+ return cl.queryGroupDb(ctx, &gid, "")
+}
+
+func (cl userdbClient) lookupUser(ctx context.Context, username string) (*User, bool, error) {
+ return cl.queryUserDb(ctx, nil, username)
+}
+
+func (cl userdbClient) lookupUserId(ctx context.Context, id string) (*User, bool, error) {
+ uid, err := strconv.ParseInt(id, 10, 64)
+ if err != nil {
+ return nil, true, err
+ }
+ return cl.queryUserDb(ctx, &uid, "")
+}
+
+type memberships struct {
+ // Keys are groupNames and values are sets of userNames.
+ groupUsers map[string]map[string]struct{}
+}
+
+// unmarshalJSON expects many (userName, groupName) records separated by a null byte.
+// This is used to build a membership map.
+func (m *memberships) unmarshalJSON(data []byte) error {
+ if m.groupUsers == nil {
+ m.groupUsers = make(map[string]map[string]struct{})
+ }
+ var (
+ kUserName = []byte(`"userName"`)
+ kGroupName = []byte(`"groupName"`)
+ )
+ // Split records by null terminator.
+ records := bytes.Split(data, []byte{byte(0)})
+ for _, rec := range records {
+ if len(rec) == 0 {
+ continue
+ }
+ var groupName string
+ var userName string
+ var err error
+ if i := bytes.Index(rec, kGroupName); i != -1 {
+ if groupName, err = parseJSONString(rec[i+len(kGroupName):]); err != nil {
+ return err
+ }
+ }
+ if i := bytes.Index(rec, kUserName); i != -1 {
+ if userName, err = parseJSONString(rec[i+len(kUserName):]); err != nil {
+ return err
+ }
+ }
+ // Associate userName with groupName.
+ if groupName != "" && userName != "" {
+ if _, ok := m.groupUsers[groupName]; ok {
+ m.groupUsers[groupName][userName] = struct{}{}
+ } else {
+ m.groupUsers[groupName] = map[string]struct{}{userName: {}}
+ }
+ }
+ }
+ return nil
+}
+
+func (cl userdbClient) lookupGroupIds(ctx context.Context, username string) ([]string, bool, error) {
+ services, ok, err := cl.services()
+ if !ok || err != nil {
+ return nil, ok, err
+ }
+ // Fetch group memberships for username.
+ var ms memberships
+ request := userdbCall{
+ method: mGetMemberships,
+ parameters: callParameters{userName: username},
+ more: true,
+ }
+ if ok, err := cl.queryMany(ctx, services, &request, &ms); !ok || err != nil {
+ return nil, ok, fmt.Errorf("error querying systemd-userdb memberships record: %s", err)
+ }
+ // Fetch user group gid.
+ var group groupRecord
+ request = userdbCall{
+ method: mGetGroupRecord,
+ parameters: callParameters{groupName: username},
+ }
+ if ok, err := cl.query(ctx, &request, &group); !ok || err != nil {
+ return nil, ok, err
+ }
+ gids := []string{strconv.FormatInt(group.gid, 10)}
+
+ // Fetch group records for each group.
+ for g := range ms.groupUsers {
+ var group groupRecord
+ request.parameters.groupName = g
+ // Query group for gid.
+ if ok, err := cl.query(ctx, &request, &group); !ok || err != nil {
+ return nil, ok, fmt.Errorf("error querying systemd-userdb group record: %s", err)
+ }
+ gids = append(gids, strconv.FormatInt(group.gid, 10))
+ }
+ return gids, true, nil
+}
diff --git a/src/os/user/userdbclient_linux_test.go b/src/os/user/userdbclient_linux_test.go
new file mode 100644
index 0000000000..1b9a336f72
--- /dev/null
+++ b/src/os/user/userdbclient_linux_test.go
@@ -0,0 +1,504 @@
+// Copyright 2023 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+//go:build linux
+
+package user
+
+import (
+ "bytes"
+ "context"
+ "errors"
+ "reflect"
+ "sort"
+ "strconv"
+ "strings"
+ "sync"
+ "syscall"
+ "testing"
+ "time"
+ "unicode/utf8"
+)
+
+func TestQueryNoUserdb(t *testing.T) {
+ cl := &userdbClient{dir: "/non/existent"}
+ if _, ok, err := cl.lookupGroup(context.Background(), "stdlibcontrib"); ok {
+ t.Fatalf("should fail but lookup has been handled or error is nil: %v", err)
+ }
+}
+
+type userdbTestData map[string]udbResponse
+
+type udbResponse struct {
+ data []byte
+ delay time.Duration
+}
+
+func userdbServer(t *testing.T, sockFn string, data userdbTestData) {
+ ready := make(chan struct{})
+ go func() {
+ if err := serveUserdb(ready, sockFn, data); err != nil {
+ t.Error(err)
+ }
+ }()
+ <-ready
+}
+
+func (u userdbTestData) String() string {
+ var s strings.Builder
+ for k, v := range u {
+ s.WriteString("Request:\n")
+ s.WriteString(k)
+ s.WriteString("\nResponse:\n")
+ if v.delay > 0 {
+ s.WriteString("Delay: ")
+ s.WriteString(v.delay.String())
+ s.WriteString("\n")
+ }
+ s.WriteString("Data:\n")
+ s.Write(v.data)
+ s.WriteString("\n")
+ }
+ return s.String()
+}
+
+// serverUserdb is a simple userdb server that replies to VARLINK method calls.
+// A message is sent on the ready channel when the server is ready to accept calls.
+// The server will reply to each request in the data map. If a request is not
+// found in the map, the server will return an error.
+func serveUserdb(ready chan<- struct{}, sockFn string, data userdbTestData) error {
+ sockFd, err := syscall.Socket(syscall.AF_UNIX, syscall.SOCK_STREAM, 0)
+ if err != nil {
+ return err
+ }
+ defer syscall.Close(sockFd)
+ if err := syscall.Bind(sockFd, &syscall.SockaddrUnix{Name: sockFn}); err != nil {
+ return err
+ }
+ if err := syscall.Listen(sockFd, 1); err != nil {
+ return err
+ }
+
+ // Send ready signal.
+ ready <- struct{}{}
+
+ var srvGroup sync.WaitGroup
+
+ srvErrs := make(chan error, len(data))
+ for len(data) != 0 {
+ nfd, _, err := syscall.Accept(sockFd)
+ if err != nil {
+ syscall.Close(nfd)
+ return err
+ }
+
+ // Read request.
+ buf := make([]byte, 4096)
+ n, err := syscall.Read(nfd, buf)
+ if err != nil {
+ syscall.Close(nfd)
+ return err
+ }
+ if n == 0 {
+ // Client went away.
+ continue
+ }
+ if buf[n-1] != 0 {
+ syscall.Close(nfd)
+ return errors.New("request not null terminated")
+ }
+ // Remove null terminator.
+ buf = buf[:n-1]
+ got := string(buf)
+
+ // Fetch response for request.
+ response, ok := data[got]
+ if !ok {
+ syscall.Close(nfd)
+ msg := "unexpected request:\n" + got + "\n\ndata:\n" + data.String()
+ return errors.New(msg)
+ }
+ delete(data, got)
+
+ srvGroup.Add(1)
+ go func() {
+ defer srvGroup.Done()
+ if err := serveClient(nfd, response); err != nil {
+ srvErrs <- err
+ }
+ }()
+ }
+
+ srvGroup.Wait()
+ // Combine serve errors if any.
+ if len(srvErrs) > 0 {
+ var errs []error
+ for err := range srvErrs {
+ errs = append(errs, err)
+ }
+ return errors.Join(errs...)
+ }
+
+ return nil
+}
+
+func serveClient(fd int, response udbResponse) error {
+ defer syscall.Close(fd)
+ time.Sleep(response.delay)
+ data := response.data
+ if len(data) != 0 && data[len(data)-1] != 0 {
+ data = append(data, 0)
+ }
+ written := 0
+ for written < len(data) {
+ if n, err := syscall.Write(fd, data[written:]); err != nil {
+ return err
+ } else {
+ written += n
+ }
+ }
+ return nil
+}
+
+func TestSlowUserdbLookup(t *testing.T) {
+ tmpdir := t.TempDir()
+ data := userdbTestData{
+ `{"method":"io.systemd.UserDatabase.GetGroupRecord","parameters":{"service":"io.systemd.Multiplexer","groupName":"stdlibcontrib"}}`: udbResponse{
+ delay: time.Hour,
+ },
+ }
+ userdbServer(t, tmpdir+"/"+svcMultiplexer, data)
+ cl := &userdbClient{dir: tmpdir}
+ // Lookup should timeout.
+ ctx, cancel := context.WithTimeout(context.Background(), time.Microsecond)
+ defer cancel()
+ if _, ok, _ := cl.lookupGroup(ctx, "stdlibcontrib"); ok {
+ t.Fatalf("lookup should not be handled but was")
+ }
+}
+
+func TestFastestUserdbLookup(t *testing.T) {
+ tmpdir := t.TempDir()
+ fastData := userdbTestData{
+ `{"method":"io.systemd.UserDatabase.GetGroupRecord","parameters":{"service":"fast","groupName":"stdlibcontrib"}}`: udbResponse{
+ data: []byte(
+ `{"parameters":{"record":{"groupName":"stdlibcontrib","gid":181,"members":["stdlibcontrib"],"status":{"ecb5a44f1a5846ad871566e113bf8937":{"service":"io.systemd.NameServiceSwitch"}}},"incomplete":false}}`,
+ ),
+ },
+ }
+ slowData := userdbTestData{
+ `{"method":"io.systemd.UserDatabase.GetGroupRecord","parameters":{"service":"slow","groupName":"stdlibcontrib"}}`: udbResponse{
+ delay: 50 * time.Millisecond,
+ data: []byte(
+ `{"parameters":{"record":{"groupName":"stdlibcontrib","gid":182,"members":["stdlibcontrib"],"status":{"ecb5a44f1a5846ad871566e113bf8937":{"service":"io.systemd.NameServiceSwitch"}}},"incomplete":false}}`,
+ ),
+ },
+ }
+ userdbServer(t, tmpdir+"/"+"fast", fastData)
+ userdbServer(t, tmpdir+"/"+"slow", slowData)
+ cl := &userdbClient{dir: tmpdir}
+ group, ok, err := cl.lookupGroup(context.Background(), "stdlibcontrib")
+ if !ok {
+ t.Fatalf("lookup should be handled but was not")
+ }
+ if err != nil {
+ t.Fatalf("lookup should not fail but did: %v", err)
+ }
+ if group.Gid != "181" {
+ t.Fatalf("lookup should return group 181 but returned %s", group.Gid)
+ }
+}
+
+func TestUserdbLookupGroup(t *testing.T) {
+ tmpdir := t.TempDir()
+ data := userdbTestData{
+ `{"method":"io.systemd.UserDatabase.GetGroupRecord","parameters":{"service":"io.systemd.Multiplexer","groupName":"stdlibcontrib"}}`: udbResponse{
+ data: []byte(
+ `{"parameters":{"record":{"groupName":"stdlibcontrib","gid":181,"members":["stdlibcontrib"],"status":{"ecb5a44f1a5846ad871566e113bf8937":{"service":"io.systemd.NameServiceSwitch"}}},"incomplete":false}}`,
+ ),
+ },
+ }
+ userdbServer(t, tmpdir+"/"+svcMultiplexer, data)
+
+ groupname := "stdlibcontrib"
+ want := &Group{
+ Name: "stdlibcontrib",
+ Gid: "181",
+ }
+ cl := &userdbClient{dir: tmpdir}
+ got, ok, err := cl.lookupGroup(context.Background(), groupname)
+ if !ok {
+ t.Fatal("lookup should have been handled")
+ }
+ if err != nil {
+ t.Fatal(err)
+ }
+ if !reflect.DeepEqual(got, want) {
+ t.Fatalf("lookupGroup(%s) = %v, want %v", groupname, got, want)
+ }
+}
+
+func TestUserdbLookupUser(t *testing.T) {
+ tmpdir := t.TempDir()
+ data := userdbTestData{
+ `{"method":"io.systemd.UserDatabase.GetUserRecord","parameters":{"service":"io.systemd.Multiplexer","userName":"stdlibcontrib"}}`: udbResponse{
+ data: []byte(
+ `{"parameters":{"record":{"userName":"stdlibcontrib","uid":181,"gid":181,"realName":"Stdlib Contrib","homeDirectory":"/home/stdlibcontrib","status":{"ecb5a44f1a5846ad871566e113bf8937":{"service":"io.systemd.NameServiceSwitch"}}},"incomplete":false}}`,
+ ),
+ },
+ }
+ userdbServer(t, tmpdir+"/"+svcMultiplexer, data)
+
+ username := "stdlibcontrib"
+ want := &User{
+ Uid: "181",
+ Gid: "181",
+ Username: "stdlibcontrib",
+ Name: "Stdlib Contrib",
+ HomeDir: "/home/stdlibcontrib",
+ }
+ cl := &userdbClient{dir: tmpdir}
+ got, ok, err := cl.lookupUser(context.Background(), username)
+ if !ok {
+ t.Fatal("lookup should have been handled")
+ }
+ if err != nil {
+ t.Fatal(err)
+ }
+ if !reflect.DeepEqual(got, want) {
+ t.Fatalf("lookupUser(%s) = %v, want %v", username, got, want)
+ }
+}
+
+func TestUserdbLookupGroupIds(t *testing.T) {
+ tmpdir := t.TempDir()
+ data := userdbTestData{
+ `{"method":"io.systemd.UserDatabase.GetMemberships","parameters":{"service":"io.systemd.Multiplexer","userName":"stdlibcontrib"},"more":true}`: udbResponse{
+ data: []byte(
+ `{"parameters":{"userName":"stdlibcontrib","groupName":"stdlib"},"continues":true}` + "\x00" + `{"parameters":{"userName":"stdlibcontrib","groupName":"contrib"}}`,
+ ),
+ },
+ // group records
+ `{"method":"io.systemd.UserDatabase.GetGroupRecord","parameters":{"service":"io.systemd.Multiplexer","groupName":"stdlibcontrib"}}`: udbResponse{
+ data: []byte(
+ `{"parameters":{"record":{"groupName":"stdlibcontrib","members":["stdlibcontrib"],"gid":181,"status":{"ecb5a44f1a5846ad871566e113bf8937":{"service":"io.systemd.NameServiceSwitch"}}},"incomplete":false}}`,
+ ),
+ },
+ `{"method":"io.systemd.UserDatabase.GetGroupRecord","parameters":{"service":"io.systemd.Multiplexer","groupName":"stdlib"}}`: udbResponse{
+ data: []byte(
+ `{"parameters":{"record":{"groupName":"stdlib","members":["stdlibcontrib"],"gid":182,"status":{"ecb5a44f1a5846ad871566e113bf8937":{"service":"io.systemd.NameServiceSwitch"}}},"incomplete":false}}`,
+ ),
+ },
+ `{"method":"io.systemd.UserDatabase.GetGroupRecord","parameters":{"service":"io.systemd.Multiplexer","groupName":"contrib"}}`: udbResponse{
+ data: []byte(
+ `{"parameters":{"record":{"groupName":"contrib","members":["stdlibcontrib"],"gid":183,"status":{"ecb5a44f1a5846ad871566e113bf8937":{"service":"io.systemd.NameServiceSwitch"}}},"incomplete":false}}`,
+ ),
+ },
+ }
+ userdbServer(t, tmpdir+"/"+svcMultiplexer, data)
+
+ username := "stdlibcontrib"
+ want := []string{"181", "182", "183"}
+ cl := &userdbClient{dir: tmpdir}
+ got, ok, err := cl.lookupGroupIds(context.Background(), username)
+ if !ok {
+ t.Fatal("lookup should have been handled")
+ }
+ if err != nil {
+ t.Fatal(err)
+ }
+ // Result order is not specified so sort it.
+ sort.Strings(got)
+ if !reflect.DeepEqual(got, want) {
+ t.Fatalf("lookupGroupIds(%s) = %v, want %v", username, got, want)
+ }
+}
+
+var findElementStartTestCases = []struct {
+ in []byte
+ want []byte
+ err bool
+}{
+ {in: []byte(`:`), want: []byte(``)},
+ {in: []byte(`: `), want: []byte(``)},
+ {in: []byte(`:"foo"`), want: []byte(`"foo"`)},
+ {in: []byte(` :"foo"`), want: []byte(`"foo"`)},
+ {in: []byte(` 1231 :"foo"`), err: true},
+ {in: []byte(``), err: true},
+ {in: []byte(`"foo"`), err: true},
+ {in: []byte(`foo`), err: true},
+}
+
+func TestFindElementStart(t *testing.T) {
+ for i, tc := range findElementStartTestCases {
+ t.Run("#"+strconv.Itoa(i), func(t *testing.T) {
+ got, err := findElementStart(tc.in)
+ if tc.err && err == nil {
+ t.Errorf("want err for findElementStart(%s), got nil", tc.in)
+ }
+ if !tc.err {
+ if err != nil {
+ t.Errorf("findElementStart(%s) unexpected error: %s", tc.in, err.Error())
+ }
+ if !bytes.Contains(tc.in, got) {
+ t.Errorf("%s should contain %s but does not", tc.in, got)
+ }
+ }
+ })
+ }
+}
+
+func FuzzFindElementStart(f *testing.F) {
+ for _, tc := range findElementStartTestCases {
+ if !tc.err {
+ f.Add(tc.in)
+ }
+ }
+ f.Fuzz(func(t *testing.T, b []byte) {
+ if out, err := findElementStart(b); err == nil && !bytes.Contains(b, out) {
+ t.Errorf("%s, %v", out, err)
+ }
+ })
+}
+
+var parseJSONStringTestCases = []struct {
+ in []byte
+ want string
+ err bool
+}{
+ {in: []byte(`:""`)},
+ {in: []byte(`:"\n"`), want: "\n"},
+ {in: []byte(`: "\""`), want: "\""},
+ {in: []byte(`:"\t \\"`), want: "\t \\"},
+ {in: []byte(`:"\\\\"`), want: `\\`},
+ {in: []byte(`::`), err: true},
+ {in: []byte(`""`), err: true},
+ {in: []byte(`"`), err: true},
+ {in: []byte(":\"0\xE5"), err: true},
+ {in: []byte{':', '"', 0xFE, 0xFE, 0xFF, 0xFF, '"'}, want: "\uFFFD\uFFFD\uFFFD\uFFFD"},
+ {in: []byte(`:"\u0061a"`), want: "aa"},
+ {in: []byte(`:"\u0159\u0170"`), want: "řŰ"},
+ {in: []byte(`:"\uD800\uDC00"`), want: "\U00010000"},
+ {in: []byte(`:"\uD800"`), want: "\uFFFD"},
+ {in: []byte(`:"\u000"`), err: true},
+ {in: []byte(`:"\u00MF"`), err: true},
+ {in: []byte(`:"\uD800\uDC0"`), err: true},
+}
+
+func TestParseJSONString(t *testing.T) {
+ for i, tc := range parseJSONStringTestCases {
+ t.Run("#"+strconv.Itoa(i), func(t *testing.T) {
+ got, err := parseJSONString(tc.in)
+ if tc.err && err == nil {
+ t.Errorf("want err for parseJSONString(%s), got nil", tc.in)
+ }
+ if !tc.err {
+ if err != nil {
+ t.Errorf("parseJSONString(%s) unexpected error: %s", tc.in, err.Error())
+ }
+ if tc.want != got {
+ t.Errorf("parseJSONString(%s) = %s, want %s", tc.in, got, tc.want)
+ }
+ }
+ })
+ }
+}
+
+func FuzzParseJSONString(f *testing.F) {
+ for _, tc := range parseJSONStringTestCases {
+ f.Add(tc.in)
+ }
+ f.Fuzz(func(t *testing.T, b []byte) {
+ if out, err := parseJSONString(b); err == nil && !utf8.ValidString(out) {
+ t.Errorf("parseJSONString(%s) = %s, invalid string", b, out)
+ }
+ })
+}
+
+var parseJSONInt64TestCases = []struct {
+ in []byte
+ want int64
+ err bool
+}{
+ {in: []byte(":1235"), want: 1235},
+ {in: []byte(": 123"), want: 123},
+ {in: []byte(":0")},
+ {in: []byte(":5012313123131231"), want: 5012313123131231},
+ {in: []byte("1231"), err: true},
+}
+
+func TestParseJSONInt64(t *testing.T) {
+ for i, tc := range parseJSONInt64TestCases {
+ t.Run("#"+strconv.Itoa(i), func(t *testing.T) {
+ got, err := parseJSONInt64(tc.in)
+ if tc.err && err == nil {
+ t.Errorf("want err for parseJSONInt64(%s), got nil", tc.in)
+ }
+ if !tc.err {
+ if err != nil {
+ t.Errorf("parseJSONInt64(%s) unexpected error: %s", tc.in, err.Error())
+ }
+ if tc.want != got {
+ t.Errorf("parseJSONInt64(%s) = %d, want %d", tc.in, got, tc.want)
+ }
+ }
+ })
+ }
+}
+
+func FuzzParseJSONInt64(f *testing.F) {
+ for _, tc := range parseJSONInt64TestCases {
+ f.Add(tc.in)
+ }
+ f.Fuzz(func(t *testing.T, b []byte) {
+ if out, err := parseJSONInt64(b); err == nil &&
+ !bytes.Contains(b, []byte(strconv.FormatInt(out, 10))) {
+ t.Errorf("parseJSONInt64(%s) = %d, %v", b, out, err)
+ }
+ })
+}
+
+var parseJSONBooleanTestCases = []struct {
+ in []byte
+ want bool
+ err bool
+}{
+ {in: []byte(": true "), want: true},
+ {in: []byte(":true "), want: true},
+ {in: []byte(": false "), want: false},
+ {in: []byte(":false "), want: false},
+ {in: []byte("true"), err: true},
+ {in: []byte("false"), err: true},
+ {in: []byte("foo"), err: true},
+}
+
+func TestParseJSONBoolean(t *testing.T) {
+ for i, tc := range parseJSONBooleanTestCases {
+ t.Run("#"+strconv.Itoa(i), func(t *testing.T) {
+ got, err := parseJSONBoolean(tc.in)
+ if tc.err && err == nil {
+ t.Errorf("want err for parseJSONBoolean(%s), got nil", tc.in)
+ }
+ if !tc.err {
+ if err != nil {
+ t.Errorf("parseJSONBoolean(%s) unexpected error: %s", tc.in, err.Error())
+ }
+ if tc.want != got {
+ t.Errorf("parseJSONBoolean(%s) = %t, want %t", tc.in, got, tc.want)
+ }
+ }
+ })
+ }
+}
+
+func FuzzParseJSONBoolean(f *testing.F) {
+ for _, tc := range parseJSONBooleanTestCases {
+ f.Add(tc.in)
+ }
+ f.Fuzz(func(t *testing.T, b []byte) {
+ if out, err := parseJSONBoolean(b); err == nil && !bytes.Contains(b, []byte(strconv.FormatBool(out))) {
+ t.Errorf("parseJSONBoolean(%s) = %t, %v", b, out, err)
+ }
+ })
+}
diff --git a/src/os/user/userdbclient_stub.go b/src/os/user/userdbclient_stub.go
new file mode 100644
index 0000000000..d31f065c3a
--- /dev/null
+++ b/src/os/user/userdbclient_stub.go
@@ -0,0 +1,29 @@
+// Copyright 2023 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+//go:build !linux
+
+package user
+
+import "context"
+
+func (cl userdbClient) lookupGroup(_ context.Context, _ string) (*Group, bool, error) {
+ return nil, false, nil
+}
+
+func (cl userdbClient) lookupGroupId(_ context.Context, _ string) (*Group, bool, error) {
+ return nil, false, nil
+}
+
+func (cl userdbClient) lookupUser(_ context.Context, _ string) (*User, bool, error) {
+ return nil, false, nil
+}
+
+func (cl userdbClient) lookupUserId(_ context.Context, _ string) (*User, bool, error) {
+ return nil, false, nil
+}
+
+func (cl userdbClient) lookupGroupIds(_ context.Context, _ string) ([]string, bool, error) {
+ return nil, false, nil
+}