//go:build linux || freebsd // +build linux freebsd package testutils import ( "fmt" "runtime" "strconv" "testing" "github.com/docker/docker/libnetwork/ns" "github.com/pkg/errors" "github.com/vishvananda/netns" "golang.org/x/sys/unix" ) // OSContext is a handle to a test OS context. type OSContext struct { origNS, newNS netns.NsHandle tid int caller string // The file:line where SetupTestOSContextEx was called, for interpolating into error messages. } // SetupTestOSContextEx joins the current goroutine to a new network namespace. // // Compared to [SetupTestOSContext], this function allows goroutines to be // spawned which are associated with the same OS context via the returned // OSContext value. // // Example usage: // // c := SetupTestOSContext(t) // defer c.Cleanup(t) func SetupTestOSContextEx(t *testing.T) *OSContext { runtime.LockOSThread() origNS, err := netns.Get() if err != nil { runtime.UnlockOSThread() t.Fatalf("Failed to open initial netns: %v", err) } c := OSContext{ tid: unix.Gettid(), origNS: origNS, } c.newNS, err = netns.New() if err != nil { // netns.New() is not atomic: it could have encountered an error // after unsharing the current thread's network namespace. c.restore(t) t.Fatalf("Failed to enter netns: %v", err) } // Since we are switching to a new test namespace make // sure to re-initialize initNs context ns.Init() nl := ns.NlHandle() lo, err := nl.LinkByName("lo") if err != nil { c.restore(t) t.Fatalf("Failed to get handle to loopback interface 'lo' in new netns: %v", err) } if err := nl.LinkSetUp(lo); err != nil { c.restore(t) t.Fatalf("Failed to enable loopback interface in new netns: %v", err) } _, file, line, ok := runtime.Caller(0) if ok { c.caller = file + ":" + strconv.Itoa(line) } return &c } // Cleanup tears down the OS context. It must be called from the same goroutine // as the [SetupTestOSContextEx] call which returned c. // // Explicit cleanup is required as (*testing.T).Cleanup() makes no guarantees // about which goroutine the cleanup functions are invoked on. func (c *OSContext) Cleanup(t *testing.T) { t.Helper() if unix.Gettid() != c.tid { t.Fatalf("c.Cleanup() must be called from the same goroutine as SetupTestOSContextEx() (%s)", c.caller) } if err := c.newNS.Close(); err != nil { t.Logf("Warning: netns closing failed (%v)", err) } c.restore(t) ns.Init() } func (c *OSContext) restore(t *testing.T) { t.Helper() if err := netns.Set(c.origNS); err != nil { t.Logf("Warning: failed to restore thread netns (%v)", err) } else { runtime.UnlockOSThread() } if err := c.origNS.Close(); err != nil { t.Logf("Warning: netns closing failed (%v)", err) } } // Set sets the OS context of the calling goroutine to c and returns a teardown // function to restore the calling goroutine's OS context and release resources. // The teardown function accepts an optional Logger argument. // // This is a lower-level interface which is less ergonomic than c.Go() but more // composable with other goroutine-spawning utilities such as [sync.WaitGroup] // or [golang.org/x/sync/errgroup.Group]. // // Example usage: // // func TestFoo(t *testing.T) { // osctx := testutils.SetupTestOSContextEx(t) // defer osctx.Cleanup(t) // var eg errgroup.Group // eg.Go(func() error { // teardown, err := osctx.Set() // if err != nil { // return err // } // defer teardown(t) // // ... // }) // if err := eg.Wait(); err != nil { // t.Fatalf("%+v", err) // } // } func (c *OSContext) Set() (func(Logger), error) { runtime.LockOSThread() orig, err := netns.Get() if err != nil { runtime.UnlockOSThread() return nil, errors.Wrap(err, "failed to open initial netns for goroutine") } if err := errors.WithStack(netns.Set(c.newNS)); err != nil { runtime.UnlockOSThread() return nil, errors.Wrap(err, "failed to set goroutine network namespace") } tid := unix.Gettid() _, file, line, callerOK := runtime.Caller(0) return func(log Logger) { if unix.Gettid() != tid { msg := "teardown function must be called from the same goroutine as c.Set()" if callerOK { msg += fmt.Sprintf(" (%s:%d)", file, line) } panic(msg) } if err := netns.Set(orig); err != nil && log != nil { log.Logf("Warning: failed to restore goroutine thread netns (%v)", err) } else { runtime.UnlockOSThread() } if err := orig.Close(); err != nil && log != nil { log.Logf("Warning: netns closing failed (%v)", err) } }, nil }