diff --git a/appWindow.go b/appWindow.go new file mode 100644 index 0000000..63ff6dd --- /dev/null +++ b/appWindow.go @@ -0,0 +1,120 @@ +package main + +import ( + "fmt" + "image/color" + "okemu/okean240" + "okemu/okean240/fdc" + + "fyne.io/fyne/v2" + "fyne.io/fyne/v2/app" + "fyne.io/fyne/v2/canvas" + "fyne.io/fyne/v2/container" + "fyne.io/fyne/v2/dialog" + "fyne.io/fyne/v2/driver/desktop" + "fyne.io/fyne/v2/layout" + "fyne.io/fyne/v2/theme" + "fyne.io/fyne/v2/widget" +) + +func mainWindow(computer *okean240.ComputerType) (*fyne.Window, *canvas.Raster, *widget.Label) { + emulatorApp := app.New() + w := emulatorApp.NewWindow("Океан 240.2") + w.Canvas().SetOnTypedKey( + func(key *fyne.KeyEvent) { + computer.PutKey(key) + }) + + w.Canvas().SetOnTypedRune( + func(key rune) { + computer.PutRune(key) + }) + + addShortcuts(w.Canvas(), computer) + + label := widget.NewLabel(fmt.Sprintf("Screen size: %dx%d", computer.ScreenWidth(), computer.ScreenHeight())) + + raster := canvas.NewRasterWithPixels( + func(x, y, w, h int) color.Color { + var xx uint16 + if computer.ScreenWidth() == 512 { + xx = uint16(x) + } else { + xx = uint16(x) / 2 + } + return computer.GetPixel(xx, uint16(y/2)) + }) + raster.Resize(fyne.NewSize(512, 512)) + raster.SetMinSize(fyne.NewSize(512, 512)) + + centerRaster := container.NewCenter(raster) + + w.Resize(fyne.NewSize(600, 600)) + + vBox := container.NewVBox( + newToolbar(computer, w), + centerRaster, + label, + ) + + w.SetContent(vBox) + + return &w, raster, label +} + +func newToolbar(c *okean240.ComputerType, w fyne.Window) *fyne.Container { + hBox := container.NewHBox() + for d := 0; d < fdc.TotalDrives; d++ { + hBox.Add(widget.NewLabel(string(rune(66+d)) + ":")) + hBox.Add(widget.NewToolbar( + widget.NewToolbarAction(theme.DocumentSaveIcon(), func() { + err := c.SaveFloppy(fdc.FloppyB) + if err != nil { + dialog.ShowError(err, w) + } + }), + //widget.NewToolbarSpacer(), + widget.NewToolbarAction(theme.FolderOpenIcon(), func() { + err := c.SaveFloppy(fdc.FloppyC) + if err != nil { + dialog.ShowError(err, w) + } + }), + )) + } + hBox.Add(widget.NewSeparator()) + hBox.Add(widget.NewButtonWithIcon("Ctrl+C", theme.LogoutIcon(), func() { + c.PutCtrlKey(0x03) + })) + hBox.Add(widget.NewSeparator()) + bNorm := widget.NewButtonWithIcon("", theme.MediaPlayIcon(), func() { + fullSpeed.Store(false) + c.SetCPUFrequency(2_500_000) + //bNorm.Disable() + + }) + bFast := widget.NewButtonWithIcon("", theme.MediaFastForwardIcon(), func() { + fullSpeed.Store(true) + c.SetCPUFrequency(50_000_000) + bNorm.Enable() + //bFast.Disable() + }) + hBox.Add(bNorm) + hBox.Add(bFast) + hBox.Add(layout.NewSpacer()) + hBox.Add(widget.NewButtonWithIcon("Reset", theme.MediaReplayIcon(), func() { + needReset = true + //computer.Reset(conf) + })) + return hBox +} + +// Add shortcuts for all Ctrl+ +func addShortcuts(c fyne.Canvas, computer *okean240.ComputerType) { + // Add shortcuts for Ctrl+A to Ctrl+Z + for kName := 'A'; kName <= 'Z'; kName++ { + kk := fyne.KeyName(kName) + sc := &desktop.CustomShortcut{KeyName: kk, Modifier: fyne.KeyModifierControl} + c.AddShortcut(sc, func(shortcut fyne.Shortcut) { computer.PutCtrlKey(byte(kName&0xff) - 0x40) }) + } +} diff --git a/config/config.go b/config/config.go index 4d97f56..ce2bf5a 100644 --- a/config/config.go +++ b/config/config.go @@ -12,14 +12,25 @@ const defaultCPMFile = "rom/CPM_v5.bin" const DefaultDebufPort = 10000 type OkEmuConfig struct { - LogFile string `yaml:"logFile"` - LogLevel string `yaml:"logLevel"` - MonitorFile string `yaml:"monitorFile"` - CPMFile string `yaml:"cpmFile"` - FloppyB string `yaml:"floppyB"` - FloppyC string `yaml:"floppyC"` - Host string `yaml:"host"` - Port int `yaml:"port"` + LogFile string `yaml:"logFile"` + LogLevel string `yaml:"logLevel"` + MonitorFile string `yaml:"monitorFile"` + CPMFile string `yaml:"cpmFile"` + FDC FDCConfig `yaml:"fdc"` + Debugger DebuggerConfig `yaml:"debugger"` +} + +type FDCConfig struct { + AutoLoadB bool `yaml:"autoLoadB"` + AutoLoadC bool `yaml:"autoLoadC"` + FloppyB string `yaml:"floppyB"` + FloppyC string `yaml:"floppyC"` +} + +type DebuggerConfig struct { + Enabled bool `yaml:"enabled"` + Host string `yaml:"host"` + Port int `yaml:"port"` } var config *OkEmuConfig @@ -62,8 +73,8 @@ func LoadConfig() { } func checkConfig(conf *OkEmuConfig) { - if conf.Host == "" { - conf.Host = "localhost" + if conf.Debugger.Host == "" { + conf.Debugger.Host = "localhost" } } @@ -80,8 +91,8 @@ func setDefaultConf(conf *OkEmuConfig) { if conf.CPMFile == "" { conf.CPMFile = defaultCPMFile } - if conf.Port < 80 || conf.Port > 65535 { - log.Infof("Port %d incorrect, using default: %d", conf.Port, DefaultDebufPort) - conf.Port = DefaultDebufPort + if conf.Debugger.Port < 80 || conf.Debugger.Port > 65535 { + log.Infof("Port %d incorrect, using default: %d", conf.Debugger.Port, DefaultDebufPort) + conf.Debugger.Port = DefaultDebufPort } } diff --git a/debug/breakpoint/breakpoint.go b/debug/breakpoint/breakpoint.go new file mode 100644 index 0000000..e9d92ea --- /dev/null +++ b/debug/breakpoint/breakpoint.go @@ -0,0 +1,246 @@ +package breakpoint + +import ( + "context" + "fmt" + "okemu/gval" + "regexp" + "strconv" + "strings" + + log "github.com/sirupsen/logrus" +) + +const MaxBreakpoints = 256 + +const ( + BPTypeSimplePC = iota // Simple PC=nn breakpoint + BPTypeSimpleSP // Simple SP>=nn breakpoint + BPTypeExpression // Complex expression breakpoint +) + +type Breakpoint struct { + addr uint16 + cond string + eval gval.Evaluable + bpType int + pass uint16 + passCount uint16 + enabled bool +} + +var andMatch = regexp.MustCompile(`\s+AND\s+`) +var orMatch = regexp.MustCompile(`\s+OR\s+`) + +// var xorMatch = regexp.MustCompile(`\s+XOR\s+`) +var hexHMatch = regexp.MustCompile(`[[:xdigit:]]+H`) +var eqMatch = regexp.MustCompile(`[^=><]=[^=]`) + +func patchExpression(expr string) string { + ex := strings.ToUpper(expr) + ex = andMatch.ReplaceAllString(ex, " && ") + ex = orMatch.ReplaceAllString(ex, " || ") + // ex = xorMatch.ReplaceAllString(ex, " ^ ") + for { + pos := hexHMatch.FindStringIndex(ex) + if pos != nil && len(pos) == 2 { + hex := "0x" + ex[pos[0]:pos[1]-1] + ex = ex[:pos[0]] + hex + ex[pos[1]:] + } else { + break + } + } + for { + pos := eqMatch.FindStringIndex(ex) + if pos != nil && len(pos) == 2 { + ex = ex[:pos[0]+1] + "==" + ex[pos[1]-1:] + } else { + break + } + } + return ex +} + +var pcMatch = regexp.MustCompile(`^PC=[[:xdigit:]]+h$`) +var spMatch = regexp.MustCompile(`^SP>=[[:xdigit:]]+$`) + +func getSecondUint16(param string, sep string) (uint16, error) { + p := strings.Split(param, sep) + v := p[1] + base := 0 + if strings.HasSuffix(v, "h") || strings.HasSuffix(v, "H") { + v = strings.TrimSuffix(v, "H") + v = strings.TrimSuffix(v, "h") + base = 16 + } + a, e := strconv.ParseUint(v, base, 16) + if e != nil { + return 0, e + } + return uint16(a), nil +} + +func NewBreakpoint(expr string) (*Breakpoint, error) { + bp := Breakpoint{ + addr: 0, + enabled: false, + passCount: 0, + pass: 0, + bpType: BPTypeSimplePC, + } + + // Check if BP is simple PC=addr + expr = strings.TrimSpace(expr) + bp.cond = expr + pcMatched := pcMatch.MatchString(expr) + spMatched := spMatch.MatchString(expr) + + if pcMatched { + // PC=xxxxh + bp.bpType = BPTypeSimplePC + v, e := getSecondUint16(expr, "=") + if e != nil { + return nil, e + } + bp.addr = v + } else if spMatched { + // SP>=xxxx + bp.bpType = BPTypeSimpleSP + v, e := getSecondUint16(expr, "=") + if e != nil { + return nil, e + } + bp.addr = v + } else { + // complex expression + bp.bpType = BPTypeExpression + ex := patchExpression(expr) + log.Debugf("Original Expression: '%s'", expr) + log.Debugf(" Patched Expression: '%s'", ex) + err := bp.SetExpression(ex) + if err != nil { + return nil, err + } + } + return &bp, nil +} + +func (b *Breakpoint) Enabled() bool { + return b.enabled +} +func (b *Breakpoint) SetEnabled(enabled bool) { + b.enabled = enabled +} +func (b *Breakpoint) PassCount() uint16 { + return b.passCount +} +func (b *Breakpoint) SetPassCount(passCount uint16) { + b.passCount = passCount +} +func (b *Breakpoint) Pass() uint16 { + return b.pass +} +func (b *Breakpoint) SetPass(pass uint16) { + b.pass = pass +} + +func (b *Breakpoint) IncPass() { + b.pass++ +} + +func (b *Breakpoint) Addr() uint16 { + return b.addr +} +func (b *Breakpoint) SetAddr(addr uint16) { + b.addr = addr +} +func (b *Breakpoint) Type() int { + return b.bpType +} +func (b *Breakpoint) SetType(bpType int) { + b.bpType = bpType +} + +func getUint16(name string, ctx map[string]interface{}) uint16 { + if v, ok := ctx[name]; ok { + if v == nil { + return 0 + } + // most frequent case + if v, ok := v.(uint16); ok { + return v + } + // for less frequent cases + switch value := v.(type) { + case int: + return uint16(value) + case int8: + return uint16(value) + case int16: + return uint16(value) + case int32: + return uint16(value) + case int64: + return uint16(value) + case uint: + return uint16(value) + case uint8: + return uint16(value) + case uint32: + return uint16(value) + case uint64: + return uint16(value) + default: + log.Errorf("Unknown type %v for variable %s", value, name) + return 0 + } + } else { + log.Errorf("Variable %s not found in context!", name) + } + return 0 +} + +func (b *Breakpoint) Hit(ctx map[string]interface{}) bool { + if !b.enabled { + return false + } + if b.bpType == BPTypeSimplePC { + pc := getUint16("PC", ctx) + if pc == b.addr { + log.Debugf("Breakpoint Hit PC=%04X", b.addr) + } + return pc == b.addr + } else if b.bpType == BPTypeSimpleSP { + sp := getUint16("SP", ctx) + if sp >= b.addr { + log.Debugf("Breakpoint Hit SP>=%04X", b.addr) + } + return sp >= b.addr + } + value, err := b.eval.EvalBool(context.Background(), ctx) + if err != nil { + fmt.Println(err) + } + return value +} + +var language gval.Language + +func init() { + language = gval.NewLanguage(gval.Base(), gval.Arithmetic(), gval.Bitmask(), gval.PropositionalLogic()) +} + +func (b *Breakpoint) SetExpression(expression string) error { + var err error + b.eval, err = language.NewEvaluable(expression) + if err != nil { + log.Error("Illegal expression", err) + return err + } + b.bpType = BPTypeExpression + return nil +} + +func (b *Breakpoint) Expression() string { + return b.cond +} diff --git a/debug/breakpoint/breakpoint_test.go b/debug/breakpoint/breakpoint_test.go new file mode 100644 index 0000000..19d1081 --- /dev/null +++ b/debug/breakpoint/breakpoint_test.go @@ -0,0 +1,76 @@ +package breakpoint + +import ( + "regexp" + "testing" +) + +const expr1 = "PC=00100h && SP>=256" +const expr2 = "PC=00115h" +const expr3 = "SP>=1332" + +var ctx = map[string]interface{}{ + "PC": 0x100, + "A": 0x55, + "SP": 0x200, +} + +const exprRep = "PC=00115h and (B=5 or BC = 5)" +const exprDst = "PC==0x00115 && (B==5 || BC == 5)" + +func Test_PatchExpression(t *testing.T) { + ex := patchExpression(exprRep) + if ex != exprDst { + t.Errorf("Patched expression does not match\n got: %s\nexpected: %s", ex, exprDst) + } + +} + +func Test_ComplexExpr(t *testing.T) { + + b, e := NewBreakpoint(expr1) + //e := b.SetExpression(exp1) + if e != nil { + t.Error(e) + } else if b != nil { + b.enabled = true + if !b.Hit(ctx) { + t.Errorf("Breakpoint not hit") + } + } + +} + +const expSimplePC = "PC=00119h" + +func Test_BPSetPC(t *testing.T) { + b, e := NewBreakpoint(expSimplePC) + if e != nil { + t.Error(e) + } else if b != nil { + if b.bpType != BPTypeSimplePC { + t.Errorf("Breakpoint type does not match BPTypeSimplePC") + } + b.enabled = true + if b.Hit(ctx) { + t.Errorf("Breakpoint hit but will not!") + } + } +} + +func Test_MatchSP(t *testing.T) { + + pcMatch := regexp.MustCompile(`SP>=[[:xdigit:]]+$`) + matched := pcMatch.MatchString(expr3) + if !matched { + t.Errorf("SP>=XXXXh not matched") + } + +} + +func Test_GetCtx(t *testing.T) { + pc := getUint16("PC", ctx) + if pc != 0x100 { + t.Errorf("PC value not found in context") + } +} diff --git a/debug/debuger.go b/debug/debuger.go new file mode 100644 index 0000000..db1dd44 --- /dev/null +++ b/debug/debuger.go @@ -0,0 +1,225 @@ +package debug + +import ( + "okemu/debug/breakpoint" + "okemu/z80" + "okemu/z80/dis" + + log "github.com/sirupsen/logrus" +) + +const BPMemAccess = 65535 + +type Debugger struct { + stepMode bool + doStep bool + runMode bool + runInst uint64 + breakpointsEnabled bool + breakpoints map[uint16]*breakpoint.Breakpoint + cpuFrequency uint32 + disassembler *dis.Disassembler + cpuHistoryEnabled bool + cpuHistoryStarted bool + cpuHistoryMaxSize int + cpuHistory []*z80.CPU + memBreakpoints [65536]byte +} + +func NewDebugger() *Debugger { + d := Debugger{ + stepMode: false, + doStep: false, + runMode: false, + runInst: 0, + breakpointsEnabled: false, + breakpoints: map[uint16]*breakpoint.Breakpoint{}, + cpuHistoryEnabled: false, + cpuHistoryStarted: false, + cpuHistoryMaxSize: 0, + cpuHistory: []*z80.CPU{}, + } + return &d +} + +func (d *Debugger) SetStepMode(step bool) { + d.SetRunMode(false) + d.stepMode = step +} + +func (d *Debugger) SetRunMode(run bool) { + if run { + d.runInst = 0 + } + d.runMode = run +} + +func (d *Debugger) RunMode() bool { + return d.runMode +} + +func (d *Debugger) DoStep() bool { + if d.doStep { + d.doStep = false + return true + } + return false +} + +func (d *Debugger) SetCpuHistoryEnabled(enable bool) { + d.cpuHistoryEnabled = enable +} + +func (d *Debugger) SetCpuHistoryMaxSize(size int) { + if size < 0 || size > 1_000_000 { + log.Error("CPU history max size must be positive and up to 1M") + } else { + d.cpuHistoryMaxSize = size + } +} + +func (d *Debugger) CpuHistoryClear() { + d.cpuHistory = make([]*z80.CPU, 0) +} + +func (d *Debugger) CpuHistorySize() int { + return len(d.cpuHistory) +} + +func (d *Debugger) CpuHistory(index int) *z80.CPU { + if index >= 0 && index < len(d.cpuHistory) { + return d.cpuHistory[index] + } + if len(d.cpuHistory) > 0 { + log.Warnf("CPU history index %d out of range [0:%d]", index, len(d.cpuHistory)-1) + } else { + log.Warn("CPU history is empty") + } + return nil +} + +func (d *Debugger) SetCpuHistoryStarted(started bool) { + d.cpuHistoryStarted = started +} + +func (d *Debugger) SaveHistory(state *z80.CPU) { + if d.cpuHistoryEnabled && d.cpuHistoryMaxSize > 0 && d.cpuHistoryStarted { + d.cpuHistory = append([]*z80.CPU{state}, d.cpuHistory...) + if len(d.cpuHistory) > d.cpuHistoryMaxSize { + d.cpuHistory = d.cpuHistory[0 : d.cpuHistoryMaxSize-1] + } + } +} + +func (d *Debugger) CheckBreakpoints(ctx map[string]interface{}) (bool, uint16) { + if d.breakpointsEnabled && d.runMode { + for n, bp := range d.breakpoints { + if bp != nil && bp.Hit(ctx) { + // breakpoint hit + if bp.Pass() >= bp.PassCount() { + bp.SetPass(0) + d.runMode = false + return true, n + } + // increment breakpoint pass count + bp.IncPass() + } + } + } + return false, 0 +} + +func (d *Debugger) SetBreakpointsEnabled(enabled bool) { + d.breakpointsEnabled = enabled +} + +func (d *Debugger) BreakpointsEnabled() bool { + return d.breakpointsEnabled +} + +// SetBreakpoint Create new breakpoint with specified number +func (d *Debugger) SetBreakpoint(number uint16, exp string) error { + var err error + bp, err := breakpoint.NewBreakpoint(exp) + if err == nil && bp != nil { + d.breakpoints[number] = bp + } + return err +} + +func (d *Debugger) SetBreakpointPassCount(number uint16, count uint16) { + bp, ok := d.breakpoints[number] + if ok && bp != nil { + bp.SetPass(0) + bp.SetPassCount(count) + } +} + +func (d *Debugger) SetBreakpointEnabled(number uint16, enabled bool) { + bp, ok := d.breakpoints[number] + if ok && bp != nil { + bp.SetEnabled(enabled) + } +} + +func (d *Debugger) BreakpointEnabled(number uint16) bool { + bp, ok := d.breakpoints[number] + if ok && bp != nil { + return bp.Enabled() + } + return false +} + +func (d *Debugger) ClearMemBreakpoints() { + for c := 0; c < 65536; c++ { + d.memBreakpoints[c] = 0 + } +} + +func (d *Debugger) StepMode() bool { + return d.stepMode +} + +func (d *Debugger) SetDoStep(on bool) { + d.doStep = on +} + +// BPExpression Return requested breakpoint +func (d *Debugger) BPExpression(number uint16) string { + bp, ok := d.breakpoints[number] + if ok && bp != nil { + return bp.Expression() + } + return "" +} + +// RunInst return and increment count of instructions executed +func (d *Debugger) RunInst() uint64 { + v := d.runInst + d.runInst++ + return v +} + +func (d *Debugger) SetMemBreakpoint(address uint16, typ byte, size uint16) { + var offset uint16 + for offset = address; offset < address+size; offset++ { + d.memBreakpoints[offset] = typ + } +} + +func (d *Debugger) CheckMemBreakpoints(accessMap *map[uint16]byte) (bool, uint16, byte) { + if !d.breakpointsEnabled { + return false, 0, 0 + } + for addr, typ := range *accessMap { + bp := d.memBreakpoints[addr] + if bp == 0 { + return false, addr, 0 + } + if (bp == 3) || bp == typ { + d.SetRunMode(false) + return true, addr, typ + } + } + return false, 0, 0 +} diff --git a/debug/evaluate.go b/debug/evaluate.go new file mode 100644 index 0000000..60efebc --- /dev/null +++ b/debug/evaluate.go @@ -0,0 +1,149 @@ +package debug + +import ( + "errors" + "strconv" + "strings" +) + +// Operators with their priority +var operators = map[string]int{ + "(": 12, ")": 12, + "*": 11, "/": 11, "%": 11, + "+": 10, "-": 10, + "<<": 9, ">>": 9, + "<": 8, "<=": 8, ">": 8, ">=": 8, + "=": 7, "!=": 7, + "&": 6, + "^": 5, + "|": 4, + "&&": 3, + "||": 2, +} + +var variables = map[string]bool{ + "A": true, "B": true, "C": true, "D": true, "E": true, "F": true, "H": true, "L": true, "I": true, + "R": true, "SF": true, "NF": true, "PF": true, "VF": true, "XF": true, "YF": true, "ZF": true, + "A'": true, "B'": true, "C'": true, "D'": true, "E'": true, "F'": true, "H'": true, "L'": true, + "AF": true, "BC": true, "DE": true, "HL": true, "IX": true, "IY": true, "PC": true, "SP": true, +} + +const ( + OTUnknown = iota + OTValue + OTVariable + OTOperation +) + +type Token struct { + name string + val uint16 + ot int +} + +type Expression struct { + infixExp string + inStack []Token + outStack []Token +} + +func NewExpression(infixExp string) *Expression { + return &Expression{infixExp, make([]Token, 0), make([]Token, 0)} +} + +func (e *Expression) Parse() error { + e.infixExp = strings.ToUpper(strings.TrimSpace(e.infixExp)) + if e.infixExp == "" { + return errors.New("no Expression") + } + ptr := 0 + for ptr < len(e.infixExp) { + token, err := getNextToken(e.infixExp[ptr:]) + if err != nil { + return err + } + err = validate(token) + if err != nil { + return err + } + err = e.parseToken(token) + if err != nil { + return err + } + ptr += len(token.name) + } + return nil +} + +func (e *Expression) parseToken(token Token) error { + return nil +} + +func validate(token Token) error { + switch token.ot { + case OTValue: + v, err := strconv.ParseUint(token.name, 0, 16) + if err != nil { + return err + } + token.val = uint16(v) + case OTVariable: + if !variables[token.name] { + return errors.New("unknown variable") + } + case OTOperation: + v, ok := operators[token.name] + if !ok { + return errors.New("unknown operation") + } + token.val = uint16(v) + default: + return errors.New("unknown token") + } + return nil +} + +const operations = "*/%+_<=>!&^|" + +func getNextToken(str string) (Token, error) { + ptr := 0 + exp := "" + ot := OTUnknown + for ptr < len(str) { + ch := str[ptr] + if ch == ' ' { + if ot == OTUnknown { + ptr++ + continue + } else { + // end of token + return Token{name: exp, ot: ot}, nil + } + } + + if (ch == 'X' || ch == 'O' || ch == 'B' || ch == 'H') && ot != OTValue { + exp += string(ch) + ptr++ + continue + } + + if ch >= '0' && ch <= '9' { + if len(exp) == 0 { + ot = OTValue + } + exp += string(ch) + ptr++ + continue + } + if strings.Contains(operations, string(ch)) { + if len(exp) == 0 { + ot = OTOperation + } + exp += string(ch) + ptr++ + continue + } + return Token{name: exp, ot: ot}, errors.New("invalid token") + } + return Token{name: exp, ot: ot}, nil +} diff --git a/debug/listener/constants.go b/debug/listener/constants.go new file mode 100644 index 0000000..92ac67c --- /dev/null +++ b/debug/listener/constants.go @@ -0,0 +1,14 @@ +package listener + +const welcomeMessage = "Welcome to Ocean-240.2 remote command protocol (ZRCP partial implementation)\nWrite help for available commands\n\ncommand> " +const emptyResponse = "\ncommand> " +const aboutResponse = "ZEsarUX remote command protocol" +const getVersionResponse = "12.1" +const getRegistersResponse = "PC=%04x SP=%04x AF=%04x BC=%04x HL=%04x DE=%04x IX=%04x IY=%04x AF'=%04x BC'=%04x HL'=%04x DE'=%04x I=%02x R=%02x F=%s F'=%s MEMPTR=%04x IM0 IFF%s VPS: 0 MMU=00000000000000000000000000000000" +const getStateResponse = "PC=%04x SP=%04x AF=%04x BC=%04x HL=%04x DE=%04x IX=%04x IY=%04x AF'=%04x BC'=%04x HL'=%04x DE'=%04x I=%02x R=%02x IM0 IFF%s (PC)=%s (SP)=%s MMU=00000000000000000000000000000000" + +const inCpuStepResponse = "\ncommand@cpu-step> " +const getMachineResponse = "64K RAM, no ZX\n" +const respErrorLoading = "ERROR loading file" +const quitResponse = "Sayonara baby\n" +const runUntilBPMessage = "Running until a breakpoint, key press or data sent, menu opening or other event\n" diff --git a/debuger/listener.go b/debug/listener/listener.go similarity index 54% rename from debuger/listener.go rename to debug/listener/listener.go index ea9dd0d..2b72970 100644 --- a/debuger/listener.go +++ b/debug/listener/listener.go @@ -1,4 +1,4 @@ -package debuger +package listener import ( "bufio" @@ -6,7 +6,11 @@ import ( "io" "net" "okemu/config" + "okemu/debug" + "okemu/debug/breakpoint" "okemu/okean240" + "okemu/z80" + "okemu/z80/dis" "os" "strings" //"okemu/logger" @@ -15,16 +19,6 @@ import ( log "github.com/sirupsen/logrus" ) -const welcomeMessage = "Welcome to ZEsarUX remote command protocol (ZRCP)\nWrite help for available commands\n\ncommand> " -const emptyResponse = "\ncommand> " -const aboutResponse = "ZEsarUX remote command protocol" -const getVersionResponse = "12.1" -const getRegistersResponse = "PC=%04x SP=%04x AF=%04x BC=%04x HL=%04x DE=%04x IX=%04x IY=%04x AF'=%04x BC'=%04x HL'=%04x DE'=%04x I=%02x R=%02x F=%s F'=%s MEMPTR=%04x IM0 IFF%s VPS: 0 MMU=00000000000000000000000000000000" -const inCpuStepResponse = "\ncommand@cpu-step> " -const getMachineResponse = "64K RAM, no ZX\n" -const respErrorLoading = "ERROR loading file" -const quitResponse = "Sayonara baby\n" - // Receive messages, split to strings and parse func handleConnection(c net.Conn) { reader := bufio.NewReader(c) @@ -40,6 +34,7 @@ func handleConnection(c net.Conn) { break } else { log.Errorf("TCP error: %v", err) + debugger.SetStepMode(false) return } } @@ -50,8 +45,8 @@ func handleConnection(c net.Conn) { } //byteBuffer.WriteByte(b) } + debugger.SetStepMode(false) activeWriter = nil - //log.Trace("TCP Connection closed") err := c.Close() if err != nil { log.Warnf("Can not close socket: %v", err) @@ -69,7 +64,7 @@ func writeWelcomeMessage(writer *bufio.Writer) bool { func writeResponseMessage(writer *bufio.Writer, message string) bool { prompt := emptyResponse - if computer.IsStepMode() { + if debugger.StepMode() { prompt = inCpuStepResponse } @@ -86,13 +81,33 @@ func writeResponseMessage(writer *bufio.Writer, message string) bool { return true } +func writeMessage(writer *bufio.Writer, message string) bool { + _, err := writer.WriteString(message) + if err != nil { + log.Errorf("TCP error: %v", err) + return false + } + err = writer.Flush() + if err != nil { + log.Errorf("TCP error: %v", err) + return false + } + return true +} + +// var +var debugger *debug.Debugger +var disassembler *dis.Disassembler var computer *okean240.ComputerType // SetupTcpHandler Setup TCP listener, handle connections -func SetupTcpHandler(config *config.OkEmuConfig, comp *okean240.ComputerType) { - port := config.Host + ":" + strconv.Itoa(config.Port) +func SetupTcpHandler(config *config.OkEmuConfig, debug *debug.Debugger, disasm *dis.Disassembler, comp *okean240.ComputerType) { + port := config.Debugger.Host + ":" + strconv.Itoa(config.Debugger.Port) + debugger = debug + disassembler = disasm computer = comp - log.Infof("Serve TCP connections on %s", port) + + log.Infof("Ready for debugger connections on %s", port) l, err := net.Listen("tcp4", port) if err != nil { @@ -138,20 +153,19 @@ func HandleCommand(str string, writer *bufio.Writer) bool { switch cmd { case "cpu-step": - computer.Do() - writeResponseMessage(writer, " "+fmt.Sprintf("%04X", computer.GetCPUState().PC)) + debugger.SetDoStep(true) // computer.Do() + text := disassembler.Disassm(computer.GetCPUState().PC) + writeResponseMessage(writer, registersResponse(computer.GetCPUState())+" TSTATES: "+strconv.Itoa(int(computer.TStatesPartial()))+"\n"+text) case "run": - _, e := writer.WriteString("Running until a breakpoint, key press or data sent, menu opening or other event\n") - if e != nil { - log.Warnf("Error writing to buffer: %v", e) - } - e = writer.Flush() - if e != nil { - log.Warnf("Error flushing the buffer: %v", e) - } - computer.SetRunMode(true) + writeMessage(writer, runUntilBPMessage) + debugger.SetRunMode(true) + case "disassemble": + writeResponseMessage(writer, disassemble(params)) case "get-tstates-partial": - writeResponseMessage(writer, strconv.FormatUint(computer.Cycles(), 10)) + writeResponseMessage(writer, strconv.FormatUint(computer.TStatesPartial(), 10)) + case "reset-tstates-partial": + computer.ResetTStatesPartial() + writeResponseMessage(writer, "") case "close-all-menus": writeResponseMessage(writer, "") case "about": @@ -159,17 +173,17 @@ func HandleCommand(str string, writer *bufio.Writer) bool { case "get-version": writeResponseMessage(writer, getVersionResponse) case "get-registers": - writeResponseMessage(writer, registersResponse()) + writeResponseMessage(writer, registersResponse(computer.GetCPUState())) case "set-register": writeResponseMessage(writer, setRegister(params)) case "hard-reset-cpu": computer.Reset() writeResponseMessage(writer, "") case "enter-cpu-step": - computer.SetStepMode(true) + debugger.SetStepMode(true) writeResponseMessage(writer, "") case "exit-cpu-step": - computer.SetStepMode(false) + debugger.SetStepMode(false) writeResponseMessage(writer, "") case "set-debug-settings": log.Debugf("Set debug settings to %s", params) @@ -177,13 +191,15 @@ func HandleCommand(str string, writer *bufio.Writer) bool { case "get-current-machine": writeResponseMessage(writer, getMachineResponse) case "clear-membreakpoints": - computer.ClearMemBreakpoints() + debugger.ClearMemBreakpoints() writeResponseMessage(writer, "") + case "set-membreakpoint": // addr type size + writeResponseMessage(writer, SetMemBreakpoint(params)) case "enable-breakpoints": - computer.SetBreakpointsEnabled(true) + debugger.SetBreakpointsEnabled(true) writeResponseMessage(writer, "") case "disable-breakpoints": - computer.SetBreakpointsEnabled(false) + debugger.SetBreakpointsEnabled(false) writeResponseMessage(writer, "") case "enable-breakpoint": writeResponseMessage(writer, setBreakpointState(params, true)) @@ -194,6 +210,9 @@ func HandleCommand(str string, writer *bufio.Writer) bool { case "set-breakpoint": // 1 PC=0010Bh writeResponseMessage(writer, setBreakpoint(params)) + case "set-breakpointpasscount": + setBreakpointPassCount(params) + writeResponseMessage(writer, "") case "cpu-code-coverage": //"enabled no" writeResponseMessage(writer, "") @@ -204,7 +223,8 @@ func HandleCommand(str string, writer *bufio.Writer) bool { // "started yes" // "ignrephalt yes" // "ignrepldxr yes" - writeResponseMessage(writer, "") + + writeResponseMessage(writer, doCpuHistory(params)) case "extended-stack": // "enabled no" // "enabled yes" @@ -219,6 +239,11 @@ func HandleCommand(str string, writer *bufio.Writer) bool { writeResponseMessage(writer, readMemory(params)) case "quit": quit = true + case "snapshot-save": + writeResponseMessage(writer, snapshotSave(params)) + case "set-breakpointaction": + // now do nothing + writeResponseMessage(writer, "") default: log.Debugf("Unhandled Command: %s", str) writeResponseMessage(writer, "") @@ -226,6 +251,96 @@ func HandleCommand(str string, writer *bufio.Writer) bool { return !quit } +func convertToUint16(s string) (uint16, error) { + v := strings.TrimSpace(strings.ToUpper(s)) + base := 0 + if strings.HasSuffix(v, "h") || strings.HasSuffix(v, "H") { + v = strings.TrimSuffix(v, "H") + v = strings.TrimSuffix(v, "h") + base = 16 + } + a, e := strconv.ParseUint(v, base, 16) + return uint16(a), e +} + +func SetMemBreakpoint(param string) string { + param = strings.TrimSpace(param) + params := strings.Split(param, " ") + if len(params) < 1 { + return "error, not enough parameters" + } + address, err := convertToUint16(params[0]) + if err != nil { + return "error, illegal address: '" + params[0] + "'" + } + t := uint16(3) + // if has type + if len(params) > 1 { + t, err = convertToUint16(params[1]) + if err != nil || t > 3 { + return "error, illegal access type: '" + params[1] + "'" + } + } + + s := uint16(1) + if len(params) > 2 { + s, err = convertToUint16(params[2]) + if err != nil { + return "error, illegal memory size: '" + params[2] + "'" + } + } + if debugger != nil { + debugger.SetMemBreakpoint(address, byte(t), s) + } + return "" +} + +func doCpuHistory(param string) string { + param = strings.TrimSpace(param) + params := strings.Split(param, " ") + if len(params) == 0 { + return "error" + } + cmd := params[0] + switch cmd { + case "enabled": + if len(params) != 2 { + return "error" + } + debugger.SetCpuHistoryEnabled(params[1] == "yes") + case "clear": + debugger.CpuHistoryClear() + case "started": + if len(params) != 2 { + return "error" + } + debugger.SetCpuHistoryStarted(params[1] == "yes") + case "set-max-size": + if len(params) != 2 { + return "error" + } + size, err := strconv.Atoi(params[1]) + if err != nil { + return "error" + } + debugger.SetCpuHistoryMaxSize(size) + case "get": + if len(params) != 2 { + return "error" + } + index, err := strconv.Atoi(params[1]) + if err != nil { + return "error" + } + history := debugger.CpuHistory(index) + if history != nil { + return stateResponse(history) + } + return "ERROR: index out of range" + } + return "" +} + func loadBinary(param string) string { params := strings.Split(param, " ") if len(params) < 2 { @@ -246,9 +361,11 @@ func loadBinary(param string) string { return respErrorLoading } if len(params) > 2 { - length, e = strconv.Atoi(params[1]) + l, e := strconv.ParseInt(params[2], 0, 32) if e != nil { length = 0 + } else { + length = int(l) } } data, err := os.ReadFile(fn) @@ -256,9 +373,10 @@ func loadBinary(param string) string { log.Errorf("Error reading file: %v", err) return respErrorLoading } - if length != 0 && len(data) < length { - log.Errorf("File too short. Expected %d bytes, got %d", len(data), length) - return respErrorLoading + if length != 0 && len(data) != length { + log.Warnf("File size does not match the specified length. Expected %d bytes, got %d.", length, len(data)) + //return respErrorLoading + length = len(data) } if length == 0 { length = len(data) @@ -289,8 +407,8 @@ func iifStr(iif1, iif2 bool) string { // registersResponse Build string // PC=%4x SP=%4x AF=%4x BC=%4x HL=%4x DE=%4x IX=%4x IY=%4x AF'=%4x BC'=%4x HL'=%4x DE'=%4x I=%2x // R=%2x F=%s F'=%s MEMPTR=%4x IM0 IFF-- VPS: 0 MMU=00000000000000000000000000000000 -func registersResponse() string { - state := computer.GetCPUState() +func registersResponse(state *z80.CPU) string { + //state := computer.GetCPUState() resp := fmt.Sprintf(getRegistersResponse, state.PC, state.SP, @@ -311,7 +429,43 @@ func registersResponse() string { state.MemPtr, iifStr(state.Iff1, state.Iff2), ) - log.Debug(resp) + log.Trace(resp) + return resp +} + +func getNBytes(addr uint16, n uint16) string { + res := "" + for i := uint16(0); i < n; i++ { + b := computer.MemRead(addr + i) + res += fmt.Sprintf("%02X", b) + } + return res +} + +// stateResponse build string, represent history state +// PC=003a SP=ff46 AF=005c BC=174b HL=107f DE=0006 IX=ffff IY=5c3a AF'=0044 BC'=ffff HL'=ffff DE'=5cb9 I=3f R=78 +// IM0 IFF-- (PC)=2a785c23 (SP)=107f MMU=00000000000000000000000000000000 +func stateResponse(state *z80.CPU) string { + resp := fmt.Sprintf(getStateResponse, + state.PC, + state.SP, + toW(state.A, state.Flags.GetFlags()), + toW(state.B, state.C), + toW(state.H, state.L), + toW(state.D, state.E), + state.IX, + state.IY, + toW(state.AAlt, state.FlagsAlt.GetFlags()), + toW(state.BAlt, state.CAlt), + toW(state.HAlt, state.LAlt), + toW(state.DAlt, state.EAlt), + state.I, + state.R, + iifStr(state.Iff1, state.Iff2), + getNBytes(state.PC, 4), + getNBytes(state.SP, 2), + ) + log.Trace(resp) return resp } @@ -320,12 +474,12 @@ func setRegister(param string) string { params := strings.Split(param, "=") if len(params) != 2 { log.Errorf("Invalid set register parameter: %s", param) - return registersResponse() + return "error" } val, e := strconv.Atoi(params[1]) if e != nil { log.Errorf("Invalid set register parameter value: %s", params[1]) - return registersResponse() + return "error" } switch params[0] { case "SP": @@ -359,14 +513,14 @@ func setRegister(param string) string { log.Errorf("Unsupported set register parameter: %s", param) } computer.SetCPUState(state) - return registersResponse() + return registersResponse(computer.GetCPUState()) } func readMemory(param string) string { params := strings.Split(param, " ") if len(params) != 2 { log.Errorf("Invalid read memory parameter: %s", param) - return registersResponse() + return "error" //registersResponse(computer.GetCPUState()) } offset, e := strconv.Atoi(params[0]) if e != nil { @@ -381,7 +535,6 @@ func readMemory(param string) string { for i := 0; i < size; i++ { resp += fmt.Sprintf("%02X", computer.MemRead(uint16(offset)+uint16(i))) } - log.Tracef("ReadMemory[%d,%d]:\n%s", offset, size, resp) return resp } @@ -411,7 +564,7 @@ func getExtendedStack(param string) string { for i := sp; i > spEnd; i -= 2 { resp += fmt.Sprintf("%04XH default\n", computer.MemRead(i)) } - log.Debugf("Stack[%d,%d]:\n%s", sp, size, resp) + //log.Debugf("Stack[%d,%d]:\n%s", sp, size, resp) return resp } @@ -421,49 +574,94 @@ func setBreakpointState(param string, enable bool) string { log.Errorf("Invalid breakpoint parameter: %s", param) return "" } - if enable && !computer.IsBreakpointsEnabled() { + if enable && !debugger.BreakpointsEnabled() { return "Error. You must enable breakpoints first" } - computer.SetBreakpointEnabled(uint16(no), enable) + debugger.SetBreakpointEnabled(uint16(no), enable) return "" } func setBreakpoint(param string) string { // 1 PC=0010Bh params := strings.Split(param, " ") - if len(params) != 2 { + if len(params) < 2 { log.Errorf("Invalid set breakpoint parameters: %s", param) - return "" + return "Error, invalid parameters" } - no, e := strconv.Atoi(params[0]) - if e != nil || no > okean240.MaxBreakpoints || no < 1 { + no, e := strconv.ParseUint(params[0], 0, 16) + if e != nil || no > breakpoint.MaxBreakpoints || no < 1 { log.Errorf("Invalid breakpoint number: %s", params[0]) - return "" + return "Error, invalid breakpoint number" } - regv := strings.Split(params[1], "=") - if len(regv) != 2 { - log.Errorf("Invalid breakpoint parameter: %s", params[1]) - return "" - } - addr, e := strconv.ParseUint(strings.TrimSuffix(regv[1], "h"), 16, 32) - if e != nil || addr < 0 || addr >= 65535 { - log.Errorf("Invalid breakpoint address: %s", regv[1]) - return "" - } - if regv[0] == "PC" { - computer.SetBreakpoint(uint16(no), uint16(addr)) - } else { - log.Errorf("Unsupported BP: %s", params[1]) + e = debugger.SetBreakpoint(uint16(no), param[len(params[0]):]) + if e != nil { + return "Error: " + e.Error() } return "" } -func BreakpointHit(no uint16) { +func typToString(typ uint8) string { + switch typ { + case 0: + return "D" + case 1: + return "R" + case 2: + return "W" + case 3: + return "R/W" + default: + return "x" + } +} + +func BreakpointHit(number uint16, typ byte) { if activeWriter != nil { pc := computer.GetCPUState().PC - rep := fmt.Sprintf("Breakpoint fired: PC=%XH\n %04X NOP", pc, pc) + res := disassembler.Disassm(pc) + msg := "" + if typ == 0 { + msg = debugger.BPExpression(number) + } else { + msg = fmt.Sprintf("MEM[%04X] %s", number, typToString(typ)) + } + rep := fmt.Sprintf("Breakpoint fired: %s\n%s", msg, res) log.Debug(rep) writeResponseMessage(activeWriter, rep) } } + +func setBreakpointPassCount(param string) { + params := strings.Split(param, " ") + if len(params) != 2 { + log.Errorf("Set breakpoint passCount failed, expected 2 params, got %d", len(params)) + } + bpNo, err := strconv.Atoi(params[0]) + if err != nil || bpNo < 0 || bpNo > breakpoint.MaxBreakpoints { + log.Errorf("Invalid BP no.: %v", err) + } + passCount, err := strconv.Atoi(params[1]) + if err != nil || passCount < 0 || passCount > 65535 { + log.Errorf("Invalid BP passCount: %v", err) + } + debugger.SetBreakpointPassCount(uint16(bpNo), uint16(passCount)) +} + +func disassemble(param string) string { + addr, e := strconv.ParseUint(param, 0, 16) + if e != nil { + log.Errorf("Invalid disassemble address: %s", param) + } + res := disassembler.Disassm(uint16(addr)) + log.Debug(res) + return res +} + +func snapshotSave(params string) string { + e := computer.SaveSnapshot(strings.TrimSpace(params)) + if e != nil { + return fmt.Sprintf("Error saving snapshot: %s", e) + } + return "" +} diff --git a/debug/listener/listener_test.go b/debug/listener/listener_test.go new file mode 100644 index 0000000..bbf3941 --- /dev/null +++ b/debug/listener/listener_test.go @@ -0,0 +1,17 @@ +package listener + +import ( + "testing" +) + +const exp1 = "CS=0x0100 & SP>=256" +const memBpSet = "11ch 3 1" + +func Test_SetMemPB(t *testing.T) { + + resp := SetMemBreakpoint(memBpSet) + if resp != "" { + t.Errorf("SetMemBreakpoint() returned %s", resp) + } + +} diff --git a/go.mod b/go.mod index fa18e51..b889e88 100644 --- a/go.mod +++ b/go.mod @@ -4,6 +4,7 @@ go 1.25 require ( fyne.io/fyne/v2 v2.7.3 + github.com/PaesslerAG/gval v1.2.4 github.com/howeyc/crc16 v0.0.0-20171223171357-2b2a61e366a6 github.com/sirupsen/logrus v1.9.4 gopkg.in/yaml.v3 v3.0.1 @@ -33,6 +34,7 @@ require ( github.com/nicksnyder/go-i18n/v2 v2.5.1 // indirect github.com/pmezard/go-difflib v1.0.0 // indirect github.com/rymdport/portal v0.4.2 // indirect + github.com/shopspring/decimal v1.3.1 // indirect github.com/srwiley/oksvg v0.0.0-20221011165216-be6e8873101c // indirect github.com/srwiley/rasterx v0.0.0-20220730225603-2ab79fcdd4ef // indirect github.com/stretchr/testify v1.11.1 // indirect diff --git a/go.sum b/go.sum index 4978028..286537a 100644 --- a/go.sum +++ b/go.sum @@ -4,6 +4,10 @@ fyne.io/systray v1.12.0 h1:CA1Kk0e2zwFlxtc02L3QFSiIbxJ/P0n582YrZHT7aTM= fyne.io/systray v1.12.0/go.mod h1:RVwqP9nYMo7h5zViCBHri2FgjXF7H2cub7MAq4NSoLs= github.com/BurntSushi/toml v1.5.0 h1:W5quZX/G/csjUnuI8SUYlsHs9M38FC7znL0lIO+DvMg= github.com/BurntSushi/toml v1.5.0/go.mod h1:ukJfTF/6rtPPRCnwkur4qwRxa8vTRFBF0uk2lLoLwho= +github.com/PaesslerAG/gval v1.2.4 h1:rhX7MpjJlcxYwL2eTTYIOBUyEKZ+A96T9vQySWkVUiU= +github.com/PaesslerAG/gval v1.2.4/go.mod h1:XRFLwvmkTEdYziLdaCeCa5ImcGVrfQbeNUbVR+C6xac= +github.com/PaesslerAG/jsonpath v0.1.0 h1:gADYeifvlqK3R3i2cR5B4DGgxLXIPb3TRTH1mGi0jPI= +github.com/PaesslerAG/jsonpath v0.1.0/go.mod h1:4BzmtoM/PI8fPO4aQGIusjGxGir2BzcV0grWtFzq1Y8= github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= @@ -59,6 +63,8 @@ github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZb github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/rymdport/portal v0.4.2 h1:7jKRSemwlTyVHHrTGgQg7gmNPJs88xkbKcIL3NlcmSU= github.com/rymdport/portal v0.4.2/go.mod h1:kFF4jslnJ8pD5uCi17brj/ODlfIidOxlgUDTO5ncnC4= +github.com/shopspring/decimal v1.3.1 h1:2Usl1nmF/WZucqkFZhnfFYxxxu8LG21F6nPQBE5gKV8= +github.com/shopspring/decimal v1.3.1/go.mod h1:DKyhrW/HYNuLGql+MJL6WCR6knT2jwCFRcu2hWCYk4o= github.com/sirupsen/logrus v1.9.4 h1:TsZE7l11zFCLZnZ+teH4Umoq5BhEIfIzfRDZ1Uzql2w= github.com/sirupsen/logrus v1.9.4/go.mod h1:ftWc9WdOfJ0a92nsE2jF5u5ZwH8Bv2zdeOC42RjbV2g= github.com/srwiley/oksvg v0.0.0-20221011165216-be6e8873101c h1:km8GpoQut05eY3GiYWEedbTT0qnSxrCjsVbb7yKY1KE= diff --git a/gval/LICENSE b/gval/LICENSE new file mode 100644 index 0000000..0716dbc --- /dev/null +++ b/gval/LICENSE @@ -0,0 +1,12 @@ +Copyright (c) 2017, Paessler AG +All rights reserved. + +Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: + +1. Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer. + +2. Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution. + +3. Neither the name of the copyright holder nor the names of its contributors may be used to endorse or promote products derived from this software without specific prior written permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. \ No newline at end of file diff --git a/gval/benchmarks_test.go b/gval/benchmarks_test.go new file mode 100644 index 0000000..dcb35bd --- /dev/null +++ b/gval/benchmarks_test.go @@ -0,0 +1,150 @@ +package gval + +import ( + "context" + "testing" +) + +func BenchmarkGval(bench *testing.B) { + benchmarks := []evaluationTest{ + { + // Serves as a "water test" to give an idea of the general overhead + name: "const", + expression: "1", + }, + { + name: "single parameter", + expression: "requests_made", + parameter: map[string]interface{}{ + "requests_made": 99.0, + }, + }, + { + name: "parameter", + expression: "requests_made > requests_succeeded", + parameter: map[string]interface{}{ + "requests_made": 99.0, + "requests_succeeded": 90.0, + }, + }, + { + // The most common use case, a single variable, modified slightly, compared to a constant. + // This is the "expected" use case. + name: "common", + expression: "(requests_made * requests_succeeded / 100) >= 90", + parameter: map[string]interface{}{ + "requests_made": 99.0, + "requests_succeeded": 90.0, + }, + }, + { + // All major possibilities in one expression. + name: "complex", + expression: `2 > 1 && + "something" != "nothing" || + date("2014-01-20") < date("Wed Jul 8 23:07:35 MDT 2015") && + object["Variable name with spaces"] <= array[0] && + modifierTest + 1000 / 2 > (80 * 100 % 2)`, + parameter: map[string]interface{}{ + "object": map[string]interface{}{"Variable name with spaces": 10.}, + "array": []interface{}{0.}, + "modifierTest": 7.3, + }, + }, + { + // no variables, no modifiers + name: "literal", + expression: "(2) > (1)", + }, + { + name: "modifier", + expression: "(2) + (2) == (4)", + }, + { + // Benchmarks uncompiled parameter regex operators, which are the most expensive of the lot. + // Note that regex compilation times are unpredictable and wily things. The regex engine has a lot of edge cases + // and possible performance pitfalls. This test doesn't aim to be comprehensive against all possible regex scenarios, + // it is primarily concerned with tracking how much longer it takes to compile a regex at evaluation-time than during parse-time. + name: "regex", + expression: "(foo !~ bar) && (foo + bar =~ oba)", + parameter: map[string]interface{}{ + "foo": "foo", + "bar": "bar", + "baz": "baz", + "oba": ".*oba.*", + }, + }, + { + // Benchmarks pre-compilable regex patterns. Meant to serve as a sanity check that constant strings used as regex patterns + // are actually being precompiled. + // Also demonstrates that (generally) compiling a regex at evaluation-time takes an order of magnitude more time than pre-compiling. + name: "constant regex", + expression: `(foo !~ "[bB]az") && (bar =~ "[bB]ar")`, + parameter: map[string]interface{}{ + "foo": "foo", + "bar": "bar", + "baz": "baz", + "oba": ".*oba.*", + }, + }, + { + name: "accessors", + expression: "foo.Int", + parameter: fooFailureParameters, + }, + { + name: "accessors method", + expression: "foo.Func()", + parameter: fooFailureParameters, + }, + { + name: "accessors method parameter", + expression: `foo.FuncArgStr("bonk")`, + parameter: fooFailureParameters, + }, + { + name: "nested accessors", + expression: `foo.Nested.Funk`, + parameter: fooFailureParameters, + }, + { + name: "decimal arithmetic", + expression: "(requests_made * requests_succeeded / 100)", + extension: decimalArithmetic, + parameter: map[string]interface{}{ + "requests_made": 99.0, + "requests_succeeded": 90.0, + }, + }, + { + name: "decimal logic", + expression: "(requests_made * requests_succeeded / 100) >= 90", + extension: decimalArithmetic, + parameter: map[string]interface{}{ + "requests_made": 99.0, + "requests_succeeded": 90.0, + }, + }, + } + for _, benchmark := range benchmarks { + eval, err := Full().NewEvaluable(benchmark.expression) + if err != nil { + bench.Fatal(err) + } + _, err = eval(context.Background(), benchmark.parameter) + if err != nil { + bench.Fatal(err) + } + bench.Run(benchmark.name+"_evaluation", func(bench *testing.B) { + for i := 0; i < bench.N; i++ { + eval(context.Background(), benchmark.parameter) + } + }) + bench.Run(benchmark.name+"_parsing", func(bench *testing.B) { + for i := 0; i < bench.N; i++ { + Full().NewEvaluable(benchmark.expression) + } + }) + + } +} diff --git a/gval/evaluable.go b/gval/evaluable.go new file mode 100644 index 0000000..efd0642 --- /dev/null +++ b/gval/evaluable.go @@ -0,0 +1,366 @@ +package gval + +import ( + "context" + "fmt" + "reflect" + "regexp" + "strconv" +) + +// Selector allows for custom variable selection from structs +// +// Return value is again handled with variable() until end of the given path +type Selector interface { + SelectGVal(c context.Context, key string) (interface{}, error) +} + +// Evaluable evaluates given parameter +type Evaluable func(c context.Context, parameter interface{}) (interface{}, error) + +// EvalInt evaluates given parameter to an int +func (e Evaluable) EvalInt(c context.Context, parameter interface{}) (int, error) { + v, err := e(c, parameter) + if err != nil { + return 0, err + } + + f, ok := convertToUint(v) + if !ok { + return 0, fmt.Errorf("expected number but got %v (%T)", v, v) + } + return int(f), nil +} + +// EvalUint evaluates given parameter to a float64 +func (e Evaluable) EvalUint(c context.Context, parameter interface{}) (uint, error) { + v, err := e(c, parameter) + if err != nil { + return 0, err + } + + f, ok := convertToUint(v) + if !ok { + return 0, fmt.Errorf("expected number but got %v (%T)", v, v) + } + return f, nil +} + +// EvalBool evaluates given parameter to a bool +func (e Evaluable) EvalBool(c context.Context, parameter interface{}) (bool, error) { + v, err := e(c, parameter) + if err != nil { + return false, err + } + + b, ok := convertToBool(v) + if !ok { + return false, fmt.Errorf("expected bool but got %v (%T)", v, v) + } + return b, nil +} + +// EvalString evaluates given parameter to a string +func (e Evaluable) EvalString(c context.Context, parameter interface{}) (string, error) { + o, err := e(c, parameter) + if err != nil { + return "", err + } + if s, ok := o.(string); ok { + return s, nil + } + return fmt.Sprintf("%v", o), nil +} + +// Const Evaluable represents given constant +func (*Parser) Const(value interface{}) Evaluable { + return constant(value) +} + +//go:noinline +func constant(value interface{}) Evaluable { + return func(c context.Context, v interface{}) (interface{}, error) { + return value, nil + } +} + +// Var Evaluable represents value at given path. +// It supports with default language VariableSelector: +// +// map[interface{}]interface{}, +// map[string]interface{} and +// []interface{} and via reflect +// struct fields, +// struct methods, +// slices and +// map with int or string key. +func (p *Parser) Var(path ...Evaluable) Evaluable { + if p.selector == nil { + return variable(path) + } + return p.selector(path) +} + +// Evaluables is a slice of Evaluable. +type Evaluables []Evaluable + +// EvalStrings evaluates given parameter to a string slice +func (evs Evaluables) EvalStrings(c context.Context, parameter interface{}) ([]string, error) { + strs := make([]string, len(evs)) + for i, p := range evs { + k, err := p.EvalString(c, parameter) + if err != nil { + return nil, err + } + strs[i] = k + } + return strs, nil +} + +func variable(path Evaluables) Evaluable { + return func(c context.Context, v interface{}) (interface{}, error) { + v2 := v + for _, p := range path { + k, err := p.EvalString(c, v) + if err != nil { + return nil, err + } + switch o := v2.(type) { + case Selector: + v2, err = o.SelectGVal(c, k) + if err != nil { + return nil, fmt.Errorf("failed to select '%s' on %T: %w", k, o, err) + } + continue + case map[interface{}]interface{}: + v2 = o[k] + continue + case map[string]interface{}: + v2 = o[k] + continue + case []interface{}: + if i, err := strconv.Atoi(k); err == nil && i >= 0 && len(o) > i { + v2 = o[i] + continue + } + default: + var ok bool + v2, ok = reflectSelect(k, o) + if !ok { + return nil, fmt.Errorf("unknown parameter '%s' on %T", k, o) + } + } + } + return v2, nil + } +} + +func reflectSelect(key string, value interface{}) (selection interface{}, ok bool) { + vv := reflect.ValueOf(value) + vvElem := resolvePotentialPointer(vv) + + switch vvElem.Kind() { + case reflect.Map: + mapKey, ok := reflectConvertTo(vv.Type().Key().Kind(), key) + if !ok { + return nil, false + } + + vvElem = vv.MapIndex(reflect.ValueOf(mapKey)) + vvElem = resolvePotentialPointer(vvElem) + + if vvElem.IsValid() { + return vvElem.Interface(), true + } + + // key didn't exist. Check if there is a bound method + method := vv.MethodByName(key) + if method.IsValid() { + return method.Interface(), true + } + + case reflect.Slice: + if i, err := strconv.Atoi(key); err == nil && i >= 0 && vv.Len() > i { + vvElem = resolvePotentialPointer(vv.Index(i)) + return vvElem.Interface(), true + } + + // key not an int. Check if there is a bound method + method := vv.MethodByName(key) + if method.IsValid() { + return method.Interface(), true + } + + case reflect.Struct: + field := vvElem.FieldByName(key) + if field.IsValid() { + return field.Interface(), true + } + + method := vv.MethodByName(key) + if method.IsValid() { + return method.Interface(), true + } + } + return nil, false +} + +func resolvePotentialPointer(value reflect.Value) reflect.Value { + if value.Kind() == reflect.Ptr { + return value.Elem() + } + return value +} + +func reflectConvertTo(k reflect.Kind, value string) (interface{}, bool) { + switch k { + case reflect.String: + return value, true + case reflect.Int: + if i, err := strconv.Atoi(value); err == nil { + return i, true + } + } + return nil, false +} + +func (*Parser) callFunc(fun function, args ...Evaluable) Evaluable { + return func(c context.Context, v interface{}) (ret interface{}, err error) { + a := make([]interface{}, len(args)) + for i, arg := range args { + ai, err := arg(c, v) + if err != nil { + return nil, err + } + a[i] = ai + } + return fun(c, a...) + } +} + +func (*Parser) callEvaluable(fullname string, fun Evaluable, args ...Evaluable) Evaluable { + return func(c context.Context, v interface{}) (ret interface{}, err error) { + f, err := fun(c, v) + + if err != nil { + return nil, fmt.Errorf("could not call function: %w", err) + } + + defer func() { + if r := recover(); r != nil { + err = fmt.Errorf("failed to execute function '%s': %s", fullname, r) + ret = nil + } + }() + + ff := reflect.ValueOf(f) + + if ff.Kind() != reflect.Func { + return nil, fmt.Errorf("could not call '%s' type %T", fullname, f) + } + + a := make([]reflect.Value, len(args)) + for i := range args { + arg, err := args[i](c, v) + if err != nil { + return nil, err + } + a[i] = reflect.ValueOf(arg) + } + + rr := ff.Call(a) + + r := make([]interface{}, len(rr)) + for i, e := range rr { + r[i] = e.Interface() + } + + errorInterface := reflect.TypeOf((*error)(nil)).Elem() + if len(r) > 0 && ff.Type().Out(len(r)-1).Implements(errorInterface) { + if r[len(r)-1] != nil { + err = r[len(r)-1].(error) + } + r = r[0 : len(r)-1] + } + + switch len(r) { + case 0: + return err, nil + case 1: + return r[0], err + default: + return r, err + } + } +} + +// IsConst returns if the Evaluable is a Parser.Const() value +func (e Evaluable) IsConst() bool { + pc := reflect.ValueOf(constant(nil)).Pointer() + pe := reflect.ValueOf(e).Pointer() + return pc == pe +} + +func regEx(a, b Evaluable) (Evaluable, error) { + if !b.IsConst() { + return func(c context.Context, o interface{}) (interface{}, error) { + a, err := a.EvalString(c, o) + if err != nil { + return nil, err + } + b, err := b.EvalString(c, o) + if err != nil { + return nil, err + } + matched, err := regexp.MatchString(b, a) + return matched, err + }, nil + } + s, err := b.EvalString(context.TODO(), nil) + if err != nil { + return nil, err + } + regex, err := regexp.Compile(s) + if err != nil { + return nil, err + } + return func(c context.Context, v interface{}) (interface{}, error) { + s, err := a.EvalString(c, v) + if err != nil { + return nil, err + } + return regex.MatchString(s), nil + }, nil +} + +func notRegEx(a, b Evaluable) (Evaluable, error) { + if !b.IsConst() { + return func(c context.Context, o interface{}) (interface{}, error) { + a, err := a.EvalString(c, o) + if err != nil { + return nil, err + } + b, err := b.EvalString(c, o) + if err != nil { + return nil, err + } + matched, err := regexp.MatchString(b, a) + return !matched, err + }, nil + } + s, err := b.EvalString(context.TODO(), nil) + if err != nil { + return nil, err + } + regex, err := regexp.Compile(s) + if err != nil { + return nil, err + } + return func(c context.Context, v interface{}) (interface{}, error) { + s, err := a.EvalString(c, v) + if err != nil { + return nil, err + } + return !regex.MatchString(s), nil + }, nil +} diff --git a/gval/evaluable_test.go b/gval/evaluable_test.go new file mode 100644 index 0000000..bd00e4d --- /dev/null +++ b/gval/evaluable_test.go @@ -0,0 +1,250 @@ +package gval + +import ( + "context" + "fmt" + "reflect" + "strings" + "testing" + "time" +) + +func TestEvaluable_IsConst(t *testing.T) { + p := Parser{} + tests := []struct { + name string + e Evaluable + want bool + }{ + { + "const", + p.Const(80.5), + true, + }, + { + "var", + p.Var(), + false, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := tt.e.IsConst(); got != tt.want { + t.Errorf("Evaluable.IsConst() = %v, want %v", got, tt.want) + } + }) + } +} + +func TestEvaluable_EvalInt(t *testing.T) { + tests := []struct { + name string + e Evaluable + want int + wantErr bool + }{ + { + "point", + constant("5.3"), + 5, + false, + }, + { + "number", + constant(255.), + 255, + false, + }, + { + "error", + constant("5.3 cm"), + 0, + true, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := tt.e.EvalInt(context.Background(), nil) + if (err != nil) != tt.wantErr { + t.Errorf("Evaluable.EvalInt() error = %v, wantErr %v", err, tt.wantErr) + return + } + if got != tt.want { + t.Errorf("Evaluable.EvalInt() = %v, want %v", got, tt.want) + } + }) + } +} + +func TestEvaluable_EvalFloat64(t *testing.T) { + tests := []struct { + name string + e Evaluable + want float64 + wantErr bool + }{ + { + "point", + constant("5.3"), + 5.3, + false, + }, + { + "number", + constant(255.), + 255, + false, + }, + { + "error", + constant("5.3 cm"), + 0, + true, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := tt.e.EvalUint(context.Background(), nil) + if (err != nil) != tt.wantErr { + t.Errorf("Evaluable.EvalUint() error = %v, wantErr %v", err, tt.wantErr) + return + } + if got != tt.want { + t.Errorf("Evaluable.EvalUint() = %v, want %v", got, tt.want) + } + }) + } +} + +type testSelector struct { + str string + Map map[string]interface{} +} + +func (s testSelector) SelectGVal(ctx context.Context, k string) (interface{}, error) { + if k == "str" { + return s.str, nil + } + + if k == "map" { + return s.Map, nil + } + + if strings.HasPrefix(k, "deep") { + return s, nil + } + + return nil, fmt.Errorf("unknown-key") + +} + +func TestEvaluable_CustomSelector(t *testing.T) { + var ( + lang = Base() + tests = []struct { + name string + expr string + params interface{} + want interface{} + wantErr bool + }{ + { + "unknown", + "s.Foo", + map[string]interface{}{"s": &testSelector{}}, + nil, + true, + }, + { + "field directly", + "s.Str", + map[string]interface{}{"s": &testSelector{str: "test-value"}}, + nil, + true, + }, + { + "field via selector", + "s.str", + map[string]interface{}{"s": &testSelector{str: "test-value"}}, + "test-value", + false, + }, + { + "flat", + "str", + &testSelector{str: "test-value"}, + "test-value", + false, + }, + { + "map field", + "s.map.foo", + map[string]interface{}{"s": &testSelector{Map: map[string]interface{}{"foo": "bar"}}}, + "bar", + false, + }, + { + "crawl to val", + "deep.deeper.deepest.str", + &testSelector{str: "foo"}, + "foo", + false, + }, + { + "crawl to struct", + "deep.deeper.deepest", + &testSelector{}, + testSelector{}, + false, + }, + } + + booltests = []struct { + name string + expr string + params interface{} + want interface{} + wantErr bool + }{ + { + "test method", + "s.IsZero", + map[string]interface{}{"s": time.Now()}, + false, + false, + }, + } + ) + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := lang.Evaluate(tt.expr, tt.params) + if (err != nil) != tt.wantErr { + t.Errorf("Evaluable.Evaluate() error = %v, wantErr %v", err, tt.wantErr) + return + } + if !reflect.DeepEqual(got, tt.want) { + t.Errorf("Evaluable.Evaluate() = %v, want %v", got, tt.want) + } + }) + } + + for _, tt := range booltests { + t.Run(tt.name, func(t *testing.T) { + got, err := lang.Evaluate(tt.expr, tt.params) + if (err != nil) != tt.wantErr { + t.Errorf("Evaluable.Evaluate() error = %v, wantErr %v", err, tt.wantErr) + return + } + got, ok := convertToBool(got) + if !ok { + t.Errorf("Evaluable.Evaluate() error = nok, wantErr %v", tt.wantErr) + return + } + + if !reflect.DeepEqual(got, tt.want) { + t.Errorf("Evaluable.Evaluate() = %v, want %v", got, tt.want) + } + }) + } +} diff --git a/gval/example_test.go b/gval/example_test.go new file mode 100644 index 0000000..1045c93 --- /dev/null +++ b/gval/example_test.go @@ -0,0 +1,450 @@ +package gval_test + +import ( + "context" + "fmt" + "strings" + "time" + + "github.com/PaesslerAG/gval" + "github.com/PaesslerAG/jsonpath" +) + +func Example() { + + vars := map[string]interface{}{"name": "World"} + + value, err := gval.Evaluate(`"Hello " + name + "!"`, vars) + if err != nil { + fmt.Println(err) + } + + fmt.Print(value) + + // Output: + // Hello World! +} + +func ExampleEvaluate() { + + value, err := gval.Evaluate("foo > 0", map[string]interface{}{ + "foo": -1., + }) + if err != nil { + fmt.Println(err) + } + + fmt.Print(value) + + // Output: + // false +} + +func ExampleEvaluate_nestedParameter() { + + value, err := gval.Evaluate("foo.bar > 0", map[string]interface{}{ + "foo": map[string]interface{}{"bar": -1.}, + }) + if err != nil { + fmt.Println(err) + } + + fmt.Print(value) + + // Output: + // false +} + +func ExampleEvaluate_array() { + + value, err := gval.Evaluate("foo[0]", map[string]interface{}{ + "foo": []interface{}{-1.}, + }) + if err != nil { + fmt.Println(err) + } + + fmt.Print(value) + + // Output: + // -1 +} + +func ExampleEvaluate_complexAccessor() { + + value, err := gval.Evaluate(`foo["b" + "a" + "r"]`, map[string]interface{}{ + "foo": map[string]interface{}{"bar": -1.}, + }) + if err != nil { + fmt.Println(err) + } + + fmt.Print(value) + + // Output: + // -1 +} + +func ExampleEvaluate_arithmetic() { + + value, err := gval.Evaluate("(requests_made * requests_succeeded / 100) >= 90", + map[string]interface{}{ + "requests_made": 100, + "requests_succeeded": 80, + }) + if err != nil { + fmt.Println(err) + } + + fmt.Print(value) + + // Output: + // false +} + +func ExampleEvaluate_string() { + + value, err := gval.Evaluate(`http_response_body == "service is ok"`, + map[string]interface{}{ + "http_response_body": "service is ok", + }) + if err != nil { + fmt.Println(err) + } + + fmt.Print(value) + + // Output: + // true +} + +func ExampleEvaluate_float64() { + + value, err := gval.Evaluate("(mem_used / total_mem) * 100", + map[string]interface{}{ + "total_mem": 1024, + "mem_used": 512, + }) + if err != nil { + fmt.Println(err) + } + + fmt.Print(value) + + // Output: + // 50 +} + +func ExampleEvaluate_dateComparison() { + + value, err := gval.Evaluate("date(`2014-01-02`) > date(`2014-01-01 23:59:59`)", + nil, + // define Date comparison because it is not part expression language gval + gval.InfixOperator(">", func(a, b interface{}) (interface{}, error) { + date1, ok1 := a.(time.Time) + date2, ok2 := b.(time.Time) + + if ok1 && ok2 { + return date1.After(date2), nil + } + return nil, fmt.Errorf("unexpected operands types (%T) > (%T)", a, b) + }), + ) + if err != nil { + fmt.Println(err) + } + + fmt.Print(value) + + // Output: + // true +} + +func ExampleEvaluable() { + eval, err := gval.Full(gval.Constant("maximum_time", 52)). + NewEvaluable("response_time <= maximum_time") + if err != nil { + fmt.Println(err) + } + + for i := 50; i < 55; i++ { + value, err := eval(context.Background(), map[string]interface{}{ + "response_time": i, + }) + if err != nil { + fmt.Println(err) + + } + + fmt.Println(value) + } + + // Output: + // true + // true + // true + // false + // false +} + +func ExampleEvaluate_strlen() { + + value, err := gval.Evaluate(`strlen("someReallyLongInputString") <= 16`, + nil, + gval.Function("strlen", func(args ...interface{}) (interface{}, error) { + length := len(args[0].(string)) + return (float64)(length), nil + })) + if err != nil { + fmt.Println(err) + } + + fmt.Print(value) + + // Output: + // false +} + +func ExampleEvaluate_encoding() { + + value, err := gval.Evaluate(`(7 < "47" == true ? "hello world!\n\u263a" : "good bye\n")`+" + ` more text`", + nil, + gval.Function("strlen", func(args ...interface{}) (interface{}, error) { + length := len(args[0].(string)) + return (float64)(length), nil + })) + if err != nil { + fmt.Println(err) + } + + fmt.Print(value) + + // Output: + // hello world! + // ☺ more text +} + +type exampleType struct { + Hello string +} + +func (e exampleType) World() string { + return "world" +} + +func ExampleEvaluate_accessor() { + + value, err := gval.Evaluate(`foo.Hello + foo.World()`, + map[string]interface{}{ + "foo": exampleType{Hello: "hello "}, + }) + if err != nil { + fmt.Println(err) + } + + fmt.Print(value) + + // Output: + // hello world +} + +func ExampleEvaluate_flatAccessor() { + + value, err := gval.Evaluate(`Hello + World()`, + exampleType{Hello: "hello "}, + ) + if err != nil { + fmt.Println(err) + } + + fmt.Print(value) + + // Output: + // hello world +} + +func ExampleEvaluate_nestedAccessor() { + + value, err := gval.Evaluate(`foo.Bar.Hello + foo.Bar.World()`, + map[string]interface{}{ + "foo": struct{ Bar exampleType }{ + Bar: exampleType{Hello: "hello "}, + }, + }) + if err != nil { + fmt.Println(err) + } + + fmt.Print(value) + + // Output: + // hello world +} + +func ExampleVariableSelector() { + value, err := gval.Evaluate(`hello.world`, + "!", + gval.VariableSelector(func(path gval.Evaluables) gval.Evaluable { + return func(c context.Context, v interface{}) (interface{}, error) { + keys, err := path.EvalStrings(c, v) + if err != nil { + return nil, err + } + return fmt.Sprintf("%s%s", strings.Join(keys, " "), v), nil + } + }), + ) + if err != nil { + fmt.Println(err) + } + + fmt.Print(value) + + // Output: + // hello world! +} + +func ExampleEvaluable_EvalInt() { + eval, err := gval.Full().NewEvaluable("1 + x") + if err != nil { + fmt.Println(err) + return + } + + value, err := eval.EvalInt(context.Background(), map[string]interface{}{"x": 5}) + if err != nil { + fmt.Println(err) + } + + fmt.Print(value) + + // Output: + // 6 +} + +func ExampleEvaluable_EvalBool() { + eval, err := gval.Full().NewEvaluable("1 == x") + if err != nil { + fmt.Println(err) + return + } + + value, err := eval.EvalBool(context.Background(), map[string]interface{}{"x": 1}) + if err != nil { + fmt.Println(err) + } + + if value { + fmt.Print("yeah") + } + + // Output: + // yeah +} + +func ExampleEvaluate_jsonpath() { + + value, err := gval.Evaluate(`$["response-time"]`, + map[string]interface{}{ + "response-time": 100, + }, + jsonpath.Language(), + ) + if err != nil { + fmt.Println(err) + } + + fmt.Print(value) + + // Output: + // 100 +} + +func ExampleLanguage() { + lang := gval.NewLanguage(gval.JSON(), gval.Arithmetic(), + //pipe operator + gval.PostfixOperator("|", func(c context.Context, p *gval.Parser, pre gval.Evaluable) (gval.Evaluable, error) { + post, err := p.ParseExpression(c) + if err != nil { + return nil, err + } + return func(c context.Context, v interface{}) (interface{}, error) { + v, err := pre(c, v) + if err != nil { + return nil, err + } + return post(c, v) + }, nil + })) + + eval, err := lang.NewEvaluable(`{"foobar": 50} | foobar + 100`) + if err != nil { + fmt.Println(err) + } + + value, err := eval(context.Background(), nil) + + if err != nil { + fmt.Println(err) + } + + fmt.Println(value) + + // Output: + // 150 +} + +type exampleCustomSelector struct{ hidden string } + +var _ gval.Selector = &exampleCustomSelector{} + +func (s *exampleCustomSelector) SelectGVal(ctx context.Context, k string) (interface{}, error) { + if k == "hidden" { + return s.hidden, nil + } + + return nil, nil +} + +func ExampleSelector() { + lang := gval.Base() + value, err := lang.Evaluate( + "myStruct.hidden", + map[string]interface{}{"myStruct": &exampleCustomSelector{hidden: "hello world"}}, + ) + + if err != nil { + fmt.Println(err) + } + + fmt.Println(value) + + // Output: + // hello world +} + +func parseSub(ctx context.Context, p *gval.Parser) (gval.Evaluable, error) { + return p.ParseSublanguage(ctx, subLang) +} + +var ( + superLang = gval.NewLanguage( + gval.PrefixExtension('$', parseSub), + ) + subLang = gval.NewLanguage( + gval.Init(func(ctx context.Context, p *gval.Parser) (gval.Evaluable, error) { return p.Const("hello world"), nil }), + ) +) + +func ExampleParser_ParseSublanguage() { + value, err := superLang.Evaluate("$", nil) + + if err != nil { + fmt.Println(err) + } + + fmt.Println(value) + + // Output: + // hello world +} diff --git a/gval/functions.go b/gval/functions.go new file mode 100644 index 0000000..39f050d --- /dev/null +++ b/gval/functions.go @@ -0,0 +1,128 @@ +package gval + +import ( + "context" + "fmt" + "reflect" +) + +type function func(ctx context.Context, arguments ...interface{}) (interface{}, error) + +func toFunc(f interface{}) function { + if f, ok := f.(func(arguments ...interface{}) (interface{}, error)); ok { + return function(func(ctx context.Context, arguments ...interface{}) (interface{}, error) { + var v interface{} + errCh := make(chan error, 1) + go func() { + defer func() { + if recovered := recover(); recovered != nil { + errCh <- fmt.Errorf("%v", recovered) + } + }() + result, err := f(arguments...) + v = result + errCh <- err + }() + + select { + case <-ctx.Done(): + return nil, ctx.Err() + case err := <-errCh: + close(errCh) + return v, err + } + }) + } + if f, ok := f.(func(ctx context.Context, arguments ...interface{}) (interface{}, error)); ok { + return function(f) + } + + fun := reflect.ValueOf(f) + t := fun.Type() + return func(ctx context.Context, args ...interface{}) (interface{}, error) { + var v interface{} + errCh := make(chan error, 1) + go func() { + defer func() { + if recovered := recover(); recovered != nil { + errCh <- fmt.Errorf("%v", recovered) + } + }() + in, err := createCallArguments(ctx, t, args) + if err != nil { + errCh <- err + return + } + out := fun.Call(in) + + r := make([]interface{}, len(out)) + for i, e := range out { + r[i] = e.Interface() + } + + err = nil + errorInterface := reflect.TypeOf((*error)(nil)).Elem() + if len(r) > 0 && t.Out(len(r)-1).Implements(errorInterface) { + if r[len(r)-1] != nil { + err = r[len(r)-1].(error) + } + r = r[0 : len(r)-1] + } + + switch len(r) { + case 0: + v = nil + case 1: + v = r[0] + default: + v = r + } + errCh <- err + }() + + select { + case <-ctx.Done(): + return nil, ctx.Err() + case err := <-errCh: + close(errCh) + return v, err + } + } +} + +func createCallArguments(ctx context.Context, t reflect.Type, args []interface{}) ([]reflect.Value, error) { + variadic := t.IsVariadic() + numIn := t.NumIn() + + // if first argument is a context, use the given execution context + if numIn > 0 { + thisFun := reflect.ValueOf(createCallArguments) + thisT := thisFun.Type() + if t.In(0) == thisT.In(0) { + args = append([]interface{}{ctx}, args...) + } + } + + if (!variadic && len(args) != numIn) || (variadic && len(args) < numIn-1) { + return nil, fmt.Errorf("invalid number of parameters") + } + + in := make([]reflect.Value, len(args)) + var inType reflect.Type + for i, arg := range args { + if !variadic || i < numIn-1 { + inType = t.In(i) + } else if i == numIn-1 { + inType = t.In(numIn - 1).Elem() + } + argVal := reflect.ValueOf(arg) + if arg == nil { + argVal = reflect.Zero(reflect.TypeOf((*interface{})(nil)).Elem()) + } else if !argVal.Type().AssignableTo(inType) { + return nil, fmt.Errorf("expected type %s for parameter %d but got %T", + inType.String(), i, arg) + } + in[i] = argVal + } + return in, nil +} diff --git a/gval/functions_test.go b/gval/functions_test.go new file mode 100644 index 0000000..0a9f8bf --- /dev/null +++ b/gval/functions_test.go @@ -0,0 +1,162 @@ +package gval + +import ( + "context" + "fmt" + "reflect" + "testing" + "time" +) + +func Test_toFunc(t *testing.T) { + myError := fmt.Errorf("my error") + tests := []struct { + name string + function interface{} + arguments []interface{} + want interface{} + wantErr error + wantAnyErr bool + }{ + { + name: "empty", + function: func() {}, + }, + { + name: "one arg", + function: func(a interface{}) { + if a != true { + panic("fail") + } + }, + arguments: []interface{}{true}, + }, + { + name: "three args", + function: func(a, b, c interface{}) { + if a != 1 || b != 2 || c != 3 { + panic("fail") + } + }, + arguments: []interface{}{1, 2, 3}, + }, + { + name: "input types", + function: func(a int, b string, c bool) { + if a != 1 || b != "2" || !c { + panic("fail") + } + }, + arguments: []interface{}{1, "2", true}, + }, + { + name: "wronge input type int", + function: func(a int, b string, c bool) {}, + arguments: []interface{}{"1", "2", true}, + wantAnyErr: true, + }, + { + name: "wronge input type string", + function: func(a int, b string, c bool) {}, + arguments: []interface{}{1, 2, true}, + wantAnyErr: true, + }, + { + name: "wronge input type bool", + function: func(a int, b string, c bool) {}, + arguments: []interface{}{1, "2", "true"}, + wantAnyErr: true, + }, + { + name: "wronge input number", + function: func(a int, b string, c bool) {}, + arguments: []interface{}{1, "2"}, + wantAnyErr: true, + }, + { + name: "one return", + function: func() bool { + return true + }, + want: true, + }, + { + name: "three returns", + function: func() (bool, string, int) { + return true, "2", 3 + }, + want: []interface{}{true, "2", 3}, + }, + { + name: "error", + function: func() error { + return myError + }, + wantErr: myError, + }, + { + name: "none error", + function: func() error { + return nil + }, + }, + { + name: "one return with error", + function: func() (bool, error) { + return false, myError + }, + want: false, + wantErr: myError, + }, + { + name: "three returns with error", + function: func() (bool, string, int, error) { + return false, "", 0, myError + }, + want: []interface{}{false, "", 0}, + wantErr: myError, + }, + { + name: "context not expiring", + function: func(ctx context.Context) error { + return nil + }, + }, + { + name: "context expires", + function: func(ctx context.Context) error { + time.Sleep(20 * time.Millisecond) + return nil + }, + wantErr: context.DeadlineExceeded, + }, + { + name: "nil arg", + function: func(a interface{}) bool { + return a == nil + }, + arguments: []interface{}{nil}, + want: true, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Millisecond) + got, err := toFunc(tt.function)(ctx, tt.arguments...) + cancel() + + if tt.wantAnyErr { + if err != nil { + return + } + t.Fatalf("toFunc()(args...) = error(nil), but wantAnyErr") + } + if err != tt.wantErr { + t.Fatalf("toFunc()(args...) = error(%v), wantErr (%v)", err, tt.wantErr) + } + if !reflect.DeepEqual(got, tt.want) { + t.Errorf("toFunc()(args...) = %v, want %v", got, tt.want) + } + }) + } +} diff --git a/gval/gval.go b/gval/gval.go new file mode 100644 index 0000000..067d22a --- /dev/null +++ b/gval/gval.go @@ -0,0 +1,334 @@ +// Package gval provides a generic expression language. +// All functions, infix and prefix operators can be replaced by composing languages into a new one. +// +// The package contains concrete expression languages for common application in text, arithmetic, decimal arithmetic, propositional logic and so on. +// They can be used as basis for a custom expression language or to evaluate expressions directly. +package gval + +import ( + "context" + "fmt" + "reflect" + "text/scanner" + "time" + + "github.com/shopspring/decimal" +) + +// Evaluate given parameter with given expression in gval full language +func Evaluate(expression string, parameter interface{}, opts ...Language) (interface{}, error) { + return EvaluateWithContext(context.Background(), expression, parameter, opts...) +} + +// Evaluate given parameter with given expression in gval full language using a context +func EvaluateWithContext(c context.Context, expression string, parameter interface{}, opts ...Language) (interface{}, error) { + l := full + if len(opts) > 0 { + l = NewLanguage(append([]Language{l}, opts...)...) + } + return l.EvaluateWithContext(c, expression, parameter) +} + +// Full is the union of Arithmetic, Bitmask, Text, PropositionalLogic, TernaryOperator, and Json +// +// Operator in: a in b is true iff value a is an element of array b +// Operator ??: a ?? b returns a if a is not false or nil, otherwise n +// +// Function Date: Date(a) parses string a. a must match RFC3339, ISO8601, ruby date, or unix date +func Full(extensions ...Language) Language { + if len(extensions) == 0 { + return full + } + return NewLanguage(append([]Language{full}, extensions...)...) +} + +// TernaryOperator contains following Operator +// +// ?: a ? b : c returns b if bool a is true, otherwise b +func TernaryOperator() Language { + return ternaryOperator +} + +// Arithmetic contains base, plus(+), minus(-), divide(/), power(**), negative(-) +// and numerical order (<=,<,>,>=) +// +// Arithmetic operators expect float64 operands. +// Called with unfitting input, they try to convert the input to float64. +// They can parse strings and convert any type of int or float. +func Arithmetic() Language { + return arithmetic +} + +// DecimalArithmetic contains base, plus(+), minus(-), divide(/), power(**), negative(-) +// and numerical order (<=,<,>,>=) +// +// DecimalArithmetic operators expect decimal.Decimal operands (github.com/shopspring/decimal) +// and are used to calculate money/decimal rather than floating point calculations. +// Called with unfitting input, they try to convert the input to decimal.Decimal. +// They can parse strings and convert any type of int or float. +func DecimalArithmetic() Language { + return decimalArithmetic +} + +// Bitmask contains base, bitwise and(&), bitwise or(|) and bitwise not(^). +// +// Bitmask operators expect float64 operands. +// Called with unfitting input they try to convert the input to float64. +// They can parse strings and convert any type of int or float. +func Bitmask() Language { + return bitmask +} + +// Text contains base, lexical order on strings (<=,<,>,>=), +// regex match (=~) and regex not match (!~) +func Text() Language { + return text +} + +// PropositionalLogic contains base, not(!), and (&&), or (||) and Base. +// +// Propositional operator expect bool operands. +// Called with unfitting input they try to convert the input to bool. +// Numbers other than 0 and the strings "TRUE" and "true" are interpreted as true. +// 0 and the strings "FALSE" and "false" are interpreted as false. +func PropositionalLogic() Language { + return propositionalLogic +} + +// JSON contains json objects ({string:expression,...}) +// and json arrays ([expression, ...]) +func JSON() Language { + return ljson +} + +// Parentheses contains support for parentheses. +func Parentheses() Language { + return parentheses +} + +// Ident contains support for variables and functions. +func Ident() Language { + return ident +} + +// Base contains equal (==) and not equal (!=), perentheses and general support for variables, constants and functions +// It contains true, false, (floating point) number, string ("" or “) and char (”) constants +func Base() Language { + return base +} + +var full = NewLanguage(arithmetic, bitmask, text, propositionalLogic, ljson, + + InfixOperator("in", inArray), + + InfixShortCircuit("??", func(a interface{}) (interface{}, bool) { + v := reflect.ValueOf(a) + return a, a != nil && !v.IsZero() + }), + InfixOperator("??", func(a, b interface{}) (interface{}, error) { + if v := reflect.ValueOf(a); a == nil || v.IsZero() { + return b, nil + } + return a, nil + }), + + ternaryOperator, + + Function("date", func(arguments ...interface{}) (interface{}, error) { + if len(arguments) != 1 { + return nil, fmt.Errorf("date() expects exactly one string argument") + } + s, ok := arguments[0].(string) + if !ok { + return nil, fmt.Errorf("date() expects exactly one string argument") + } + for _, format := range [...]string{ + time.ANSIC, + time.UnixDate, + time.RubyDate, + time.Kitchen, + time.RFC3339, + time.RFC3339Nano, + "2006-01-02", // RFC 3339 + "2006-01-02 15:04", // RFC 3339 with minutes + "2006-01-02 15:04:05", // RFC 3339 with seconds + "2006-01-02 15:04:05-07:00", // RFC 3339 with seconds and timezone + "2006-01-02T15Z0700", // ISO8601 with hour + "2006-01-02T15:04Z0700", // ISO8601 with minutes + "2006-01-02T15:04:05Z0700", // ISO8601 with seconds + "2006-01-02T15:04:05.999999999Z0700", // ISO8601 with nanoseconds + } { + ret, err := time.ParseInLocation(format, s, time.Local) + if err == nil { + return ret, nil + } + } + return nil, fmt.Errorf("date() could not parse %s", s) + }), +) + +var ternaryOperator = PostfixOperator("?", parseIf) + +var ljson = NewLanguage( + PrefixExtension('[', parseJSONArray), + PrefixExtension('{', parseJSONObject), +) + +var arithmetic = NewLanguage( + InfixNumberOperator("+", func(a, b uint) (interface{}, error) { return a + b, nil }), + InfixNumberOperator("-", func(a, b uint) (interface{}, error) { return a - b, nil }), + InfixNumberOperator("*", func(a, b uint) (interface{}, error) { return a * b, nil }), + InfixNumberOperator("/", func(a, b uint) (interface{}, error) { return a / b, nil }), + InfixNumberOperator("%", func(a, b uint) (interface{}, error) { return a % b, nil }), + + InfixNumberOperator(">", func(a, b uint) (interface{}, error) { return a > b, nil }), + InfixNumberOperator(">=", func(a, b uint) (interface{}, error) { return a >= b, nil }), + InfixNumberOperator("<", func(a, b uint) (interface{}, error) { return a < b, nil }), + InfixNumberOperator("<=", func(a, b uint) (interface{}, error) { return a <= b, nil }), + + InfixNumberOperator("==", func(a, b uint) (interface{}, error) { return a == b, nil }), + InfixNumberOperator("!=", func(a, b uint) (interface{}, error) { return a != b, nil }), + + base, +) + +var decimalArithmetic = NewLanguage( + InfixDecimalOperator("+", func(a, b decimal.Decimal) (interface{}, error) { return a.Add(b), nil }), + InfixDecimalOperator("-", func(a, b decimal.Decimal) (interface{}, error) { return a.Sub(b), nil }), + InfixDecimalOperator("*", func(a, b decimal.Decimal) (interface{}, error) { return a.Mul(b), nil }), + InfixDecimalOperator("/", func(a, b decimal.Decimal) (interface{}, error) { return a.Div(b), nil }), + InfixDecimalOperator("%", func(a, b decimal.Decimal) (interface{}, error) { return a.Mod(b), nil }), + InfixDecimalOperator("**", func(a, b decimal.Decimal) (interface{}, error) { return a.Pow(b), nil }), + + InfixDecimalOperator(">", func(a, b decimal.Decimal) (interface{}, error) { return a.GreaterThan(b), nil }), + InfixDecimalOperator(">=", func(a, b decimal.Decimal) (interface{}, error) { return a.GreaterThanOrEqual(b), nil }), + InfixDecimalOperator("<", func(a, b decimal.Decimal) (interface{}, error) { return a.LessThan(b), nil }), + InfixDecimalOperator("<=", func(a, b decimal.Decimal) (interface{}, error) { return a.LessThanOrEqual(b), nil }), + + InfixDecimalOperator("==", func(a, b decimal.Decimal) (interface{}, error) { return a.Equal(b), nil }), + InfixDecimalOperator("!=", func(a, b decimal.Decimal) (interface{}, error) { return !a.Equal(b), nil }), + base, + //Base is before these overrides so that the Base options are overridden + PrefixExtension(scanner.Int, parseDecimal), + PrefixExtension(scanner.Float, parseDecimal), + PrefixOperator("-", func(c context.Context, v interface{}) (interface{}, error) { + i, ok := convertToUint(v) + if !ok { + return nil, fmt.Errorf("unexpected %v(%T) expected number", v, v) + } + return -i, nil + }), +) + +var bitmask = NewLanguage( + InfixNumberOperator("^", func(a, b uint) (interface{}, error) { return uint(int64(a) ^ int64(b)), nil }), + InfixNumberOperator("&", func(a, b uint) (interface{}, error) { return uint(int64(a) & int64(b)), nil }), + InfixNumberOperator("|", func(a, b uint) (interface{}, error) { return uint(int64(a) | int64(b)), nil }), + InfixNumberOperator("<<", func(a, b uint) (interface{}, error) { return uint(int64(a) << uint64(b)), nil }), + InfixNumberOperator(">>", func(a, b uint) (interface{}, error) { return uint(int64(a) >> uint64(b)), nil }), + + PrefixOperator("~", func(c context.Context, v interface{}) (interface{}, error) { + i, ok := convertToUint(v) + if !ok { + return nil, fmt.Errorf("unexpected %T expected number", v) + } + return float64(^int64(i)), nil + }), +) + +var text = NewLanguage( + InfixTextOperator("+", func(a, b string) (interface{}, error) { return fmt.Sprintf("%v%v", a, b), nil }), + + InfixTextOperator("<", func(a, b string) (interface{}, error) { return a < b, nil }), + InfixTextOperator("<=", func(a, b string) (interface{}, error) { return a <= b, nil }), + InfixTextOperator(">", func(a, b string) (interface{}, error) { return a > b, nil }), + InfixTextOperator(">=", func(a, b string) (interface{}, error) { return a >= b, nil }), + + InfixEvalOperator("=~", regEx), + InfixEvalOperator("!~", notRegEx), + base, +) + +var propositionalLogic = NewLanguage( + PrefixOperator("!", func(c context.Context, v interface{}) (interface{}, error) { + b, ok := convertToBool(v) + if !ok { + return nil, fmt.Errorf("unexpected %T expected bool", v) + } + return !b, nil + }), + + InfixShortCircuit("&&", func(a interface{}) (interface{}, bool) { return false, a == false }), + InfixBoolOperator("&&", func(a, b bool) (interface{}, error) { return a && b, nil }), + InfixShortCircuit("||", func(a interface{}) (interface{}, bool) { return true, a == true }), + InfixBoolOperator("||", func(a, b bool) (interface{}, error) { return a || b, nil }), + + InfixBoolOperator("==", func(a, b bool) (interface{}, error) { return a == b, nil }), + InfixBoolOperator("!=", func(a, b bool) (interface{}, error) { return a != b, nil }), + + base, +) + +var parentheses = NewLanguage( + PrefixExtension('(', parseParentheses), +) + +var ident = NewLanguage( + PrefixMetaPrefix(scanner.Ident, parseIdent), +) + +var base = NewLanguage( + PrefixExtension(scanner.Int, parseNumber), + PrefixExtension(scanner.Float, parseNumber), + PrefixOperator("-", func(c context.Context, v interface{}) (interface{}, error) { + i, ok := convertToUint(v) + if !ok { + return nil, fmt.Errorf("unexpected %v(%T) expected number", v, v) + } + return -i, nil + }), + + PrefixExtension(scanner.String, parseString), + PrefixExtension(scanner.Char, parseString), + PrefixExtension(scanner.RawString, parseString), + + Constant("true", true), + Constant("false", false), + + InfixOperator("==", func(a, b interface{}) (interface{}, error) { return reflect.DeepEqual(a, b), nil }), + InfixOperator("!=", func(a, b interface{}) (interface{}, error) { return !reflect.DeepEqual(a, b), nil }), + parentheses, + + Precedence("??", 0), + + Precedence("||", 20), + Precedence("&&", 21), + + Precedence("==", 40), + Precedence("!=", 40), + Precedence(">", 40), + Precedence(">=", 40), + Precedence("<", 40), + Precedence("<=", 40), + Precedence("=~", 40), + Precedence("!~", 40), + Precedence("in", 40), + + Precedence("^", 60), + Precedence("&", 60), + Precedence("|", 60), + + Precedence("<<", 90), + Precedence(">>", 90), + + Precedence("+", 120), + Precedence("-", 120), + + Precedence("*", 150), + Precedence("/", 150), + Precedence("%", 150), + + Precedence("**", 200), + + ident, +) diff --git a/gval/gval_evaluationFailure_test.go b/gval/gval_evaluationFailure_test.go new file mode 100644 index 0000000..b52500b --- /dev/null +++ b/gval/gval_evaluationFailure_test.go @@ -0,0 +1,396 @@ +package gval + +/* + Tests to make sure evaluation fails in the expected ways. +*/ +import ( + "errors" + "fmt" + "testing" +) + +func TestModifierTyping(test *testing.T) { + var ( + invalidOperator = "invalid operation" + unknownParameter = "unknown parameter" + invalidRegex = "error parsing regex" + tooFewArguments = "reflect: Call with too few input arguments" + tooManyArguments = "reflect: Call with too many input arguments" + mismatchedParameters = "reflect: Call using" + custom = "test error" + ) + evaluationTests := []evaluationTest{ + //ModifierTyping + { + name: "PLUS literal number to literal bool", + expression: "1 + true", + want: "1true", // + on string is defined + }, + { + name: "PLUS number to bool", + expression: "number + bool", + want: "1true", // + on string is defined + }, + { + name: "MINUS number to bool", + expression: "number - bool", + wantErr: invalidOperator, + }, + { + name: "MINUS number to bool", + expression: "number - bool", + wantErr: invalidOperator, + }, + { + name: "MULTIPLY number to bool", + expression: "number * bool", + wantErr: invalidOperator, + }, + { + name: "DIVIDE number to bool", + expression: "number / bool", + wantErr: invalidOperator, + }, + { + name: "EXPONENT number to bool", + expression: "number ** bool", + wantErr: invalidOperator, + }, + { + name: "MODULUS number to bool", + expression: "number % bool", + wantErr: invalidOperator, + }, + { + name: "XOR number to bool", + expression: "number % bool", + wantErr: invalidOperator, + }, + { + name: "BITWISE_OR number to bool", + expression: "number | bool", + wantErr: invalidOperator, + }, + { + name: "BITWISE_AND number to bool", + expression: "number & bool", + wantErr: invalidOperator, + }, + { + name: "BITWISE_XOR number to bool", + expression: "number ^ bool", + wantErr: invalidOperator, + }, + { + name: "BITWISE_LSHIFT number to bool", + expression: "number << bool", + wantErr: invalidOperator, + }, + { + name: "BITWISE_RSHIFT number to bool", + expression: "number >> bool", + wantErr: invalidOperator, + }, + //LogicalOperatorTyping + { + name: "AND number to number", + expression: "number && number", + want: true, // number != 0 is true + }, + { + + name: "OR number to number", + expression: "number || number", + want: true, // number != 0 is true + }, + { + name: "AND string to string", + expression: "string && string", + wantErr: invalidOperator, + }, + { + name: "OR string to string", + expression: "string || string", + wantErr: invalidOperator, + }, + { + name: "AND number to string", + expression: "number && string", + wantErr: invalidOperator, + }, + { + name: "OR number to string", + expression: "number || string", + wantErr: invalidOperator, + }, + { + name: "AND bool to string", + expression: "bool && string", + wantErr: invalidOperator, + }, + { + name: "OR string to bool", + expression: "string || bool", + wantErr: invalidOperator, + }, + //ComparatorTyping + { + name: "GT literal bool to literal bool", + expression: "true > true", + want: false, //lexical order on "true" + }, + { + name: "GT bool to bool", + expression: "bool > bool", + want: false, //lexical order on "true" + }, + { + name: "GTE bool to bool", + expression: "bool >= bool", + want: true, //lexical order on "true" + }, + { + name: "LT bool to bool", + expression: "bool < bool", + want: false, //lexical order on "true" + }, + { + name: "LTE bool to bool", + expression: "bool <= bool", + want: true, //lexical order on "true" + }, + { + name: "GT number to string", + expression: "number > string", + want: false, //lexical order "1" < "foo" + }, + { + + name: "GTE number to string", + expression: "number >= string", + want: false, //lexical order "1" < "foo" + }, + { + name: "LT number to string", + expression: "number < string", + want: true, //lexical order "1" < "foo" + }, + { + name: "REQ number to string", + expression: "number =~ string", + want: false, + }, + { + name: "REQ number to bool", + expression: "number =~ bool", + want: false, + }, + { + name: "REQ bool to number", + expression: "bool =~ number", + want: false, + }, + { + name: "REQ bool to string", + expression: "bool =~ string", + want: false, + }, + { + name: "NREQ number to string", + expression: "number !~ string", + want: true, + }, + { + name: "NREQ number to bool", + expression: "number !~ bool", + want: true, + }, + { + name: "NREQ bool to number", + expression: "bool !~ number", + want: true, + }, + { + + name: "NREQ bool to string", + expression: "bool !~ string", + want: true, + }, + { + name: "IN non-array numeric", + expression: "1 in 2", + wantErr: "expected type []interface{} for in operator but got float64", + }, + { + name: "IN non-array string", + expression: `1 in "foo"`, + wantErr: "expected type []interface{} for in operator but got string", + }, + { + + name: "IN non-array boolean", + expression: "1 in true", + wantErr: "expected type []interface{} for in operator but got bool", + }, + //TernaryTyping + { + name: "Ternary with number", + expression: "10 ? true", + want: true, // 10 != nil && 10 != false + }, + { + name: "Ternary with string", + expression: `"foo" ? true`, + want: true, // "foo" != nil && "foo" != false + }, + //RegexParameterCompilation + { + name: "Regex equality runtime parsing", + expression: `"foo" =~ foo`, + parameter: map[string]interface{}{ + "foo": "[foo", + }, + wantErr: invalidRegex, + }, + { + name: "Regex inequality runtime parsing", + expression: `"foo" !~ foo`, + parameter: map[string]interface{}{ + "foo": "[foo", + }, + wantErr: invalidRegex, + }, + { + name: "Regex equality runtime right side evaluation", + expression: `"foo" =~ error()`, + wantErr: custom, + }, + { + name: "Regex inequality runtime right side evaluation", + expression: `"foo" !~ error()`, + wantErr: custom, + }, + { + name: "Regex equality runtime left side evaluation", + expression: `error() =~ "."`, + wantErr: custom, + }, + { + name: "Regex inequality runtime left side evaluation", + expression: `error() !~ "."`, + wantErr: custom, + }, + //FuncExecution + { + name: "Func error bubbling", + expression: "error()", + extension: Function("error", func(arguments ...interface{}) (interface{}, error) { + return nil, errors.New("Huge problems") + }), + wantErr: "Huge problems", + }, + //InvalidParameterCalls + { + name: "Missing parameter field reference", + expression: "foo.NotExists", + parameter: fooFailureParameters, + wantErr: unknownParameter, + }, + { + name: "Parameter method call on missing function", + expression: "foo.NotExist()", + parameter: fooFailureParameters, + wantErr: unknownParameter, + }, + { + name: "Nested missing parameter field reference", + expression: "foo.Nested.NotExists", + parameter: fooFailureParameters, + wantErr: unknownParameter, + }, + { + name: "Parameter method call returns error", + expression: "foo.AlwaysFail()", + parameter: fooFailureParameters, + wantErr: "function should always fail", + }, + { + name: "Too few arguments to parameter call", + expression: "foo.FuncArgStr()", + parameter: fooFailureParameters, + wantErr: tooFewArguments, + }, + { + name: "Too many arguments to parameter call", + expression: `foo.FuncArgStr("foo", "bar", 15)`, + parameter: fooFailureParameters, + wantErr: tooManyArguments, + }, + { + name: "Mismatched parameters", + expression: "foo.FuncArgStr(5)", + parameter: fooFailureParameters, + wantErr: mismatchedParameters, + }, + { + name: "Negative Array Index", + expression: "foo[-1]", + parameter: map[string]interface{}{ + "foo": []int{1, 2, 3}, + }, + wantErr: unknownParameter, + }, + { + name: "Nested slice call index out of bound", + expression: `foo.Nested.Slice[10]`, + parameter: map[string]interface{}{"foo": foo}, + wantErr: unknownParameter, + }, + { + name: "Nested map call missing key", + expression: `foo.Nested.Map["d"]`, + parameter: map[string]interface{}{"foo": foo}, + wantErr: unknownParameter, + }, + { + name: "invalid selector", + expression: "hello[world()]", + extension: NewLanguage(Base(), Function("world", func() (int, error) { + return 0, fmt.Errorf("test error") + })), + wantErr: "test error", + }, + { + name: "eval `nil > 1` returns true #23", + expression: `nil > 1`, + wantErr: "invalid operation () > (float64)", + }, + { + name: "map with unknown func", + expression: `foo.MapWithFunc.NotExist()`, + parameter: map[string]interface{}{"foo": foo}, + wantErr: unknownParameter, + }, + { + name: "map with unknown func", + expression: `foo.SliceWithFunc.NotExist()`, + parameter: map[string]interface{}{"foo": foo}, + wantErr: unknownParameter, + }, + } + + for i := range evaluationTests { + if evaluationTests[i].parameter == nil { + evaluationTests[i].parameter = map[string]interface{}{ + "number": 1, + "string": "foo", + "bool": true, + "error": func() (int, error) { + return 0, fmt.Errorf("test error") + }, + } + } + } + + testEvaluate(evaluationTests, test) +} diff --git a/gval/gval_noparameter_test.go b/gval/gval_noparameter_test.go new file mode 100644 index 0000000..e0af69c --- /dev/null +++ b/gval/gval_noparameter_test.go @@ -0,0 +1,818 @@ +package gval + +import ( + "context" + "fmt" + "testing" + "text/scanner" +) + +func TestNoParameter(t *testing.T) { + testEvaluate( + []evaluationTest{ + { + name: "Number", + expression: "100", + want: 100.0, + }, + { + name: "Single PLUS", + expression: "51 + 49", + want: 100.0, + }, + { + name: "Single MINUS", + expression: "100 - 51", + want: 49.0, + }, + { + name: "Single BITWISE AND", + expression: "100 & 50", + want: 32.0, + }, + { + name: "Single BITWISE OR", + expression: "100 | 50", + want: 118.0, + }, + { + name: "Single BITWISE XOR", + expression: "100 ^ 50", + want: 86.0, + }, + { + name: "Single shift left", + expression: "2 << 1", + want: 4.0, + }, + { + name: "Single shift right", + expression: "2 >> 1", + want: 1.0, + }, + { + name: "Single BITWISE NOT", + expression: "~10", + want: -11.0, + }, + { + + name: "Single MULTIPLY", + expression: "5 * 20", + want: 100.0, + }, + { + + name: "Single DIVIDE", + expression: "100 / 20", + want: 5.0, + }, + { + + name: "Single even MODULUS", + expression: "100 % 2", + want: 0.0, + }, + { + name: "Single odd MODULUS", + expression: "101 % 2", + want: 1.0, + }, + { + + name: "Single EXPONENT", + expression: "10 ** 2", + want: 100.0, + }, + { + + name: "Compound PLUS", + expression: "20 + 30 + 50", + want: 100.0, + }, + { + + name: "Compound BITWISE AND", + expression: "20 & 30 & 50", + want: 16.0, + }, + { + name: "Mutiple operators", + expression: "20 * 5 - 49", + want: 51.0, + }, + { + name: "Parenthesis usage", + expression: "100 - (5 * 10)", + want: 50.0, + }, + { + + name: "Nested parentheses", + expression: "50 + (5 * (15 - 5))", + want: 100.0, + }, + { + + name: "Nested parentheses with bitwise", + expression: "100 ^ (23 * (2 | 5))", + want: 197.0, + }, + { + name: "Logical OR operation of two clauses", + expression: "(1 == 1) || (true == true)", + want: true, + }, + { + name: "Logical AND operation of two clauses", + expression: "(1 == 1) && (true == true)", + want: true, + }, + { + + name: "Implicit boolean", + expression: "2 > 1", + want: true, + }, + { + name: "Equal test minus numbers and no spaces", + expression: "-1==-1", + want: true, + }, + { + + name: "Compound boolean", + expression: "5 < 10 && 1 < 5", + want: true, + }, + { + name: "Evaluated true && false operation (for issue #8)", + expression: "1 > 10 && 11 > 10", + want: false, + }, + { + + name: "Evaluated true && false operation (for issue #8)", + expression: "true == true && false == true", + want: false, + }, + { + + name: "Parenthesis boolean", + expression: "10 < 50 && (1 != 2 && 1 > 0)", + want: true, + }, + { + name: "Comparison of string constants", + expression: `"foo" == "foo"`, + want: true, + }, + { + + name: "NEQ comparison of string constants", + expression: `"foo" != "bar"`, + want: true, + }, + { + + name: "REQ comparison of string constants", + expression: `"foobar" =~ "oba"`, + want: true, + }, + { + + name: "NREQ comparison of string constants", + expression: `"foo" !~ "bar"`, + want: true, + }, + { + + name: "Multiplicative/additive order", + expression: "5 + 10 * 2", + want: 25.0, + }, + { + name: "Multiple constant multiplications", + expression: "10 * 10 * 10", + want: 1000.0, + }, + { + + name: "Multiple adds/multiplications", + expression: "10 * 10 * 10 + 1 * 10 * 10", + want: 1100.0, + }, + { + + name: "Modulus operatorPrecedence", + expression: "1 + 101 % 2 * 5", + want: 6.0, + }, + { + name: "Exponent operatorPrecedence", + expression: "1 + 5 ** 3 % 2 * 5", + want: 6.0, + }, + { + + name: "Bit shift operatorPrecedence", + expression: "50 << 1 & 90", + want: 64.0, + }, + { + + name: "Bit shift operatorPrecedence", + expression: "90 & 50 << 1", + want: 64.0, + }, + { + + name: "Bit shift operatorPrecedence amongst non-bitwise", + expression: "90 + 50 << 1 * 5", + want: 4480.0, + }, + { + name: "Order of non-commutative same-operatorPrecedence operators (additive)", + expression: "1 - 2 - 4 - 8", + want: -13.0, + }, + { + name: "Order of non-commutative same-operatorPrecedence operators (multiplicative)", + expression: "1 * 4 / 2 * 8", + want: 16.0, + }, + { + name: "Null coalesce operatorPrecedence", + expression: "true ?? true ? 100 + 200 : 400", + want: 300.0, + }, + { + name: "Identical date equivalence", + expression: `"2014-01-02 14:12:22" == "2014-01-02 14:12:22"`, + want: true, + }, + { + name: "Positive date GT", + expression: `"2014-01-02 14:12:22" > "2014-01-02 12:12:22"`, + want: true, + }, + { + name: "Negative date GT", + expression: `"2014-01-02 14:12:22" > "2014-01-02 16:12:22"`, + want: false, + }, + { + name: "Positive date GTE", + expression: `"2014-01-02 14:12:22" >= "2014-01-02 12:12:22"`, + want: true, + }, + { + name: "Negative date GTE", + expression: `"2014-01-02 14:12:22" >= "2014-01-02 16:12:22"`, + want: false, + }, + { + name: "Positive date LT", + expression: `"2014-01-02 14:12:22" < "2014-01-02 16:12:22"`, + want: true, + }, + { + + name: "Negative date LT", + expression: `"2014-01-02 14:12:22" < "2014-01-02 11:12:22"`, + want: false, + }, + { + + name: "Positive date LTE", + expression: `"2014-01-02 09:12:22" <= "2014-01-02 12:12:22"`, + want: true, + }, + { + name: "Negative date LTE", + expression: `"2014-01-02 14:12:22" <= "2014-01-02 11:12:22"`, + want: false, + }, + { + + name: "Sign prefix comparison", + expression: "-1 < 0", + want: true, + }, + { + + name: "Lexicographic LT", + expression: `"ab" < "abc"`, + want: true, + }, + { + name: "Lexicographic LTE", + expression: `"ab" <= "abc"`, + want: true, + }, + { + + name: "Lexicographic GT", + expression: `"aba" > "abc"`, + want: false, + }, + { + + name: "Lexicographic GTE", + expression: `"aba" >= "abc"`, + want: false, + }, + { + + name: "Boolean sign prefix comparison", + expression: "!true == false", + want: true, + }, + { + name: "Inversion of clause", + expression: "!(10 < 0)", + want: true, + }, + { + + name: "Negation after modifier", + expression: "10 * -10", + want: -100.0, + }, + { + + name: "Ternary with single boolean", + expression: "true ? 10", + want: 10.0, + }, + { + + name: "Ternary nil with single boolean", + expression: "false ? 10", + want: nil, + }, + { + name: "Ternary with comparator boolean", + expression: "10 > 5 ? 35.50", + want: 35.50, + }, + { + + name: "Ternary nil with comparator boolean", + expression: "1 > 5 ? 35.50", + want: nil, + }, + { + + name: "Ternary with parentheses", + expression: "(5 * (15 - 5)) > 5 ? 35.50", + want: 35.50, + }, + { + + name: "Ternary operatorPrecedence", + expression: "true ? 35.50 > 10", + want: true, + }, + { + name: "Ternary-else", + expression: "false ? 35.50 : 50", + want: 50.0, + }, + { + + name: "Ternary-else inside clause", + expression: "(false ? 5 : 35.50) > 10", + want: true, + }, + { + + name: "Ternary-else (true-case) inside clause", + expression: "(true ? 1 : 5) < 10", + want: true, + }, + { + + name: "Ternary-else before comparator (negative case)", + expression: "true ? 1 : 5 > 10", + want: 1.0, + }, + { + name: "Nested ternaries (#32)", + expression: "(2 == 2) ? 1 : (true ? 2 : 3)", + want: 1.0, + }, + { + + name: "Nested ternaries, right case (#32)", + expression: "false ? 1 : (true ? 2 : 3)", + want: 2.0, + }, + { + + name: "Doubly-nested ternaries (#32)", + expression: "true ? (false ? 1 : (false ? 2 : 3)) : (false ? 4 : 5)", + want: 3.0, + }, + { + + name: "String to string concat", + expression: `"foo" + "bar" == "foobar"`, + want: true, + }, + { + name: "String to float64 concat", + expression: `"foo" + 123 == "foo123"`, + want: true, + }, + { + + name: "Float64 to string concat", + expression: `123 + "bar" == "123bar"`, + want: true, + }, + { + + name: "String to date concat", + expression: `"foo" + "02/05/1970" == "foobar"`, + want: false, + }, + { + + name: "String to bool concat", + expression: `"foo" + true == "footrue"`, + want: true, + }, + { + name: "Bool to string concat", + expression: `true + "bar" == "truebar"`, + want: true, + }, + { + + name: "Null coalesce left", + expression: "1 ?? 2", + want: 1.0, + }, + { + + name: "Array membership literals", + expression: "1 in [1, 2, 3]", + want: true, + }, + { + + name: "Array membership literal with inversion", + expression: "!(1 in [1, 2, 3])", + want: false, + }, + { + name: "Logical operator reordering (#30)", + expression: "(true && true) || (true && false)", + want: true, + }, + { + + name: "Logical operator reordering without parens (#30)", + expression: "true && true || true && false", + want: true, + }, + { + + name: "Logical operator reordering with multiple OR (#30)", + expression: "false || true && true || false", + want: true, + }, + { + name: "Left-side multiple consecutive (should be reordered) operators", + expression: "(10 * 10 * 10) > 10", + want: true, + }, + { + + name: "Three-part non-paren logical op reordering (#44)", + expression: "false && true || true", + want: true, + }, + { + + name: "Three-part non-paren logical op reordering (#44), second one", + expression: "true || false && true", + want: true, + }, + { + name: "Logical operator reordering without parens (#45)", + expression: "true && true || false && false", + want: true, + }, + { + name: "Single function", + expression: "foo()", + extension: Function("foo", func(arguments ...interface{}) (interface{}, error) { + return true, nil + }), + + want: true, + }, + { + name: "Func with argument", + expression: "passthrough(1)", + extension: Function("passthrough", func(arguments ...interface{}) (interface{}, error) { + return arguments[0], nil + }), + want: 1.0, + }, + { + name: "Func with arguments", + expression: "passthrough(1, 2)", + extension: Function("passthrough", func(arguments ...interface{}) (interface{}, error) { + return arguments[0].(float64) + arguments[1].(float64), nil + }), + want: 3.0, + }, + { + name: "Nested function with operatorPrecedence", + expression: "sum(1, sum(2, 3), 2 + 2, true ? 4 : 5)", + extension: Function("sum", func(arguments ...interface{}) (interface{}, error) { + sum := 0.0 + for _, v := range arguments { + sum += v.(float64) + } + return sum, nil + }), + want: 14.0, + }, + { + name: "Empty function and modifier, compared", + expression: "numeric()-1 > 0", + extension: Function("numeric", func(arguments ...interface{}) (interface{}, error) { + return 2.0, nil + }), + want: true, + }, + { + name: "Empty function comparator", + expression: "numeric() > 0", + extension: Function("numeric", func(arguments ...interface{}) (interface{}, error) { + return 2.0, nil + }), + want: true, + }, + { + + name: "Empty function logical operator", + expression: "success() && !false", + extension: Function("success", func(arguments ...interface{}) (interface{}, error) { + return true, nil + }), + want: true, + }, + { + name: "Empty function ternary", + expression: "nope() ? 1 : 2.0", + extension: Function("nope", func(arguments ...interface{}) (interface{}, error) { + return false, nil + }), + want: 2.0, + }, + { + + name: "Empty function null coalesce", + expression: "null() ?? 2", + extension: Function("null", func(arguments ...interface{}) (interface{}, error) { + return nil, nil + }), + want: 2.0, + }, + { + name: "Empty function with prefix", + expression: "-ten()", + extension: Function("ten", func(arguments ...interface{}) (interface{}, error) { + return 10.0, nil + }), + want: -10.0, + }, + { + name: "Empty function as part of chain", + expression: "10 - numeric() - 2", + extension: Function("numeric", func(arguments ...interface{}) (interface{}, error) { + return 5.0, nil + }), + want: 3.0, + }, + { + name: "Empty function near separator", + expression: "10 in [1, 2, 3, ten(), 8]", + extension: Function("ten", func(arguments ...interface{}) (interface{}, error) { + return 10.0, nil + }), + want: true, + }, + { + name: "Enclosed empty function with modifier and comparator (#28)", + expression: "(ten() - 1) > 3", + extension: Function("ten", func(arguments ...interface{}) (interface{}, error) { + return 10.0, nil + }), + want: true, + }, + { + name: "Array", + expression: `[(ten() - 1) > 3, (ten() - 1),"hey"]`, + extension: Function("ten", func(arguments ...interface{}) (interface{}, error) { + return 10.0, nil + }), + want: []interface{}{true, 9., "hey"}, + }, + { + name: "Object", + expression: `{1: (ten() - 1) > 3, 7 + ".X" : (ten() - 1),"hello" : "hey"}`, + extension: Function("ten", func(arguments ...interface{}) (interface{}, error) { + return 10.0, nil + }), + want: map[string]interface{}{"1": true, "7.X": 9., "hello": "hey"}, + }, + { + name: "Object negativ value", + expression: `{1: -1,"hello" : "hey"}`, + want: map[string]interface{}{"1": -1., "hello": "hey"}, + }, + { + name: "Empty Array", + expression: `[]`, + want: []interface{}{}, + }, + { + name: "Empty Object", + expression: `{}`, + want: map[string]interface{}{}, + }, + { + name: "Variadic", + expression: `sum(1,2,3,4)`, + extension: Function("sum", func(arguments ...float64) (interface{}, error) { + sum := 0. + for _, a := range arguments { + sum += a + } + return sum, nil + }), + want: 10.0, + }, + { + name: "Ident Operator", + expression: `1 plus 1`, + extension: InfixNumberOperator("plus", func(a, b float64) (interface{}, error) { + return a + b, nil + }), + want: 2.0, + }, + { + name: "Postfix Operator", + expression: `4§`, + extension: PostfixOperator("§", func(_ context.Context, _ *Parser, eval Evaluable) (Evaluable, error) { + return func(ctx context.Context, parameter interface{}) (interface{}, error) { + i, err := eval.EvalInt(ctx, parameter) + if err != nil { + return nil, err + } + return fmt.Sprintf("§%d", i), nil + }, nil + }), + want: "§4", + }, + { + name: "Tabs as non-whitespace", + expression: "4\t5\t6", + extension: NewLanguage( + Init(func(ctx context.Context, p *Parser) (Evaluable, error) { + p.SetWhitespace('\n', '\r', ' ') + return p.ParseExpression(ctx) + }), + InfixNumberOperator("\t", func(a, b float64) (interface{}, error) { + return a * b, nil + }), + ), + want: 120.0, + }, + { + name: "Handle all other prefixes", + expression: "^foo + $bar + &baz", + extension: DefaultExtension(func(ctx context.Context, p *Parser) (Evaluable, error) { + var mul int + switch p.TokenText() { + case "^": + mul = 1 + case "$": + mul = 2 + case "&": + mul = 3 + } + + switch p.Scan() { + case scanner.Ident: + return p.Const(mul * len(p.TokenText())), nil + default: + return nil, p.Expected("length multiplier", scanner.Ident) + } + }), + want: 18.0, + }, + { + name: "Embed languages", + expression: "left { 5 + 5 } right", + extension: func() Language { + step := func(ctx context.Context, p *Parser, cur Evaluable) (Evaluable, error) { + next, err := p.ParseExpression(ctx) + if err != nil { + return nil, err + } + + return func(ctx context.Context, parameter interface{}) (interface{}, error) { + us, err := cur.EvalString(ctx, parameter) + if err != nil { + return nil, err + } + + them, err := next.EvalString(ctx, parameter) + if err != nil { + return nil, err + } + + return us + them, nil + }, nil + } + + return NewLanguage( + Init(func(ctx context.Context, p *Parser) (Evaluable, error) { + p.SetWhitespace() + p.SetMode(0) + + return p.ParseExpression(ctx) + }), + DefaultExtension(func(ctx context.Context, p *Parser) (Evaluable, error) { + return step(ctx, p, p.Const(p.TokenText())) + }), + PrefixExtension(scanner.EOF, func(ctx context.Context, p *Parser) (Evaluable, error) { + return p.Const(""), nil + }), + PrefixExtension('{', func(ctx context.Context, p *Parser) (Evaluable, error) { + eval, err := p.ParseSublanguage(ctx, Full()) + if err != nil { + return nil, err + } + + switch p.Scan() { + case '}': + default: + return nil, p.Expected("embedded", '}') + } + + return step(ctx, p, eval) + }), + ) + }(), + want: "left 10 right", + }, + { + name: "Late binding", + expression: "5 * [ 10 * { 20 / [ 10 ] } ]", + extension: func() Language { + var inner, outer Language + + parseCurly := func(ctx context.Context, p *Parser) (Evaluable, error) { + eval, err := p.ParseSublanguage(ctx, outer) + if err != nil { + return nil, err + } + + if p.Scan() != '}' { + return nil, p.Expected("end", '}') + } + + return eval, nil + } + + parseSquare := func(ctx context.Context, p *Parser) (Evaluable, error) { + eval, err := p.ParseSublanguage(ctx, inner) + if err != nil { + return nil, err + } + + if p.Scan() != ']' { + return nil, p.Expected("end", ']') + } + + return eval, nil + } + + inner = Full(PrefixExtension('{', parseCurly)) + outer = Full(PrefixExtension('[', parseSquare)) + return outer + }(), + want: 100.0, + }, + }, + t, + ) +} diff --git a/gval/gval_parameterized_test.go b/gval/gval_parameterized_test.go new file mode 100644 index 0000000..c1f1bd3 --- /dev/null +++ b/gval/gval_parameterized_test.go @@ -0,0 +1,719 @@ +package gval + +import ( + "context" + "fmt" + "regexp" + "strings" + "testing" + "time" + + "github.com/shopspring/decimal" +) + +func TestParameterized(t *testing.T) { + testEvaluate( + []evaluationTest{ + { + name: "Single parameter modified by constant", + expression: "foo + 2", + parameter: map[string]interface{}{ + "foo": 2.0, + }, + want: 4.0, + }, + { + + name: "Single parameter modified by variable", + expression: "foo * bar", + parameter: map[string]interface{}{ + "foo": 5.0, + "bar": 2.0, + }, + want: 10.0, + }, + { + + name: "Single parameter modified by variable", + expression: `foo["hey"] * bar[1]`, + parameter: map[string]interface{}{ + "foo": map[string]interface{}{"hey": 5.0}, + "bar": []interface{}{7., 2.0}, + }, + want: 10.0, + }, + { + + name: "Multiple multiplications of the same parameter", + expression: "foo * foo * foo", + parameter: map[string]interface{}{ + "foo": 10.0, + }, + want: 1000.0, + }, + { + + name: "Multiple additions of the same parameter", + expression: "foo + foo + foo", + parameter: map[string]interface{}{ + "foo": 10.0, + }, + want: 30.0, + }, + { + name: "NoSpaceOperator", + expression: "true&&name", + parameter: map[string]interface{}{ + "name": true, + }, + want: true, + }, + { + + name: "Parameter name sensitivity", + expression: "foo + FoO + FOO", + parameter: map[string]interface{}{ + "foo": 8.0, + "FoO": 4.0, + "FOO": 2.0, + }, + want: 14.0, + }, + { + + name: "Sign prefix comparison against prefixed variable", + expression: "-1 < -foo", + parameter: map[string]interface{}{"foo": -8.0}, + want: true, + }, + { + + name: "Fixed-point parameter", + expression: "foo > 1", + parameter: map[string]interface{}{"foo": 2}, + want: true, + }, + { + + name: "Modifier after closing clause", + expression: "(2 + 2) + 2 == 6", + want: true, + }, + { + + name: "Comparator after closing clause", + expression: "(2 + 2) >= 4", + want: true, + }, + { + + name: "Two-boolean logical operation (for issue #8)", + expression: "(foo == true) || (bar == true)", + parameter: map[string]interface{}{ + "foo": true, + "bar": false, + }, + want: true, + }, + { + + name: "Two-variable integer logical operation (for issue #8)", + expression: "foo > 10 && bar > 10", + parameter: map[string]interface{}{ + "foo": 1, + "bar": 11, + }, + want: false, + }, + { + + name: "Regex against right-hand parameter", + expression: `"foobar" =~ foo`, + parameter: map[string]interface{}{ + "foo": "obar", + }, + want: true, + }, + { + + name: "Not-regex against right-hand parameter", + expression: `"foobar" !~ foo`, + parameter: map[string]interface{}{ + "foo": "baz", + }, + want: true, + }, + { + + name: "Regex against two parameter", + expression: `foo =~ bar`, + parameter: map[string]interface{}{ + "foo": "foobar", + "bar": "oba", + }, + want: true, + }, + { + + name: "Not-regex against two parameter", + expression: "foo !~ bar", + parameter: map[string]interface{}{ + "foo": "foobar", + "bar": "baz", + }, + want: true, + }, + { + + name: "Pre-compiled regex", + expression: "foo =~ bar", + parameter: map[string]interface{}{ + "foo": "foobar", + "bar": regexp.MustCompile("[fF][oO]+"), + }, + want: true, + }, + { + + name: "Pre-compiled not-regex", + expression: "foo !~ bar", + parameter: map[string]interface{}{ + "foo": "foobar", + "bar": regexp.MustCompile("[fF][oO]+"), + }, + want: false, + }, + { + + name: "Single boolean parameter", + expression: "commission ? 10", + parameter: map[string]interface{}{ + "commission": true}, + want: 10.0, + }, + { + + name: "True comparator with a parameter", + expression: `partner == "amazon" ? 10`, + parameter: map[string]interface{}{ + "partner": "amazon"}, + want: 10.0, + }, + { + + name: "False comparator with a parameter", + expression: `partner == "amazon" ? 10`, + parameter: map[string]interface{}{ + "partner": "ebay"}, + want: nil, + }, + { + + name: "True comparator with multiple parameters", + expression: "theft && period == 24 ? 60", + parameter: map[string]interface{}{ + "theft": true, + "period": 24, + }, + want: 60.0, + }, + { + + name: "False comparator with multiple parameters", + expression: "theft && period == 24 ? 60", + parameter: map[string]interface{}{ + "theft": false, + "period": 24, + }, + want: nil, + }, + { + + name: "String concat with single string parameter", + expression: `foo + "bar"`, + parameter: map[string]interface{}{ + "foo": "baz"}, + want: "bazbar", + }, + { + + name: "String concat with multiple string parameter", + expression: "foo + bar", + parameter: map[string]interface{}{ + "foo": "baz", + "bar": "quux", + }, + want: "bazquux", + }, + { + + name: "String concat with float parameter", + expression: "foo + bar", + parameter: map[string]interface{}{ + "foo": "baz", + "bar": 123.0, + }, + want: "baz123", + }, + { + + name: "Mixed multiple string concat", + expression: `foo + 123 + "bar" + true`, + parameter: map[string]interface{}{"foo": "baz"}, + want: "baz123bartrue", + }, + { + + name: "Integer width spectrum", + expression: "uint8 + uint16 + uint32 + uint64 + int8 + int16 + int32 + int64", + parameter: map[string]interface{}{ + "uint8": uint8(0), + "uint16": uint16(0), + "uint32": uint32(0), + "uint64": uint64(0), + "int8": int8(0), + "int16": int16(0), + "int32": int32(0), + "int64": int64(0), + }, + want: 0.0, + }, + { + + name: "Null coalesce right", + expression: "foo ?? 1.0", + parameter: map[string]interface{}{"foo": nil}, + want: 1.0, + }, + { + + name: "Multiple comparator/logical operators (#30)", + expression: "(foo >= 2887057408 && foo <= 2887122943) || (foo >= 168100864 && foo <= 168118271)", + parameter: map[string]interface{}{"foo": 2887057409}, + want: true, + }, + { + + name: "Multiple comparator/logical operators, opposite order (#30)", + expression: "(foo >= 168100864 && foo <= 168118271) || (foo >= 2887057408 && foo <= 2887122943)", + parameter: map[string]interface{}{"foo": 2887057409}, + want: true, + }, + { + + name: "Multiple comparator/logical operators, small value (#30)", + expression: "(foo >= 2887057408 && foo <= 2887122943) || (foo >= 168100864 && foo <= 168118271)", + parameter: map[string]interface{}{"foo": 168100865}, + want: true, + }, + { + + name: "Multiple comparator/logical operators, small value, opposite order (#30)", + expression: "(foo >= 168100864 && foo <= 168118271) || (foo >= 2887057408 && foo <= 2887122943)", + parameter: map[string]interface{}{"foo": 168100865}, + want: true, + }, + { + + name: "Incomparable array equality comparison", + expression: "arr == arr", + parameter: map[string]interface{}{"arr": []int{0, 0, 0}}, + want: true, + }, + { + + name: "Incomparable array not-equality comparison", + expression: "arr != arr", + parameter: map[string]interface{}{"arr": []int{0, 0, 0}}, + want: false, + }, + { + + name: "Mixed function and parameters", + expression: "sum(1.2, amount) + name", + extension: Function("sum", func(arguments ...interface{}) (interface{}, error) { + sum := 0.0 + for _, v := range arguments { + sum += v.(float64) + } + return sum, nil + }, + ), + parameter: map[string]interface{}{"amount": .8, + "name": "awesome", + }, + + want: "2awesome", + }, + { + + name: "Short-circuit OR", + expression: "true || fail()", + extension: Function("fail", func(arguments ...interface{}) (interface{}, error) { + return nil, fmt.Errorf("Did not short-circuit") + }), + want: true, + }, + { + + name: "Short-circuit AND", + expression: "false && fail()", + extension: Function("fail", func(arguments ...interface{}) (interface{}, error) { + return nil, fmt.Errorf("Did not short-circuit") + }), + want: false, + }, + { + + name: "Short-circuit ternary", + expression: "true ? 1 : fail()", + extension: Function("fail", func(arguments ...interface{}) (interface{}, error) { + return nil, fmt.Errorf("Did not short-circuit") + }), + want: 1.0, + }, + { + + name: "Short-circuit coalesce", + expression: `"foo" ?? fail()`, + extension: Function("fail", func(arguments ...interface{}) (interface{}, error) { + return nil, fmt.Errorf("Did not short-circuit") + }), + want: "foo", + }, + { + + name: "Simple parameter call", + expression: "foo.String", + parameter: map[string]interface{}{"foo": foo}, + want: foo.String, + }, + { + + name: "Simple parameter function call", + expression: "foo.Func()", + parameter: map[string]interface{}{"foo": foo}, + want: "funk", + }, + { + + name: "Simple parameter call from pointer", + expression: "fooptr.String", + parameter: map[string]interface{}{"fooptr": &foo}, + want: foo.String, + }, + { + + name: "Simple parameter function call from pointer", + expression: "fooptr.Func()", + parameter: map[string]interface{}{"fooptr": &foo}, + want: "funk", + }, + { + + name: "Simple parameter call", + expression: `foo.String == "hi"`, + parameter: map[string]interface{}{"foo": foo}, + want: false, + }, + { + + name: "Simple parameter call with modifier", + expression: `foo.String + "hi"`, + parameter: map[string]interface{}{"foo": foo}, + want: foo.String + "hi", + }, + { + + name: "Simple parameter function call, two-arg return", + expression: `foo.Func2()`, + parameter: map[string]interface{}{"foo": foo}, + want: "frink", + }, + { + + name: "Simple parameter function call, one arg", + expression: `foo.FuncArgStr("boop")`, + parameter: map[string]interface{}{"foo": foo}, + want: "boop", + }, + { + + name: "Simple parameter function call, one arg", + expression: `foo.FuncArgStr("boop") + "hi"`, + parameter: map[string]interface{}{"foo": foo}, + want: "boophi", + }, + { + + name: "Nested parameter function call", + expression: `foo.Nested.Dunk("boop")`, + parameter: map[string]interface{}{"foo": foo}, + want: "boopdunk", + }, + { + + name: "Nested parameter call", + expression: "foo.Nested.Funk", + parameter: map[string]interface{}{"foo": foo}, + want: "funkalicious", + }, + { + name: "Nested map call", + expression: `foo.Nested.Map["a"]`, + parameter: map[string]interface{}{"foo": foo}, + want: 1, + }, + { + name: "Nested slice call", + expression: `foo.Nested.Slice[1]`, + parameter: map[string]interface{}{"foo": foo}, + want: 2, + }, + { + + name: "Parameter call with + modifier", + expression: "1 + foo.Int", + parameter: map[string]interface{}{"foo": foo}, + want: 102.0, + }, + { + + name: "Parameter string call with + modifier", + expression: `"woop" + (foo.String)`, + parameter: map[string]interface{}{"foo": foo}, + want: "woopstring!", + }, + { + + name: "Parameter call with && operator", + expression: "true && foo.BoolFalse", + parameter: map[string]interface{}{"foo": foo}, + want: false, + }, + { + name: "Null coalesce nested parameter", + expression: "foo.Nil ?? false", + parameter: map[string]interface{}{"foo": foo}, + want: false, + }, + { + name: "input functions", + expression: "func1() + func2()", + parameter: map[string]interface{}{ + "func1": func() float64 { return 2000 }, + "func2": func() float64 { return 2001 }, + }, + want: 4001.0, + }, + { + name: "input functions", + expression: "func1(date1) + func2(date2)", + parameter: map[string]interface{}{ + "date1": func() interface{} { + y2k, _ := time.Parse("2006", "2000") + return y2k + }(), + "date2": func() interface{} { + y2k1, _ := time.Parse("2006", "2001") + return y2k1 + }(), + }, + extension: NewLanguage( + Function("func1", func(arguments ...interface{}) (interface{}, error) { + return float64(arguments[0].(time.Time).Year()), nil + }), + Function("func2", func(arguments ...interface{}) (interface{}, error) { + return float64(arguments[0].(time.Time).Year()), nil + }), + ), + want: 4001.0, + }, + { + name: "complex64 number as parameter", + expression: "complex64", + parameter: map[string]interface{}{ + "complex64": complex64(0), + "complex128": complex128(0), + }, + want: complex64(0), + }, + { + name: "complex128 number as parameter", + expression: "complex128", + parameter: map[string]interface{}{ + "complex64": complex64(0), + "complex128": complex128(0), + }, + want: complex128(0), + }, + { + name: "coalesce with undefined", + expression: "fooz ?? foo", + parameter: map[string]interface{}{ + "foo": "bar", + }, + want: "bar", + }, + { + name: "map[interface{}]interface{}", + expression: "foo", + parameter: map[interface{}]interface{}{ + "foo": "bar", + }, + want: "bar", + }, + { + name: "method on pointer type", + expression: "foo.PointerFunc()", + parameter: map[string]interface{}{ + "foo": &dummyParameter{}, + }, + want: "point", + }, + { + name: "custom selector", + expression: "hello.world", + parameter: "!", + extension: NewLanguage(Base(), VariableSelector(func(path Evaluables) Evaluable { + return func(c context.Context, v interface{}) (interface{}, error) { + keys, err := path.EvalStrings(c, v) + if err != nil { + return nil, err + } + return fmt.Sprintf("%s%s", strings.Join(keys, " "), v), nil + } + })), + want: "hello world!", + }, + { + name: "map[int]int", + expression: `a[0] + a[2]`, + parameter: map[string]interface{}{ + "a": map[int]int{0: 1, 2: 1}, + }, + want: 2., + }, + { + name: "map[int]string", + expression: `a[0] * a[2]`, + parameter: map[string]interface{}{ + "a": map[int]string{0: "1", 2: "1"}, + }, + want: 1., + }, + { + name: "coalesce typed nil 0", + expression: `ProjectID ?? 0`, + parameter: struct { + ProjectID *uint + }{}, + want: 0., + }, + { + name: "coalesce typed nil 99", + expression: `ProjectID ?? 99`, + parameter: struct { + ProjectID *uint + }{}, + want: 99., + }, + { + name: "operator with typed nil 99", + expression: `ProjectID + 99`, + parameter: struct { + ProjectID *uint + }{}, + want: "99", + }, + { + name: "operator with typed nil if", + expression: `Flag ? 1 : 2`, + parameter: struct { + Flag *uint + }{}, + want: 2., + }, + { + name: "Decimal math doesn't experience rounding error", + expression: "(x * 12.146) - y", + extension: decimalArithmetic, + parameter: map[string]interface{}{ + "x": 12.5, + "y": -5, + }, + want: decimal.NewFromFloat(156.825), + equalityFunc: decimalEqualityFunc, + }, + { + name: "Decimal logical operators fractional difference", + expression: "((x * 12.146) - y) > 156.824999999", + extension: decimalArithmetic, + parameter: map[string]interface{}{ + "x": 12.5, + "y": -5, + }, + want: true, + }, + { + name: "Decimal logical operators whole number difference", + expression: "((x * 12.146) - y) > 156", + extension: decimalArithmetic, + parameter: map[string]interface{}{ + "x": 12.5, + "y": -5, + }, + want: true, + }, + { + name: "Decimal logical operators exact decimal match against GT", + expression: "((x * 12.146) - y) > 156.825", + extension: decimalArithmetic, + parameter: map[string]interface{}{ + "x": 12.5, + "y": -5, + }, + want: false, + }, + { + name: "Decimal logical operators exact equality", + expression: "((x * 12.146) - y) == 156.825", + extension: decimalArithmetic, + parameter: map[string]interface{}{ + "x": 12.5, + "y": -5, + }, + want: true, + }, + { + name: "Decimal mixes with string logic with force fail", + expression: `(((x * 12.146) - y) == 156.825) && a == "test" && !b && b`, + extension: decimalArithmetic, + parameter: map[string]interface{}{ + "x": 12.5, + "y": -5, + "a": "test", + "b": false, + }, + want: false, + }, + { + name: "Typed map with function call", + expression: `foo.MapWithFunc.Sum("a")`, + parameter: map[string]interface{}{ + "foo": foo, + }, + want: 3, + }, + { + name: "Types slice with function call", + expression: `foo.SliceWithFunc.Sum("a")`, + parameter: map[string]interface{}{ + "foo": foo, + }, + want: 2, + }, + }, + t, + ) +} diff --git a/gval/gval_parsingFailure_test.go b/gval/gval_parsingFailure_test.go new file mode 100644 index 0000000..3d08488 --- /dev/null +++ b/gval/gval_parsingFailure_test.go @@ -0,0 +1,179 @@ +package gval + +import ( + "regexp/syntax" + "testing" +) + +func TestParsingFailure(t *testing.T) { + testEvaluate( + []evaluationTest{ + { + name: "Invalid equality comparator", + expression: "1 = 1", + wantErr: unexpected(`"="`, "operator"), + }, + { + name: "Invalid equality comparator", + expression: "1 === 1", + wantErr: unexpected(`"="`, "extension"), + }, + { + name: "Too many characters for logical operator", + expression: "true &&& false", + wantErr: unexpected(`"&"`, "extension"), + }, + { + + name: "Too many characters for logical operator", + expression: "true ||| false", + wantErr: unexpected(`"|"`, "extension"), + }, + { + + name: "Premature end to expression, via modifier", + expression: "10 > 5 +", + wantErr: unexpected("EOF", "extensions"), + }, + { + name: "Premature end to expression, via comparator", + expression: "10 + 5 >", + wantErr: unexpected("EOF", "extensions"), + }, + { + name: "Premature end to expression, via logical operator", + expression: "10 > 5 &&", + wantErr: unexpected("EOF", "extensions"), + }, + { + + name: "Premature end to expression, via ternary operator", + expression: "true ?", + wantErr: unexpected("EOF", "extensions"), + }, + { + name: "Hanging REQ", + expression: "`wat` =~", + wantErr: unexpected("EOF", "extensions"), + }, + { + + name: "Invalid operator change to REQ", + expression: " / =~", + wantErr: unexpected(`"/"`, "extensions"), + }, + { + name: "Invalid starting token, comparator", + expression: "> 10", + wantErr: unexpected(`">"`, "extensions"), + }, + { + name: "Invalid starting token, modifier", + expression: "+ 5", + wantErr: unexpected(`"+"`, "extensions"), + }, + { + name: "Invalid starting token, logical operator", + expression: "&& 5 < 10", + wantErr: unexpected(`"&"`, "extensions"), + }, + { + name: "Invalid NUMERIC transition", + expression: "10 10", + wantErr: unexpected(`Int`, "operator"), + }, + { + name: "Invalid STRING transition", + expression: "`foo` `foo`", + wantErr: `String while scanning operator`, // can't use func unexpected because the token was changed from String to RawString in go 1.11 + }, + { + name: "Invalid operator transition", + expression: "10 > < 10", + wantErr: unexpected(`"<"`, "extensions"), + }, + { + + name: "Starting with unbalanced parens", + expression: " ) ( arg2", + wantErr: unexpected(`")"`, "extensions"), + }, + { + + name: "Unclosed bracket", + expression: "[foo bar", + wantErr: unexpected(`EOF`, "extensions"), + }, + { + + name: "Unclosed quote", + expression: "foo == `responseTime", + wantErr: "could not parse string", + }, + { + + name: "Constant regex pattern fail to compile", + expression: "foo =~ `[abc`", + wantErr: string(syntax.ErrMissingBracket), + }, + { + + name: "Constant unmatch regex pattern fail to compile", + expression: "foo !~ `[abc`", + wantErr: string(syntax.ErrMissingBracket), + }, + { + + name: "Unbalanced parentheses", + expression: "10 > (1 + 50", + wantErr: unexpected(`EOF`, "parentheses"), + }, + { + + name: "Multiple radix", + expression: "127.0.0.1", + wantErr: unexpected(`Float`, "operator"), + }, + { + + name: "Hanging accessor", + expression: "foo.Bar.", + wantErr: unexpected(`EOF`, "field"), + }, + { + name: "Incomplete Hex", + expression: "0x", + wantErr: `strconv.ParseFloat: parsing "0x": invalid syntax`, + }, + { + name: "Invalid Hex literal", + expression: "0x > 0", + wantErr: `strconv.ParseFloat: parsing "0x": invalid syntax`, + }, + { + name: "Hex float (Unsupported)", + expression: "0x1.1", + wantErr: `strconv.ParseFloat: parsing "0x1.1": invalid syntax`, + }, + { + name: "Hex invalid letter", + expression: "0x12g1", + wantErr: `strconv.ParseFloat: parsing "0x12": invalid syntax`, + }, + { + name: "Error after camouflage", + expression: "0 + ,", + wantErr: `unexpected "," while scanning extensions`, + }, + }, + t, + ) +} + +func unknownOp(op string) string { + return "unknown operator " + op +} + +func unexpected(token, unit string) string { + return "unexpected " + token + " while scanning " + unit +} diff --git a/gval/gval_test.go b/gval/gval_test.go new file mode 100644 index 0000000..109dc67 --- /dev/null +++ b/gval/gval_test.go @@ -0,0 +1,153 @@ +package gval + +import ( + "fmt" + "reflect" + "strings" + "testing" + + "github.com/shopspring/decimal" +) + +type evaluationTest struct { + name string + expression string + extension Language + parameter interface{} + want interface{} + equalityFunc func(x, y interface{}) bool + wantErr string +} + +func testEvaluate(tests []evaluationTest, t *testing.T) { + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := Evaluate(tt.expression, tt.parameter, tt.extension) + + if tt.wantErr != "" { + if err == nil { + t.Fatalf("Evaluate(%s) expected error but got %v", tt.expression, got) + } + if !strings.Contains(err.Error(), tt.wantErr) { + t.Fatalf("Evaluate(%s) expected error %s but got error %v", tt.expression, tt.wantErr, err) + } + return + } + if err != nil { + t.Errorf("Evaluate() error = %v", err) + return + } + if ef := tt.equalityFunc; ef != nil { + if !ef(got, tt.want) { + t.Errorf("Evaluate(%s) = %v, want %v", tt.expression, got, tt.want) + } + } else if !reflect.DeepEqual(got, tt.want) { + t.Errorf("Evaluate(%s) = %v, want %v", tt.expression, got, tt.want) + } + }) + } +} + +// dummyParameter used to test "parameter calls". +type dummyParameter struct { + String string + Int int + BoolFalse bool + Nil interface{} + Nested dummyNestedParameter + MapWithFunc dummyMapWithFunc + SliceWithFunc dummySliceWithFunc +} + +func (d dummyParameter) Func() string { + return "funk" +} + +func (d dummyParameter) Func2() (string, error) { + return "frink", nil +} + +func (d *dummyParameter) PointerFunc() (string, error) { + return "point", nil +} + +func (d dummyParameter) FuncErr() (string, error) { + return "", fmt.Errorf("fumps") +} + +func (d dummyParameter) FuncArgStr(arg1 string) string { + return arg1 +} + +func (d dummyParameter) AlwaysFail() (interface{}, error) { + return nil, fmt.Errorf("function should always fail") +} + +type dummyNestedParameter struct { + Funk string + Map map[string]int + Slice []int +} + +func (d dummyNestedParameter) Dunk(arg1 string) string { + return arg1 + "dunk" +} + +var foo = dummyParameter{ + String: "string!", + Int: 101, + BoolFalse: false, + Nil: nil, + Nested: dummyNestedParameter{ + Funk: "funkalicious", + Map: map[string]int{"a": 1, "b": 2, "c": 3}, + Slice: []int{1, 2, 3}, + }, + MapWithFunc: dummyMapWithFunc{"a": {1, 2}, "b": {3, 4}}, + SliceWithFunc: dummySliceWithFunc{"a", "b", "c", "a"}, +} + +var fooFailureParameters = map[string]interface{}{ + "foo": foo, + "fooptr": &foo, +} + +var decimalEqualityFunc = func(x, y interface{}) bool { + v1, ok1 := x.(decimal.Decimal) + v2, ok2 := y.(decimal.Decimal) + + if !ok1 || !ok2 { + return false + } + + return v1.Equal(v2) +} + +type dummyMapWithFunc map[string][]int + +func (m dummyMapWithFunc) Sum(key string) int { + values, ok := m[key] + if !ok { + return -1 + } + + sum := 0 + for _, v := range values { + sum += v + } + + return sum +} + +type dummySliceWithFunc []string + +func (m dummySliceWithFunc) Sum(key string) int { + sum := 0 + for _, v := range m { + if v == key { + sum += 1 + } + } + + return sum +} diff --git a/gval/language.go b/gval/language.go new file mode 100644 index 0000000..a17525c --- /dev/null +++ b/gval/language.go @@ -0,0 +1,281 @@ +package gval + +import ( + "context" + "fmt" + "text/scanner" + "unicode" + + "github.com/shopspring/decimal" +) + +// Language is an expression language +type Language struct { + prefixes map[interface{}]extension + operators map[string]operator + operatorSymbols map[rune]struct{} + init extension + def extension + selector func(Evaluables) Evaluable +} + +// NewLanguage returns the union of given Languages as new Language. +func NewLanguage(bases ...Language) Language { + l := newLanguage() + for _, base := range bases { + for i, e := range base.prefixes { + l.prefixes[i] = e + } + for i, e := range base.operators { + l.operators[i] = e.merge(l.operators[i]) + l.operators[i].initiate(i) + } + for i := range base.operatorSymbols { + l.operatorSymbols[i] = struct{}{} + } + if base.init != nil { + l.init = base.init + } + if base.def != nil { + l.def = base.def + } + if base.selector != nil { + l.selector = base.selector + } + } + return l +} + +func newLanguage() Language { + return Language{ + prefixes: map[interface{}]extension{}, + operators: map[string]operator{}, + operatorSymbols: map[rune]struct{}{}, + } +} + +// NewEvaluable returns an Evaluable for given expression in the specified language +func (l Language) NewEvaluable(expression string) (Evaluable, error) { + return l.NewEvaluableWithContext(context.Background(), expression) +} + +// NewEvaluableWithContext returns an Evaluable for given expression in the specified language using context +func (l Language) NewEvaluableWithContext(c context.Context, expression string) (Evaluable, error) { + p := newParser(expression, l) + + eval, err := p.parse(c) + if err == nil && p.isCamouflaged() && p.lastScan != scanner.EOF { + err = p.camouflage + } + if err != nil { + pos := p.scanner.Pos() + return nil, fmt.Errorf("parsing error: %s - %d:%d %w", p.scanner.Position, pos.Line, pos.Column, err) + } + + return eval, nil +} + +// Evaluate given parameter with given expression +func (l Language) Evaluate(expression string, parameter interface{}) (interface{}, error) { + return l.EvaluateWithContext(context.Background(), expression, parameter) +} + +// Evaluate given parameter with given expression using context +func (l Language) EvaluateWithContext(c context.Context, expression string, parameter interface{}) (interface{}, error) { + eval, err := l.NewEvaluableWithContext(c, expression) + if err != nil { + return nil, err + } + v, err := eval(c, parameter) + if err != nil { + return nil, fmt.Errorf("can not evaluate %s: %w", expression, err) + } + return v, nil +} + +// Function returns a Language with given function. +// Function has no conversion for input types. +// +// If the function returns an error it must be the last return parameter. +// +// If the function has (without the error) more then one return parameter, +// it returns them as []interface{}. +func Function(name string, function interface{}) Language { + l := newLanguage() + l.prefixes[name] = func(c context.Context, p *Parser) (eval Evaluable, err error) { + args := []Evaluable{} + scan := p.Scan() + switch scan { + case '(': + args, err = p.parseArguments(c) + if err != nil { + return nil, err + } + default: + p.Camouflage("function call", '(') + } + return p.callFunc(toFunc(function), args...), nil + } + return l +} + +// Constant returns a Language with given constant +func Constant(name string, value interface{}) Language { + l := newLanguage() + l.prefixes[l.makePrefixKey(name)] = func(c context.Context, p *Parser) (eval Evaluable, err error) { + return p.Const(value), nil + } + return l +} + +// PrefixExtension extends a Language +func PrefixExtension(r rune, ext func(context.Context, *Parser) (Evaluable, error)) Language { + l := newLanguage() + l.prefixes[r] = ext + return l +} + +// Init is a language that does no parsing, but invokes the given function when +// parsing starts. It is incumbent upon the function to call ParseExpression to +// continue parsing. +// +// This function can be used to customize the parser settings, such as +// whitespace or ident behavior. +func Init(ext func(context.Context, *Parser) (Evaluable, error)) Language { + l := newLanguage() + l.init = ext + return l +} + +// DefaultExtension is a language that runs the given function if no other +// prefix matches. +func DefaultExtension(ext func(context.Context, *Parser) (Evaluable, error)) Language { + l := newLanguage() + l.def = ext + return l +} + +// PrefixMetaPrefix chooses a Prefix to be executed +func PrefixMetaPrefix(r rune, ext func(context.Context, *Parser) (call string, alternative func() (Evaluable, error), err error)) Language { + l := newLanguage() + l.prefixes[r] = func(c context.Context, p *Parser) (Evaluable, error) { + call, alternative, err := ext(c, p) + if err != nil { + return nil, err + } + if prefix, ok := p.prefixes[l.makePrefixKey(call)]; ok { + return prefix(c, p) + } + return alternative() + } + return l +} + +// PrefixOperator returns a Language with given prefix +func PrefixOperator(name string, e Evaluable) Language { + l := newLanguage() + l.prefixes[l.makePrefixKey(name)] = func(c context.Context, p *Parser) (Evaluable, error) { + eval, err := p.ParseNextExpression(c) + if err != nil { + return nil, err + } + prefix := func(c context.Context, v interface{}) (interface{}, error) { + a, err := eval(c, v) + if err != nil { + return nil, err + } + return e(c, a) + } + if eval.IsConst() { + v, err := prefix(c, nil) + if err != nil { + return nil, err + } + prefix = p.Const(v) + } + return prefix, nil + } + return l +} + +// PostfixOperator extends a Language. +func PostfixOperator(name string, ext func(context.Context, *Parser, Evaluable) (Evaluable, error)) Language { + l := newLanguage() + l.operators[l.makeInfixKey(name)] = postfix{ + f: func(c context.Context, p *Parser, eval Evaluable, pre operatorPrecedence) (Evaluable, error) { + return ext(c, p, eval) + }, + } + return l +} + +// InfixOperator for two arbitrary values. +func InfixOperator(name string, f func(a, b interface{}) (interface{}, error)) Language { + return newLanguageOperator(name, &infix{arbitrary: f}) +} + +// InfixShortCircuit operator is called after the left operand is evaluated. +func InfixShortCircuit(name string, f func(a interface{}) (interface{}, bool)) Language { + return newLanguageOperator(name, &infix{shortCircuit: f}) +} + +// InfixTextOperator for two text values. +func InfixTextOperator(name string, f func(a, b string) (interface{}, error)) Language { + return newLanguageOperator(name, &infix{text: f}) +} + +// InfixNumberOperator for two number values. +func InfixNumberOperator(name string, f func(a, b uint) (interface{}, error)) Language { + return newLanguageOperator(name, &infix{number: f}) +} + +// InfixDecimalOperator for two decimal values. +func InfixDecimalOperator(name string, f func(a, b decimal.Decimal) (interface{}, error)) Language { + return newLanguageOperator(name, &infix{decimal: f}) +} + +// InfixBoolOperator for two bool values. +func InfixBoolOperator(name string, f func(a, b bool) (interface{}, error)) Language { + return newLanguageOperator(name, &infix{boolean: f}) +} + +// Precedence of operator. The Operator with higher operatorPrecedence is evaluated first. +func Precedence(name string, operatorPrecendence uint8) Language { + return newLanguageOperator(name, operatorPrecedence(operatorPrecendence)) +} + +// InfixEvalOperator operates on the raw operands. +// Therefore it cannot be combined with operators for other operand types. +func InfixEvalOperator(name string, f func(a, b Evaluable) (Evaluable, error)) Language { + return newLanguageOperator(name, directInfix{infixBuilder: f}) +} + +func newLanguageOperator(name string, op operator) Language { + op.initiate(name) + l := newLanguage() + l.operators[l.makeInfixKey(name)] = op + return l +} + +func (l *Language) makePrefixKey(key string) interface{} { + runes := []rune(key) + if len(runes) == 1 && !unicode.IsLetter(runes[0]) { + return runes[0] + } + return key +} + +func (l *Language) makeInfixKey(key string) string { + for _, r := range key { + l.operatorSymbols[r] = struct{}{} + } + return key +} + +// VariableSelector returns a Language which uses given variable selector. +// It must be combined with a Language that uses the vatiable selector. E.g. gval.Base(). +func VariableSelector(selector func(path Evaluables) Evaluable) Language { + l := newLanguage() + l.selector = selector + return l +} diff --git a/gval/operator.go b/gval/operator.go new file mode 100644 index 0000000..da0285c --- /dev/null +++ b/gval/operator.go @@ -0,0 +1,403 @@ +package gval + +import ( + "context" + "fmt" + "reflect" + "strconv" + "strings" + + "github.com/shopspring/decimal" +) + +type stage struct { + Evaluable + infixBuilder + operatorPrecedence +} + +type stageStack []stage //operatorPrecedence in stacktStage is continuously, monotone ascending + +func (s *stageStack) push(b stage) error { + for len(*s) > 0 && s.peek().operatorPrecedence >= b.operatorPrecedence { + a := s.pop() + eval, err := a.infixBuilder(a.Evaluable, b.Evaluable) + if err != nil { + return err + } + if a.IsConst() && b.IsConst() { + v, err := eval(nil, nil) + if err != nil { + return err + } + b.Evaluable = constant(v) + continue + } + b.Evaluable = eval + } + *s = append(*s, b) + return nil +} + +func (s *stageStack) peek() stage { + return (*s)[len(*s)-1] +} + +func (s *stageStack) pop() stage { + a := s.peek() + (*s) = (*s)[:len(*s)-1] + return a +} + +type infixBuilder func(a, b Evaluable) (Evaluable, error) + +func (l Language) isSymbolOperation(r rune) bool { + _, in := l.operatorSymbols[r] + return in +} + +func (l Language) isOperatorPrefix(op string) bool { + for k := range l.operators { + if strings.HasPrefix(k, op) { + return true + } + } + return false +} + +func (op *infix) initiate(name string) { + f := func(a, b interface{}) (interface{}, error) { + return nil, fmt.Errorf("invalid operation (%T) %s (%T)", a, name, b) + } + if op.arbitrary != nil { + f = op.arbitrary + } + for _, typeConvertion := range []bool{true, false} { + if op.text != nil && (!typeConvertion || op.arbitrary == nil) { + f = getStringOpFunc(op.text, f, typeConvertion) + } + if op.boolean != nil { + f = getBoolOpFunc(op.boolean, f, typeConvertion) + } + if op.number != nil { + f = getUintOpFunc(op.number, f, typeConvertion) + } + if op.decimal != nil { + f = getDecimalOpFunc(op.decimal, f, typeConvertion) + } + } + if op.shortCircuit == nil { + op.builder = func(a, b Evaluable) (Evaluable, error) { + return func(c context.Context, x interface{}) (interface{}, error) { + a, err := a(c, x) + if err != nil { + return nil, err + } + b, err := b(c, x) + if err != nil { + return nil, err + } + return f(a, b) + }, nil + } + return + } + shortF := op.shortCircuit + op.builder = func(a, b Evaluable) (Evaluable, error) { + return func(c context.Context, x interface{}) (interface{}, error) { + a, err := a(c, x) + if err != nil { + return nil, err + } + if r, ok := shortF(a); ok { + return r, nil + } + b, err := b(c, x) + if err != nil { + return nil, err + } + return f(a, b) + }, nil + } +} + +type opFunc func(a, b interface{}) (interface{}, error) + +func getStringOpFunc(s func(a, b string) (interface{}, error), f opFunc, typeConversion bool) opFunc { + if typeConversion { + return func(a, b interface{}) (interface{}, error) { + if a != nil && b != nil { + return s(fmt.Sprintf("%v", a), fmt.Sprintf("%v", b)) + } + return f(a, b) + } + } + return func(a, b interface{}) (interface{}, error) { + s1, k := a.(string) + s2, l := b.(string) + if k && l { + return s(s1, s2) + } + return f(a, b) + } +} +func convertToBool(o interface{}) (bool, bool) { + if b, ok := o.(bool); ok { + return b, true + } + v := reflect.ValueOf(o) + + if v.Kind() == reflect.Func { + if vt := v.Type(); vt.NumIn() == 0 && vt.NumOut() == 1 { + retType := vt.Out(0) + + if retType.Kind() == reflect.Bool { + funcResults := v.Call([]reflect.Value{}) + v = funcResults[0] + o = v.Interface() + } + } + } + + for o != nil && v.Kind() == reflect.Ptr { + v = v.Elem() + if !v.IsValid() { + return false, false + } + o = v.Interface() + } + + if o == false || o == nil || o == "false" || o == "FALSE" { + return false, true + } + if o == true || o == "true" || o == "TRUE" { + return true, true + } + if f, ok := convertToUint(o); ok { + return f != 0., true + } + return false, false +} +func getBoolOpFunc(o func(a, b bool) (interface{}, error), f opFunc, typeConversion bool) opFunc { + if typeConversion { + return func(a, b interface{}) (interface{}, error) { + x, k := convertToBool(a) + y, l := convertToBool(b) + if k && l { + return o(x, y) + } + return f(a, b) + } + } + return func(a, b interface{}) (interface{}, error) { + x, k := a.(bool) + y, l := b.(bool) + if k && l { + return o(x, y) + } + return f(a, b) + } +} +func convertToUint(o interface{}) (uint, bool) { + if i, ok := o.(uint); ok { + return i, true + } + + v := reflect.ValueOf(o) + for o != nil && v.Kind() == reflect.Ptr { + v = v.Elem() + if !v.IsValid() { + return 0, false + } + o = v.Interface() + } + switch v.Kind() { + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + return uint(v.Int()), true + case reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: + return uint(v.Uint()), true + } + if s, ok := o.(string); ok { + u, err := strconv.ParseUint(s, 0, 32) + if err == nil { + return uint(u), true + } + } + return 0, false +} +func getUintOpFunc(o func(a, b uint) (interface{}, error), f opFunc, typeConversion bool) opFunc { + if typeConversion { + return func(a, b interface{}) (interface{}, error) { + x, k := convertToUint(a) + y, l := convertToUint(b) + if k && l { + return o(x, y) + } + + return f(a, b) + } + } + return func(a, b interface{}) (interface{}, error) { + x, k := a.(uint) + y, l := b.(uint) + if k && l { + return o(x, y) + } + return f(a, b) + } +} +func convertToDecimal(o interface{}) (decimal.Decimal, bool) { + if i, ok := o.(decimal.Decimal); ok { + return i, true + } + if i, ok := o.(float64); ok { + return decimal.NewFromFloat(i), true + } + v := reflect.ValueOf(o) + for o != nil && v.Kind() == reflect.Ptr { + v = v.Elem() + if !v.IsValid() { + return decimal.Zero, false + } + o = v.Interface() + } + switch v.Kind() { + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + return decimal.NewFromInt(v.Int()), true + case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: + return decimal.NewFromFloat(float64(v.Uint())), true + case reflect.Float32, reflect.Float64: + return decimal.NewFromFloat(v.Float()), true + } + if s, ok := o.(string); ok { + f, err := strconv.ParseFloat(s, 64) + if err == nil { + return decimal.NewFromFloat(f), true + } + } + return decimal.Zero, false +} +func getDecimalOpFunc(o func(a, b decimal.Decimal) (interface{}, error), f opFunc, typeConversion bool) opFunc { + if typeConversion { + return func(a, b interface{}) (interface{}, error) { + x, k := convertToDecimal(a) + y, l := convertToDecimal(b) + if k && l { + return o(x, y) + } + + return f(a, b) + } + } + return func(a, b interface{}) (interface{}, error) { + x, k := a.(decimal.Decimal) + y, l := b.(decimal.Decimal) + if k && l { + return o(x, y) + } + + return f(a, b) + } +} + +type operator interface { + merge(operator) operator + precedence() operatorPrecedence + initiate(name string) +} + +type operatorPrecedence uint8 + +func (pre operatorPrecedence) merge(op operator) operator { + if op, ok := op.(operatorPrecedence); ok { + if op > pre { + return op + } + return pre + } + if op == nil { + return pre + } + return op.merge(pre) +} + +func (pre operatorPrecedence) precedence() operatorPrecedence { + return pre +} + +func (pre operatorPrecedence) initiate(name string) {} + +type infix struct { + operatorPrecedence + number func(a, b uint) (interface{}, error) + decimal func(a, b decimal.Decimal) (interface{}, error) + boolean func(a, b bool) (interface{}, error) + text func(a, b string) (interface{}, error) + arbitrary func(a, b interface{}) (interface{}, error) + shortCircuit func(a interface{}) (interface{}, bool) + builder infixBuilder +} + +func (op infix) merge(op2 operator) operator { + switch op2 := op2.(type) { + case *infix: + if op.number == nil { + op.number = op2.number + } + if op.decimal == nil { + op.decimal = op2.decimal + } + if op.boolean == nil { + op.boolean = op2.boolean + } + if op.text == nil { + op.text = op2.text + } + if op.arbitrary == nil { + op.arbitrary = op2.arbitrary + } + if op.shortCircuit == nil { + op.shortCircuit = op2.shortCircuit + } + } + if op2 != nil && op2.precedence() > op.operatorPrecedence { + op.operatorPrecedence = op2.precedence() + } + return &op +} + +type directInfix struct { + operatorPrecedence + infixBuilder +} + +func (op directInfix) merge(op2 operator) operator { + switch op2 := op2.(type) { + case operatorPrecedence: + op.operatorPrecedence = op2 + } + if op2 != nil && op2.precedence() > op.operatorPrecedence { + op.operatorPrecedence = op2.precedence() + } + return op +} + +type extension func(context.Context, *Parser) (Evaluable, error) + +type postfix struct { + operatorPrecedence + f func(context.Context, *Parser, Evaluable, operatorPrecedence) (Evaluable, error) +} + +func (op postfix) merge(op2 operator) operator { + switch op2 := op2.(type) { + case postfix: + if op2.f != nil { + op.f = op2.f + } + } + if op2 != nil && op2.precedence() > op.operatorPrecedence { + op.operatorPrecedence = op2.precedence() + } + return op +} diff --git a/gval/operator_test.go b/gval/operator_test.go new file mode 100644 index 0000000..25d1be1 --- /dev/null +++ b/gval/operator_test.go @@ -0,0 +1,166 @@ +package gval + +import ( + "context" + "fmt" + "reflect" + "testing" +) + +func Test_Infix(t *testing.T) { + type subTest struct { + name string + a interface{} + b interface{} + wantRet interface{} + } + tests := []struct { + name string + infix + subTests []subTest + }{ + { + "number operator", + infix{ + number: func(a, b float64) (interface{}, error) { return a * b, nil }, + }, + []subTest{ + {"float64 arguments", 7., 3., 21.}, + {"int arguments", 7, 3, 21.}, + {"string arguments", "7", "3.", 21.}, + }, + }, + { + "number and string operator", + infix{ + number: func(a, b float64) (interface{}, error) { return a + b, nil }, + text: func(a, b string) (interface{}, error) { return fmt.Sprintf("%v%v", a, b), nil }, + }, + + []subTest{ + {"float64 arguments", 7., 3., 10.}, + {"int arguments", 7, 3, 10.}, + {"number string arguments", "7", "3.", "73."}, + {"string arguments", "hello ", "world", "hello world"}, + }, + }, + { + "bool operator", + infix{ + shortCircuit: func(a interface{}) (interface{}, bool) { return false, a == false }, + boolean: func(a, b bool) (interface{}, error) { return a && b, nil }, + }, + + []subTest{ + {"bool arguments", false, true, false}, + {"number arguments", 0, true, false}, + {"lower string arguments", "false", "true", false}, + {"upper string arguments", "TRUE", "FALSE", false}, + {"shortCircuit", false, "not a boolean", false}, + }, + }, + { + "bool, number, text and interface operator", + infix{ + number: func(a, b float64) (interface{}, error) { return a == b, nil }, + boolean: func(a, b bool) (interface{}, error) { return a == b, nil }, + text: func(a, b string) (interface{}, error) { return a == b, nil }, + arbitrary: func(a, b interface{}) (interface{}, error) { return a == b, nil }, + }, + + []subTest{ + {"number string and int arguments", "7", 7, true}, + {"bool string and bool arguments", "true", true, true}, + {"string arguments", "hello", "hello", true}, + {"upper string arguments", "TRUE", "FALSE", false}, + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + tt.infix.initiate("<" + tt.name + ">") + builder := tt.infix.builder + for _, tt := range tt.subTests { + t.Run(tt.name, func(t *testing.T) { + eval, err := builder(constant(tt.a), constant(tt.b)) + if err != nil { + t.Fatal(err) + } + + got, err := eval(context.Background(), nil) + if err != nil { + t.Fatal(err) + } + + if !reflect.DeepEqual(got, tt.wantRet) { + t.Fatalf("binaryOperator() eval() = %v, want %v", got, tt.wantRet) + } + }) + } + }) + } +} + +func Test_stageStack_push(t *testing.T) { + p := (*Parser)(nil) + tests := []struct { + name string + pres []operatorPrecedence + expect string + }{ + { + "flat", + []operatorPrecedence{1, 1, 1, 1}, + "((((AB)C)D)E)", + }, + { + "asc", + []operatorPrecedence{1, 2, 3, 4}, + "(A(B(C(DE))))", + }, + { + "desc", + []operatorPrecedence{4, 3, 2, 1}, + "((((AB)C)D)E)", + }, + { + "mixed", + []operatorPrecedence{1, 2, 1, 1}, + "(((A(BC))D)E)", + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + X := int('A') + + op := func(a, b Evaluable) (Evaluable, error) { + return func(c context.Context, o interface{}) (interface{}, error) { + aa, _ := a.EvalString(c, nil) + bb, _ := b.EvalString(c, nil) + s := "(" + aa + bb + ")" + return s, nil + }, nil + } + stack := stageStack{} + for _, pre := range tt.pres { + if err := stack.push(stage{p.Const(string(rune(X))), op, pre}); err != nil { + t.Fatal(err) + } + X++ + } + + if err := stack.push(stage{p.Const(string(rune(X))), nil, 0}); err != nil { + t.Fatal(err) + } + + if len(stack) != 1 { + t.Fatalf("stack must hold exactly one element") + } + + got, _ := stack[0].EvalString(context.Background(), nil) + if got != tt.expect { + t.Fatalf("got %s but expected %s", got, tt.expect) + } + }) + } +} diff --git a/gval/parse.go b/gval/parse.go new file mode 100644 index 0000000..51bd09f --- /dev/null +++ b/gval/parse.go @@ -0,0 +1,347 @@ +package gval + +import ( + "context" + "fmt" + "reflect" + "strconv" + "text/scanner" + + "github.com/shopspring/decimal" +) + +// ParseExpression scans an expression into an Evaluable. +func (p *Parser) ParseExpression(c context.Context) (eval Evaluable, err error) { + stack := stageStack{} + for { + eval, err = p.ParseNextExpression(c) + if err != nil { + return nil, err + } + + if stage, err := p.parseOperator(c, &stack, eval); err != nil { + return nil, err + } else if err = stack.push(stage); err != nil { + return nil, err + } + + if stack.peek().infixBuilder == nil { + return stack.pop().Evaluable, nil + } + } +} + +// ParseNextExpression scans the expression ignoring following operators +func (p *Parser) ParseNextExpression(c context.Context) (eval Evaluable, err error) { + scan := p.Scan() + ex, ok := p.prefixes[scan] + if !ok { + if scan != scanner.EOF && p.def != nil { + return p.def(c, p) + } + return nil, p.Expected("extensions") + } + return ex(c, p) +} + +// ParseSublanguage sets the next language for this parser to parse and calls +// its initialization function, usually ParseExpression. +func (p *Parser) ParseSublanguage(c context.Context, l Language) (Evaluable, error) { + if p.isCamouflaged() { + panic("can not ParseSublanguage() on camouflaged Parser") + } + curLang := p.Language + curWhitespace := p.scanner.Whitespace + curMode := p.scanner.Mode + curIsIdentRune := p.scanner.IsIdentRune + + p.Language = l + p.resetScannerProperties() + + defer func() { + p.Language = curLang + p.scanner.Whitespace = curWhitespace + p.scanner.Mode = curMode + p.scanner.IsIdentRune = curIsIdentRune + }() + + return p.parse(c) +} + +func (p *Parser) parse(c context.Context) (Evaluable, error) { + if p.init != nil { + return p.init(c, p) + } + + return p.ParseExpression(c) +} + +func parseString(c context.Context, p *Parser) (Evaluable, error) { + s, err := strconv.Unquote(p.TokenText()) + if err != nil { + return nil, fmt.Errorf("could not parse string: %w", err) + } + return p.Const(s), nil +} + +func parseNumber(c context.Context, p *Parser) (Evaluable, error) { + n, err := strconv.ParseUint(p.TokenText(), 0, 32) + if err != nil { + return nil, err + } + return p.Const(n), nil +} + +func parseDecimal(c context.Context, p *Parser) (Evaluable, error) { + n, err := strconv.ParseFloat(p.TokenText(), 64) + if err != nil { + return nil, err + } + return p.Const(decimal.NewFromFloat(n)), nil +} + +func parseParentheses(c context.Context, p *Parser) (Evaluable, error) { + eval, err := p.ParseExpression(c) + if err != nil { + return nil, err + } + switch p.Scan() { + case ')': + return eval, nil + default: + return nil, p.Expected("parentheses", ')') + } +} + +func (p *Parser) parseOperator(c context.Context, stack *stageStack, eval Evaluable) (st stage, err error) { + for { + scan := p.Scan() + op := p.TokenText() + mustOp := false + if p.isSymbolOperation(scan) { + scan = p.Peek() + for p.isSymbolOperation(scan) && p.isOperatorPrefix(op+string(scan)) { + mustOp = true + op += string(scan) + p.Next() + scan = p.Peek() + } + } else if scan != scanner.Ident { + p.Camouflage("operator") + return stage{Evaluable: eval}, nil + } + switch operator := p.operators[op].(type) { + case *infix: + return stage{ + Evaluable: eval, + infixBuilder: operator.builder, + operatorPrecedence: operator.operatorPrecedence, + }, nil + case directInfix: + return stage{ + Evaluable: eval, + infixBuilder: operator.infixBuilder, + operatorPrecedence: operator.operatorPrecedence, + }, nil + case postfix: + if err = stack.push(stage{ + operatorPrecedence: operator.operatorPrecedence, + Evaluable: eval, + }); err != nil { + return stage{}, err + } + eval, err = operator.f(c, p, stack.pop().Evaluable, operator.operatorPrecedence) + if err != nil { + return + } + continue + } + + if !mustOp { + p.Camouflage("operator") + return stage{Evaluable: eval}, nil + } + return stage{}, fmt.Errorf("unknown operator %s", op) + } +} + +func parseIdent(c context.Context, p *Parser) (call string, alternative func() (Evaluable, error), err error) { + token := p.TokenText() + return token, + func() (Evaluable, error) { + fullname := token + + keys := []Evaluable{p.Const(token)} + for { + scan := p.Scan() + switch scan { + case '.': + scan = p.Scan() + switch scan { + case scanner.Ident: + token = p.TokenText() + keys = append(keys, p.Const(token)) + default: + return nil, p.Expected("field", scanner.Ident) + } + case '(': + args, err := p.parseArguments(c) + if err != nil { + return nil, err + } + return p.callEvaluable(fullname, p.Var(keys...), args...), nil + case '[': + key, err := p.ParseExpression(c) + if err != nil { + return nil, err + } + switch p.Scan() { + case ']': + keys = append(keys, key) + default: + return nil, p.Expected("array key", ']') + } + default: + p.Camouflage("variable", '.', '(', '[') + return p.Var(keys...), nil + } + } + }, nil + +} + +func (p *Parser) parseArguments(c context.Context) (args []Evaluable, err error) { + if p.Scan() == ')' { + return + } + p.Camouflage("scan arguments", ')') + for { + arg, err := p.ParseExpression(c) + args = append(args, arg) + if err != nil { + return nil, err + } + switch p.Scan() { + case ')': + return args, nil + case ',': + default: + return nil, p.Expected("arguments", ')', ',') + } + } +} + +func inArray(a, b interface{}) (interface{}, error) { + col, ok := b.([]interface{}) + if !ok { + return nil, fmt.Errorf("expected type []interface{} for in operator but got %T", b) + } + for _, value := range col { + if reflect.DeepEqual(a, value) { + return true, nil + } + } + return false, nil +} + +func parseIf(c context.Context, p *Parser, e Evaluable) (Evaluable, error) { + a, err := p.ParseExpression(c) + if err != nil { + return nil, err + } + b := p.Const(nil) + switch p.Scan() { + case ':': + b, err = p.ParseExpression(c) + if err != nil { + return nil, err + } + case scanner.EOF: + default: + return nil, p.Expected("<> ? <> : <>", ':', scanner.EOF) + } + return func(c context.Context, v interface{}) (interface{}, error) { + x, err := e(c, v) + if err != nil { + return nil, err + } + if valX := reflect.ValueOf(x); x == nil || valX.IsZero() { + return b(c, v) + } + return a(c, v) + }, nil +} + +func parseJSONArray(c context.Context, p *Parser) (Evaluable, error) { + evals := []Evaluable{} + for { + switch p.Scan() { + default: + p.Camouflage("array", ',', ']') + eval, err := p.ParseExpression(c) + if err != nil { + return nil, err + } + evals = append(evals, eval) + case ',': + case ']': + return func(c context.Context, v interface{}) (interface{}, error) { + vs := make([]interface{}, len(evals)) + for i, e := range evals { + eval, err := e(c, v) + if err != nil { + return nil, err + } + vs[i] = eval + } + + return vs, nil + }, nil + } + } +} + +func parseJSONObject(c context.Context, p *Parser) (Evaluable, error) { + type kv struct { + key Evaluable + value Evaluable + } + evals := []kv{} + for { + switch p.Scan() { + default: + p.Camouflage("object", ',', '}') + key, err := p.ParseExpression(c) + if err != nil { + return nil, err + } + if p.Scan() != ':' { + if err != nil { + return nil, p.Expected("object", ':') + } + } + value, err := p.ParseExpression(c) + if err != nil { + return nil, err + } + evals = append(evals, kv{key, value}) + case ',': + case '}': + return func(c context.Context, v interface{}) (interface{}, error) { + vs := map[string]interface{}{} + for _, e := range evals { + value, err := e.value(c, v) + if err != nil { + return nil, err + } + key, err := e.key.EvalString(c, v) + if err != nil { + return nil, err + } + vs[key] = value + } + return vs, nil + }, nil + } + } +} diff --git a/gval/parser.go b/gval/parser.go new file mode 100644 index 0000000..19cbabe --- /dev/null +++ b/gval/parser.go @@ -0,0 +1,147 @@ +package gval + +import ( + "bytes" + "fmt" + "strings" + "text/scanner" + "unicode" +) + +// Parser parses expressions in a Language into an Evaluable +type Parser struct { + scanner scanner.Scanner + Language + lastScan rune + camouflage error +} + +func newParser(expression string, l Language) *Parser { + sc := scanner.Scanner{} + sc.Init(strings.NewReader(expression)) + sc.Error = func(*scanner.Scanner, string) {} + sc.Filename = expression + "\t" + p := &Parser{scanner: sc, Language: l} + p.resetScannerProperties() + return p +} + +func (p *Parser) resetScannerProperties() { + p.scanner.Whitespace = scanner.GoWhitespace + p.scanner.Mode = scanner.GoTokens + p.scanner.IsIdentRune = func(r rune, pos int) bool { + return unicode.IsLetter(r) || r == '_' || (pos > 0 && unicode.IsDigit(r)) + } +} + +// SetWhitespace sets the behavior of the whitespace matcher. The given +// characters must be less than or equal to 0x20 (' '). +func (p *Parser) SetWhitespace(chars ...rune) { + var mask uint64 + for _, char := range chars { + mask |= 1 << uint(char) + } + + p.scanner.Whitespace = mask +} + +// SetMode sets the tokens that the underlying scanner will match. +func (p *Parser) SetMode(mode uint) { + p.scanner.Mode = mode +} + +// SetIsIdentRuneFunc sets the function that matches ident characters in the +// underlying scanner. +func (p *Parser) SetIsIdentRuneFunc(fn func(ch rune, i int) bool) { + p.scanner.IsIdentRune = fn +} + +// Scan reads the next token or Unicode character from source and returns it. +// It only recognizes tokens t for which the respective Mode bit (1<<-t) is set. +// It returns scanner.EOF at the end of the source. +func (p *Parser) Scan() rune { + if p.isCamouflaged() { + p.camouflage = nil + return p.lastScan + } + p.camouflage = nil + p.lastScan = p.scanner.Scan() + return p.lastScan +} + +func (p *Parser) isCamouflaged() bool { + return p.camouflage != nil && p.camouflage != errCamouflageAfterNext +} + +// Camouflage rewind the last Scan(). The Parser holds the camouflage error until +// the next Scan() +// Do not call Rewind() on a camouflaged Parser +func (p *Parser) Camouflage(unit string, expected ...rune) { + if p.isCamouflaged() { + panic(fmt.Errorf("can only Camouflage() after Scan(): %w", p.camouflage)) + } + p.camouflage = p.Expected(unit, expected...) +} + +// Peek returns the next Unicode character in the source without advancing +// the scanner. It returns EOF if the scanner's position is at the last +// character of the source. +// Do not call Peek() on a camouflaged Parser +func (p *Parser) Peek() rune { + if p.isCamouflaged() { + panic("can not Peek() on camouflaged Parser") + } + return p.scanner.Peek() +} + +var errCamouflageAfterNext = fmt.Errorf("Camouflage() after Next()") + +// Next reads and returns the next Unicode character. +// It returns EOF at the end of the source. +// Do not call Next() on a camouflaged Parser +func (p *Parser) Next() rune { + if p.isCamouflaged() { + panic("can not Next() on camouflaged Parser") + } + p.camouflage = errCamouflageAfterNext + return p.scanner.Next() +} + +// TokenText returns the string corresponding to the most recently scanned token. +// Valid after calling Scan(). +func (p *Parser) TokenText() string { + return p.scanner.TokenText() +} + +// Expected returns an error signaling an unexpected Scan() result +func (p *Parser) Expected(unit string, expected ...rune) error { + return unexpectedRune{unit, expected, p.lastScan} +} + +type unexpectedRune struct { + unit string + expected []rune + got rune +} + +func (err unexpectedRune) Error() string { + exp := bytes.Buffer{} + runes := err.expected + switch len(runes) { + default: + for _, r := range runes[:len(runes)-2] { + exp.WriteString(scanner.TokenString(r)) + exp.WriteString(", ") + } + fallthrough + case 2: + exp.WriteString(scanner.TokenString(runes[len(runes)-2])) + exp.WriteString(" or ") + fallthrough + case 1: + exp.WriteString(scanner.TokenString(runes[len(runes)-1])) + case 0: + return fmt.Sprintf("unexpected %s while scanning %s", scanner.TokenString(err.got), err.unit) + } + return fmt.Sprintf("unexpected %s while scanning %s expected %s", scanner.TokenString(err.got), err.unit, exp.String()) +} diff --git a/gval/parser_test.go b/gval/parser_test.go new file mode 100644 index 0000000..30b99bd --- /dev/null +++ b/gval/parser_test.go @@ -0,0 +1,148 @@ +package gval + +import ( + "testing" + "text/scanner" + "unicode" +) + +func TestParser_Scan(t *testing.T) { + tests := []struct { + name string + input string + Language + do func(p *Parser) + wantScan rune + wantToken string + wantPanic bool + }{ + { + name: "camouflage", + input: "$abc", + do: func(p *Parser) { + p.Scan() + p.Camouflage("test") + }, + wantScan: '$', + wantToken: "$", + }, + { + name: "camouflage with next", + input: "$abc", + do: func(p *Parser) { + p.Scan() + p.Camouflage("test") + p.Next() + }, + wantPanic: true, + }, + { + name: "camouflage scan camouflage", + input: "$abc", + do: func(p *Parser) { + p.Scan() + p.Camouflage("test") + p.Scan() + p.Camouflage("test2") + }, + wantScan: '$', + wantToken: "$", + }, + { + name: "camouflage with peek", + input: "$abc", + do: func(p *Parser) { + p.Scan() + p.Camouflage("test") + p.Peek() + }, + wantPanic: true, + }, + { + name: "next and peek", + input: "$#abc", + do: func(p *Parser) { + p.Scan() + p.Next() + p.Peek() + }, + wantScan: scanner.Ident, + wantToken: "abc", + }, + { + name: "scan token camouflage token", + input: "abc", + do: func(p *Parser) { + p.Scan() + p.TokenText() + p.Camouflage("test") + }, + wantScan: scanner.Ident, + wantToken: "abc", + }, + { + name: "scan token peek camouflage token", + input: "abc", + do: func(p *Parser) { + p.Scan() + p.TokenText() + p.Peek() + p.Camouflage("test") + }, + wantScan: scanner.Ident, + wantToken: "abc", + }, + { + name: "tokenize all whitespace", + input: "foo\tbar\nbaz", + do: func(p *Parser) { + p.SetWhitespace() + p.Scan() + }, + wantScan: '\t', + wantToken: "\t", + }, + { + name: "custom ident", + input: "$#foo", + do: func(p *Parser) { + p.SetIsIdentRuneFunc(func(ch rune, i int) bool { return unicode.IsLetter(ch) || ch == '#' }) + p.Scan() + }, + wantScan: scanner.Ident, + wantToken: "#foo", + }, + { + name: "do not scan idents", + input: "abc", + do: func(p *Parser) { + p.SetMode(scanner.GoTokens ^ scanner.ScanIdents) + p.Scan() + }, + wantScan: 'b', + wantToken: "b", + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + defer func() { + err := recover() + if err != nil && !tt.wantPanic { + t.Fatalf("unexpected panic: %v", err) + } + }() + + p := newParser(tt.input, tt.Language) + tt.do(p) + if tt.wantPanic { + return + } + scan := p.Scan() + token := p.TokenText() + + if scan != tt.wantScan || token != tt.wantToken { + t.Errorf("Parser.Scan() = %v (%v), want %v (%v)", scan, token, tt.wantScan, tt.wantToken) + } + }) + } +} diff --git a/main.go b/main.go index 3b0e727..1ae364a 100644 --- a/main.go +++ b/main.go @@ -3,19 +3,18 @@ package main import ( _ "embed" "fmt" - "image/color" "okemu/config" - "okemu/debuger" + "okemu/debug" + "okemu/debug/listener" "okemu/logger" "okemu/okean240" + "okemu/okean240/fdc" + "okemu/z80/dis" "sync/atomic" "time" "fyne.io/fyne/v2" - "fyne.io/fyne/v2/app" "fyne.io/fyne/v2/canvas" - "fyne.io/fyne/v2/container" - "fyne.io/fyne/v2/driver/desktop" "fyne.io/fyne/v2/widget" ) @@ -49,107 +48,38 @@ func main() { // Reconfigure logging by config values // logger.ReconfigureLogging(conf) - computer := okean240.New(conf) + debugger := debug.NewDebugger() + computer := okean240.NewComputer(conf, debugger) + computer.SetSerialBytes(serialBytes) - computer.LoadFloppy() + + if conf.FDC.AutoLoadB { + err := computer.LoadFloppy(fdc.FloppyB) + if err != nil { + // show message + } + } + if conf.FDC.AutoLoadC { + err := computer.LoadFloppy(fdc.FloppyC) + if err != nil { + // show message + } + } + + disasm := dis.NewDisassembler(computer) w, raster, label := mainWindow(computer) go emulator(computer) go screen(computer, raster, label) - go debuger.SetupTcpHandler(conf, computer) + + if conf.Debugger.Enabled { + go listener.SetupTcpHandler(conf, debugger, disasm, computer) + } (*w).ShowAndRun() } -func mainWindow(computer *okean240.ComputerType) (*fyne.Window, *canvas.Raster, *widget.Label) { - emulatorApp := app.New() - w := emulatorApp.NewWindow("Океан 240.2") - w.Canvas().SetOnTypedKey( - func(key *fyne.KeyEvent) { - computer.PutKey(key) - }) - - w.Canvas().SetOnTypedRune( - func(key rune) { - computer.PutRune(key) - }) - - addShortcuts(w.Canvas(), computer) - - label := widget.NewLabel(fmt.Sprintf("Screen size: %dx%d", computer.ScreenWidth(), computer.ScreenHeight())) - - raster := canvas.NewRasterWithPixels( - func(x, y, w, h int) color.Color { - var xx uint16 - if computer.ScreenWidth() == 512 { - xx = uint16(x) - } else { - xx = uint16(x) / 2 - } - return computer.GetPixel(xx, uint16(y/2)) - }) - raster.Resize(fyne.NewSize(512, 512)) - raster.SetMinSize(fyne.NewSize(512, 512)) - - centerRaster := container.NewCenter(raster) - - w.Resize(fyne.NewSize(600, 600)) - - hBox := container.NewHBox( - //widget.NewButton("++", func() { - // computer.IncOffset() - //}), - //widget.NewButton("--", func() { - // computer.DecOffset() - //}), - widget.NewButton("Ctrl+C", func() { - computer.PutCtrlKey(0x03) - }), - widget.NewButton("Load Floppy", func() { - computer.LoadFloppy() - }), - widget.NewButton("Save Floppy", func() { - computer.SaveFloppy() - }), - widget.NewButton("RUN1", func() { - computer.SetRamBytes(ramBytes1) - }), - widget.NewButton("RUN2", func() { - computer.SetRamBytes(ramBytes2) - }), - widget.NewButton("DUMP", func() { - computer.Dump(0x100, 15000) - }), - widget.NewCheck("Full speed", func(b bool) { - fullSpeed.Store(b) - if b { - computer.SetCPUFrequency(50_000_000) - } else { - computer.SetCPUFrequency(2_500_000) - } - }), - widget.NewSeparator(), - widget.NewButton("Reset", func() { - needReset = true - //computer.Reset(conf) - }), - widget.NewSeparator(), - widget.NewButton("Закрыть", func() { - emulatorApp.Quit() - }), - ) - vBox := container.NewVBox( - centerRaster, - label, - hBox, - ) - - w.SetContent(vBox) - - return &w, raster, label -} - func screen(computer *okean240.ComputerType, raster *canvas.Raster, label *widget.Label) { ticker := time.NewTicker(20 * time.Millisecond) frame := 0 @@ -164,13 +94,17 @@ func screen(computer *okean240.ComputerType, raster *canvas.Raster, label *widge if frame%50 == 0 { freq = computer.Cycles() - pre pre = computer.Cycles() - label.SetText(fmt.Sprintf("Screen size: %dx%d F: %d", computer.ScreenWidth(), computer.ScreenHeight(), freq)) + label.SetText(formatLabel(computer, freq)) } raster.Refresh() }) } } +func formatLabel(computer *okean240.ComputerType, freq uint64) string { + return fmt.Sprintf("Screen size: %dx%d | F: %d | Debugger: %s", computer.ScreenWidth(), computer.ScreenHeight(), freq, computer.DebuggerState()) +} + const ticksPerTact uint64 = 4 func emulator(computer *okean240.ComputerType) { @@ -184,35 +118,24 @@ func emulator(computer *okean240.ComputerType) { // 1.5 MHz computer.TimerClk() } - if !computer.IsStepMode() || computer.IsRunMode() { - var bp uint16 = 0 - if fullSpeed.Load() { - _, bp = computer.Do() - } else { - if ticks >= nextClock { - var t uint32 - t, bp = computer.Do() - nextClock = ticks + uint64(t)*ticksPerTact - } - } - // Breakpoint hit - if bp > 0 { - debuger.BreakpointHit(bp) + var bp uint16 = 0 + var typ byte = 0 + if fullSpeed.Load() { + _, bp, typ = computer.Do() + } else { + if ticks >= nextClock { + var t uint32 + t, bp, typ = computer.Do() + nextClock = ticks + uint64(t)*ticksPerTact } } + // Breakpoint hit + if bp > 0 || typ != 0 { + listener.BreakpointHit(bp, typ) + } if needReset { computer.Reset() needReset = false } } } - -// Add shortcuts for all Ctrl+ -func addShortcuts(c fyne.Canvas, computer *okean240.ComputerType) { - // Add shortcuts for Ctrl+A to Ctrl+Z - for kName := 'A'; kName <= 'Z'; kName++ { - kk := fyne.KeyName(kName) - sc := &desktop.CustomShortcut{KeyName: kk, Modifier: fyne.KeyModifierControl} - c.AddShortcut(sc, func(shortcut fyne.Shortcut) { computer.PutCtrlKey(byte(kName&0xff) - 0x40) }) - } -} diff --git a/okean240/computer.go b/okean240/computer.go index 8c6c10a..47d4a13 100644 --- a/okean240/computer.go +++ b/okean240/computer.go @@ -3,8 +3,11 @@ package okean240 import ( _ "embed" "encoding/binary" + "encoding/json" + "fmt" "image/color" "okemu/config" + "okemu/debug" "okemu/okean240/fdc" "okemu/okean240/pic" "okemu/okean240/pit" @@ -22,39 +25,36 @@ import ( const DefaultCPUFrequency = 2_500_000 -type Breakpoint struct { - addr uint16 - enabled bool -} - type ComputerType struct { - cpu *c99.Z80 - memory Memory - ioPorts [256]byte - cycles uint64 - dd17EnableOut bool - colorMode bool - screenWidth int - screenHeight int - vRAM *RamBlock - palette byte - bgColor byte - pit *pit.I8253 - usart *usart.I8251 - pic *pic.I8259 - fdc *fdc.FloppyDriveController - kbdBuffer []byte - vShift byte - hShift byte - stepMode bool - runMode bool - bpEnabled bool - breakpoints [MaxBreakpoints]Breakpoint - //aOffset uint16 - cpuFrequency uint32 + cpu *c99.Z80 + memory Memory + ioPorts [256]byte + cycles uint64 + tstatesPartial uint64 + dd17EnableOut bool + colorMode bool + screenWidth int + screenHeight int + vRAM *RamBlock + palette byte + bgColor byte + pit *pit.I8253 + usart *usart.I8251 + pic *pic.I8259 + fdc *fdc.FloppyDriveController + kbdBuffer []byte + vShift byte + hShift byte + cpuFrequency uint32 + // + debugger *debug.Debugger +} + +type Snapshot struct { + CPU *z80.CPU `json:"cpu,omitempty"` + Memory string `json:"memory,omitempty"` } -const MaxBreakpoints = 256 const VRAMBlock0 = 3 const VRAMBlock1 = 7 const VidVsuBit = 0x80 @@ -71,66 +71,15 @@ type ComputerInterface interface { PutCtrlKey(shortcut fyne.Shortcut) SaveFloppy() LoadFloppy() - CPUState() *z80.Z80CPU - SetCPUState(state *z80.Z80CPU) - StepMode() bool - SetStepMode(step bool) - ClearMemBreakpoints() - SetBreakpointsEnabled(enabled bool) - IsBreakpoint() bool - //Dump(start uint16, length uint16) + CPUState() *z80.CPU + SetCPUState(state *z80.CPU) } -func (c *ComputerType) SetBreakpointsEnabled(enabled bool) { - c.bpEnabled = enabled -} - -func (c *ComputerType) IsBreakpointsEnabled() bool { - return c.bpEnabled -} - -func (c *ComputerType) SetBreakpoint(no uint16, addr uint16) { - if no > 0 && no <= MaxBreakpoints { - c.breakpoints[no-1].addr = addr - log.Debugf("BP[%d] SET AT PC=%04X", no, addr) - } else { - log.Warnf("Breakpoint number %d out or range!", no) - } -} - -func (c *ComputerType) SetBreakpointEnabled(no uint16, enabled bool) { - if no <= MaxBreakpoints && no > 0 { - c.breakpoints[no-1].enabled = enabled - } else { - log.Warnf("Breakpoint number %d out or range!", no) - } -} - -func (c *ComputerType) IsBreakpointEnabled(no uint16) bool { - if no <= MaxBreakpoints && no > 0 { - return c.breakpoints[no-1].enabled - } - log.Warnf("Breakpoint number %d out or range!", no) - return false -} - -func (c *ComputerType) ClearMemBreakpoints() { - log.Warnf("Clearing memory bpEnabled unimplemented!") -} - -func (c *ComputerType) SetStepMode(step bool) { - c.stepMode = step -} - -func (c *ComputerType) IsStepMode() bool { - return c.stepMode -} - -func (c *ComputerType) GetCPUState() *z80.Z80CPU { +func (c *ComputerType) GetCPUState() *z80.CPU { return c.cpu.GetState() } -func (c *ComputerType) SetCPUState(state *z80.Z80CPU) { +func (c *ComputerType) SetCPUState(state *z80.CPU) { c.cpu.SetState(state) } @@ -146,8 +95,8 @@ func (c *ComputerType) MemWrite(addr uint16, val byte) { c.memory.MemWrite(addr, val) } -// New Builds new computer -func New(cfg *config.OkEmuConfig) *ComputerType { +// NewComputer Builds new computer +func NewComputer(cfg *config.OkEmuConfig, deb *debug.Debugger) *ComputerType { c := ComputerType{} c.memory = Memory{} c.memory.Init(cfg.MonitorFile, cfg.CPMFile) @@ -155,6 +104,7 @@ func New(cfg *config.OkEmuConfig) *ComputerType { c.cpu = c99.New(&c) c.cycles = 0 + c.tstatesPartial = 0 c.dd17EnableOut = false c.screenWidth = 512 c.screenHeight = 256 @@ -169,58 +119,80 @@ func New(cfg *config.OkEmuConfig) *ComputerType { c.pit = pit.New() c.usart = usart.New() c.pic = pic.New() - c.fdc = fdc.New(cfg) + c.fdc = fdc.NewFDC(cfg) c.cpuFrequency = DefaultCPUFrequency - c.bpEnabled = false - c.breakpoints = [256]Breakpoint{} - for i := range c.breakpoints { - c.breakpoints[i] = Breakpoint{} - c.breakpoints[i].enabled = false - c.breakpoints[i].addr = 0 - } + c.debugger = deb return &c } func (c *ComputerType) Reset() { c.cpu.Reset() c.cycles = 0 - //c.vShift = 0 - //c.hShift = 0 - //c.memory = Memory{} - //c.memory.Init(cfg.MonitorFile, cfg.CPMFile) - //c.dd17EnableOut = false - //c.screenWidth = 256 - //c.screenHeight = 256 - //c.vRAM = c.memory.allMemory[3] - + c.tstatesPartial = 0 } -func (c *ComputerType) SetRunMode(run bool) { - c.runMode = run +func (c *ComputerType) getContext() map[string]interface{} { + context := make(map[string]interface{}) + s := c.cpu.GetState() + context["A"] = s.A + context["B"] = s.B + context["C"] = s.C + context["D"] = s.D + context["E"] = s.E + context["H"] = s.H + context["L"] = s.L + context["A'"] = s.AAlt + context["B'"] = s.BAlt + context["C'"] = s.CAlt + context["D'"] = s.DAlt + context["E'"] = s.EAlt + context["H'"] = s.HAlt + context["L'"] = s.LAlt + context["PC"] = s.PC + context["SP"] = s.SP + context["IX"] = s.IX + context["IY"] = s.IY + context["ZF"] = s.Flags.Z + context["SF"] = s.Flags.S + context["NF"] = s.Flags.N + context["PF"] = s.Flags.P + context["HF"] = s.Flags.H + context["YF"] = s.Flags.Y + context["XF"] = s.Flags.X + context["CF"] = s.Flags.C + context["BC"] = uint16(s.B)<<8 | uint16(s.C) + context["DE"] = uint16(s.D)<<8 | uint16(s.E) + context["HL"] = uint16(s.H)<<8 | uint16(s.L) + context["AF"] = uint16(s.A)<<8 | uint16(s.Flags.GetFlags()) + return context } -func (c *ComputerType) IsRunMode() bool { - return c.runMode -} - -func (c *ComputerType) Do() (uint32, uint16) { - // check breakpoints - if c.bpEnabled && c.runMode { - for no, bp := range c.breakpoints { - if bp.enabled && bp.addr == c.cpu.GetState().PC { - c.runMode = false - return 0, uint16(no + 1) +func (c *ComputerType) Do() (uint32, uint16, byte) { + ticks := uint32(0) + var memAccess *map[uint16]byte + if c.debugger.StepMode() { + if c.debugger.RunMode() || c.debugger.DoStep() { + if c.debugger.RunInst() > 0 { + // skip first instruction after run-mode activated + bpHit, bp := c.debugger.CheckBreakpoints(c.getContext()) + if bpHit { + //c.debugger.SetRunMode(false) + return 0, bp, 0 + } + } + c.debugger.SaveHistory(c.cpu.GetState()) + ticks, memAccess = c.cpu.RunInstruction() + mHit, mAddr, mTyp := c.debugger.CheckMemBreakpoints(memAccess) + if mHit { + return ticks, mAddr, mTyp } } + } else { + ticks, memAccess = c.cpu.RunInstruction() } - - ticks := c.cpu.RunInstruction() c.cycles += uint64(ticks) - //pc := c.cpu.GetState().PC - //if pc >= 0xfea3 && pc <= 0xff25 { - // c.cpu.DebugOutput() - //} - return ticks, 0 + c.tstatesPartial += uint64(ticks) + return ticks, 0, 0 } func (c *ComputerType) GetPixel(x uint16, y uint16) color.RGBA { @@ -303,6 +275,14 @@ func (c *ComputerType) Cycles() uint64 { return c.cycles } +func (c *ComputerType) ResetTStatesPartial() { + c.tstatesPartial = 0 +} + +func (c *ComputerType) TStatesPartial() uint64 { + return c.tstatesPartial +} + func (c *ComputerType) TimerClk() { // DD70 KR580VI53 CLK0, CKL1 @ 1.5MHz c.pit.Tick(0) @@ -319,12 +299,12 @@ func (c *ComputerType) TimerClk() { } } -func (c *ComputerType) LoadFloppy() { - c.fdc.LoadFloppy() +func (c *ComputerType) LoadFloppy(drive byte) error { + return c.fdc.LoadFloppy(drive) } -func (c *ComputerType) SaveFloppy() { - c.fdc.SaveFloppy() +func (c *ComputerType) SaveFloppy(drive byte) error { + return c.fdc.SaveFloppy(drive) } func (c *ComputerType) SetSerialBytes(bytes []byte) { @@ -379,3 +359,51 @@ func (c *ComputerType) CPUFrequency() uint32 { func (c *ComputerType) SetCPUFrequency(frequency uint32) { c.cpuFrequency = frequency } + +func (c *ComputerType) DebuggerState() string { + if c.debugger.StepMode() { + if c.debugger.RunMode() { + return "Run" + } + return "Step" + } + return "Off" +} + +func (c *ComputerType) memoryAsHexStr() string { + res := "" + for addr := 0; addr <= 65535; addr++ { + res += fmt.Sprintf("%02X", c.memory.MemRead(uint16(addr))) + } + return res +} + +func (c *ComputerType) SaveSnapshot(fn string) error { + // create snapshot file + file, err := os.Create(fn) + if err != nil { + return err + } + defer func() { + err := file.Close() + if err != nil { + log.Error(err) + } + }() + // take snapshot + s := Snapshot{ + CPU: c.cpu.GetState(), + Memory: c.memoryAsHexStr(), + } + // convert to JSON + b, err := json.Marshal(s) + if err != nil { + return err + } + // and save + err = binary.Write(file, binary.LittleEndian, b) + if err != nil { + return err + } + return nil +} diff --git a/okean240/fdc/fdc.go b/okean240/fdc/fdc.go index 3f44911..18e3325 100644 --- a/okean240/fdc/fdc.go +++ b/okean240/fdc/fdc.go @@ -10,6 +10,7 @@ package fdc import ( "bytes" "encoding/binary" + "errors" "okemu/config" "os" "slices" @@ -21,6 +22,9 @@ import ( // Floppy parameters const ( + FloppyB = 0 + FloppyC = 1 + TotalDrives = 2 FloppySizeK = 720 SectorSize = 512 @@ -84,8 +88,8 @@ type FloppyDriveController struct { //curSector *SectorType bytePtr uint16 trackBuffer []byte - floppyBFile string - floppyCFile string + floppyFile []string + config *config.OkEmuConfig } type FloppyDriveControllerInterface interface { @@ -98,7 +102,7 @@ type FloppyDriveControllerInterface interface { SetData(value byte) Data() byte Drq() byte - SaveFloppy() + SaveFloppy(drive byte) GetSectorNo() uint16 Track() byte Sector() byte @@ -326,17 +330,21 @@ func (f *FloppyDriveController) Drq() byte { return f.drq } -func (f *FloppyDriveController) LoadFloppy() { - loadFloppy(&f.sectors[0], f.floppyBFile) - loadFloppy(&f.sectors[1], f.floppyCFile) +func (f *FloppyDriveController) LoadFloppy(drive byte) error { + if drive < TotalDrives { + return loadFloppy(&f.sectors[drive], f.floppyFile[drive]) + } + return errors.New("DriveNo " + strconv.Itoa(int(drive)) + " out of range") } -func (f *FloppyDriveController) SaveFloppy() { - saveFloppy(&f.sectors[0], f.floppyBFile) - saveFloppy(&f.sectors[1], f.floppyCFile) +func (f *FloppyDriveController) SaveFloppy(drive byte) error { + if drive < TotalDrives { + return saveFloppy(&f.sectors[drive], f.floppyFile[drive]) + } + return errors.New("DriveNo " + strconv.Itoa(int(drive)) + " out of range") } -func New(conf *config.OkEmuConfig) *FloppyDriveController { +func NewFDC(conf *config.OkEmuConfig) *FloppyDriveController { sec := [2][SizeInSectors]SectorType{} // for each drive for d := 0; d < TotalDrives; d++ { @@ -346,20 +354,19 @@ func New(conf *config.OkEmuConfig) *FloppyDriveController { } } return &FloppyDriveController{ - sideNo: 0, - ddEn: 0, - init: 0, - drive: 0, - mot1: 0, - mot0: 0, - intRq: 0, - motSt: 0, - drq: 0, - lastCmd: 0xff, - sectors: sec, - bytePtr: 0xffff, - floppyBFile: conf.FloppyB, - floppyCFile: conf.FloppyC, + sideNo: 0, + ddEn: 0, + init: 0, + drive: 0, + mot1: 0, + mot0: 0, + intRq: 0, + motSt: 0, + drq: 0, + lastCmd: 0xff, + sectors: sec, + bytePtr: 0xffff, + floppyFile: []string{conf.FDC.FloppyB, conf.FDC.FloppyC}, } } @@ -393,12 +400,12 @@ func (f *FloppyDriveController) Sector() byte { } // loadFloppy load floppy image to sector buffer from file -func loadFloppy(sectors *[SizeInSectors]SectorType, fileName string) { +func loadFloppy(sectors *[SizeInSectors]SectorType, fileName string) error { log.Debugf("Load Floppy content from file %s.", fileName) file, err := os.Open(fileName) if err != nil { log.Error(err) - return + return err } defer func(file *os.File) { @@ -416,17 +423,19 @@ func loadFloppy(sectors *[SizeInSectors]SectorType, fileName string) { } if err != nil { log.Error("Load floppy content failed:", err) - break + return err } } + return nil } -func saveFloppy(sectors *[SizeInSectors]SectorType, fileName string) { +// saveFloppy Save specified sectors to file with name fileName +func saveFloppy(sectors *[SizeInSectors]SectorType, fileName string) error { log.Debugf("Save Floppy to file %s.", fileName) file, err := os.Create(fileName) if err != nil { log.Error(err) - return + return err } defer func(file *os.File) { err := file.Close() @@ -443,7 +452,8 @@ func saveFloppy(sectors *[SizeInSectors]SectorType, fileName string) { } if err != nil { log.Error("Save floppy content failed:", err) - break + return err } } + return nil } diff --git a/okemu.yml b/okemu.yml index a48a875..5487ce8 100644 --- a/okemu.yml +++ b/okemu.yml @@ -1,8 +1,15 @@ logFile: "okemu.log" logLevel: "info" + monitorFile: "rom/MON_r8_9c6c6546.bin" -# cpmFile: "rom/CPM_r7_b89a7e16.bin" cpmFile: "rom/CPM_r8_bc0695e4.bin" -floppyB: "floppy/floppyB.okd" -floppyC: "floppy/floppyC.okd" -port: 10001 + +fdc: + autoLoadB: true + floppyB: "floppy/floppyB.okd" + autoLoadC: true + floppyC: "floppy/floppyC.okd" + +debugger: + enabled: true + port: 10001 diff --git a/z80/c99/cpu.go b/z80/c99/cpu.go index 3a43a84..af88099 100644 --- a/z80/c99/cpu.go +++ b/z80/c99/cpu.go @@ -33,9 +33,15 @@ type Z80 struct { intPending bool nmiPending bool - core z80.MemIoRW + core z80.MemIoRW + memAccess map[uint16]byte } +const ( + MemAccessRead = 1 + MemAccessWrite = 2 +) + // New initializes a Z80 instance and return pointer to it func New(core z80.MemIoRW) *Z80 { z := Z80{} @@ -90,7 +96,9 @@ func New(core z80.MemIoRW) *Z80 { } // RunInstruction executes the next instruction in memory + handles interrupts -func (z *Z80) RunInstruction() uint32 { +func (z *Z80) RunInstruction() (uint32, *map[uint16]byte) { + z.memAccess = map[uint16]byte{} + pre := z.cycleCount if z.isHalted { z.execOpcode(0x00) @@ -99,10 +107,10 @@ func (z *Z80) RunInstruction() uint32 { z.execOpcode(opcode) } z.processInterrupts() - return z.cycleCount - pre + return z.cycleCount - pre, &z.memAccess } -func (z *Z80) SetState(state *z80.Z80CPU) { +func (z *Z80) SetState(state *z80.CPU) { z.cycleCount = 0 z.a = state.A z.b = state.B @@ -148,8 +156,8 @@ func (z *Z80) SetState(state *z80.Z80CPU) { z.nmiPending = false z.intData = 0 } -func (z *Z80) GetState() *z80.Z80CPU { - return &z80.Z80CPU{ +func (z *Z80) GetState() *z80.CPU { + return &z80.CPU{ A: z.a, B: z.b, C: z.c, @@ -211,3 +219,7 @@ func (z *Z80) getAltFlags() z80.FlagsType { C: z.f_&0x01 != 0, } } + +func (z *Z80) PC() uint16 { + return z.pc +} diff --git a/z80/c99/helper.go b/z80/c99/helper.go index 47bab45..6569fe9 100644 --- a/z80/c99/helper.go +++ b/z80/c99/helper.go @@ -3,18 +3,24 @@ package c99 import log "github.com/sirupsen/logrus" func (z *Z80) rb(addr uint16) byte { + z.memAccess[addr] = MemAccessRead return z.core.MemRead(addr) } func (z *Z80) wb(addr uint16, val byte) { + z.memAccess[addr] = MemAccessWrite z.core.MemWrite(addr, val) } func (z *Z80) rw(addr uint16) uint16 { + z.memAccess[addr] = MemAccessRead + z.memAccess[addr+1] = MemAccessRead return (uint16(z.core.MemRead(addr+1)) << 8) | uint16(z.core.MemRead(addr)) } func (z *Z80) ww(addr uint16, val uint16) { + z.memAccess[addr] = MemAccessWrite + z.memAccess[addr+1] = MemAccessWrite z.core.MemWrite(addr, byte(val)) z.core.MemWrite(addr+1, byte(val>>8)) } @@ -30,13 +36,13 @@ func (z *Z80) popW() uint16 { } func (z *Z80) nextB() byte { - b := z.rb(z.pc) + b := z.core.MemRead(z.pc) z.pc++ return b } func (z *Z80) nextW() uint16 { - w := z.rw(z.pc) + w := (uint16(z.core.MemRead(z.pc+1)) << 8) | uint16(z.core.MemRead(z.pc)) z.pc += 2 return w } diff --git a/z80/cpu.go b/z80/cpu.go index b70b96f..d46772b 100644 --- a/z80/cpu.go +++ b/z80/cpu.go @@ -19,58 +19,58 @@ type CPUInterface interface { // RunInstruction Run single instruction, return number of CPU cycles RunInstruction() uint32 // GetState Get current CPU state - GetState() *Z80CPU + GetState() *CPU // SetState Set current CPU state - SetState(state *Z80CPU) + SetState(state *CPU) // DebugOutput out current CPU state DebugOutput() } // FlagsType - Processor flags type FlagsType struct { - S bool - Z bool - Y bool - H bool - X bool - P bool - N bool - C bool + S bool `json:"s,omitempty"` + Z bool `json:"z,omitempty"` + Y bool `json:"y,omitempty"` + H bool `json:"h,omitempty"` + X bool `json:"x,omitempty"` + P bool `json:"p,omitempty"` + N bool `json:"n,omitempty"` + C bool `json:"c,omitempty"` } // Z80CPU - Processor state -type Z80CPU struct { - A byte - B byte - C byte - D byte - E byte - H byte - L byte - AAlt byte - BAlt byte - CAlt byte - DAlt byte - EAlt byte - HAlt byte - LAlt byte - IX uint16 - IY uint16 - I byte - R byte - SP uint16 - PC uint16 - Flags FlagsType - FlagsAlt FlagsType - IMode byte - Iff1 bool - Iff2 bool - Halted bool - DoDelayedDI bool - DoDelayedEI bool - CycleCount uint32 - InterruptOccurred bool - MemPtr uint16 +type CPU struct { + A byte `json:"a,omitempty"` + B byte `json:"b,omitempty"` + C byte `json:"c,omitempty"` + D byte `json:"d,omitempty"` + E byte `json:"e,omitempty"` + H byte `json:"h,omitempty"` + L byte `json:"l,omitempty"` + AAlt byte `json:"AAlt,omitempty"` + BAlt byte `json:"BAlt,omitempty"` + CAlt byte `json:"CAlt,omitempty"` + DAlt byte `json:"DAlt,omitempty"` + EAlt byte `json:"EAlt,omitempty"` + HAlt byte `json:"HAlt,omitempty"` + LAlt byte `json:"LAlt,omitempty"` + IX uint16 `json:"IX,omitempty"` + IY uint16 `json:"IY,omitempty"` + I byte `json:"i,omitempty"` + R byte `json:"r,omitempty"` + SP uint16 `json:"SP,omitempty"` + PC uint16 `json:"PC,omitempty"` + Flags FlagsType `json:"flags"` + FlagsAlt FlagsType `json:"flagsAlt"` + IMode byte `json:"IMode,omitempty"` + Iff1 bool `json:"iff1,omitempty"` + Iff2 bool `json:"iff2,omitempty"` + Halted bool `json:"halted,omitempty"` + DoDelayedDI bool `json:"doDelayedDI,omitempty"` + DoDelayedEI bool `json:"doDelayedEI,omitempty"` + CycleCount uint32 `json:"cycleCount,omitempty"` + InterruptOccurred bool `json:"interruptOccurred,omitempty"` + MemPtr uint16 `json:"memPtr,omitempty"` //core MemIoRW } @@ -132,7 +132,7 @@ func (f *FlagsType) GetFlagsStr() string { return string(flags) } -func (z *Z80CPU) IIFStr() string { +func (z *CPU) IIFStr() string { flags := []byte{'-', '-'} if z.Iff1 { flags[0] = '1' @@ -166,3 +166,7 @@ func (f *FlagsType) SetFlags(flags byte) { f.N = flags&0x02 != 0 f.C = flags&0x01 != 0 } + +func (z *CPU) GetPC() uint16 { + return z.PC +} diff --git a/z80/cpu_test.go b/z80/cpu_test.go index 46f7ae6..432efc2 100644 --- a/z80/cpu_test.go +++ b/z80/cpu_test.go @@ -413,7 +413,7 @@ func TestZ80Fuse(t *testing.T) { } func setComputerState(test Z80TestIn) { - state := z80.Z80CPU{ + state := z80.CPU{ A: byte(test.registers.AF >> 8), B: byte(test.registers.BC >> 8), C: byte(test.registers.BC), diff --git a/z80/dis/z80disasm.go b/z80/dis/z80disasm.go new file mode 100644 index 0000000..45f9653 --- /dev/null +++ b/z80/dis/z80disasm.go @@ -0,0 +1,672 @@ +package dis + +import ( + "fmt" + "okemu/z80" + "strings" +) + +type Disassembler struct { + pc uint16 + core z80.MemIoRW +} + +type Disassembly interface { + Disassm(pc uint16) string +} + +func NewDisassembler(core z80.MemIoRW) *Disassembler { + d := Disassembler{ + pc: 0, + core: core, + } + return &d +} + +// opcode & 0x07 +var operands = []string{"B", "C", "D", "E", "H", "L", "(HL)", "A"} +var aluOp = []string{"ADD A" + sep, "ADC A" + sep, "SUB ", "SBC A" + sep, "AND ", "XOR ", "OR ", "CP "} + +const sep = ", " + +func (d *Disassembler) jp(op, cond string) string { + addr := d.getW() + if cond != "" { + cond += sep + } + return fmt.Sprintf("%s %s%s", op, cond, addr) +} + +func (d *Disassembler) jr(op, cond string) string { + addr := d.pc + offset := d.getByte() + if offset&0x80 != 0 { + addr += 0xFF00 | uint16(offset) + } else { + addr += d.pc + uint16(offset) + } + if cond != "" { + cond += sep + } + return fmt.Sprintf("%s %s0x%04X", op, cond, addr) +} + +func (d *Disassembler) getByte() byte { + b := d.core.MemRead(d.pc) + d.pc++ + return b +} + +func (d *Disassembler) Disassm(pc uint16) string { + d.pc = pc + result := fmt.Sprintf(" %04X ", d.pc) + op := d.getByte() + + switch { + // == 00:0F + case op == 0x00: + result += "NOP" + case op == 0x01: + result += "LD BC" + sep + d.getW() + case op == 0x02: + result += "LD (BC)" + sep + "A" + case op == 0x03: + result += "INC BC" + case op == 0x04: + result += "INC B" + case op == 0x05: + result += "DEC B" + case op == 0x06: + result += "LD B" + sep + d.getB() + case op == 0x07: + result += "RLCA" + case op == 0x08: + result += "EX AF, AF'" + case op == 0x09: + result += "ADD HL" + sep + "BC" + case op == 0x0A: + result += "LD A" + sep + "(BC)" + case op == 0x0B: + result += "DEC BC" + case op == 0x0C: + result += "INC C" + case op == 0x0D: + result += "DEC C" + case op == 0x0E: + result += "LD C" + sep + d.getB() + case op == 0x0F: + result += "RRCA" + // 10:1F + case op == 0x10: + // DJNZ rel + result += d.jr("DJNZ", "") + case op == 0x11: + result += "LD DE" + sep + d.getW() + case op == 0x12: + result += "LD (DE)" + sep + "A" + case op == 0x13: + result += "INC DE" + case op == 0x14: + result += "INC D" + case op == 0x15: + result += "DEC D" + case op == 0x16: + result += "LD D" + sep + d.getB() + case op == 0x17: + result += "RLA" + case op == 0x18: + result += d.jr("JR", "") + case op == 0x19: + result += "ADD HL" + sep + "DE" + case op == 0x1A: + result += "LD A" + sep + "(DE)" + case op == 0x1B: + result += "DEC DE" + case op == 0x1C: + result += "INC E" + case op == 0x1D: + result += "DEC E" + case op == 0x1E: + result += "LD E" + sep + d.getB() + case op == 0x1F: + result += "RRA" + // == 20:2F + case op == 0x20: + result += d.jr("JR", "NZ") + case op == 0x21: + result += "LD HL" + sep + d.getW() + case op == 0x22: + // LD (nn),HL + result += "LD (" + d.getW() + ")" + sep + "HL" + case op == 0x23: + result += "INC HL" + case op == 0x24: + result += "INC H" + case op == 0x25: + result += "DEC H" + case op == 0x26: + result += "LD H" + sep + d.getB() + case op == 0x27: + result += "DAA" + case op == 0x28: + result += d.jr("JR", "Z") + case op == 0x29: + result += "ADD HL" + sep + "HL" + case op == 0x2A: + result += "LD HL" + sep + "(" + d.getW() + ")" + case op == 0x2B: + result += "DEC HL" + case op == 0x2C: + result += "INC L" + case op == 0x2D: + result += "DEC L" + case op == 0x2E: + result += "LD L" + sep + d.getB() + case op == 0x2F: + result += "CPL" + // == 30:3F + case op == 0x30: + result += d.jr("JR", "NC") + case op == 0x31: + result += "LD SP" + sep + d.getW() + case op == 0x32: + result += "LD (" + d.getW() + ")" + sep + "A" + case op == 0x33: + result += "INC SP" + case op == 0x34: + result += "INC (HL)" + case op == 0x35: + result += "DEC (HL)" + case op == 0x36: + result += "LD (HL)" + sep + d.getB() + case op == 0x37: + result += "SCF" + case op == 0x38: + result += d.jr("JR", "C") + case op == 0x39: + result += "ADD HL" + sep + "SP" + case op == 0x3A: + result += "LD A" + sep + "(" + d.getW() + ")" + case op == 0x3B: + result += "DEC SP" + case op == 0x3C: + result += "INC A" + case op == 0x3D: + result += "DEC A" + case op == 0x3E: + result += "LD A" + sep + d.getB() + case op == 0x3F: + result += "CCF" + case op == 0x76: + result += "HALT" + case op >= 0x40 && op <= 0x7F: + // LD op8, op8 + result += "LD " + operands[(op>>3)&0x07] + sep + operands[op&0x07] + case op >= 0x80 && op <= 0xBF: + // ALU op8 + result += aluOp[(op>>3)&0x07] + operands[op&0x07] + case op == 0xc0: + result += "RET NZ" + case op == 0xc1: + result += "POP BC" + case op == 0xc2: + result += d.jp("JP", "NZ") + case op == 0xc3: + result += d.jp("JP", "") + case op == 0xc4: + result += d.jp("CALL", "NZ") + case op == 0xc5: + result += "PUSH BC" + case op == 0xc6: + result += "ADD A" + sep + d.getB() + case op == 0xc7 || op == 0xd7 || op == 0xe7 || op == 0xf7 || op == 0xcf || op == 0xdf || op == 0xef || op == 0xff: + // RST nnH + result += fmt.Sprintf("RST %d%dH", (op>>4)&3, (op&1)*8) + case op == 0xc8: + result += "RET Z" + case op == 0xc9: + result += "RET" + case op == 0xca: + result += d.jp("JP", "Z") + case op == 0xcb: + result += d.opocodeCB() + case op == 0xcc: + result += d.jp("CALL", "Z") + case op == 0xcd: + result += d.jp("CALL", "") + case op == 0xce: + result += "ADC A" + sep + d.getB() + case op == 0xd0: + result += "RET NC" + case op == 0xd1: + result += "POP DE" + case op == 0xd2: + result += d.jp("JP", "NC") + case op == 0xd3: + result += "OUT (" + d.getB() + ")" + sep + "A" + case op == 0xd4: + result += d.jp("CALL", "NC") + case op == 0xd5: + result += "PUSH DE" + case op == 0xd6: + result += "SUB " + d.getB() + case op == 0xd8: + result += "RET C" + case op == 0xd9: + result += "EXX" + case op == 0xda: + result += d.jp("JP", "C") + case op == 0xdb: + result += "IN A" + sep + " (" + d.getB() + ")" + case op == 0xdc: + result += d.jp("CALL", "C") + case op == 0xdd: + result += d.opocodeDD(op) + case op == 0xde: + result += "SBC A" + sep + d.getB() + case op == 0xe0: + result += "RET PO" + case op == 0xe1: + result += "POP HL" + case op == 0xe2: + result += d.jp("JP", "PO") + case op == 0xe3: + result += "EX (SP)" + sep + "HL" + case op == 0xe4: + result += d.jp("CALL", "PO") + case op == 0xe5: + result += "PUSH HL" + case op == 0xe6: + result += "AND " + d.getB() + case op == 0xe8: + result += "RET PE" + case op == 0xe9: + result += "JP (HL)" + case op == 0xea: + result += d.jp("JP", "PE") + case op == 0xeb: + result += "EX DE" + sep + "HL" + case op == 0xec: + result += d.jp("CALL", "PE") + case op == 0xed: + result += d.opocodeED() + case op == 0xee: + result += "XOR " + d.getB() + case op == 0xf0: + result += "RET P" + case op == 0xf1: + result += "POP AF" + case op == 0xf2: + result += d.jp("JP", "P") + case op == 0xf3: + result += "DI" + case op == 0xf4: + result += d.jp("CALL", "P") + case op == 0xf5: + result += "PUSH AF" + case op == 0xf6: + result += "OR " + d.getB() + case op == 0xf8: + result += "RET M" + case op == 0xf9: + result += "LD SP" + sep + "HL" + case op == 0xfa: + result += d.jp("JP", "M") + case op == 0xfb: + result += "EI" + case op == 0xfc: + result += d.jp("CALL", "M") + case op == 0xfd: + result += d.opocodeDD(op) + case op == 0xfe: + result += "CP " + d.getB() + default: + // All unknown as DB + result += fmt.Sprintf("DB 0x%02X", op) + } + return result +} + +func (d *Disassembler) getW() string { + lo := d.core.MemRead(d.pc) + d.pc++ + hi := d.core.MemRead(d.pc) + d.pc++ + return fmt.Sprintf("0x%02X%02X", hi, lo) +} + +func (d *Disassembler) getB() string { + lo := d.core.MemRead(d.pc) + d.pc++ + return fmt.Sprintf("0x%02X", lo) +} + +func (d *Disassembler) getRel() string { + offset := d.core.MemRead(d.pc) + var sign string + if int8(offset) < 0 { + sign = "-" + } else { + sign = "+" + } + return sign + fmt.Sprintf("0x%02X", offset&0x7F) +} + +var shiftOps = []string{"RLC", "RRC", "RL", "RR", "SLA", "SRA", "SLL", "SRL"} +var bitOps = []string{"BIT", "RES", "SET"} + +// opocodeCB disassemble Z80 Opcodes, with CB first byte +func (d *Disassembler) opocodeCB() string { + op := "" + opcode := d.getByte() + if opcode <= 0x3F { + op = shiftOps[opcode>>3&0x07] + operands[opcode&0x7] + } else { + op = shiftOps[(opcode>>6&0x03)-1] + operands[opcode&0x7] + } + return op +} + +func (d *Disassembler) opocodeDD(op byte) string { + opcode := d.getByte() + result := "" + switch opcode { + case 0x09: + result = "ADD ii" + sep + "BC" + case 0x19: + result = "ADD ii" + sep + "DE" + case 0x21: + result = "LD ii" + sep + d.getW() + case 0x22: + result = "LD (" + d.getW() + ")" + sep + "ii" + case 0x23: + result = "INC ii" + case 0x24: + result = "INC IXH" + case 0x25: + result = "DEC IXH" + case 0x26: + result = "LD IXH" + sep + "n" + case 0x29: + result = "ADD ii" + sep + "ii" + case 0x2A: + result = "LD ii" + sep + "(" + d.getW() + ")" + case 0x2B: + result = "DEC ii" + case 0x34: + result = "INC (ii" + d.getRel() + ")" + case 0x35: + result = "DEC (ii" + d.getRel() + ")" + case 0x36: + result = "LD (ii" + d.getRel() + ")" + sep + "n" + case 0x39: + result = "ADD ii" + sep + "SP" + case 0x46: + result = "LD B" + sep + "(ii" + d.getRel() + ")" + case 0x4E: + result = "LD C" + sep + "(ii" + d.getRel() + ")" + case 0x56: + result = "LD D" + sep + "(ii" + d.getRel() + ")" + case 0x5E: + result = "LD E" + sep + "(ii" + d.getRel() + ")" + case 0x66: + result = "LD H" + sep + "(ii" + d.getRel() + ")" + case 0x6E: + result = "LD L" + sep + "(ii" + d.getRel() + ")" + case 0x70: + result = "LD (ii" + d.getRel() + ")" + sep + "B" + case 0x71: + result = "LD (ii" + d.getRel() + ")" + sep + "C" + case 0x72: + result = "LD (ii" + d.getRel() + ")" + sep + "D" + case 0x73: + result = "LD (ii" + d.getRel() + ")" + sep + "E" + case 0x74: + result = "LD (ii" + d.getRel() + ")" + sep + "H" + case 0x75: + result = "LD (ii" + d.getRel() + ")" + sep + "L" + case 0x77: + result = "LD (ii" + d.getRel() + ")" + sep + "A" + case 0x7E: + result = "LD A" + sep + "(ii" + d.getRel() + ")" + case 0x86: + result = "ADD A" + sep + "(ii" + d.getRel() + ")" + case 0x8E: + result = "ADC A" + sep + "(ii" + d.getRel() + ")" + case 0x96: + result = "SUB (ii" + d.getRel() + ")" + case 0x9E: + result = "SBC A" + sep + "(ii" + d.getRel() + ")" + case 0xA6: + result = "AND (ii" + d.getRel() + ")" + case 0xAE: + result = "XOR (ii" + d.getRel() + ")" + case 0xB6: + result = "OR (ii" + d.getRel() + ")" + case 0xBE: + result = "CP (ii" + d.getRel() + ")" + case 0xCB: + result = d.opocodeDDCB(op, opcode) + case 0xE1: + result = "POP ii" + case 0xE3: + result = "EX (SP)" + sep + "ii" + case 0xE5: + result = "PUSH ii" + case 0xE9: + result = "JP (ii)" + case 0xF9: + result = "LD SP" + sep + "ii" + default: + return fmt.Sprintf("DB 0x%02X, 0x%02X", op, opcode) + } + + reg := "IX" + if op == 0xFD { + reg = "IY" + } + + return strings.ReplaceAll(result, "ii", reg) +} + +func (d *Disassembler) opocodeDDCB(op1 byte, op2 byte) string { + opcode := d.getByte() + result := "" + switch opcode { + case 0x06: + result = "RLC (ii" + d.getRel() + ")" + case 0x0E: + result = "RRC (ii" + d.getRel() + ")" + case 0x16: + result = "RL (ii" + d.getRel() + ")" + case 0x1E: + result = "RR (ii" + d.getRel() + ")" + case 0x26: + result = "SLA (ii" + d.getRel() + ")" + case 0x2E: + result = "SRA (ii" + d.getRel() + ")" + case 0x3E: + result = "SRL (ii" + d.getRel() + ")" + case 0x46: + result = "BIT 0" + sep + "(ii" + d.getRel() + ")" + case 0x4E: + result = "BIT 1" + sep + "(ii" + d.getRel() + ")" + case 0x56: + result = "BIT 2" + sep + "(ii" + d.getRel() + ")" + case 0x5E: + result = "BIT 3" + sep + "(ii" + d.getRel() + ")" + case 0x66: + result = "BIT 4" + sep + "(ii" + d.getRel() + ")" + case 0x6E: + result = "BIT 5" + sep + "(ii" + d.getRel() + ")" + case 0x76: + result = "BIT 6" + sep + "(ii" + d.getRel() + ")" + case 0x7E: + result = "BIT 7" + sep + "(ii" + d.getRel() + ")" + case 0x86: + result = "RES 0" + sep + "(ii" + d.getRel() + ")" + case 0x8E: + result = "RES 1" + sep + "(ii" + d.getRel() + ")" + case 0x96: + result = "RES 2" + sep + "(ii" + d.getRel() + ")" + case 0x9E: + result = "RES 3" + sep + "(ii" + d.getRel() + ")" + case 0xA6: + result = "RES 4" + sep + "(ii" + d.getRel() + ")" + case 0xAE: + result = "RES 5" + sep + "(ii" + d.getRel() + ")" + case 0xB6: + result = "RES 6" + sep + "(ii" + d.getRel() + ")" + case 0xBE: + result = "RES 7" + sep + "(ii" + d.getRel() + ")" + case 0xC6: + result = "SET 0" + sep + "(ii" + d.getRel() + ")" + case 0xCE: + result = "SET 1" + sep + "(ii" + d.getRel() + ")" + case 0xD6: + result = "SET 2" + sep + "(ii" + d.getRel() + ")" + case 0xDE: + result = "SET 3" + sep + "(ii" + d.getRel() + ")" + case 0xE6: + result = "SET 4" + sep + "(ii" + d.getRel() + ")" + case 0xEE: + result = "SET 5" + sep + "(ii" + d.getRel() + ")" + case 0xF6: + result = "SET 6" + sep + "(ii" + d.getRel() + ")" + case 0xFE: + result = "SET 7" + sep + "(ii" + d.getRel() + ")" + default: + result = fmt.Sprintf("DB 0x%02X, 0x%02X, 0x%02X", op1, op2, opcode) + } + return result +} + +func (d *Disassembler) opocodeED() string { + opcode := d.getByte() + result := "" + switch opcode { + case 0x40: + result = "IN B" + sep + "(C)" + case 0x41: + result = "OUT (C)" + sep + "B" + case 0x42: + result = "SBC HL" + sep + "BC" + case 0x43: + result = "LD (" + d.getW() + ")" + sep + "BC" + case 0x44, 0x4C, 0x54, 0x5C, 0x64, 0x6C, 0x74, 0x7C: + result = "NEG" + case 0x45, 0x55, 0x5D, 0x65, 0x6D, 0x75, 0x7D: + result = "RETN" + case 0x46, 0x4E, 0x66, 0x6E: + result = "IM 0" + case 0x47: + result = "LD I" + sep + "A" + case 0x48: + result = "IN C" + sep + "(C)" + case 0x49: + result = "OUT (C)" + sep + "C" + case 0x4A: + result = "ADC HL" + sep + "BC" + case 0x4B: + result = "LD BC" + sep + "(" + d.getW() + ")" + case 0x4D: + result = "REТI" + case 0x4F: + result = "LD R" + sep + "A" + case 0x50: + result = "IN D" + sep + "(C)" + case 0x51: + result = "OUT (C)" + sep + "D" + case 0x52: + result = "SBC HL" + sep + "DE" + case 0x53: + result = "LD (nn)" + sep + "DE" + case 0x56, 0x76: + result = "IM 1" + case 0x57: + result = "LD A" + sep + "I" + case 0x58: + result = "IN E" + sep + "(C)" + case 0x59: + result = "OUT (C)" + sep + "E" + case 0x5A: + result = "ADC HL" + sep + "DE" + case 0x5B: + result = "LD DE" + sep + "(" + d.getW() + ")" + case 0x5E, 0x7E: + result = "IM 2" + case 0x5F: + result = "LD A" + sep + "R" + case 0x60: + result = "IN H" + sep + "(C)" + case 0x61: + result = "OUT (C)" + sep + "H" + case 0x62: + result = "SBC HL" + sep + "HL" + case 0x63: + result = "LD (nn)" + sep + "HL" + case 0x67: + result = "RRD" + case 0x68: + result = "IN L" + sep + " (C)" + case 0x69: + result = "OUT (C)" + sep + "L" + case 0x6A: + result = "ADC HL" + sep + " HL" + case 0x6B: + result = "LD HL" + sep + " (nn)" + case 0x6F: + result = "RLD" + case 0x70: + result = "INF" + case 0x71: + result = "OUT (C)" + sep + " 0" + case 0x72: + result = "SBC HL" + sep + "SP" + case 0x73: + result = "LD (nn)" + sep + "SP" + case 0x78: + result = "IN A" + sep + "(C)" + case 0x79: + result = "OUT (C)" + sep + "A" + case 0x7A: + result = "ADC HL" + sep + "SP" + case 0x7B: + result = "LD SP" + sep + "(" + d.getW() + ")" + case 0xA0: + result = "LDI" + case 0xA1: + result = "CPI" + case 0xA2: + result = "INI" + case 0xA3: + result = "OUTI" + case 0xA8: + result = "LDD" + case 0xA9: + result = "CPD" + case 0xAA: + result = "IND" + case 0xAB: + result = "OUTD" + case 0xB0: + result = "LDIR" + case 0xB1: + result = "CPIR" + case 0xB2: + result = "INIR" + case 0xB3: + result = "OTIR" + case 0xB8: + result = "LDDR" + case 0xB9: + result = "CPDR" + case 0xBA: + result = "INDR" + case 0xBB: + result = "OTDR" + default: + result = fmt.Sprintf("DB 0xED, 0x%02X", opcode) + } + return result +} diff --git a/z80/dis/z80disasm_test.go b/z80/dis/z80disasm_test.go new file mode 100644 index 0000000..a904da5 --- /dev/null +++ b/z80/dis/z80disasm_test.go @@ -0,0 +1,91 @@ +package dis + +import "testing" + +var disasm *Disassembler + +type TestComp struct { + memory [65536]byte +} + +func (t *TestComp) M1MemRead(addr uint16) byte { + return t.memory[addr] +} +func (t *TestComp) MemRead(addr uint16) byte { + return t.memory[addr] +} + +func (t *TestComp) MemWrite(addr uint16, val byte) { + t.memory[addr] = val +} + +func (t *TestComp) IORead(port uint16) byte { + return byte(port >> 8) +} + +func (t *TestComp) IOWrite(port uint16, val byte) { + // +} + +var testComp *TestComp + +func init() { + testComp = &TestComp{} + for i := 0; i < 65536; i++ { + testComp.memory[i] = 0x3f + } + disasm = NewDisassembler(testComp) +} + +func setMemory(addr uint16, value []byte) { + for i := 0; i < len(value); i++ { + testComp.memory[addr+uint16(i)] = value[i] + } +} + +var test = []byte{0x31, 0x2c, 0x05, 0x11, 0x0e, 0x01, 0x0e, 0x09, 0xcd, 0x05, 0x00, 0xc3, 0x00, 0x00} + +func Test_LD_SP_nn(t *testing.T) { + expected := " 0100 LD SP, 0x052C" + setMemory(0x100, test) + res := disasm.Disassm(0x100) + if res != expected { + t.Errorf("Error disasm LD SP, nn, result '%s', expected '%s'", res, expected) + } +} + +func Test_LD_DE_nn(t *testing.T) { + expected := " 0103 LD DE, 0x010E" + setMemory(0x100, test) + res := disasm.Disassm(0x103) + if res != expected { + t.Errorf("Error disasm LD DE, nn, result '%s', expected '%s'", res, expected) + } +} + +func Test_LD_C_n(t *testing.T) { + expected := " 0106 LD C, 0x09" + setMemory(0x100, test) + res := disasm.Disassm(0x106) + if res != expected { + t.Errorf("Error disasm LD C, n, result '%s', expected '%s'", res, expected) + } +} + +func Test_CALL_nn(t *testing.T) { + expected := " 0108 CALL 0x0005" + setMemory(0x100, test) + res := disasm.Disassm(0x108) + if res != expected { + t.Errorf("Error disasm CALL nn, result '%s', expected '%s'", res, expected) + } +} + +func Test_JP_nn(t *testing.T) { + expected := " 010B JP 0x0000" + setMemory(0x100, test) + res := disasm.Disassm(0x10b) + if res != expected { + t.Errorf("Error disasm JP nn, result '%s', expected '%s'", res, expected) + } +}