From 05c8d4a1c459dc4c73918af898119b1e17b1c0ce Mon Sep 17 00:00:00 2001
From: Jakob Unterwurzacher
Date: Sat, 22 Sep 2018 19:41:58 +0200
Subject: tests: add symlink_race tool

Help uncover symlink races.
---
 tests/symlink_race/.gitignore |  3 ++
 tests/symlink_race/main.go    | 91 +++++++++++++++++++++++++++++++++++++++++++
 2 files changed, 94 insertions(+)
 create mode 100644 tests/symlink_race/.gitignore
 create mode 100644 tests/symlink_race/main.go

diff --git a/tests/symlink_race/.gitignore b/tests/symlink_race/.gitignore
new file mode 100644
index 0000000..45df5f8
--- /dev/null
+++ b/tests/symlink_race/.gitignore
@@ -0,0 +1,3 @@
+symlink_race.test_file.tmp
+symlink_race.test_file
+symlink_race
diff --git a/tests/symlink_race/main.go b/tests/symlink_race/main.go
new file mode 100644
index 0000000..6fb794f
--- /dev/null
+++ b/tests/symlink_race/main.go
@@ -0,0 +1,91 @@
+package main
+
+import (
+	"fmt"
+	"os"
+	"syscall"
+	"time"
+)
+
+const (
+	testFile    = "symlink_race.test_file"
+	testFileTmp = testFile + ".tmp"
+)
+
+func renameLoop() {
+	// May be left behind from an earlier run
+	syscall.Unlink(testFileTmp)
+
+	var err error
+	var fd *os.File
+	for {
+		err = syscall.Symlink("/root/chmod_me", testFileTmp)
+		if err != nil {
+			fmt.Printf("Symlink() failed: %v\n", err)
+			continue
+		}
+		err = syscall.Rename(testFileTmp, testFile)
+		if err != nil {
+			fmt.Printf("Rename() 1 failed: %v\n", err)
+			continue
+		}
+		fd, err = os.Create(testFileTmp)
+		if err != nil {
+			fmt.Printf("Create() failed: %v\n", err)
+			continue
+		}
+		fd.Close()
+		err = syscall.Rename(testFileTmp, testFile)
+		if err != nil {
+			fmt.Printf("Rename() 2 failed: %v\n", err)
+			continue
+		}
+		fmt.Printf(".")
+	}
+}
+
+func chmodLoop() {
+	var err error
+	for {
+		err = syscall.Chmod(testFile, 0777)
+		if err != nil {
+			fmt.Printf("Chmod() failed: %v\n", err)
+		} else {
+			fmt.Printf("Chmod() ok\n")
+		}
+		time.Sleep(100 * time.Microsecond)
+	}
+}
+
+func openLoop() {
+	var err error
+	var f *os.File
+	buf := make([]byte, 100)
+	owned := []byte("owned")
+	var n int
+	for {
+		f, err = os.OpenFile(testFile, os.O_RDWR, 0777)
+		if err != nil {
+			fmt.Printf("Open() failed: %v\n", err)
+			continue
+		}
+		_, err = f.Write(owned)
+		if err != nil {
+			fmt.Printf("Write() failed: %v\n", err)
+		}
+		n, err = f.Read(buf)
+		if err != nil {
+			fmt.Printf("Read() failed: %v\n", err)
+			continue
+		}
+		if n > 0 {
+			fmt.Printf("Content: %q\n", string(buf[:n]))
+			os.Exit(1)
+		}
+	}
+}
+
+func main() {
+	go openLoop()
+	renameLoop()
+}
-- 
cgit v1.2.3