Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

notify_socket.go: use sd_notify_barrier mechanism #3291

Merged
merged 2 commits into from
Dec 8, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
96 changes: 75 additions & 21 deletions notify_socket.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@ package main

import (
"bytes"
"errors"
"io"
"net"
"os"
"path"
Expand All @@ -11,7 +13,9 @@ import (

"github.com/opencontainers/runc/libcontainer"
"github.com/opencontainers/runtime-spec/specs-go"
"github.com/sirupsen/logrus"
"github.com/urfave/cli"
"golang.org/x/sys/unix"
)

type notifySocket struct {
Expand Down Expand Up @@ -141,29 +145,79 @@ func (n *notifySocket) run(pid1 int) error {
return nil
}
case b := <-fileChan:
var out bytes.Buffer
_, err = out.Write(b)
if err != nil {
return err
}
return notifyHost(client, b, pid1)
}
}
}

_, err = out.Write([]byte{'\n'})
if err != nil {
return err
}
// notifyHost tells the host (usually systemd) that the container reported READY.
// Also sends MAINPID and BARRIER.
func notifyHost(client *net.UnixConn, ready []byte, pid1 int) error {
_, err := client.Write(append(ready, '\n'))
if err != nil {
return err
}

_, err = client.Write(out.Bytes())
if err != nil {
return err
}
// now we can inform systemd to use pid1 as the pid to monitor
newPid := "MAINPID=" + strconv.Itoa(pid1)
_, err = client.Write([]byte(newPid + "\n"))
if err != nil {
return err
}

// now we can inform systemd to use pid1 as the pid to monitor
newPid := "MAINPID=" + strconv.Itoa(pid1)
_, err := client.Write([]byte(newPid + "\n"))
if err != nil {
return err
}
return nil
}
// wait for systemd to acknowledge the communication
return sdNotifyBarrier(client)
}

// errUnexpectedRead is reported when actual data was read from the pipe used
// to synchronize with systemd. Usually, that pipe is only closed.
var errUnexpectedRead = errors.New("unexpected read from synchronization pipe")

// sdNotifyBarrier performs synchronization with systemd by means of the sd_notify_barrier protocol.
func sdNotifyBarrier(client *net.UnixConn) error {
// Create a pipe for communicating with systemd daemon.
pipeR, pipeW, err := os.Pipe()
if err != nil {
return err
}

// Get the FD for the unix socket file to be able to do perform syscall.Sendmsg.
clientFd, err := client.File()
if err != nil {
return err
}

// Send the write end of the pipe along with a BARRIER=1 message.
fdRights := unix.UnixRights(int(pipeW.Fd()))
err = unix.Sendmsg(int(clientFd.Fd()), []byte("BARRIER=1"), fdRights, nil, 0)
if err != nil {
return &os.SyscallError{Syscall: "sendmsg", Err: err}
}

// Close our copy of pipeW.
err = pipeW.Close()
if err != nil {
return err
}

// Expect the read end of the pipe to be closed after 30 seconds.
err = pipeR.SetReadDeadline(time.Now().Add(30 * time.Second))
if err != nil {
return nil
}

// Read a single byte expecting EOF.
var buf [1]byte
n, err := pipeR.Read(buf[:])
if n != 0 || err == nil {
return errUnexpectedRead
} else if errors.Is(err, os.ErrDeadlineExceeded) {
// Probably the other end doesn't support the sd_notify_barrier protocol.
logrus.Warn("Timeout after waiting 30s for barrier. Ignored.")
return nil
} else if err == io.EOF { //nolint:errorlint // comparison with io.EOF is legit.
return nil
} else {
return err
}
}
120 changes: 120 additions & 0 deletions notify_socket_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,120 @@
package main

import (
"bytes"
"io"
"net"
"testing"
"time"

"golang.org/x/sys/unix"
)

// TestNotifyHost tests how runc reports container readiness to the host (usually systemd).
func TestNotifyHost(t *testing.T) {
addr := net.UnixAddr{
Name: t.TempDir() + "/testsocket",
Net: "unixgram",
}

server, err := net.ListenUnixgram("unixgram", &addr)
if err != nil {
t.Fatal(err)
}
defer server.Close()

client, err := net.DialUnix("unixgram", nil, &addr)
if err != nil {
t.Fatal(err)
}
defer client.Close()

// run notifyHost in a separate goroutine
notifyHostChan := make(chan error)
go func() {
notifyHostChan <- notifyHost(client, []byte("READY=42"), 1337)
}()

// mock a host process listening for runc's notifications
expectRead(t, server, "READY=42\n")
expectRead(t, server, "MAINPID=1337\n")
expectBarrier(t, server, notifyHostChan)
}

func expectRead(t *testing.T, r io.Reader, expected string) {
var buf [1024]byte
n, err := r.Read(buf[:])
if err != nil {
t.Fatal(err)
}
if !bytes.Equal(buf[:n], []byte(expected)) {
t.Fatalf("Expected to read '%s' but runc sent '%s' instead", expected, buf[:n])
}
}

func expectBarrier(t *testing.T, conn *net.UnixConn, notifyHostChan <-chan error) {
var msg, oob [1024]byte
n, oobn, _, _, err := conn.ReadMsgUnix(msg[:], oob[:])
if err != nil {
t.Fatal("Failed to receive BARRIER message", err)
}
if !bytes.Equal(msg[:n], []byte("BARRIER=1")) {
t.Fatalf("Expected to receive 'BARRIER=1' but got '%s' instead.", msg[:n])
}

fd := mustExtractFd(t, oob[:oobn])

// Test whether notifyHost actually honors the barrier
timer := time.NewTimer(500 * time.Millisecond)
select {
case <-timer.C:
// this is the expected case
break
case <-notifyHostChan:
t.Fatal("runc has terminated before barrier was lifted")
}

// Lift the barrier
err = unix.Close(fd)
if err != nil {
t.Fatal(err)
}

// Expect notifyHost to terminate now
err = <-notifyHostChan
if err != nil {
t.Fatal("notifyHost function returned with error", err)
}
}

func mustExtractFd(t *testing.T, buf []byte) int {
cmsgs, err := unix.ParseSocketControlMessage(buf)
if err != nil {
t.Fatal("Failed to parse control message", err)
}

fd := 0
seenScmRights := false
for _, cmsg := range cmsgs {
if cmsg.Header.Type != unix.SCM_RIGHTS {
continue
}
if seenScmRights {
t.Fatal("Expected to see exactly one SCM_RIGHTS message, but got a second one")
}
seenScmRights = true
fds, err := unix.ParseUnixRights(&cmsg)
if err != nil {
t.Fatal("Failed to parse SCM_RIGHTS message", err)
}
if len(fds) != 1 {
t.Fatal("Expected to read exactly one file descriptor, but got", len(fds))
}
fd = fds[0]
}
if !seenScmRights {
t.Fatal("Control messages didn't contain an SCM_RIGHTS message")
}

return fd
}