Skip to content

Commit

Permalink
notify_socket.go: Use sd_notify_barrier mechanism
Browse files Browse the repository at this point in the history
Signed-off-by: Jonas Eschenburg <jonas.eschenburg@kuka.com>
  • Loading branch information
Jonas Eschenburg committed Dec 2, 2021
1 parent d1f316f commit 007af91
Show file tree
Hide file tree
Showing 2 changed files with 201 additions and 21 deletions.
102 changes: 81 additions & 21 deletions notify_socket.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@ package main

import (
"bytes"
"errors"
"io"
"net"
"os"
"path"
Expand All @@ -12,6 +14,7 @@ import (
"github.com/opencontainers/runc/libcontainer"
"github.com/opencontainers/runtime-spec/specs-go"
"github.com/urfave/cli"
"golang.org/x/sys/unix"
)

type notifySocket struct {
Expand Down Expand Up @@ -142,29 +145,86 @@ func (n *notifySocket) run(pid1 int) error {
return nil
}
case b := <-fileChan:
var out bytes.Buffer
_, err = out.Write(b)
if err != nil {
return err
}
return notifyHost(client, b, pid1)
}
}
}

_, err = out.Write([]byte{'\n'})
if err != nil {
return err
}
// notifyHost tells the host (usually systemd) that the container reported READY.
// Also sends MAINPID and BARRIER.
func notifyHost(client *net.UnixConn, ready []byte, pid1 int) error {
var out bytes.Buffer
_, err := out.Write(ready)
if err != nil {
return err
}

_, err = client.Write(out.Bytes())
if err != nil {
return err
}
_, err = out.Write([]byte{'\n'})
if err != nil {
return err
}

// now we can inform systemd to use pid1 as the pid to monitor
newPid := "MAINPID=" + strconv.Itoa(pid1)
_, err := client.Write([]byte(newPid + "\n"))
if err != nil {
return err
}
return nil
}
_, err = client.Write(out.Bytes())
if err != nil {
return err
}

// now we can inform systemd to use pid1 as the pid to monitor
newPid := "MAINPID=" + strconv.Itoa(pid1)
_, err = client.Write([]byte(newPid + "\n"))
if err != nil {
return err
}

// wait for systemd to acknowledge the communication
return sdNotifyBarrier(client)
}

// errUnexpectedRead is reported when actual data was read from the pipe used
// to synchronize with systemd. Usually, that pipe is only closed.
var errUnexpectedRead = errors.New("unexpected read from synchronization pipe")

// sdNotifyBarrier performs synchronization with systemd by means of the sd_notify_barrier protocol.
func sdNotifyBarrier(client *net.UnixConn) error {
// Create a pipe for communicating with systemd daemon.
pipeR, pipeW, err := os.Pipe()
if err != nil {
return err
}

// Get the FD for the unix socket file to be able to do perform syscall.Sendmsg.
clientFd, err := client.File()
if err != nil {
return err
}

// Send the write end of the pipe along with a BARRIER=1 message.
fdRights := unix.UnixRights(int(pipeW.Fd()))
err = unix.Sendmsg(int(clientFd.Fd()), []byte("BARRIER=1"), fdRights, nil, 0)
if err != nil {
return &os.SyscallError{Syscall: "sendmsg", Err: err}
}

// Close our copy of pipeW.
err = pipeW.Close()
if err != nil {
return err
}

// Expect the read end of the pipe to be closed after 5 seconds.
err = pipeR.SetReadDeadline(time.Now().Add(5 * time.Second))
if err != nil {
return nil
}

// Read a single byte expecting EOF.
var buf [1]byte
n, err := pipeR.Read(buf[:])
if n != 0 || err == nil {
return errUnexpectedRead
} else if err == io.EOF { //nolint:errorlint // comparison with io.EOF is legit.
return nil
} else {
return err
}
}
120 changes: 120 additions & 0 deletions notify_socket_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,120 @@
package main

import (
"bytes"
"io"
"net"
"testing"
"time"

"golang.org/x/sys/unix"
)

// TestNotifyHost tests how runc reports container readiness to the host (usually systemd).
func TestNotifyHost(t *testing.T) {
addr := net.UnixAddr{
Name: t.TempDir() + "/testsocket",
Net: "unixgram",
}

server, err := net.ListenUnixgram("unixgram", &addr)
if err != nil {
t.Fatal(err)
}
defer server.Close()

client, err := net.DialUnix("unixgram", nil, &addr)
if err != nil {
t.Fatal(err)
}
defer client.Close()

// run notifyHost in a separate goroutine
notifyHostChan := make(chan error)
go func() {
notifyHostChan <- notifyHost(client, []byte("READY=42"), 1337)
}()

// mock a host process listening for runc's notifications
expectRead(t, server, "READY=42\n")
expectRead(t, server, "MAINPID=1337\n")
expectBarrier(t, server, notifyHostChan)
}

func expectRead(t *testing.T, r io.Reader, expected string) {
var buf [1024]byte
n, err := r.Read(buf[:])
if err != nil {
t.Fatal(err)
}
if !bytes.Equal(buf[:n], []byte(expected)) {
t.Fatalf("Expected to read '%s' but runc sent '%s' instead", expected, buf[:n])
}
}

func expectBarrier(t *testing.T, conn *net.UnixConn, notifyHostChan <-chan error) {
var msg, oob [1024]byte
n, oobn, _, _, err := conn.ReadMsgUnix(msg[:], oob[:])
if err != nil {
t.Fatal("Failed to receive BARRIER message", err)
}
if !bytes.Equal(msg[:n], []byte("BARRIER=1")) {
t.Fatalf("Expected to receive 'BARRIER=1' but got '%s' instead.", msg[:n])
}

fd := mustExtractFd(t, oob[:oobn])

// Test whether notifyHost actually honors the barrier
timer := time.NewTimer(500 * time.Millisecond)
select {
case <-timer.C:
// this is the expected case
break
case <-notifyHostChan:
t.Fatal("runc has terminated before barrier was lifted")
}

// Lift the barrier
err = unix.Close(fd)
if err != nil {
t.Fatal(err)
}

// Expect notifyHost to terminate now
err = <-notifyHostChan
if err != nil {
t.Fatal("notifyHost function returned with error", err)
}
}

func mustExtractFd(t *testing.T, buf []byte) int {
cmsgs, err := unix.ParseSocketControlMessage(buf)
if err != nil {
t.Fatal("Failed to parse control message", err)
}

fd := 0
seenScmRights := false
for _, cmsg := range cmsgs {
if cmsg.Header.Type != unix.SCM_RIGHTS {
continue
}
if seenScmRights {
t.Fatal("Expected to see exactly one SCM_RIGHTS message, but got a second one")
}
seenScmRights = true
fds, err := unix.ParseUnixRights(&cmsg)
if err != nil {
t.Fatal("Failed to parse SCM_RIGHTS message", err)
}
if len(fds) != 1 {
t.Fatal("Expected to read exactly one file descriptor, but got", len(fds))
}
fd = fds[0]
}
if !seenScmRights {
t.Fatal("Control messages didn't contain an SCM_RIGHTS message")
}

return fd
}

0 comments on commit 007af91

Please sign in to comment.