diff options
Diffstat (limited to 'tests/root_test/root_test.go')
-rw-r--r-- | tests/root_test/root_test.go | 178 |
1 files changed, 169 insertions, 9 deletions
diff --git a/tests/root_test/root_test.go b/tests/root_test/root_test.go index 26caafc..079c03b 100644 --- a/tests/root_test/root_test.go +++ b/tests/root_test/root_test.go @@ -1,11 +1,19 @@ +// Package root_test contains tests that need root +// permissions to run package root_test import ( + "io/ioutil" "os" + "os/exec" + "path/filepath" "runtime" + "sync" "syscall" "testing" + "golang.org/x/sys/unix" + "github.com/rfjakob/gocryptfs/tests/test_helpers" ) @@ -13,25 +21,57 @@ func asUser(uid int, gid int, supplementaryGroups []int, f func() error) error { runtime.LockOSThread() defer runtime.UnlockOSThread() - err := syscall.Setgroups(supplementaryGroups) + err := unix.Setgroups(supplementaryGroups) if err != nil { return err } - defer syscall.Setgroups(nil) - - err = syscall.Setregid(-1, gid) + defer func() { + err = unix.Setgroups(nil) + if err != nil { + panic(err) + } + }() + err = unix.Setregid(-1, gid) if err != nil { return err } - defer syscall.Setregid(-1, 0) - - err = syscall.Setreuid(-1, uid) + defer func() { + err = unix.Setregid(-1, 0) + if err != nil { + panic(err) + } + }() + err = unix.Setreuid(-1, uid) if err != nil { return err } - defer syscall.Setreuid(-1, 0) + defer func() { + err = unix.Setreuid(-1, 0) + if err != nil { + panic(err) + } + }() + + ret := f() - return f() + // Also reset the saved user id (suid) and saved group id (sgid) to prevent + // bizarre failures in later tests. + // + // Yes, the kernel checks that *all of them* match: + // https://git.kernel.org/pub/scm/linux/kernel/git/torvalds/linux.git/tree/fs/fuse/dir.c?h=v5.12-rc2#n1193 + // + // How to check: + // ps -o tid,pid,euid,ruid,suid,egid,rgid,sgid,cmd -eL + err = unix.Setresuid(0, 0, 0) + if err != nil { + panic(err) + } + err = unix.Setresgid(0, 0, 0) + if err != nil { + panic(err) + } + + return ret } func TestSupplementaryGroups(t *testing.T) { @@ -73,3 +113,123 @@ func TestSupplementaryGroups(t *testing.T) { t.Error(err) } } + +func writeTillFull(t *testing.T, path string) (int, syscall.Errno) { + runtime.LockOSThread() + defer runtime.UnlockOSThread() + + fd, err := syscall.Creat(path, 0600) + if err != nil { + return 0, err.(syscall.Errno) + } + defer syscall.Close(fd) + // Write in 100.000 byte-blocks, which is not aligend to the + // underlying block size + buf := make([]byte, 100000) + var sz int + for { + n, err := syscall.Write(fd, buf) + if err != nil { + return sz, err.(syscall.Errno) + } + sz += n + } + return sz, 0 +} + +func TestDiskFull(t *testing.T) { + if os.Getuid() != 0 { + t.Skip("must run as root") + } + + // Create 10 MB file full of zeros + ext4img := filepath.Join(test_helpers.TmpDir, t.Name()+".ext4") + f, err := os.Create(ext4img) + if err != nil { + t.Fatal(err) + } + defer f.Close() + err = f.Truncate(10 * 1024 * 1024) + if err != nil { + t.Fatal(err) + } + + // Format as ext4 + cmd := exec.Command("mkfs.ext4", ext4img) + out, err := cmd.CombinedOutput() + if err != nil { + t.Log(string(out)) + t.Fatal(err) + } + + // Mount ext4 + ext4mnt := ext4img + ".mnt" + err = os.Mkdir(ext4mnt, 0600) + if err != nil { + t.Fatal(err) + } + cmd = exec.Command("mount", ext4img, ext4mnt) + out, err = cmd.CombinedOutput() + if err != nil { + t.Log(string(out)) + t.Fatal(err) + } + defer syscall.Unlink(ext4img) + defer syscall.Unmount(ext4mnt, 0) + + // gocryptfs -init + cipherdir := ext4mnt + "/a" + if err = os.Mkdir(cipherdir, 0600); err != nil { + t.Fatal(err) + } + cmd = exec.Command(test_helpers.GocryptfsBinary, "-q", "-init", "-extpass", "echo test", "-scryptn=10", cipherdir) + out, err = cmd.CombinedOutput() + if err != nil { + t.Log(string(out)) + t.Fatal(err) + } + + // Mount gocryptfs + mnt := ext4mnt + "/b" + err = os.Mkdir(mnt, 0600) + if err != nil { + t.Fatal(err) + } + test_helpers.MountOrFatal(t, cipherdir, mnt, "-extpass", "echo test") + defer test_helpers.UnmountPanic(mnt) + + // Write till we get ENOSPC + var err1, err2 error + var sz1, sz2 int + var wg sync.WaitGroup + wg.Add(2) + go func() { + sz1, err1 = writeTillFull(t, mnt+"/foo1") + wg.Done() + }() + go func() { + sz2, err2 = writeTillFull(t, mnt+"/foo2") + wg.Done() + }() + wg.Wait() + if err1 != syscall.ENOSPC || err2 != syscall.ENOSPC { + t.Fatalf("err1=%v, err2=%v", err1, err2) + } + t.Logf("sz1=%d, sz2=%d", sz1, sz2) + + foo1, err := ioutil.ReadFile(mnt + "/foo1") + if err != nil { + t.Fatal(err) + } + if len(foo1) != sz1 { + t.Fail() + } + + foo2, err := ioutil.ReadFile(mnt + "/foo2") + if err != nil { + t.Fatal(err) + } + if len(foo2) != sz2 { + t.Fail() + } +} |