Skip to content

Commit

Permalink
Merge pull request #105 from potoo0/main
Browse files Browse the repository at this point in the history
feat: writefile magic func
  • Loading branch information
janpfeifer authored Apr 5, 2024
2 parents 5912614 + 49a8b4a commit 4027fab
Show file tree
Hide file tree
Showing 4 changed files with 187 additions and 5 deletions.
33 changes: 33 additions & 0 deletions common/common.go
Original file line number Diff line number Diff line change
Expand Up @@ -293,3 +293,36 @@ func (f *ArrayFlag) Set(value string) error {
*f = append(*f, value)
return nil
}

// FlagsParse parse args to map
func FlagsParse(args []string, noValArg Set[string], schema map[string]string) map[string]string {
keyPos := 0 // position arg
keyGen := func() string {
keyPos++
return fmt.Sprintf("-pos%d", keyPos)
}
resultMap := make(map[string]string)
var key string
for _, arg := range args {
switch {
case len(arg) > 2 && arg[:2] == "--":
key = arg[2:]
resultMap[key] = ""
case len(arg) > 1 && arg[0] == '-':
d, ok := schema[arg[1:]]
if ok && len(d) > 0 {
key = d
} else {
key = arg[1:]
}
resultMap[key] = ""
case len(arg) > 0 && arg[0] != '-':
if noValArg.Has(key) || key == "" {
key = keyGen()
}
resultMap[key] = arg
key = ""
}
}
return resultMap
}
31 changes: 31 additions & 0 deletions common/common_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
package common

import (
"github.com/stretchr/testify/assert"
"testing"
)

func TestFlagsParse(t *testing.T) {
args := []string{
"-a",
"pos-arg1",
"-b",
"vb",
"pos-arg2",
}
noValArg := MakeSet[string](2)
noValArg.Insert("append")
schema := map[string]string{
"a": "append",
"b": "block",
}
actual := FlagsParse(args, noValArg, schema)

expected := map[string]string{
"-pos1": "pos-arg1",
"-pos2": "pos-arg2",
"append": "",
"block": "vb",
}
assert.Equal(t, expected, actual)
}
70 changes: 65 additions & 5 deletions internal/specialcmd/specialcmd.go
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ type cellStatus struct {
func Parse(msg kernel.Message, goExec *goexec.State, execute bool, codeLines []string, usedLines Set[int]) (err error) {
status := &cellStatus{}
for lineNum := 0; lineNum < len(codeLines); lineNum++ {
if _, found := usedLines[lineNum]; found {
if usedLines.Has(lineNum) {
continue
}
line := codeLines[lineNum]
Expand All @@ -65,9 +65,19 @@ func Parse(msg kernel.Message, goExec *goexec.State, execute bool, codeLines []s
if execute {
switch cmdType {
case '%':
err = execInternal(msg, goExec, cmdStr, status)
if err != nil {
return
parts := splitCmd(cmdStr)
// optimize...
if len(parts) > 0 && parts[0] == "writefile" {
cmdBody := parseCmdBody(codeLines, lineNum, usedLines)
err = execWriteFile(msg, goExec, parts[1:], cmdBody)
if err != nil {
return
}
} else {
err = execInternal(msg, goExec, cmdStr, status)
if err != nil {
return
}
}
case '!':
err = execShell(msg, goExec, cmdStr, status)
Expand Down Expand Up @@ -95,7 +105,7 @@ func Parse(msg kernel.Message, goExec *goexec.State, execute bool, codeLines []s
func joinLine(lines []string, fromLine int, usedLines Set[int]) (cmdStr string) {
for ; fromLine < len(lines); fromLine++ {
cmdStr += lines[fromLine]
usedLines[fromLine] = struct{}{}
usedLines.Insert(fromLine)
if cmdStr[len(cmdStr)-1] != '\\' {
return
}
Expand All @@ -104,6 +114,23 @@ func joinLine(lines []string, fromLine int, usedLines Set[int]) (cmdStr string)
return
}

// parseCmdBody starts from fromLine and joins consecutive lines until the line start with magic symbol( % ! )
//
// It returns the joined lines with the '\n', and appends the used lines (including fromLine) to usedLines.
func parseCmdBody(lines []string, fromLine int, usedLines Set[int]) (cmdBody string) {
usedLines.Insert(fromLine)
fromLine++
for ; fromLine < len(lines); fromLine++ {
if len(lines[fromLine]) > 0 && (lines[fromLine][0] == '%' || lines[fromLine][0] == '!') {
return
}
cmdBody += lines[fromLine]
cmdBody += "\n"
usedLines.Insert(fromLine)
}
return
}

// execInternal executes internal configuration commands, see HelpMessage for details.
//
// It only returns errors for system errors that will lead to the kernel restart. Syntax errors
Expand Down Expand Up @@ -275,6 +302,39 @@ func execInternal(msg kernel.Message, goExec *goexec.State, cmdStr string, statu
return nil
}

// execWriteFile write cell body to file
func execWriteFile(msg kernel.Message, goExec *goexec.State, args []string, cmdBody string) error {
// parse arg
noValArg := MakeSet[string](2)
noValArg.Insert("append")
schema := map[string]string{"a": "append"}
parse := FlagsParse(args, noValArg, schema)
_, appendMode := parse["append"]
filename, hasFileName := parse["-pos1"]
if !hasFileName {
filename = goExec.UniqueID + ".out"
}

// do write
fileFlag := os.O_RDWR | os.O_CREATE
if appendMode {
fileFlag |= os.O_APPEND
} else {
fileFlag |= os.O_TRUNC
}
file, err := os.OpenFile(filename, fileFlag, 0666)
if err != nil {
return err
}
defer file.Close()

_, err = file.WriteString(cmdBody)
if err != nil {
return err
}
return kernel.PublishWriteStream(msg, kernel.StreamStdout, "write to "+filename+" success\n")
}

// execInternal executes internal configuration commands, see HelpMessage for details.
//
// It only returns errors for system errors that will lead to the kernel restart. Syntax errors
Expand Down
58 changes: 58 additions & 0 deletions internal/specialcmd/specialcmd_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -61,3 +61,61 @@ func TestDirEnv(t *testing.T) {
assert.Equal(t, "/tmp", os.Getenv(protocol.GONB_DIR_ENV))
require.NoError(t, s.Stop())
}

func TestMagicWrite(t *testing.T) {
s := newEmptyState(t)

expected := `fmt.Println("1")
fmt.Println("2")
// !*cat main.go
`

type TestCase struct {
appendMode bool
filename, src, fileContent string
}
srcGen := func(testCase *TestCase) {
var appendArg string
if testCase.appendMode {
appendArg = " -a "
}
testCase.src = `%writefile ` + appendArg + testCase.filename + "\n" + expected + "%%\nfmt.Println(1)"
}

// build test cases
testCases := []*TestCase{
{false, "", "", expected},
{true, "", "", strings.Repeat(expected, 2)},
{false, "/tmp/TestMagicWrite.log", "", expected},
}
for _, testCase := range testCases {
srcGen(testCase)
}

// run test cases
fileClean := MakeSet[string](4)
defer func() {
for filename := range fileClean {
defer os.Remove(filename)
}
}()
for idx, testCase := range testCases {
t.Run(fmt.Sprintf("test-case-%d", idx), func(t *testing.T) {
filename := testCase.filename
if filename == "" {
filename = s.UniqueID + ".out"
}
fileClean.Insert(filename)

var msg kernel.Message
usedLines := MakeSet[int]()
lines := strings.Split(testCase.src, "\n")
err := Parse(msg, s, true, lines, usedLines)
require.NoError(t, err)

fileBytes, err := os.ReadFile(filename)
require.NoError(t, err)
assert.Equal(t, testCase.fileContent, string(fileBytes))
})
}
}

0 comments on commit 4027fab

Please sign in to comment.