aboutsummaryrefslogtreecommitdiff
path: root/tests/root_test
diff options
context:
space:
mode:
Diffstat (limited to 'tests/root_test')
-rw-r--r--tests/root_test/root_test.go178
1 files changed, 169 insertions, 9 deletions
diff --git a/tests/root_test/root_test.go b/tests/root_test/root_test.go
index 26caafc..079c03b 100644
--- a/tests/root_test/root_test.go
+++ b/tests/root_test/root_test.go
@@ -1,11 +1,19 @@
+// Package root_test contains tests that need root
+// permissions to run
package root_test
import (
+ "io/ioutil"
"os"
+ "os/exec"
+ "path/filepath"
"runtime"
+ "sync"
"syscall"
"testing"
+ "golang.org/x/sys/unix"
+
"github.com/rfjakob/gocryptfs/tests/test_helpers"
)
@@ -13,25 +21,57 @@ func asUser(uid int, gid int, supplementaryGroups []int, f func() error) error {
runtime.LockOSThread()
defer runtime.UnlockOSThread()
- err := syscall.Setgroups(supplementaryGroups)
+ err := unix.Setgroups(supplementaryGroups)
if err != nil {
return err
}
- defer syscall.Setgroups(nil)
-
- err = syscall.Setregid(-1, gid)
+ defer func() {
+ err = unix.Setgroups(nil)
+ if err != nil {
+ panic(err)
+ }
+ }()
+ err = unix.Setregid(-1, gid)
if err != nil {
return err
}
- defer syscall.Setregid(-1, 0)
-
- err = syscall.Setreuid(-1, uid)
+ defer func() {
+ err = unix.Setregid(-1, 0)
+ if err != nil {
+ panic(err)
+ }
+ }()
+ err = unix.Setreuid(-1, uid)
if err != nil {
return err
}
- defer syscall.Setreuid(-1, 0)
+ defer func() {
+ err = unix.Setreuid(-1, 0)
+ if err != nil {
+ panic(err)
+ }
+ }()
+
+ ret := f()
- return f()
+ // Also reset the saved user id (suid) and saved group id (sgid) to prevent
+ // bizarre failures in later tests.
+ //
+ // Yes, the kernel checks that *all of them* match:
+ // https://git.kernel.org/pub/scm/linux/kernel/git/torvalds/linux.git/tree/fs/fuse/dir.c?h=v5.12-rc2#n1193
+ //
+ // How to check:
+ // ps -o tid,pid,euid,ruid,suid,egid,rgid,sgid,cmd -eL
+ err = unix.Setresuid(0, 0, 0)
+ if err != nil {
+ panic(err)
+ }
+ err = unix.Setresgid(0, 0, 0)
+ if err != nil {
+ panic(err)
+ }
+
+ return ret
}
func TestSupplementaryGroups(t *testing.T) {
@@ -73,3 +113,123 @@ func TestSupplementaryGroups(t *testing.T) {
t.Error(err)
}
}
+
+func writeTillFull(t *testing.T, path string) (int, syscall.Errno) {
+ runtime.LockOSThread()
+ defer runtime.UnlockOSThread()
+
+ fd, err := syscall.Creat(path, 0600)
+ if err != nil {
+ return 0, err.(syscall.Errno)
+ }
+ defer syscall.Close(fd)
+ // Write in 100.000 byte-blocks, which is not aligend to the
+ // underlying block size
+ buf := make([]byte, 100000)
+ var sz int
+ for {
+ n, err := syscall.Write(fd, buf)
+ if err != nil {
+ return sz, err.(syscall.Errno)
+ }
+ sz += n
+ }
+ return sz, 0
+}
+
+func TestDiskFull(t *testing.T) {
+ if os.Getuid() != 0 {
+ t.Skip("must run as root")
+ }
+
+ // Create 10 MB file full of zeros
+ ext4img := filepath.Join(test_helpers.TmpDir, t.Name()+".ext4")
+ f, err := os.Create(ext4img)
+ if err != nil {
+ t.Fatal(err)
+ }
+ defer f.Close()
+ err = f.Truncate(10 * 1024 * 1024)
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ // Format as ext4
+ cmd := exec.Command("mkfs.ext4", ext4img)
+ out, err := cmd.CombinedOutput()
+ if err != nil {
+ t.Log(string(out))
+ t.Fatal(err)
+ }
+
+ // Mount ext4
+ ext4mnt := ext4img + ".mnt"
+ err = os.Mkdir(ext4mnt, 0600)
+ if err != nil {
+ t.Fatal(err)
+ }
+ cmd = exec.Command("mount", ext4img, ext4mnt)
+ out, err = cmd.CombinedOutput()
+ if err != nil {
+ t.Log(string(out))
+ t.Fatal(err)
+ }
+ defer syscall.Unlink(ext4img)
+ defer syscall.Unmount(ext4mnt, 0)
+
+ // gocryptfs -init
+ cipherdir := ext4mnt + "/a"
+ if err = os.Mkdir(cipherdir, 0600); err != nil {
+ t.Fatal(err)
+ }
+ cmd = exec.Command(test_helpers.GocryptfsBinary, "-q", "-init", "-extpass", "echo test", "-scryptn=10", cipherdir)
+ out, err = cmd.CombinedOutput()
+ if err != nil {
+ t.Log(string(out))
+ t.Fatal(err)
+ }
+
+ // Mount gocryptfs
+ mnt := ext4mnt + "/b"
+ err = os.Mkdir(mnt, 0600)
+ if err != nil {
+ t.Fatal(err)
+ }
+ test_helpers.MountOrFatal(t, cipherdir, mnt, "-extpass", "echo test")
+ defer test_helpers.UnmountPanic(mnt)
+
+ // Write till we get ENOSPC
+ var err1, err2 error
+ var sz1, sz2 int
+ var wg sync.WaitGroup
+ wg.Add(2)
+ go func() {
+ sz1, err1 = writeTillFull(t, mnt+"/foo1")
+ wg.Done()
+ }()
+ go func() {
+ sz2, err2 = writeTillFull(t, mnt+"/foo2")
+ wg.Done()
+ }()
+ wg.Wait()
+ if err1 != syscall.ENOSPC || err2 != syscall.ENOSPC {
+ t.Fatalf("err1=%v, err2=%v", err1, err2)
+ }
+ t.Logf("sz1=%d, sz2=%d", sz1, sz2)
+
+ foo1, err := ioutil.ReadFile(mnt + "/foo1")
+ if err != nil {
+ t.Fatal(err)
+ }
+ if len(foo1) != sz1 {
+ t.Fail()
+ }
+
+ foo2, err := ioutil.ReadFile(mnt + "/foo2")
+ if err != nil {
+ t.Fatal(err)
+ }
+ if len(foo2) != sz2 {
+ t.Fail()
+ }
+}