Impl MEM BP, StateSave, BP with Expr, Toolbar

This commit is contained in:
Роман Бойков 2026-03-19 00:20:41 +03:00
parent b200f269a0
commit 415664fc8d
40 changed files with 7845 additions and 419 deletions

120
appWindow.go Normal file
View File

@ -0,0 +1,120 @@
package main
import (
"fmt"
"image/color"
"okemu/okean240"
"okemu/okean240/fdc"
"fyne.io/fyne/v2"
"fyne.io/fyne/v2/app"
"fyne.io/fyne/v2/canvas"
"fyne.io/fyne/v2/container"
"fyne.io/fyne/v2/dialog"
"fyne.io/fyne/v2/driver/desktop"
"fyne.io/fyne/v2/layout"
"fyne.io/fyne/v2/theme"
"fyne.io/fyne/v2/widget"
)
func mainWindow(computer *okean240.ComputerType) (*fyne.Window, *canvas.Raster, *widget.Label) {
emulatorApp := app.New()
w := emulatorApp.NewWindow("Океан 240.2")
w.Canvas().SetOnTypedKey(
func(key *fyne.KeyEvent) {
computer.PutKey(key)
})
w.Canvas().SetOnTypedRune(
func(key rune) {
computer.PutRune(key)
})
addShortcuts(w.Canvas(), computer)
label := widget.NewLabel(fmt.Sprintf("Screen size: %dx%d", computer.ScreenWidth(), computer.ScreenHeight()))
raster := canvas.NewRasterWithPixels(
func(x, y, w, h int) color.Color {
var xx uint16
if computer.ScreenWidth() == 512 {
xx = uint16(x)
} else {
xx = uint16(x) / 2
}
return computer.GetPixel(xx, uint16(y/2))
})
raster.Resize(fyne.NewSize(512, 512))
raster.SetMinSize(fyne.NewSize(512, 512))
centerRaster := container.NewCenter(raster)
w.Resize(fyne.NewSize(600, 600))
vBox := container.NewVBox(
newToolbar(computer, w),
centerRaster,
label,
)
w.SetContent(vBox)
return &w, raster, label
}
func newToolbar(c *okean240.ComputerType, w fyne.Window) *fyne.Container {
hBox := container.NewHBox()
for d := 0; d < fdc.TotalDrives; d++ {
hBox.Add(widget.NewLabel(string(rune(66+d)) + ":"))
hBox.Add(widget.NewToolbar(
widget.NewToolbarAction(theme.DocumentSaveIcon(), func() {
err := c.SaveFloppy(fdc.FloppyB)
if err != nil {
dialog.ShowError(err, w)
}
}),
//widget.NewToolbarSpacer(),
widget.NewToolbarAction(theme.FolderOpenIcon(), func() {
err := c.SaveFloppy(fdc.FloppyC)
if err != nil {
dialog.ShowError(err, w)
}
}),
))
}
hBox.Add(widget.NewSeparator())
hBox.Add(widget.NewButtonWithIcon("Ctrl+C", theme.LogoutIcon(), func() {
c.PutCtrlKey(0x03)
}))
hBox.Add(widget.NewSeparator())
bNorm := widget.NewButtonWithIcon("", theme.MediaPlayIcon(), func() {
fullSpeed.Store(false)
c.SetCPUFrequency(2_500_000)
//bNorm.Disable()
})
bFast := widget.NewButtonWithIcon("", theme.MediaFastForwardIcon(), func() {
fullSpeed.Store(true)
c.SetCPUFrequency(50_000_000)
bNorm.Enable()
//bFast.Disable()
})
hBox.Add(bNorm)
hBox.Add(bFast)
hBox.Add(layout.NewSpacer())
hBox.Add(widget.NewButtonWithIcon("Reset", theme.MediaReplayIcon(), func() {
needReset = true
//computer.Reset(conf)
}))
return hBox
}
// Add shortcuts for all Ctrl+<Letter>
func addShortcuts(c fyne.Canvas, computer *okean240.ComputerType) {
// Add shortcuts for Ctrl+A to Ctrl+Z
for kName := 'A'; kName <= 'Z'; kName++ {
kk := fyne.KeyName(kName)
sc := &desktop.CustomShortcut{KeyName: kk, Modifier: fyne.KeyModifierControl}
c.AddShortcut(sc, func(shortcut fyne.Shortcut) { computer.PutCtrlKey(byte(kName&0xff) - 0x40) })
}
}

View File

@ -12,14 +12,25 @@ const defaultCPMFile = "rom/CPM_v5.bin"
const DefaultDebufPort = 10000
type OkEmuConfig struct {
LogFile string `yaml:"logFile"`
LogLevel string `yaml:"logLevel"`
MonitorFile string `yaml:"monitorFile"`
CPMFile string `yaml:"cpmFile"`
FloppyB string `yaml:"floppyB"`
FloppyC string `yaml:"floppyC"`
Host string `yaml:"host"`
Port int `yaml:"port"`
LogFile string `yaml:"logFile"`
LogLevel string `yaml:"logLevel"`
MonitorFile string `yaml:"monitorFile"`
CPMFile string `yaml:"cpmFile"`
FDC FDCConfig `yaml:"fdc"`
Debugger DebuggerConfig `yaml:"debugger"`
}
type FDCConfig struct {
AutoLoadB bool `yaml:"autoLoadB"`
AutoLoadC bool `yaml:"autoLoadC"`
FloppyB string `yaml:"floppyB"`
FloppyC string `yaml:"floppyC"`
}
type DebuggerConfig struct {
Enabled bool `yaml:"enabled"`
Host string `yaml:"host"`
Port int `yaml:"port"`
}
var config *OkEmuConfig
@ -62,8 +73,8 @@ func LoadConfig() {
}
func checkConfig(conf *OkEmuConfig) {
if conf.Host == "" {
conf.Host = "localhost"
if conf.Debugger.Host == "" {
conf.Debugger.Host = "localhost"
}
}
@ -80,8 +91,8 @@ func setDefaultConf(conf *OkEmuConfig) {
if conf.CPMFile == "" {
conf.CPMFile = defaultCPMFile
}
if conf.Port < 80 || conf.Port > 65535 {
log.Infof("Port %d incorrect, using default: %d", conf.Port, DefaultDebufPort)
conf.Port = DefaultDebufPort
if conf.Debugger.Port < 80 || conf.Debugger.Port > 65535 {
log.Infof("Port %d incorrect, using default: %d", conf.Debugger.Port, DefaultDebufPort)
conf.Debugger.Port = DefaultDebufPort
}
}

View File

@ -0,0 +1,246 @@
package breakpoint
import (
"context"
"fmt"
"okemu/gval"
"regexp"
"strconv"
"strings"
log "github.com/sirupsen/logrus"
)
const MaxBreakpoints = 256
const (
BPTypeSimplePC = iota // Simple PC=nn breakpoint
BPTypeSimpleSP // Simple SP>=nn breakpoint
BPTypeExpression // Complex expression breakpoint
)
type Breakpoint struct {
addr uint16
cond string
eval gval.Evaluable
bpType int
pass uint16
passCount uint16
enabled bool
}
var andMatch = regexp.MustCompile(`\s+AND\s+`)
var orMatch = regexp.MustCompile(`\s+OR\s+`)
// var xorMatch = regexp.MustCompile(`\s+XOR\s+`)
var hexHMatch = regexp.MustCompile(`[[:xdigit:]]+H`)
var eqMatch = regexp.MustCompile(`[^=><]=[^=]`)
func patchExpression(expr string) string {
ex := strings.ToUpper(expr)
ex = andMatch.ReplaceAllString(ex, " && ")
ex = orMatch.ReplaceAllString(ex, " || ")
// ex = xorMatch.ReplaceAllString(ex, " ^ ")
for {
pos := hexHMatch.FindStringIndex(ex)
if pos != nil && len(pos) == 2 {
hex := "0x" + ex[pos[0]:pos[1]-1]
ex = ex[:pos[0]] + hex + ex[pos[1]:]
} else {
break
}
}
for {
pos := eqMatch.FindStringIndex(ex)
if pos != nil && len(pos) == 2 {
ex = ex[:pos[0]+1] + "==" + ex[pos[1]-1:]
} else {
break
}
}
return ex
}
var pcMatch = regexp.MustCompile(`^PC=[[:xdigit:]]+h$`)
var spMatch = regexp.MustCompile(`^SP>=[[:xdigit:]]+$`)
func getSecondUint16(param string, sep string) (uint16, error) {
p := strings.Split(param, sep)
v := p[1]
base := 0
if strings.HasSuffix(v, "h") || strings.HasSuffix(v, "H") {
v = strings.TrimSuffix(v, "H")
v = strings.TrimSuffix(v, "h")
base = 16
}
a, e := strconv.ParseUint(v, base, 16)
if e != nil {
return 0, e
}
return uint16(a), nil
}
func NewBreakpoint(expr string) (*Breakpoint, error) {
bp := Breakpoint{
addr: 0,
enabled: false,
passCount: 0,
pass: 0,
bpType: BPTypeSimplePC,
}
// Check if BP is simple PC=addr
expr = strings.TrimSpace(expr)
bp.cond = expr
pcMatched := pcMatch.MatchString(expr)
spMatched := spMatch.MatchString(expr)
if pcMatched {
// PC=xxxxh
bp.bpType = BPTypeSimplePC
v, e := getSecondUint16(expr, "=")
if e != nil {
return nil, e
}
bp.addr = v
} else if spMatched {
// SP>=xxxx
bp.bpType = BPTypeSimpleSP
v, e := getSecondUint16(expr, "=")
if e != nil {
return nil, e
}
bp.addr = v
} else {
// complex expression
bp.bpType = BPTypeExpression
ex := patchExpression(expr)
log.Debugf("Original Expression: '%s'", expr)
log.Debugf(" Patched Expression: '%s'", ex)
err := bp.SetExpression(ex)
if err != nil {
return nil, err
}
}
return &bp, nil
}
func (b *Breakpoint) Enabled() bool {
return b.enabled
}
func (b *Breakpoint) SetEnabled(enabled bool) {
b.enabled = enabled
}
func (b *Breakpoint) PassCount() uint16 {
return b.passCount
}
func (b *Breakpoint) SetPassCount(passCount uint16) {
b.passCount = passCount
}
func (b *Breakpoint) Pass() uint16 {
return b.pass
}
func (b *Breakpoint) SetPass(pass uint16) {
b.pass = pass
}
func (b *Breakpoint) IncPass() {
b.pass++
}
func (b *Breakpoint) Addr() uint16 {
return b.addr
}
func (b *Breakpoint) SetAddr(addr uint16) {
b.addr = addr
}
func (b *Breakpoint) Type() int {
return b.bpType
}
func (b *Breakpoint) SetType(bpType int) {
b.bpType = bpType
}
func getUint16(name string, ctx map[string]interface{}) uint16 {
if v, ok := ctx[name]; ok {
if v == nil {
return 0
}
// most frequent case
if v, ok := v.(uint16); ok {
return v
}
// for less frequent cases
switch value := v.(type) {
case int:
return uint16(value)
case int8:
return uint16(value)
case int16:
return uint16(value)
case int32:
return uint16(value)
case int64:
return uint16(value)
case uint:
return uint16(value)
case uint8:
return uint16(value)
case uint32:
return uint16(value)
case uint64:
return uint16(value)
default:
log.Errorf("Unknown type %v for variable %s", value, name)
return 0
}
} else {
log.Errorf("Variable %s not found in context!", name)
}
return 0
}
func (b *Breakpoint) Hit(ctx map[string]interface{}) bool {
if !b.enabled {
return false
}
if b.bpType == BPTypeSimplePC {
pc := getUint16("PC", ctx)
if pc == b.addr {
log.Debugf("Breakpoint Hit PC=%04X", b.addr)
}
return pc == b.addr
} else if b.bpType == BPTypeSimpleSP {
sp := getUint16("SP", ctx)
if sp >= b.addr {
log.Debugf("Breakpoint Hit SP>=%04X", b.addr)
}
return sp >= b.addr
}
value, err := b.eval.EvalBool(context.Background(), ctx)
if err != nil {
fmt.Println(err)
}
return value
}
var language gval.Language
func init() {
language = gval.NewLanguage(gval.Base(), gval.Arithmetic(), gval.Bitmask(), gval.PropositionalLogic())
}
func (b *Breakpoint) SetExpression(expression string) error {
var err error
b.eval, err = language.NewEvaluable(expression)
if err != nil {
log.Error("Illegal expression", err)
return err
}
b.bpType = BPTypeExpression
return nil
}
func (b *Breakpoint) Expression() string {
return b.cond
}

View File

@ -0,0 +1,76 @@
package breakpoint
import (
"regexp"
"testing"
)
const expr1 = "PC=00100h && SP>=256"
const expr2 = "PC=00115h"
const expr3 = "SP>=1332"
var ctx = map[string]interface{}{
"PC": 0x100,
"A": 0x55,
"SP": 0x200,
}
const exprRep = "PC=00115h and (B=5 or BC = 5)"
const exprDst = "PC==0x00115 && (B==5 || BC == 5)"
func Test_PatchExpression(t *testing.T) {
ex := patchExpression(exprRep)
if ex != exprDst {
t.Errorf("Patched expression does not match\n got: %s\nexpected: %s", ex, exprDst)
}
}
func Test_ComplexExpr(t *testing.T) {
b, e := NewBreakpoint(expr1)
//e := b.SetExpression(exp1)
if e != nil {
t.Error(e)
} else if b != nil {
b.enabled = true
if !b.Hit(ctx) {
t.Errorf("Breakpoint not hit")
}
}
}
const expSimplePC = "PC=00119h"
func Test_BPSetPC(t *testing.T) {
b, e := NewBreakpoint(expSimplePC)
if e != nil {
t.Error(e)
} else if b != nil {
if b.bpType != BPTypeSimplePC {
t.Errorf("Breakpoint type does not match BPTypeSimplePC")
}
b.enabled = true
if b.Hit(ctx) {
t.Errorf("Breakpoint hit but will not!")
}
}
}
func Test_MatchSP(t *testing.T) {
pcMatch := regexp.MustCompile(`SP>=[[:xdigit:]]+$`)
matched := pcMatch.MatchString(expr3)
if !matched {
t.Errorf("SP>=XXXXh not matched")
}
}
func Test_GetCtx(t *testing.T) {
pc := getUint16("PC", ctx)
if pc != 0x100 {
t.Errorf("PC value not found in context")
}
}

225
debug/debuger.go Normal file
View File

@ -0,0 +1,225 @@
package debug
import (
"okemu/debug/breakpoint"
"okemu/z80"
"okemu/z80/dis"
log "github.com/sirupsen/logrus"
)
const BPMemAccess = 65535
type Debugger struct {
stepMode bool
doStep bool
runMode bool
runInst uint64
breakpointsEnabled bool
breakpoints map[uint16]*breakpoint.Breakpoint
cpuFrequency uint32
disassembler *dis.Disassembler
cpuHistoryEnabled bool
cpuHistoryStarted bool
cpuHistoryMaxSize int
cpuHistory []*z80.CPU
memBreakpoints [65536]byte
}
func NewDebugger() *Debugger {
d := Debugger{
stepMode: false,
doStep: false,
runMode: false,
runInst: 0,
breakpointsEnabled: false,
breakpoints: map[uint16]*breakpoint.Breakpoint{},
cpuHistoryEnabled: false,
cpuHistoryStarted: false,
cpuHistoryMaxSize: 0,
cpuHistory: []*z80.CPU{},
}
return &d
}
func (d *Debugger) SetStepMode(step bool) {
d.SetRunMode(false)
d.stepMode = step
}
func (d *Debugger) SetRunMode(run bool) {
if run {
d.runInst = 0
}
d.runMode = run
}
func (d *Debugger) RunMode() bool {
return d.runMode
}
func (d *Debugger) DoStep() bool {
if d.doStep {
d.doStep = false
return true
}
return false
}
func (d *Debugger) SetCpuHistoryEnabled(enable bool) {
d.cpuHistoryEnabled = enable
}
func (d *Debugger) SetCpuHistoryMaxSize(size int) {
if size < 0 || size > 1_000_000 {
log.Error("CPU history max size must be positive and up to 1M")
} else {
d.cpuHistoryMaxSize = size
}
}
func (d *Debugger) CpuHistoryClear() {
d.cpuHistory = make([]*z80.CPU, 0)
}
func (d *Debugger) CpuHistorySize() int {
return len(d.cpuHistory)
}
func (d *Debugger) CpuHistory(index int) *z80.CPU {
if index >= 0 && index < len(d.cpuHistory) {
return d.cpuHistory[index]
}
if len(d.cpuHistory) > 0 {
log.Warnf("CPU history index %d out of range [0:%d]", index, len(d.cpuHistory)-1)
} else {
log.Warn("CPU history is empty")
}
return nil
}
func (d *Debugger) SetCpuHistoryStarted(started bool) {
d.cpuHistoryStarted = started
}
func (d *Debugger) SaveHistory(state *z80.CPU) {
if d.cpuHistoryEnabled && d.cpuHistoryMaxSize > 0 && d.cpuHistoryStarted {
d.cpuHistory = append([]*z80.CPU{state}, d.cpuHistory...)
if len(d.cpuHistory) > d.cpuHistoryMaxSize {
d.cpuHistory = d.cpuHistory[0 : d.cpuHistoryMaxSize-1]
}
}
}
func (d *Debugger) CheckBreakpoints(ctx map[string]interface{}) (bool, uint16) {
if d.breakpointsEnabled && d.runMode {
for n, bp := range d.breakpoints {
if bp != nil && bp.Hit(ctx) {
// breakpoint hit
if bp.Pass() >= bp.PassCount() {
bp.SetPass(0)
d.runMode = false
return true, n
}
// increment breakpoint pass count
bp.IncPass()
}
}
}
return false, 0
}
func (d *Debugger) SetBreakpointsEnabled(enabled bool) {
d.breakpointsEnabled = enabled
}
func (d *Debugger) BreakpointsEnabled() bool {
return d.breakpointsEnabled
}
// SetBreakpoint Create new breakpoint with specified number
func (d *Debugger) SetBreakpoint(number uint16, exp string) error {
var err error
bp, err := breakpoint.NewBreakpoint(exp)
if err == nil && bp != nil {
d.breakpoints[number] = bp
}
return err
}
func (d *Debugger) SetBreakpointPassCount(number uint16, count uint16) {
bp, ok := d.breakpoints[number]
if ok && bp != nil {
bp.SetPass(0)
bp.SetPassCount(count)
}
}
func (d *Debugger) SetBreakpointEnabled(number uint16, enabled bool) {
bp, ok := d.breakpoints[number]
if ok && bp != nil {
bp.SetEnabled(enabled)
}
}
func (d *Debugger) BreakpointEnabled(number uint16) bool {
bp, ok := d.breakpoints[number]
if ok && bp != nil {
return bp.Enabled()
}
return false
}
func (d *Debugger) ClearMemBreakpoints() {
for c := 0; c < 65536; c++ {
d.memBreakpoints[c] = 0
}
}
func (d *Debugger) StepMode() bool {
return d.stepMode
}
func (d *Debugger) SetDoStep(on bool) {
d.doStep = on
}
// BPExpression Return requested breakpoint
func (d *Debugger) BPExpression(number uint16) string {
bp, ok := d.breakpoints[number]
if ok && bp != nil {
return bp.Expression()
}
return ""
}
// RunInst return and increment count of instructions executed
func (d *Debugger) RunInst() uint64 {
v := d.runInst
d.runInst++
return v
}
func (d *Debugger) SetMemBreakpoint(address uint16, typ byte, size uint16) {
var offset uint16
for offset = address; offset < address+size; offset++ {
d.memBreakpoints[offset] = typ
}
}
func (d *Debugger) CheckMemBreakpoints(accessMap *map[uint16]byte) (bool, uint16, byte) {
if !d.breakpointsEnabled {
return false, 0, 0
}
for addr, typ := range *accessMap {
bp := d.memBreakpoints[addr]
if bp == 0 {
return false, addr, 0
}
if (bp == 3) || bp == typ {
d.SetRunMode(false)
return true, addr, typ
}
}
return false, 0, 0
}

149
debug/evaluate.go Normal file
View File

@ -0,0 +1,149 @@
package debug
import (
"errors"
"strconv"
"strings"
)
// Operators with their priority
var operators = map[string]int{
"(": 12, ")": 12,
"*": 11, "/": 11, "%": 11,
"+": 10, "-": 10,
"<<": 9, ">>": 9,
"<": 8, "<=": 8, ">": 8, ">=": 8,
"=": 7, "!=": 7,
"&": 6,
"^": 5,
"|": 4,
"&&": 3,
"||": 2,
}
var variables = map[string]bool{
"A": true, "B": true, "C": true, "D": true, "E": true, "F": true, "H": true, "L": true, "I": true,
"R": true, "SF": true, "NF": true, "PF": true, "VF": true, "XF": true, "YF": true, "ZF": true,
"A'": true, "B'": true, "C'": true, "D'": true, "E'": true, "F'": true, "H'": true, "L'": true,
"AF": true, "BC": true, "DE": true, "HL": true, "IX": true, "IY": true, "PC": true, "SP": true,
}
const (
OTUnknown = iota
OTValue
OTVariable
OTOperation
)
type Token struct {
name string
val uint16
ot int
}
type Expression struct {
infixExp string
inStack []Token
outStack []Token
}
func NewExpression(infixExp string) *Expression {
return &Expression{infixExp, make([]Token, 0), make([]Token, 0)}
}
func (e *Expression) Parse() error {
e.infixExp = strings.ToUpper(strings.TrimSpace(e.infixExp))
if e.infixExp == "" {
return errors.New("no Expression")
}
ptr := 0
for ptr < len(e.infixExp) {
token, err := getNextToken(e.infixExp[ptr:])
if err != nil {
return err
}
err = validate(token)
if err != nil {
return err
}
err = e.parseToken(token)
if err != nil {
return err
}
ptr += len(token.name)
}
return nil
}
func (e *Expression) parseToken(token Token) error {
return nil
}
func validate(token Token) error {
switch token.ot {
case OTValue:
v, err := strconv.ParseUint(token.name, 0, 16)
if err != nil {
return err
}
token.val = uint16(v)
case OTVariable:
if !variables[token.name] {
return errors.New("unknown variable")
}
case OTOperation:
v, ok := operators[token.name]
if !ok {
return errors.New("unknown operation")
}
token.val = uint16(v)
default:
return errors.New("unknown token")
}
return nil
}
const operations = "*/%+_<=>!&^|"
func getNextToken(str string) (Token, error) {
ptr := 0
exp := ""
ot := OTUnknown
for ptr < len(str) {
ch := str[ptr]
if ch == ' ' {
if ot == OTUnknown {
ptr++
continue
} else {
// end of token
return Token{name: exp, ot: ot}, nil
}
}
if (ch == 'X' || ch == 'O' || ch == 'B' || ch == 'H') && ot != OTValue {
exp += string(ch)
ptr++
continue
}
if ch >= '0' && ch <= '9' {
if len(exp) == 0 {
ot = OTValue
}
exp += string(ch)
ptr++
continue
}
if strings.Contains(operations, string(ch)) {
if len(exp) == 0 {
ot = OTOperation
}
exp += string(ch)
ptr++
continue
}
return Token{name: exp, ot: ot}, errors.New("invalid token")
}
return Token{name: exp, ot: ot}, nil
}

View File

@ -0,0 +1,14 @@
package listener
const welcomeMessage = "Welcome to Ocean-240.2 remote command protocol (ZRCP partial implementation)\nWrite help for available commands\n\ncommand> "
const emptyResponse = "\ncommand> "
const aboutResponse = "ZEsarUX remote command protocol"
const getVersionResponse = "12.1"
const getRegistersResponse = "PC=%04x SP=%04x AF=%04x BC=%04x HL=%04x DE=%04x IX=%04x IY=%04x AF'=%04x BC'=%04x HL'=%04x DE'=%04x I=%02x R=%02x F=%s F'=%s MEMPTR=%04x IM0 IFF%s VPS: 0 MMU=00000000000000000000000000000000"
const getStateResponse = "PC=%04x SP=%04x AF=%04x BC=%04x HL=%04x DE=%04x IX=%04x IY=%04x AF'=%04x BC'=%04x HL'=%04x DE'=%04x I=%02x R=%02x IM0 IFF%s (PC)=%s (SP)=%s MMU=00000000000000000000000000000000"
const inCpuStepResponse = "\ncommand@cpu-step> "
const getMachineResponse = "64K RAM, no ZX\n"
const respErrorLoading = "ERROR loading file"
const quitResponse = "Sayonara baby\n"
const runUntilBPMessage = "Running until a breakpoint, key press or data sent, menu opening or other event\n"

View File

@ -1,4 +1,4 @@
package debuger
package listener
import (
"bufio"
@ -6,7 +6,11 @@ import (
"io"
"net"
"okemu/config"
"okemu/debug"
"okemu/debug/breakpoint"
"okemu/okean240"
"okemu/z80"
"okemu/z80/dis"
"os"
"strings"
//"okemu/logger"
@ -15,16 +19,6 @@ import (
log "github.com/sirupsen/logrus"
)
const welcomeMessage = "Welcome to ZEsarUX remote command protocol (ZRCP)\nWrite help for available commands\n\ncommand> "
const emptyResponse = "\ncommand> "
const aboutResponse = "ZEsarUX remote command protocol"
const getVersionResponse = "12.1"
const getRegistersResponse = "PC=%04x SP=%04x AF=%04x BC=%04x HL=%04x DE=%04x IX=%04x IY=%04x AF'=%04x BC'=%04x HL'=%04x DE'=%04x I=%02x R=%02x F=%s F'=%s MEMPTR=%04x IM0 IFF%s VPS: 0 MMU=00000000000000000000000000000000"
const inCpuStepResponse = "\ncommand@cpu-step> "
const getMachineResponse = "64K RAM, no ZX\n"
const respErrorLoading = "ERROR loading file"
const quitResponse = "Sayonara baby\n"
// Receive messages, split to strings and parse
func handleConnection(c net.Conn) {
reader := bufio.NewReader(c)
@ -40,6 +34,7 @@ func handleConnection(c net.Conn) {
break
} else {
log.Errorf("TCP error: %v", err)
debugger.SetStepMode(false)
return
}
}
@ -50,8 +45,8 @@ func handleConnection(c net.Conn) {
}
//byteBuffer.WriteByte(b)
}
debugger.SetStepMode(false)
activeWriter = nil
//log.Trace("TCP Connection closed")
err := c.Close()
if err != nil {
log.Warnf("Can not close socket: %v", err)
@ -69,7 +64,7 @@ func writeWelcomeMessage(writer *bufio.Writer) bool {
func writeResponseMessage(writer *bufio.Writer, message string) bool {
prompt := emptyResponse
if computer.IsStepMode() {
if debugger.StepMode() {
prompt = inCpuStepResponse
}
@ -86,13 +81,33 @@ func writeResponseMessage(writer *bufio.Writer, message string) bool {
return true
}
func writeMessage(writer *bufio.Writer, message string) bool {
_, err := writer.WriteString(message)
if err != nil {
log.Errorf("TCP error: %v", err)
return false
}
err = writer.Flush()
if err != nil {
log.Errorf("TCP error: %v", err)
return false
}
return true
}
// var
var debugger *debug.Debugger
var disassembler *dis.Disassembler
var computer *okean240.ComputerType
// SetupTcpHandler Setup TCP listener, handle connections
func SetupTcpHandler(config *config.OkEmuConfig, comp *okean240.ComputerType) {
port := config.Host + ":" + strconv.Itoa(config.Port)
func SetupTcpHandler(config *config.OkEmuConfig, debug *debug.Debugger, disasm *dis.Disassembler, comp *okean240.ComputerType) {
port := config.Debugger.Host + ":" + strconv.Itoa(config.Debugger.Port)
debugger = debug
disassembler = disasm
computer = comp
log.Infof("Serve TCP connections on %s", port)
log.Infof("Ready for debugger connections on %s", port)
l, err := net.Listen("tcp4", port)
if err != nil {
@ -138,20 +153,19 @@ func HandleCommand(str string, writer *bufio.Writer) bool {
switch cmd {
case "cpu-step":
computer.Do()
writeResponseMessage(writer, " "+fmt.Sprintf("%04X", computer.GetCPUState().PC))
debugger.SetDoStep(true) // computer.Do()
text := disassembler.Disassm(computer.GetCPUState().PC)
writeResponseMessage(writer, registersResponse(computer.GetCPUState())+" TSTATES: "+strconv.Itoa(int(computer.TStatesPartial()))+"\n"+text)
case "run":
_, e := writer.WriteString("Running until a breakpoint, key press or data sent, menu opening or other event\n")
if e != nil {
log.Warnf("Error writing to buffer: %v", e)
}
e = writer.Flush()
if e != nil {
log.Warnf("Error flushing the buffer: %v", e)
}
computer.SetRunMode(true)
writeMessage(writer, runUntilBPMessage)
debugger.SetRunMode(true)
case "disassemble":
writeResponseMessage(writer, disassemble(params))
case "get-tstates-partial":
writeResponseMessage(writer, strconv.FormatUint(computer.Cycles(), 10))
writeResponseMessage(writer, strconv.FormatUint(computer.TStatesPartial(), 10))
case "reset-tstates-partial":
computer.ResetTStatesPartial()
writeResponseMessage(writer, "")
case "close-all-menus":
writeResponseMessage(writer, "")
case "about":
@ -159,17 +173,17 @@ func HandleCommand(str string, writer *bufio.Writer) bool {
case "get-version":
writeResponseMessage(writer, getVersionResponse)
case "get-registers":
writeResponseMessage(writer, registersResponse())
writeResponseMessage(writer, registersResponse(computer.GetCPUState()))
case "set-register":
writeResponseMessage(writer, setRegister(params))
case "hard-reset-cpu":
computer.Reset()
writeResponseMessage(writer, "")
case "enter-cpu-step":
computer.SetStepMode(true)
debugger.SetStepMode(true)
writeResponseMessage(writer, "")
case "exit-cpu-step":
computer.SetStepMode(false)
debugger.SetStepMode(false)
writeResponseMessage(writer, "")
case "set-debug-settings":
log.Debugf("Set debug settings to %s", params)
@ -177,13 +191,15 @@ func HandleCommand(str string, writer *bufio.Writer) bool {
case "get-current-machine":
writeResponseMessage(writer, getMachineResponse)
case "clear-membreakpoints":
computer.ClearMemBreakpoints()
debugger.ClearMemBreakpoints()
writeResponseMessage(writer, "")
case "set-membreakpoint": // addr type size
writeResponseMessage(writer, SetMemBreakpoint(params))
case "enable-breakpoints":
computer.SetBreakpointsEnabled(true)
debugger.SetBreakpointsEnabled(true)
writeResponseMessage(writer, "")
case "disable-breakpoints":
computer.SetBreakpointsEnabled(false)
debugger.SetBreakpointsEnabled(false)
writeResponseMessage(writer, "")
case "enable-breakpoint":
writeResponseMessage(writer, setBreakpointState(params, true))
@ -194,6 +210,9 @@ func HandleCommand(str string, writer *bufio.Writer) bool {
case "set-breakpoint":
// 1 PC=0010Bh
writeResponseMessage(writer, setBreakpoint(params))
case "set-breakpointpasscount":
setBreakpointPassCount(params)
writeResponseMessage(writer, "")
case "cpu-code-coverage":
//"enabled no"
writeResponseMessage(writer, "")
@ -204,7 +223,8 @@ func HandleCommand(str string, writer *bufio.Writer) bool {
// "started yes"
// "ignrephalt yes"
// "ignrepldxr yes"
writeResponseMessage(writer, "")
writeResponseMessage(writer, doCpuHistory(params))
case "extended-stack":
// "enabled no"
// "enabled yes"
@ -219,6 +239,11 @@ func HandleCommand(str string, writer *bufio.Writer) bool {
writeResponseMessage(writer, readMemory(params))
case "quit":
quit = true
case "snapshot-save":
writeResponseMessage(writer, snapshotSave(params))
case "set-breakpointaction":
// now do nothing
writeResponseMessage(writer, "")
default:
log.Debugf("Unhandled Command: %s", str)
writeResponseMessage(writer, "")
@ -226,6 +251,96 @@ func HandleCommand(str string, writer *bufio.Writer) bool {
return !quit
}
func convertToUint16(s string) (uint16, error) {
v := strings.TrimSpace(strings.ToUpper(s))
base := 0
if strings.HasSuffix(v, "h") || strings.HasSuffix(v, "H") {
v = strings.TrimSuffix(v, "H")
v = strings.TrimSuffix(v, "h")
base = 16
}
a, e := strconv.ParseUint(v, base, 16)
return uint16(a), e
}
func SetMemBreakpoint(param string) string {
param = strings.TrimSpace(param)
params := strings.Split(param, " ")
if len(params) < 1 {
return "error, not enough parameters"
}
address, err := convertToUint16(params[0])
if err != nil {
return "error, illegal address: '" + params[0] + "'"
}
t := uint16(3)
// if has type
if len(params) > 1 {
t, err = convertToUint16(params[1])
if err != nil || t > 3 {
return "error, illegal access type: '" + params[1] + "'"
}
}
s := uint16(1)
if len(params) > 2 {
s, err = convertToUint16(params[2])
if err != nil {
return "error, illegal memory size: '" + params[2] + "'"
}
}
if debugger != nil {
debugger.SetMemBreakpoint(address, byte(t), s)
}
return ""
}
func doCpuHistory(param string) string {
param = strings.TrimSpace(param)
params := strings.Split(param, " ")
if len(params) == 0 {
return "error"
}
cmd := params[0]
switch cmd {
case "enabled":
if len(params) != 2 {
return "error"
}
debugger.SetCpuHistoryEnabled(params[1] == "yes")
case "clear":
debugger.CpuHistoryClear()
case "started":
if len(params) != 2 {
return "error"
}
debugger.SetCpuHistoryStarted(params[1] == "yes")
case "set-max-size":
if len(params) != 2 {
return "error"
}
size, err := strconv.Atoi(params[1])
if err != nil {
return "error"
}
debugger.SetCpuHistoryMaxSize(size)
case "get":
if len(params) != 2 {
return "error"
}
index, err := strconv.Atoi(params[1])
if err != nil {
return "error"
}
history := debugger.CpuHistory(index)
if history != nil {
return stateResponse(history)
}
return "ERROR: index out of range"
}
return ""
}
func loadBinary(param string) string {
params := strings.Split(param, " ")
if len(params) < 2 {
@ -246,9 +361,11 @@ func loadBinary(param string) string {
return respErrorLoading
}
if len(params) > 2 {
length, e = strconv.Atoi(params[1])
l, e := strconv.ParseInt(params[2], 0, 32)
if e != nil {
length = 0
} else {
length = int(l)
}
}
data, err := os.ReadFile(fn)
@ -256,9 +373,10 @@ func loadBinary(param string) string {
log.Errorf("Error reading file: %v", err)
return respErrorLoading
}
if length != 0 && len(data) < length {
log.Errorf("File too short. Expected %d bytes, got %d", len(data), length)
return respErrorLoading
if length != 0 && len(data) != length {
log.Warnf("File size does not match the specified length. Expected %d bytes, got %d.", length, len(data))
//return respErrorLoading
length = len(data)
}
if length == 0 {
length = len(data)
@ -289,8 +407,8 @@ func iifStr(iif1, iif2 bool) string {
// registersResponse Build string
// PC=%4x SP=%4x AF=%4x BC=%4x HL=%4x DE=%4x IX=%4x IY=%4x AF'=%4x BC'=%4x HL'=%4x DE'=%4x I=%2x
// R=%2x F=%s F'=%s MEMPTR=%4x IM0 IFF-- VPS: 0 MMU=00000000000000000000000000000000
func registersResponse() string {
state := computer.GetCPUState()
func registersResponse(state *z80.CPU) string {
//state := computer.GetCPUState()
resp := fmt.Sprintf(getRegistersResponse,
state.PC,
state.SP,
@ -311,7 +429,43 @@ func registersResponse() string {
state.MemPtr,
iifStr(state.Iff1, state.Iff2),
)
log.Debug(resp)
log.Trace(resp)
return resp
}
func getNBytes(addr uint16, n uint16) string {
res := ""
for i := uint16(0); i < n; i++ {
b := computer.MemRead(addr + i)
res += fmt.Sprintf("%02X", b)
}
return res
}
// stateResponse build string, represent history state
// PC=003a SP=ff46 AF=005c BC=174b HL=107f DE=0006 IX=ffff IY=5c3a AF'=0044 BC'=ffff HL'=ffff DE'=5cb9 I=3f R=78
// IM0 IFF-- (PC)=2a785c23 (SP)=107f MMU=00000000000000000000000000000000
func stateResponse(state *z80.CPU) string {
resp := fmt.Sprintf(getStateResponse,
state.PC,
state.SP,
toW(state.A, state.Flags.GetFlags()),
toW(state.B, state.C),
toW(state.H, state.L),
toW(state.D, state.E),
state.IX,
state.IY,
toW(state.AAlt, state.FlagsAlt.GetFlags()),
toW(state.BAlt, state.CAlt),
toW(state.HAlt, state.LAlt),
toW(state.DAlt, state.EAlt),
state.I,
state.R,
iifStr(state.Iff1, state.Iff2),
getNBytes(state.PC, 4),
getNBytes(state.SP, 2),
)
log.Trace(resp)
return resp
}
@ -320,12 +474,12 @@ func setRegister(param string) string {
params := strings.Split(param, "=")
if len(params) != 2 {
log.Errorf("Invalid set register parameter: %s", param)
return registersResponse()
return "error"
}
val, e := strconv.Atoi(params[1])
if e != nil {
log.Errorf("Invalid set register parameter value: %s", params[1])
return registersResponse()
return "error"
}
switch params[0] {
case "SP":
@ -359,14 +513,14 @@ func setRegister(param string) string {
log.Errorf("Unsupported set register parameter: %s", param)
}
computer.SetCPUState(state)
return registersResponse()
return registersResponse(computer.GetCPUState())
}
func readMemory(param string) string {
params := strings.Split(param, " ")
if len(params) != 2 {
log.Errorf("Invalid read memory parameter: %s", param)
return registersResponse()
return "error" //registersResponse(computer.GetCPUState())
}
offset, e := strconv.Atoi(params[0])
if e != nil {
@ -381,7 +535,6 @@ func readMemory(param string) string {
for i := 0; i < size; i++ {
resp += fmt.Sprintf("%02X", computer.MemRead(uint16(offset)+uint16(i)))
}
log.Tracef("ReadMemory[%d,%d]:\n%s", offset, size, resp)
return resp
}
@ -411,7 +564,7 @@ func getExtendedStack(param string) string {
for i := sp; i > spEnd; i -= 2 {
resp += fmt.Sprintf("%04XH default\n", computer.MemRead(i))
}
log.Debugf("Stack[%d,%d]:\n%s", sp, size, resp)
//log.Debugf("Stack[%d,%d]:\n%s", sp, size, resp)
return resp
}
@ -421,49 +574,94 @@ func setBreakpointState(param string, enable bool) string {
log.Errorf("Invalid breakpoint parameter: %s", param)
return ""
}
if enable && !computer.IsBreakpointsEnabled() {
if enable && !debugger.BreakpointsEnabled() {
return "Error. You must enable breakpoints first"
}
computer.SetBreakpointEnabled(uint16(no), enable)
debugger.SetBreakpointEnabled(uint16(no), enable)
return ""
}
func setBreakpoint(param string) string {
// 1 PC=0010Bh
params := strings.Split(param, " ")
if len(params) != 2 {
if len(params) < 2 {
log.Errorf("Invalid set breakpoint parameters: %s", param)
return ""
return "Error, invalid parameters"
}
no, e := strconv.Atoi(params[0])
if e != nil || no > okean240.MaxBreakpoints || no < 1 {
no, e := strconv.ParseUint(params[0], 0, 16)
if e != nil || no > breakpoint.MaxBreakpoints || no < 1 {
log.Errorf("Invalid breakpoint number: %s", params[0])
return ""
return "Error, invalid breakpoint number"
}
regv := strings.Split(params[1], "=")
if len(regv) != 2 {
log.Errorf("Invalid breakpoint parameter: %s", params[1])
return ""
}
addr, e := strconv.ParseUint(strings.TrimSuffix(regv[1], "h"), 16, 32)
if e != nil || addr < 0 || addr >= 65535 {
log.Errorf("Invalid breakpoint address: %s", regv[1])
return ""
}
if regv[0] == "PC" {
computer.SetBreakpoint(uint16(no), uint16(addr))
} else {
log.Errorf("Unsupported BP: %s", params[1])
e = debugger.SetBreakpoint(uint16(no), param[len(params[0]):])
if e != nil {
return "Error: " + e.Error()
}
return ""
}
func BreakpointHit(no uint16) {
func typToString(typ uint8) string {
switch typ {
case 0:
return "D"
case 1:
return "R"
case 2:
return "W"
case 3:
return "R/W"
default:
return "x"
}
}
func BreakpointHit(number uint16, typ byte) {
if activeWriter != nil {
pc := computer.GetCPUState().PC
rep := fmt.Sprintf("Breakpoint fired: PC=%XH\n %04X NOP", pc, pc)
res := disassembler.Disassm(pc)
msg := ""
if typ == 0 {
msg = debugger.BPExpression(number)
} else {
msg = fmt.Sprintf("MEM[%04X] %s", number, typToString(typ))
}
rep := fmt.Sprintf("Breakpoint fired: %s\n%s", msg, res)
log.Debug(rep)
writeResponseMessage(activeWriter, rep)
}
}
func setBreakpointPassCount(param string) {
params := strings.Split(param, " ")
if len(params) != 2 {
log.Errorf("Set breakpoint passCount failed, expected 2 params, got %d", len(params))
}
bpNo, err := strconv.Atoi(params[0])
if err != nil || bpNo < 0 || bpNo > breakpoint.MaxBreakpoints {
log.Errorf("Invalid BP no.: %v", err)
}
passCount, err := strconv.Atoi(params[1])
if err != nil || passCount < 0 || passCount > 65535 {
log.Errorf("Invalid BP passCount: %v", err)
}
debugger.SetBreakpointPassCount(uint16(bpNo), uint16(passCount))
}
func disassemble(param string) string {
addr, e := strconv.ParseUint(param, 0, 16)
if e != nil {
log.Errorf("Invalid disassemble address: %s", param)
}
res := disassembler.Disassm(uint16(addr))
log.Debug(res)
return res
}
func snapshotSave(params string) string {
e := computer.SaveSnapshot(strings.TrimSpace(params))
if e != nil {
return fmt.Sprintf("Error saving snapshot: %s", e)
}
return ""
}

View File

@ -0,0 +1,17 @@
package listener
import (
"testing"
)
const exp1 = "CS=0x0100 & SP>=256"
const memBpSet = "11ch 3 1"
func Test_SetMemPB(t *testing.T) {
resp := SetMemBreakpoint(memBpSet)
if resp != "" {
t.Errorf("SetMemBreakpoint() returned %s", resp)
}
}

2
go.mod
View File

@ -4,6 +4,7 @@ go 1.25
require (
fyne.io/fyne/v2 v2.7.3
github.com/PaesslerAG/gval v1.2.4
github.com/howeyc/crc16 v0.0.0-20171223171357-2b2a61e366a6
github.com/sirupsen/logrus v1.9.4
gopkg.in/yaml.v3 v3.0.1
@ -33,6 +34,7 @@ require (
github.com/nicksnyder/go-i18n/v2 v2.5.1 // indirect
github.com/pmezard/go-difflib v1.0.0 // indirect
github.com/rymdport/portal v0.4.2 // indirect
github.com/shopspring/decimal v1.3.1 // indirect
github.com/srwiley/oksvg v0.0.0-20221011165216-be6e8873101c // indirect
github.com/srwiley/rasterx v0.0.0-20220730225603-2ab79fcdd4ef // indirect
github.com/stretchr/testify v1.11.1 // indirect

6
go.sum
View File

@ -4,6 +4,10 @@ fyne.io/systray v1.12.0 h1:CA1Kk0e2zwFlxtc02L3QFSiIbxJ/P0n582YrZHT7aTM=
fyne.io/systray v1.12.0/go.mod h1:RVwqP9nYMo7h5zViCBHri2FgjXF7H2cub7MAq4NSoLs=
github.com/BurntSushi/toml v1.5.0 h1:W5quZX/G/csjUnuI8SUYlsHs9M38FC7znL0lIO+DvMg=
github.com/BurntSushi/toml v1.5.0/go.mod h1:ukJfTF/6rtPPRCnwkur4qwRxa8vTRFBF0uk2lLoLwho=
github.com/PaesslerAG/gval v1.2.4 h1:rhX7MpjJlcxYwL2eTTYIOBUyEKZ+A96T9vQySWkVUiU=
github.com/PaesslerAG/gval v1.2.4/go.mod h1:XRFLwvmkTEdYziLdaCeCa5ImcGVrfQbeNUbVR+C6xac=
github.com/PaesslerAG/jsonpath v0.1.0 h1:gADYeifvlqK3R3i2cR5B4DGgxLXIPb3TRTH1mGi0jPI=
github.com/PaesslerAG/jsonpath v0.1.0/go.mod h1:4BzmtoM/PI8fPO4aQGIusjGxGir2BzcV0grWtFzq1Y8=
github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E=
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
@ -59,6 +63,8 @@ github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZb
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/rymdport/portal v0.4.2 h1:7jKRSemwlTyVHHrTGgQg7gmNPJs88xkbKcIL3NlcmSU=
github.com/rymdport/portal v0.4.2/go.mod h1:kFF4jslnJ8pD5uCi17brj/ODlfIidOxlgUDTO5ncnC4=
github.com/shopspring/decimal v1.3.1 h1:2Usl1nmF/WZucqkFZhnfFYxxxu8LG21F6nPQBE5gKV8=
github.com/shopspring/decimal v1.3.1/go.mod h1:DKyhrW/HYNuLGql+MJL6WCR6knT2jwCFRcu2hWCYk4o=
github.com/sirupsen/logrus v1.9.4 h1:TsZE7l11zFCLZnZ+teH4Umoq5BhEIfIzfRDZ1Uzql2w=
github.com/sirupsen/logrus v1.9.4/go.mod h1:ftWc9WdOfJ0a92nsE2jF5u5ZwH8Bv2zdeOC42RjbV2g=
github.com/srwiley/oksvg v0.0.0-20221011165216-be6e8873101c h1:km8GpoQut05eY3GiYWEedbTT0qnSxrCjsVbb7yKY1KE=

12
gval/LICENSE Normal file
View File

@ -0,0 +1,12 @@
Copyright (c) 2017, Paessler AG <support@paessler.com>
All rights reserved.
Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met:
1. Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer.
2. Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution.
3. Neither the name of the copyright holder nor the names of its contributors may be used to endorse or promote products derived from this software without specific prior written permission.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

150
gval/benchmarks_test.go Normal file
View File

@ -0,0 +1,150 @@
package gval
import (
"context"
"testing"
)
func BenchmarkGval(bench *testing.B) {
benchmarks := []evaluationTest{
{
// Serves as a "water test" to give an idea of the general overhead
name: "const",
expression: "1",
},
{
name: "single parameter",
expression: "requests_made",
parameter: map[string]interface{}{
"requests_made": 99.0,
},
},
{
name: "parameter",
expression: "requests_made > requests_succeeded",
parameter: map[string]interface{}{
"requests_made": 99.0,
"requests_succeeded": 90.0,
},
},
{
// The most common use case, a single variable, modified slightly, compared to a constant.
// This is the "expected" use case.
name: "common",
expression: "(requests_made * requests_succeeded / 100) >= 90",
parameter: map[string]interface{}{
"requests_made": 99.0,
"requests_succeeded": 90.0,
},
},
{
// All major possibilities in one expression.
name: "complex",
expression: `2 > 1 &&
"something" != "nothing" ||
date("2014-01-20") < date("Wed Jul 8 23:07:35 MDT 2015") &&
object["Variable name with spaces"] <= array[0] &&
modifierTest + 1000 / 2 > (80 * 100 % 2)`,
parameter: map[string]interface{}{
"object": map[string]interface{}{"Variable name with spaces": 10.},
"array": []interface{}{0.},
"modifierTest": 7.3,
},
},
{
// no variables, no modifiers
name: "literal",
expression: "(2) > (1)",
},
{
name: "modifier",
expression: "(2) + (2) == (4)",
},
{
// Benchmarks uncompiled parameter regex operators, which are the most expensive of the lot.
// Note that regex compilation times are unpredictable and wily things. The regex engine has a lot of edge cases
// and possible performance pitfalls. This test doesn't aim to be comprehensive against all possible regex scenarios,
// it is primarily concerned with tracking how much longer it takes to compile a regex at evaluation-time than during parse-time.
name: "regex",
expression: "(foo !~ bar) && (foo + bar =~ oba)",
parameter: map[string]interface{}{
"foo": "foo",
"bar": "bar",
"baz": "baz",
"oba": ".*oba.*",
},
},
{
// Benchmarks pre-compilable regex patterns. Meant to serve as a sanity check that constant strings used as regex patterns
// are actually being precompiled.
// Also demonstrates that (generally) compiling a regex at evaluation-time takes an order of magnitude more time than pre-compiling.
name: "constant regex",
expression: `(foo !~ "[bB]az") && (bar =~ "[bB]ar")`,
parameter: map[string]interface{}{
"foo": "foo",
"bar": "bar",
"baz": "baz",
"oba": ".*oba.*",
},
},
{
name: "accessors",
expression: "foo.Int",
parameter: fooFailureParameters,
},
{
name: "accessors method",
expression: "foo.Func()",
parameter: fooFailureParameters,
},
{
name: "accessors method parameter",
expression: `foo.FuncArgStr("bonk")`,
parameter: fooFailureParameters,
},
{
name: "nested accessors",
expression: `foo.Nested.Funk`,
parameter: fooFailureParameters,
},
{
name: "decimal arithmetic",
expression: "(requests_made * requests_succeeded / 100)",
extension: decimalArithmetic,
parameter: map[string]interface{}{
"requests_made": 99.0,
"requests_succeeded": 90.0,
},
},
{
name: "decimal logic",
expression: "(requests_made * requests_succeeded / 100) >= 90",
extension: decimalArithmetic,
parameter: map[string]interface{}{
"requests_made": 99.0,
"requests_succeeded": 90.0,
},
},
}
for _, benchmark := range benchmarks {
eval, err := Full().NewEvaluable(benchmark.expression)
if err != nil {
bench.Fatal(err)
}
_, err = eval(context.Background(), benchmark.parameter)
if err != nil {
bench.Fatal(err)
}
bench.Run(benchmark.name+"_evaluation", func(bench *testing.B) {
for i := 0; i < bench.N; i++ {
eval(context.Background(), benchmark.parameter)
}
})
bench.Run(benchmark.name+"_parsing", func(bench *testing.B) {
for i := 0; i < bench.N; i++ {
Full().NewEvaluable(benchmark.expression)
}
})
}
}

366
gval/evaluable.go Normal file
View File

@ -0,0 +1,366 @@
package gval
import (
"context"
"fmt"
"reflect"
"regexp"
"strconv"
)
// Selector allows for custom variable selection from structs
//
// Return value is again handled with variable() until end of the given path
type Selector interface {
SelectGVal(c context.Context, key string) (interface{}, error)
}
// Evaluable evaluates given parameter
type Evaluable func(c context.Context, parameter interface{}) (interface{}, error)
// EvalInt evaluates given parameter to an int
func (e Evaluable) EvalInt(c context.Context, parameter interface{}) (int, error) {
v, err := e(c, parameter)
if err != nil {
return 0, err
}
f, ok := convertToUint(v)
if !ok {
return 0, fmt.Errorf("expected number but got %v (%T)", v, v)
}
return int(f), nil
}
// EvalUint evaluates given parameter to a float64
func (e Evaluable) EvalUint(c context.Context, parameter interface{}) (uint, error) {
v, err := e(c, parameter)
if err != nil {
return 0, err
}
f, ok := convertToUint(v)
if !ok {
return 0, fmt.Errorf("expected number but got %v (%T)", v, v)
}
return f, nil
}
// EvalBool evaluates given parameter to a bool
func (e Evaluable) EvalBool(c context.Context, parameter interface{}) (bool, error) {
v, err := e(c, parameter)
if err != nil {
return false, err
}
b, ok := convertToBool(v)
if !ok {
return false, fmt.Errorf("expected bool but got %v (%T)", v, v)
}
return b, nil
}
// EvalString evaluates given parameter to a string
func (e Evaluable) EvalString(c context.Context, parameter interface{}) (string, error) {
o, err := e(c, parameter)
if err != nil {
return "", err
}
if s, ok := o.(string); ok {
return s, nil
}
return fmt.Sprintf("%v", o), nil
}
// Const Evaluable represents given constant
func (*Parser) Const(value interface{}) Evaluable {
return constant(value)
}
//go:noinline
func constant(value interface{}) Evaluable {
return func(c context.Context, v interface{}) (interface{}, error) {
return value, nil
}
}
// Var Evaluable represents value at given path.
// It supports with default language VariableSelector:
//
// map[interface{}]interface{},
// map[string]interface{} and
// []interface{} and via reflect
// struct fields,
// struct methods,
// slices and
// map with int or string key.
func (p *Parser) Var(path ...Evaluable) Evaluable {
if p.selector == nil {
return variable(path)
}
return p.selector(path)
}
// Evaluables is a slice of Evaluable.
type Evaluables []Evaluable
// EvalStrings evaluates given parameter to a string slice
func (evs Evaluables) EvalStrings(c context.Context, parameter interface{}) ([]string, error) {
strs := make([]string, len(evs))
for i, p := range evs {
k, err := p.EvalString(c, parameter)
if err != nil {
return nil, err
}
strs[i] = k
}
return strs, nil
}
func variable(path Evaluables) Evaluable {
return func(c context.Context, v interface{}) (interface{}, error) {
v2 := v
for _, p := range path {
k, err := p.EvalString(c, v)
if err != nil {
return nil, err
}
switch o := v2.(type) {
case Selector:
v2, err = o.SelectGVal(c, k)
if err != nil {
return nil, fmt.Errorf("failed to select '%s' on %T: %w", k, o, err)
}
continue
case map[interface{}]interface{}:
v2 = o[k]
continue
case map[string]interface{}:
v2 = o[k]
continue
case []interface{}:
if i, err := strconv.Atoi(k); err == nil && i >= 0 && len(o) > i {
v2 = o[i]
continue
}
default:
var ok bool
v2, ok = reflectSelect(k, o)
if !ok {
return nil, fmt.Errorf("unknown parameter '%s' on %T", k, o)
}
}
}
return v2, nil
}
}
func reflectSelect(key string, value interface{}) (selection interface{}, ok bool) {
vv := reflect.ValueOf(value)
vvElem := resolvePotentialPointer(vv)
switch vvElem.Kind() {
case reflect.Map:
mapKey, ok := reflectConvertTo(vv.Type().Key().Kind(), key)
if !ok {
return nil, false
}
vvElem = vv.MapIndex(reflect.ValueOf(mapKey))
vvElem = resolvePotentialPointer(vvElem)
if vvElem.IsValid() {
return vvElem.Interface(), true
}
// key didn't exist. Check if there is a bound method
method := vv.MethodByName(key)
if method.IsValid() {
return method.Interface(), true
}
case reflect.Slice:
if i, err := strconv.Atoi(key); err == nil && i >= 0 && vv.Len() > i {
vvElem = resolvePotentialPointer(vv.Index(i))
return vvElem.Interface(), true
}
// key not an int. Check if there is a bound method
method := vv.MethodByName(key)
if method.IsValid() {
return method.Interface(), true
}
case reflect.Struct:
field := vvElem.FieldByName(key)
if field.IsValid() {
return field.Interface(), true
}
method := vv.MethodByName(key)
if method.IsValid() {
return method.Interface(), true
}
}
return nil, false
}
func resolvePotentialPointer(value reflect.Value) reflect.Value {
if value.Kind() == reflect.Ptr {
return value.Elem()
}
return value
}
func reflectConvertTo(k reflect.Kind, value string) (interface{}, bool) {
switch k {
case reflect.String:
return value, true
case reflect.Int:
if i, err := strconv.Atoi(value); err == nil {
return i, true
}
}
return nil, false
}
func (*Parser) callFunc(fun function, args ...Evaluable) Evaluable {
return func(c context.Context, v interface{}) (ret interface{}, err error) {
a := make([]interface{}, len(args))
for i, arg := range args {
ai, err := arg(c, v)
if err != nil {
return nil, err
}
a[i] = ai
}
return fun(c, a...)
}
}
func (*Parser) callEvaluable(fullname string, fun Evaluable, args ...Evaluable) Evaluable {
return func(c context.Context, v interface{}) (ret interface{}, err error) {
f, err := fun(c, v)
if err != nil {
return nil, fmt.Errorf("could not call function: %w", err)
}
defer func() {
if r := recover(); r != nil {
err = fmt.Errorf("failed to execute function '%s': %s", fullname, r)
ret = nil
}
}()
ff := reflect.ValueOf(f)
if ff.Kind() != reflect.Func {
return nil, fmt.Errorf("could not call '%s' type %T", fullname, f)
}
a := make([]reflect.Value, len(args))
for i := range args {
arg, err := args[i](c, v)
if err != nil {
return nil, err
}
a[i] = reflect.ValueOf(arg)
}
rr := ff.Call(a)
r := make([]interface{}, len(rr))
for i, e := range rr {
r[i] = e.Interface()
}
errorInterface := reflect.TypeOf((*error)(nil)).Elem()
if len(r) > 0 && ff.Type().Out(len(r)-1).Implements(errorInterface) {
if r[len(r)-1] != nil {
err = r[len(r)-1].(error)
}
r = r[0 : len(r)-1]
}
switch len(r) {
case 0:
return err, nil
case 1:
return r[0], err
default:
return r, err
}
}
}
// IsConst returns if the Evaluable is a Parser.Const() value
func (e Evaluable) IsConst() bool {
pc := reflect.ValueOf(constant(nil)).Pointer()
pe := reflect.ValueOf(e).Pointer()
return pc == pe
}
func regEx(a, b Evaluable) (Evaluable, error) {
if !b.IsConst() {
return func(c context.Context, o interface{}) (interface{}, error) {
a, err := a.EvalString(c, o)
if err != nil {
return nil, err
}
b, err := b.EvalString(c, o)
if err != nil {
return nil, err
}
matched, err := regexp.MatchString(b, a)
return matched, err
}, nil
}
s, err := b.EvalString(context.TODO(), nil)
if err != nil {
return nil, err
}
regex, err := regexp.Compile(s)
if err != nil {
return nil, err
}
return func(c context.Context, v interface{}) (interface{}, error) {
s, err := a.EvalString(c, v)
if err != nil {
return nil, err
}
return regex.MatchString(s), nil
}, nil
}
func notRegEx(a, b Evaluable) (Evaluable, error) {
if !b.IsConst() {
return func(c context.Context, o interface{}) (interface{}, error) {
a, err := a.EvalString(c, o)
if err != nil {
return nil, err
}
b, err := b.EvalString(c, o)
if err != nil {
return nil, err
}
matched, err := regexp.MatchString(b, a)
return !matched, err
}, nil
}
s, err := b.EvalString(context.TODO(), nil)
if err != nil {
return nil, err
}
regex, err := regexp.Compile(s)
if err != nil {
return nil, err
}
return func(c context.Context, v interface{}) (interface{}, error) {
s, err := a.EvalString(c, v)
if err != nil {
return nil, err
}
return !regex.MatchString(s), nil
}, nil
}

250
gval/evaluable_test.go Normal file
View File

@ -0,0 +1,250 @@
package gval
import (
"context"
"fmt"
"reflect"
"strings"
"testing"
"time"
)
func TestEvaluable_IsConst(t *testing.T) {
p := Parser{}
tests := []struct {
name string
e Evaluable
want bool
}{
{
"const",
p.Const(80.5),
true,
},
{
"var",
p.Var(),
false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if got := tt.e.IsConst(); got != tt.want {
t.Errorf("Evaluable.IsConst() = %v, want %v", got, tt.want)
}
})
}
}
func TestEvaluable_EvalInt(t *testing.T) {
tests := []struct {
name string
e Evaluable
want int
wantErr bool
}{
{
"point",
constant("5.3"),
5,
false,
},
{
"number",
constant(255.),
255,
false,
},
{
"error",
constant("5.3 cm"),
0,
true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got, err := tt.e.EvalInt(context.Background(), nil)
if (err != nil) != tt.wantErr {
t.Errorf("Evaluable.EvalInt() error = %v, wantErr %v", err, tt.wantErr)
return
}
if got != tt.want {
t.Errorf("Evaluable.EvalInt() = %v, want %v", got, tt.want)
}
})
}
}
func TestEvaluable_EvalFloat64(t *testing.T) {
tests := []struct {
name string
e Evaluable
want float64
wantErr bool
}{
{
"point",
constant("5.3"),
5.3,
false,
},
{
"number",
constant(255.),
255,
false,
},
{
"error",
constant("5.3 cm"),
0,
true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got, err := tt.e.EvalUint(context.Background(), nil)
if (err != nil) != tt.wantErr {
t.Errorf("Evaluable.EvalUint() error = %v, wantErr %v", err, tt.wantErr)
return
}
if got != tt.want {
t.Errorf("Evaluable.EvalUint() = %v, want %v", got, tt.want)
}
})
}
}
type testSelector struct {
str string
Map map[string]interface{}
}
func (s testSelector) SelectGVal(ctx context.Context, k string) (interface{}, error) {
if k == "str" {
return s.str, nil
}
if k == "map" {
return s.Map, nil
}
if strings.HasPrefix(k, "deep") {
return s, nil
}
return nil, fmt.Errorf("unknown-key")
}
func TestEvaluable_CustomSelector(t *testing.T) {
var (
lang = Base()
tests = []struct {
name string
expr string
params interface{}
want interface{}
wantErr bool
}{
{
"unknown",
"s.Foo",
map[string]interface{}{"s": &testSelector{}},
nil,
true,
},
{
"field directly",
"s.Str",
map[string]interface{}{"s": &testSelector{str: "test-value"}},
nil,
true,
},
{
"field via selector",
"s.str",
map[string]interface{}{"s": &testSelector{str: "test-value"}},
"test-value",
false,
},
{
"flat",
"str",
&testSelector{str: "test-value"},
"test-value",
false,
},
{
"map field",
"s.map.foo",
map[string]interface{}{"s": &testSelector{Map: map[string]interface{}{"foo": "bar"}}},
"bar",
false,
},
{
"crawl to val",
"deep.deeper.deepest.str",
&testSelector{str: "foo"},
"foo",
false,
},
{
"crawl to struct",
"deep.deeper.deepest",
&testSelector{},
testSelector{},
false,
},
}
booltests = []struct {
name string
expr string
params interface{}
want interface{}
wantErr bool
}{
{
"test method",
"s.IsZero",
map[string]interface{}{"s": time.Now()},
false,
false,
},
}
)
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got, err := lang.Evaluate(tt.expr, tt.params)
if (err != nil) != tt.wantErr {
t.Errorf("Evaluable.Evaluate() error = %v, wantErr %v", err, tt.wantErr)
return
}
if !reflect.DeepEqual(got, tt.want) {
t.Errorf("Evaluable.Evaluate() = %v, want %v", got, tt.want)
}
})
}
for _, tt := range booltests {
t.Run(tt.name, func(t *testing.T) {
got, err := lang.Evaluate(tt.expr, tt.params)
if (err != nil) != tt.wantErr {
t.Errorf("Evaluable.Evaluate() error = %v, wantErr %v", err, tt.wantErr)
return
}
got, ok := convertToBool(got)
if !ok {
t.Errorf("Evaluable.Evaluate() error = nok, wantErr %v", tt.wantErr)
return
}
if !reflect.DeepEqual(got, tt.want) {
t.Errorf("Evaluable.Evaluate() = %v, want %v", got, tt.want)
}
})
}
}

450
gval/example_test.go Normal file
View File

@ -0,0 +1,450 @@
package gval_test
import (
"context"
"fmt"
"strings"
"time"
"github.com/PaesslerAG/gval"
"github.com/PaesslerAG/jsonpath"
)
func Example() {
vars := map[string]interface{}{"name": "World"}
value, err := gval.Evaluate(`"Hello " + name + "!"`, vars)
if err != nil {
fmt.Println(err)
}
fmt.Print(value)
// Output:
// Hello World!
}
func ExampleEvaluate() {
value, err := gval.Evaluate("foo > 0", map[string]interface{}{
"foo": -1.,
})
if err != nil {
fmt.Println(err)
}
fmt.Print(value)
// Output:
// false
}
func ExampleEvaluate_nestedParameter() {
value, err := gval.Evaluate("foo.bar > 0", map[string]interface{}{
"foo": map[string]interface{}{"bar": -1.},
})
if err != nil {
fmt.Println(err)
}
fmt.Print(value)
// Output:
// false
}
func ExampleEvaluate_array() {
value, err := gval.Evaluate("foo[0]", map[string]interface{}{
"foo": []interface{}{-1.},
})
if err != nil {
fmt.Println(err)
}
fmt.Print(value)
// Output:
// -1
}
func ExampleEvaluate_complexAccessor() {
value, err := gval.Evaluate(`foo["b" + "a" + "r"]`, map[string]interface{}{
"foo": map[string]interface{}{"bar": -1.},
})
if err != nil {
fmt.Println(err)
}
fmt.Print(value)
// Output:
// -1
}
func ExampleEvaluate_arithmetic() {
value, err := gval.Evaluate("(requests_made * requests_succeeded / 100) >= 90",
map[string]interface{}{
"requests_made": 100,
"requests_succeeded": 80,
})
if err != nil {
fmt.Println(err)
}
fmt.Print(value)
// Output:
// false
}
func ExampleEvaluate_string() {
value, err := gval.Evaluate(`http_response_body == "service is ok"`,
map[string]interface{}{
"http_response_body": "service is ok",
})
if err != nil {
fmt.Println(err)
}
fmt.Print(value)
// Output:
// true
}
func ExampleEvaluate_float64() {
value, err := gval.Evaluate("(mem_used / total_mem) * 100",
map[string]interface{}{
"total_mem": 1024,
"mem_used": 512,
})
if err != nil {
fmt.Println(err)
}
fmt.Print(value)
// Output:
// 50
}
func ExampleEvaluate_dateComparison() {
value, err := gval.Evaluate("date(`2014-01-02`) > date(`2014-01-01 23:59:59`)",
nil,
// define Date comparison because it is not part expression language gval
gval.InfixOperator(">", func(a, b interface{}) (interface{}, error) {
date1, ok1 := a.(time.Time)
date2, ok2 := b.(time.Time)
if ok1 && ok2 {
return date1.After(date2), nil
}
return nil, fmt.Errorf("unexpected operands types (%T) > (%T)", a, b)
}),
)
if err != nil {
fmt.Println(err)
}
fmt.Print(value)
// Output:
// true
}
func ExampleEvaluable() {
eval, err := gval.Full(gval.Constant("maximum_time", 52)).
NewEvaluable("response_time <= maximum_time")
if err != nil {
fmt.Println(err)
}
for i := 50; i < 55; i++ {
value, err := eval(context.Background(), map[string]interface{}{
"response_time": i,
})
if err != nil {
fmt.Println(err)
}
fmt.Println(value)
}
// Output:
// true
// true
// true
// false
// false
}
func ExampleEvaluate_strlen() {
value, err := gval.Evaluate(`strlen("someReallyLongInputString") <= 16`,
nil,
gval.Function("strlen", func(args ...interface{}) (interface{}, error) {
length := len(args[0].(string))
return (float64)(length), nil
}))
if err != nil {
fmt.Println(err)
}
fmt.Print(value)
// Output:
// false
}
func ExampleEvaluate_encoding() {
value, err := gval.Evaluate(`(7 < "47" == true ? "hello world!\n\u263a" : "good bye\n")`+" + ` more text`",
nil,
gval.Function("strlen", func(args ...interface{}) (interface{}, error) {
length := len(args[0].(string))
return (float64)(length), nil
}))
if err != nil {
fmt.Println(err)
}
fmt.Print(value)
// Output:
// hello world!
// ☺ more text
}
type exampleType struct {
Hello string
}
func (e exampleType) World() string {
return "world"
}
func ExampleEvaluate_accessor() {
value, err := gval.Evaluate(`foo.Hello + foo.World()`,
map[string]interface{}{
"foo": exampleType{Hello: "hello "},
})
if err != nil {
fmt.Println(err)
}
fmt.Print(value)
// Output:
// hello world
}
func ExampleEvaluate_flatAccessor() {
value, err := gval.Evaluate(`Hello + World()`,
exampleType{Hello: "hello "},
)
if err != nil {
fmt.Println(err)
}
fmt.Print(value)
// Output:
// hello world
}
func ExampleEvaluate_nestedAccessor() {
value, err := gval.Evaluate(`foo.Bar.Hello + foo.Bar.World()`,
map[string]interface{}{
"foo": struct{ Bar exampleType }{
Bar: exampleType{Hello: "hello "},
},
})
if err != nil {
fmt.Println(err)
}
fmt.Print(value)
// Output:
// hello world
}
func ExampleVariableSelector() {
value, err := gval.Evaluate(`hello.world`,
"!",
gval.VariableSelector(func(path gval.Evaluables) gval.Evaluable {
return func(c context.Context, v interface{}) (interface{}, error) {
keys, err := path.EvalStrings(c, v)
if err != nil {
return nil, err
}
return fmt.Sprintf("%s%s", strings.Join(keys, " "), v), nil
}
}),
)
if err != nil {
fmt.Println(err)
}
fmt.Print(value)
// Output:
// hello world!
}
func ExampleEvaluable_EvalInt() {
eval, err := gval.Full().NewEvaluable("1 + x")
if err != nil {
fmt.Println(err)
return
}
value, err := eval.EvalInt(context.Background(), map[string]interface{}{"x": 5})
if err != nil {
fmt.Println(err)
}
fmt.Print(value)
// Output:
// 6
}
func ExampleEvaluable_EvalBool() {
eval, err := gval.Full().NewEvaluable("1 == x")
if err != nil {
fmt.Println(err)
return
}
value, err := eval.EvalBool(context.Background(), map[string]interface{}{"x": 1})
if err != nil {
fmt.Println(err)
}
if value {
fmt.Print("yeah")
}
// Output:
// yeah
}
func ExampleEvaluate_jsonpath() {
value, err := gval.Evaluate(`$["response-time"]`,
map[string]interface{}{
"response-time": 100,
},
jsonpath.Language(),
)
if err != nil {
fmt.Println(err)
}
fmt.Print(value)
// Output:
// 100
}
func ExampleLanguage() {
lang := gval.NewLanguage(gval.JSON(), gval.Arithmetic(),
//pipe operator
gval.PostfixOperator("|", func(c context.Context, p *gval.Parser, pre gval.Evaluable) (gval.Evaluable, error) {
post, err := p.ParseExpression(c)
if err != nil {
return nil, err
}
return func(c context.Context, v interface{}) (interface{}, error) {
v, err := pre(c, v)
if err != nil {
return nil, err
}
return post(c, v)
}, nil
}))
eval, err := lang.NewEvaluable(`{"foobar": 50} | foobar + 100`)
if err != nil {
fmt.Println(err)
}
value, err := eval(context.Background(), nil)
if err != nil {
fmt.Println(err)
}
fmt.Println(value)
// Output:
// 150
}
type exampleCustomSelector struct{ hidden string }
var _ gval.Selector = &exampleCustomSelector{}
func (s *exampleCustomSelector) SelectGVal(ctx context.Context, k string) (interface{}, error) {
if k == "hidden" {
return s.hidden, nil
}
return nil, nil
}
func ExampleSelector() {
lang := gval.Base()
value, err := lang.Evaluate(
"myStruct.hidden",
map[string]interface{}{"myStruct": &exampleCustomSelector{hidden: "hello world"}},
)
if err != nil {
fmt.Println(err)
}
fmt.Println(value)
// Output:
// hello world
}
func parseSub(ctx context.Context, p *gval.Parser) (gval.Evaluable, error) {
return p.ParseSublanguage(ctx, subLang)
}
var (
superLang = gval.NewLanguage(
gval.PrefixExtension('$', parseSub),
)
subLang = gval.NewLanguage(
gval.Init(func(ctx context.Context, p *gval.Parser) (gval.Evaluable, error) { return p.Const("hello world"), nil }),
)
)
func ExampleParser_ParseSublanguage() {
value, err := superLang.Evaluate("$", nil)
if err != nil {
fmt.Println(err)
}
fmt.Println(value)
// Output:
// hello world
}

128
gval/functions.go Normal file
View File

@ -0,0 +1,128 @@
package gval
import (
"context"
"fmt"
"reflect"
)
type function func(ctx context.Context, arguments ...interface{}) (interface{}, error)
func toFunc(f interface{}) function {
if f, ok := f.(func(arguments ...interface{}) (interface{}, error)); ok {
return function(func(ctx context.Context, arguments ...interface{}) (interface{}, error) {
var v interface{}
errCh := make(chan error, 1)
go func() {
defer func() {
if recovered := recover(); recovered != nil {
errCh <- fmt.Errorf("%v", recovered)
}
}()
result, err := f(arguments...)
v = result
errCh <- err
}()
select {
case <-ctx.Done():
return nil, ctx.Err()
case err := <-errCh:
close(errCh)
return v, err
}
})
}
if f, ok := f.(func(ctx context.Context, arguments ...interface{}) (interface{}, error)); ok {
return function(f)
}
fun := reflect.ValueOf(f)
t := fun.Type()
return func(ctx context.Context, args ...interface{}) (interface{}, error) {
var v interface{}
errCh := make(chan error, 1)
go func() {
defer func() {
if recovered := recover(); recovered != nil {
errCh <- fmt.Errorf("%v", recovered)
}
}()
in, err := createCallArguments(ctx, t, args)
if err != nil {
errCh <- err
return
}
out := fun.Call(in)
r := make([]interface{}, len(out))
for i, e := range out {
r[i] = e.Interface()
}
err = nil
errorInterface := reflect.TypeOf((*error)(nil)).Elem()
if len(r) > 0 && t.Out(len(r)-1).Implements(errorInterface) {
if r[len(r)-1] != nil {
err = r[len(r)-1].(error)
}
r = r[0 : len(r)-1]
}
switch len(r) {
case 0:
v = nil
case 1:
v = r[0]
default:
v = r
}
errCh <- err
}()
select {
case <-ctx.Done():
return nil, ctx.Err()
case err := <-errCh:
close(errCh)
return v, err
}
}
}
func createCallArguments(ctx context.Context, t reflect.Type, args []interface{}) ([]reflect.Value, error) {
variadic := t.IsVariadic()
numIn := t.NumIn()
// if first argument is a context, use the given execution context
if numIn > 0 {
thisFun := reflect.ValueOf(createCallArguments)
thisT := thisFun.Type()
if t.In(0) == thisT.In(0) {
args = append([]interface{}{ctx}, args...)
}
}
if (!variadic && len(args) != numIn) || (variadic && len(args) < numIn-1) {
return nil, fmt.Errorf("invalid number of parameters")
}
in := make([]reflect.Value, len(args))
var inType reflect.Type
for i, arg := range args {
if !variadic || i < numIn-1 {
inType = t.In(i)
} else if i == numIn-1 {
inType = t.In(numIn - 1).Elem()
}
argVal := reflect.ValueOf(arg)
if arg == nil {
argVal = reflect.Zero(reflect.TypeOf((*interface{})(nil)).Elem())
} else if !argVal.Type().AssignableTo(inType) {
return nil, fmt.Errorf("expected type %s for parameter %d but got %T",
inType.String(), i, arg)
}
in[i] = argVal
}
return in, nil
}

162
gval/functions_test.go Normal file
View File

@ -0,0 +1,162 @@
package gval
import (
"context"
"fmt"
"reflect"
"testing"
"time"
)
func Test_toFunc(t *testing.T) {
myError := fmt.Errorf("my error")
tests := []struct {
name string
function interface{}
arguments []interface{}
want interface{}
wantErr error
wantAnyErr bool
}{
{
name: "empty",
function: func() {},
},
{
name: "one arg",
function: func(a interface{}) {
if a != true {
panic("fail")
}
},
arguments: []interface{}{true},
},
{
name: "three args",
function: func(a, b, c interface{}) {
if a != 1 || b != 2 || c != 3 {
panic("fail")
}
},
arguments: []interface{}{1, 2, 3},
},
{
name: "input types",
function: func(a int, b string, c bool) {
if a != 1 || b != "2" || !c {
panic("fail")
}
},
arguments: []interface{}{1, "2", true},
},
{
name: "wronge input type int",
function: func(a int, b string, c bool) {},
arguments: []interface{}{"1", "2", true},
wantAnyErr: true,
},
{
name: "wronge input type string",
function: func(a int, b string, c bool) {},
arguments: []interface{}{1, 2, true},
wantAnyErr: true,
},
{
name: "wronge input type bool",
function: func(a int, b string, c bool) {},
arguments: []interface{}{1, "2", "true"},
wantAnyErr: true,
},
{
name: "wronge input number",
function: func(a int, b string, c bool) {},
arguments: []interface{}{1, "2"},
wantAnyErr: true,
},
{
name: "one return",
function: func() bool {
return true
},
want: true,
},
{
name: "three returns",
function: func() (bool, string, int) {
return true, "2", 3
},
want: []interface{}{true, "2", 3},
},
{
name: "error",
function: func() error {
return myError
},
wantErr: myError,
},
{
name: "none error",
function: func() error {
return nil
},
},
{
name: "one return with error",
function: func() (bool, error) {
return false, myError
},
want: false,
wantErr: myError,
},
{
name: "three returns with error",
function: func() (bool, string, int, error) {
return false, "", 0, myError
},
want: []interface{}{false, "", 0},
wantErr: myError,
},
{
name: "context not expiring",
function: func(ctx context.Context) error {
return nil
},
},
{
name: "context expires",
function: func(ctx context.Context) error {
time.Sleep(20 * time.Millisecond)
return nil
},
wantErr: context.DeadlineExceeded,
},
{
name: "nil arg",
function: func(a interface{}) bool {
return a == nil
},
arguments: []interface{}{nil},
want: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Millisecond)
got, err := toFunc(tt.function)(ctx, tt.arguments...)
cancel()
if tt.wantAnyErr {
if err != nil {
return
}
t.Fatalf("toFunc()(args...) = error(nil), but wantAnyErr")
}
if err != tt.wantErr {
t.Fatalf("toFunc()(args...) = error(%v), wantErr (%v)", err, tt.wantErr)
}
if !reflect.DeepEqual(got, tt.want) {
t.Errorf("toFunc()(args...) = %v, want %v", got, tt.want)
}
})
}
}

334
gval/gval.go Normal file
View File

@ -0,0 +1,334 @@
// Package gval provides a generic expression language.
// All functions, infix and prefix operators can be replaced by composing languages into a new one.
//
// The package contains concrete expression languages for common application in text, arithmetic, decimal arithmetic, propositional logic and so on.
// They can be used as basis for a custom expression language or to evaluate expressions directly.
package gval
import (
"context"
"fmt"
"reflect"
"text/scanner"
"time"
"github.com/shopspring/decimal"
)
// Evaluate given parameter with given expression in gval full language
func Evaluate(expression string, parameter interface{}, opts ...Language) (interface{}, error) {
return EvaluateWithContext(context.Background(), expression, parameter, opts...)
}
// Evaluate given parameter with given expression in gval full language using a context
func EvaluateWithContext(c context.Context, expression string, parameter interface{}, opts ...Language) (interface{}, error) {
l := full
if len(opts) > 0 {
l = NewLanguage(append([]Language{l}, opts...)...)
}
return l.EvaluateWithContext(c, expression, parameter)
}
// Full is the union of Arithmetic, Bitmask, Text, PropositionalLogic, TernaryOperator, and Json
//
// Operator in: a in b is true iff value a is an element of array b
// Operator ??: a ?? b returns a if a is not false or nil, otherwise n
//
// Function Date: Date(a) parses string a. a must match RFC3339, ISO8601, ruby date, or unix date
func Full(extensions ...Language) Language {
if len(extensions) == 0 {
return full
}
return NewLanguage(append([]Language{full}, extensions...)...)
}
// TernaryOperator contains following Operator
//
// ?: a ? b : c returns b if bool a is true, otherwise b
func TernaryOperator() Language {
return ternaryOperator
}
// Arithmetic contains base, plus(+), minus(-), divide(/), power(**), negative(-)
// and numerical order (<=,<,>,>=)
//
// Arithmetic operators expect float64 operands.
// Called with unfitting input, they try to convert the input to float64.
// They can parse strings and convert any type of int or float.
func Arithmetic() Language {
return arithmetic
}
// DecimalArithmetic contains base, plus(+), minus(-), divide(/), power(**), negative(-)
// and numerical order (<=,<,>,>=)
//
// DecimalArithmetic operators expect decimal.Decimal operands (github.com/shopspring/decimal)
// and are used to calculate money/decimal rather than floating point calculations.
// Called with unfitting input, they try to convert the input to decimal.Decimal.
// They can parse strings and convert any type of int or float.
func DecimalArithmetic() Language {
return decimalArithmetic
}
// Bitmask contains base, bitwise and(&), bitwise or(|) and bitwise not(^).
//
// Bitmask operators expect float64 operands.
// Called with unfitting input they try to convert the input to float64.
// They can parse strings and convert any type of int or float.
func Bitmask() Language {
return bitmask
}
// Text contains base, lexical order on strings (<=,<,>,>=),
// regex match (=~) and regex not match (!~)
func Text() Language {
return text
}
// PropositionalLogic contains base, not(!), and (&&), or (||) and Base.
//
// Propositional operator expect bool operands.
// Called with unfitting input they try to convert the input to bool.
// Numbers other than 0 and the strings "TRUE" and "true" are interpreted as true.
// 0 and the strings "FALSE" and "false" are interpreted as false.
func PropositionalLogic() Language {
return propositionalLogic
}
// JSON contains json objects ({string:expression,...})
// and json arrays ([expression, ...])
func JSON() Language {
return ljson
}
// Parentheses contains support for parentheses.
func Parentheses() Language {
return parentheses
}
// Ident contains support for variables and functions.
func Ident() Language {
return ident
}
// Base contains equal (==) and not equal (!=), perentheses and general support for variables, constants and functions
// It contains true, false, (floating point) number, string ("" or “) and char (”) constants
func Base() Language {
return base
}
var full = NewLanguage(arithmetic, bitmask, text, propositionalLogic, ljson,
InfixOperator("in", inArray),
InfixShortCircuit("??", func(a interface{}) (interface{}, bool) {
v := reflect.ValueOf(a)
return a, a != nil && !v.IsZero()
}),
InfixOperator("??", func(a, b interface{}) (interface{}, error) {
if v := reflect.ValueOf(a); a == nil || v.IsZero() {
return b, nil
}
return a, nil
}),
ternaryOperator,
Function("date", func(arguments ...interface{}) (interface{}, error) {
if len(arguments) != 1 {
return nil, fmt.Errorf("date() expects exactly one string argument")
}
s, ok := arguments[0].(string)
if !ok {
return nil, fmt.Errorf("date() expects exactly one string argument")
}
for _, format := range [...]string{
time.ANSIC,
time.UnixDate,
time.RubyDate,
time.Kitchen,
time.RFC3339,
time.RFC3339Nano,
"2006-01-02", // RFC 3339
"2006-01-02 15:04", // RFC 3339 with minutes
"2006-01-02 15:04:05", // RFC 3339 with seconds
"2006-01-02 15:04:05-07:00", // RFC 3339 with seconds and timezone
"2006-01-02T15Z0700", // ISO8601 with hour
"2006-01-02T15:04Z0700", // ISO8601 with minutes
"2006-01-02T15:04:05Z0700", // ISO8601 with seconds
"2006-01-02T15:04:05.999999999Z0700", // ISO8601 with nanoseconds
} {
ret, err := time.ParseInLocation(format, s, time.Local)
if err == nil {
return ret, nil
}
}
return nil, fmt.Errorf("date() could not parse %s", s)
}),
)
var ternaryOperator = PostfixOperator("?", parseIf)
var ljson = NewLanguage(
PrefixExtension('[', parseJSONArray),
PrefixExtension('{', parseJSONObject),
)
var arithmetic = NewLanguage(
InfixNumberOperator("+", func(a, b uint) (interface{}, error) { return a + b, nil }),
InfixNumberOperator("-", func(a, b uint) (interface{}, error) { return a - b, nil }),
InfixNumberOperator("*", func(a, b uint) (interface{}, error) { return a * b, nil }),
InfixNumberOperator("/", func(a, b uint) (interface{}, error) { return a / b, nil }),
InfixNumberOperator("%", func(a, b uint) (interface{}, error) { return a % b, nil }),
InfixNumberOperator(">", func(a, b uint) (interface{}, error) { return a > b, nil }),
InfixNumberOperator(">=", func(a, b uint) (interface{}, error) { return a >= b, nil }),
InfixNumberOperator("<", func(a, b uint) (interface{}, error) { return a < b, nil }),
InfixNumberOperator("<=", func(a, b uint) (interface{}, error) { return a <= b, nil }),
InfixNumberOperator("==", func(a, b uint) (interface{}, error) { return a == b, nil }),
InfixNumberOperator("!=", func(a, b uint) (interface{}, error) { return a != b, nil }),
base,
)
var decimalArithmetic = NewLanguage(
InfixDecimalOperator("+", func(a, b decimal.Decimal) (interface{}, error) { return a.Add(b), nil }),
InfixDecimalOperator("-", func(a, b decimal.Decimal) (interface{}, error) { return a.Sub(b), nil }),
InfixDecimalOperator("*", func(a, b decimal.Decimal) (interface{}, error) { return a.Mul(b), nil }),
InfixDecimalOperator("/", func(a, b decimal.Decimal) (interface{}, error) { return a.Div(b), nil }),
InfixDecimalOperator("%", func(a, b decimal.Decimal) (interface{}, error) { return a.Mod(b), nil }),
InfixDecimalOperator("**", func(a, b decimal.Decimal) (interface{}, error) { return a.Pow(b), nil }),
InfixDecimalOperator(">", func(a, b decimal.Decimal) (interface{}, error) { return a.GreaterThan(b), nil }),
InfixDecimalOperator(">=", func(a, b decimal.Decimal) (interface{}, error) { return a.GreaterThanOrEqual(b), nil }),
InfixDecimalOperator("<", func(a, b decimal.Decimal) (interface{}, error) { return a.LessThan(b), nil }),
InfixDecimalOperator("<=", func(a, b decimal.Decimal) (interface{}, error) { return a.LessThanOrEqual(b), nil }),
InfixDecimalOperator("==", func(a, b decimal.Decimal) (interface{}, error) { return a.Equal(b), nil }),
InfixDecimalOperator("!=", func(a, b decimal.Decimal) (interface{}, error) { return !a.Equal(b), nil }),
base,
//Base is before these overrides so that the Base options are overridden
PrefixExtension(scanner.Int, parseDecimal),
PrefixExtension(scanner.Float, parseDecimal),
PrefixOperator("-", func(c context.Context, v interface{}) (interface{}, error) {
i, ok := convertToUint(v)
if !ok {
return nil, fmt.Errorf("unexpected %v(%T) expected number", v, v)
}
return -i, nil
}),
)
var bitmask = NewLanguage(
InfixNumberOperator("^", func(a, b uint) (interface{}, error) { return uint(int64(a) ^ int64(b)), nil }),
InfixNumberOperator("&", func(a, b uint) (interface{}, error) { return uint(int64(a) & int64(b)), nil }),
InfixNumberOperator("|", func(a, b uint) (interface{}, error) { return uint(int64(a) | int64(b)), nil }),
InfixNumberOperator("<<", func(a, b uint) (interface{}, error) { return uint(int64(a) << uint64(b)), nil }),
InfixNumberOperator(">>", func(a, b uint) (interface{}, error) { return uint(int64(a) >> uint64(b)), nil }),
PrefixOperator("~", func(c context.Context, v interface{}) (interface{}, error) {
i, ok := convertToUint(v)
if !ok {
return nil, fmt.Errorf("unexpected %T expected number", v)
}
return float64(^int64(i)), nil
}),
)
var text = NewLanguage(
InfixTextOperator("+", func(a, b string) (interface{}, error) { return fmt.Sprintf("%v%v", a, b), nil }),
InfixTextOperator("<", func(a, b string) (interface{}, error) { return a < b, nil }),
InfixTextOperator("<=", func(a, b string) (interface{}, error) { return a <= b, nil }),
InfixTextOperator(">", func(a, b string) (interface{}, error) { return a > b, nil }),
InfixTextOperator(">=", func(a, b string) (interface{}, error) { return a >= b, nil }),
InfixEvalOperator("=~", regEx),
InfixEvalOperator("!~", notRegEx),
base,
)
var propositionalLogic = NewLanguage(
PrefixOperator("!", func(c context.Context, v interface{}) (interface{}, error) {
b, ok := convertToBool(v)
if !ok {
return nil, fmt.Errorf("unexpected %T expected bool", v)
}
return !b, nil
}),
InfixShortCircuit("&&", func(a interface{}) (interface{}, bool) { return false, a == false }),
InfixBoolOperator("&&", func(a, b bool) (interface{}, error) { return a && b, nil }),
InfixShortCircuit("||", func(a interface{}) (interface{}, bool) { return true, a == true }),
InfixBoolOperator("||", func(a, b bool) (interface{}, error) { return a || b, nil }),
InfixBoolOperator("==", func(a, b bool) (interface{}, error) { return a == b, nil }),
InfixBoolOperator("!=", func(a, b bool) (interface{}, error) { return a != b, nil }),
base,
)
var parentheses = NewLanguage(
PrefixExtension('(', parseParentheses),
)
var ident = NewLanguage(
PrefixMetaPrefix(scanner.Ident, parseIdent),
)
var base = NewLanguage(
PrefixExtension(scanner.Int, parseNumber),
PrefixExtension(scanner.Float, parseNumber),
PrefixOperator("-", func(c context.Context, v interface{}) (interface{}, error) {
i, ok := convertToUint(v)
if !ok {
return nil, fmt.Errorf("unexpected %v(%T) expected number", v, v)
}
return -i, nil
}),
PrefixExtension(scanner.String, parseString),
PrefixExtension(scanner.Char, parseString),
PrefixExtension(scanner.RawString, parseString),
Constant("true", true),
Constant("false", false),
InfixOperator("==", func(a, b interface{}) (interface{}, error) { return reflect.DeepEqual(a, b), nil }),
InfixOperator("!=", func(a, b interface{}) (interface{}, error) { return !reflect.DeepEqual(a, b), nil }),
parentheses,
Precedence("??", 0),
Precedence("||", 20),
Precedence("&&", 21),
Precedence("==", 40),
Precedence("!=", 40),
Precedence(">", 40),
Precedence(">=", 40),
Precedence("<", 40),
Precedence("<=", 40),
Precedence("=~", 40),
Precedence("!~", 40),
Precedence("in", 40),
Precedence("^", 60),
Precedence("&", 60),
Precedence("|", 60),
Precedence("<<", 90),
Precedence(">>", 90),
Precedence("+", 120),
Precedence("-", 120),
Precedence("*", 150),
Precedence("/", 150),
Precedence("%", 150),
Precedence("**", 200),
ident,
)

View File

@ -0,0 +1,396 @@
package gval
/*
Tests to make sure evaluation fails in the expected ways.
*/
import (
"errors"
"fmt"
"testing"
)
func TestModifierTyping(test *testing.T) {
var (
invalidOperator = "invalid operation"
unknownParameter = "unknown parameter"
invalidRegex = "error parsing regex"
tooFewArguments = "reflect: Call with too few input arguments"
tooManyArguments = "reflect: Call with too many input arguments"
mismatchedParameters = "reflect: Call using"
custom = "test error"
)
evaluationTests := []evaluationTest{
//ModifierTyping
{
name: "PLUS literal number to literal bool",
expression: "1 + true",
want: "1true", // + on string is defined
},
{
name: "PLUS number to bool",
expression: "number + bool",
want: "1true", // + on string is defined
},
{
name: "MINUS number to bool",
expression: "number - bool",
wantErr: invalidOperator,
},
{
name: "MINUS number to bool",
expression: "number - bool",
wantErr: invalidOperator,
},
{
name: "MULTIPLY number to bool",
expression: "number * bool",
wantErr: invalidOperator,
},
{
name: "DIVIDE number to bool",
expression: "number / bool",
wantErr: invalidOperator,
},
{
name: "EXPONENT number to bool",
expression: "number ** bool",
wantErr: invalidOperator,
},
{
name: "MODULUS number to bool",
expression: "number % bool",
wantErr: invalidOperator,
},
{
name: "XOR number to bool",
expression: "number % bool",
wantErr: invalidOperator,
},
{
name: "BITWISE_OR number to bool",
expression: "number | bool",
wantErr: invalidOperator,
},
{
name: "BITWISE_AND number to bool",
expression: "number & bool",
wantErr: invalidOperator,
},
{
name: "BITWISE_XOR number to bool",
expression: "number ^ bool",
wantErr: invalidOperator,
},
{
name: "BITWISE_LSHIFT number to bool",
expression: "number << bool",
wantErr: invalidOperator,
},
{
name: "BITWISE_RSHIFT number to bool",
expression: "number >> bool",
wantErr: invalidOperator,
},
//LogicalOperatorTyping
{
name: "AND number to number",
expression: "number && number",
want: true, // number != 0 is true
},
{
name: "OR number to number",
expression: "number || number",
want: true, // number != 0 is true
},
{
name: "AND string to string",
expression: "string && string",
wantErr: invalidOperator,
},
{
name: "OR string to string",
expression: "string || string",
wantErr: invalidOperator,
},
{
name: "AND number to string",
expression: "number && string",
wantErr: invalidOperator,
},
{
name: "OR number to string",
expression: "number || string",
wantErr: invalidOperator,
},
{
name: "AND bool to string",
expression: "bool && string",
wantErr: invalidOperator,
},
{
name: "OR string to bool",
expression: "string || bool",
wantErr: invalidOperator,
},
//ComparatorTyping
{
name: "GT literal bool to literal bool",
expression: "true > true",
want: false, //lexical order on "true"
},
{
name: "GT bool to bool",
expression: "bool > bool",
want: false, //lexical order on "true"
},
{
name: "GTE bool to bool",
expression: "bool >= bool",
want: true, //lexical order on "true"
},
{
name: "LT bool to bool",
expression: "bool < bool",
want: false, //lexical order on "true"
},
{
name: "LTE bool to bool",
expression: "bool <= bool",
want: true, //lexical order on "true"
},
{
name: "GT number to string",
expression: "number > string",
want: false, //lexical order "1" < "foo"
},
{
name: "GTE number to string",
expression: "number >= string",
want: false, //lexical order "1" < "foo"
},
{
name: "LT number to string",
expression: "number < string",
want: true, //lexical order "1" < "foo"
},
{
name: "REQ number to string",
expression: "number =~ string",
want: false,
},
{
name: "REQ number to bool",
expression: "number =~ bool",
want: false,
},
{
name: "REQ bool to number",
expression: "bool =~ number",
want: false,
},
{
name: "REQ bool to string",
expression: "bool =~ string",
want: false,
},
{
name: "NREQ number to string",
expression: "number !~ string",
want: true,
},
{
name: "NREQ number to bool",
expression: "number !~ bool",
want: true,
},
{
name: "NREQ bool to number",
expression: "bool !~ number",
want: true,
},
{
name: "NREQ bool to string",
expression: "bool !~ string",
want: true,
},
{
name: "IN non-array numeric",
expression: "1 in 2",
wantErr: "expected type []interface{} for in operator but got float64",
},
{
name: "IN non-array string",
expression: `1 in "foo"`,
wantErr: "expected type []interface{} for in operator but got string",
},
{
name: "IN non-array boolean",
expression: "1 in true",
wantErr: "expected type []interface{} for in operator but got bool",
},
//TernaryTyping
{
name: "Ternary with number",
expression: "10 ? true",
want: true, // 10 != nil && 10 != false
},
{
name: "Ternary with string",
expression: `"foo" ? true`,
want: true, // "foo" != nil && "foo" != false
},
//RegexParameterCompilation
{
name: "Regex equality runtime parsing",
expression: `"foo" =~ foo`,
parameter: map[string]interface{}{
"foo": "[foo",
},
wantErr: invalidRegex,
},
{
name: "Regex inequality runtime parsing",
expression: `"foo" !~ foo`,
parameter: map[string]interface{}{
"foo": "[foo",
},
wantErr: invalidRegex,
},
{
name: "Regex equality runtime right side evaluation",
expression: `"foo" =~ error()`,
wantErr: custom,
},
{
name: "Regex inequality runtime right side evaluation",
expression: `"foo" !~ error()`,
wantErr: custom,
},
{
name: "Regex equality runtime left side evaluation",
expression: `error() =~ "."`,
wantErr: custom,
},
{
name: "Regex inequality runtime left side evaluation",
expression: `error() !~ "."`,
wantErr: custom,
},
//FuncExecution
{
name: "Func error bubbling",
expression: "error()",
extension: Function("error", func(arguments ...interface{}) (interface{}, error) {
return nil, errors.New("Huge problems")
}),
wantErr: "Huge problems",
},
//InvalidParameterCalls
{
name: "Missing parameter field reference",
expression: "foo.NotExists",
parameter: fooFailureParameters,
wantErr: unknownParameter,
},
{
name: "Parameter method call on missing function",
expression: "foo.NotExist()",
parameter: fooFailureParameters,
wantErr: unknownParameter,
},
{
name: "Nested missing parameter field reference",
expression: "foo.Nested.NotExists",
parameter: fooFailureParameters,
wantErr: unknownParameter,
},
{
name: "Parameter method call returns error",
expression: "foo.AlwaysFail()",
parameter: fooFailureParameters,
wantErr: "function should always fail",
},
{
name: "Too few arguments to parameter call",
expression: "foo.FuncArgStr()",
parameter: fooFailureParameters,
wantErr: tooFewArguments,
},
{
name: "Too many arguments to parameter call",
expression: `foo.FuncArgStr("foo", "bar", 15)`,
parameter: fooFailureParameters,
wantErr: tooManyArguments,
},
{
name: "Mismatched parameters",
expression: "foo.FuncArgStr(5)",
parameter: fooFailureParameters,
wantErr: mismatchedParameters,
},
{
name: "Negative Array Index",
expression: "foo[-1]",
parameter: map[string]interface{}{
"foo": []int{1, 2, 3},
},
wantErr: unknownParameter,
},
{
name: "Nested slice call index out of bound",
expression: `foo.Nested.Slice[10]`,
parameter: map[string]interface{}{"foo": foo},
wantErr: unknownParameter,
},
{
name: "Nested map call missing key",
expression: `foo.Nested.Map["d"]`,
parameter: map[string]interface{}{"foo": foo},
wantErr: unknownParameter,
},
{
name: "invalid selector",
expression: "hello[world()]",
extension: NewLanguage(Base(), Function("world", func() (int, error) {
return 0, fmt.Errorf("test error")
})),
wantErr: "test error",
},
{
name: "eval `nil > 1` returns true #23",
expression: `nil > 1`,
wantErr: "invalid operation (<nil>) > (float64)",
},
{
name: "map with unknown func",
expression: `foo.MapWithFunc.NotExist()`,
parameter: map[string]interface{}{"foo": foo},
wantErr: unknownParameter,
},
{
name: "map with unknown func",
expression: `foo.SliceWithFunc.NotExist()`,
parameter: map[string]interface{}{"foo": foo},
wantErr: unknownParameter,
},
}
for i := range evaluationTests {
if evaluationTests[i].parameter == nil {
evaluationTests[i].parameter = map[string]interface{}{
"number": 1,
"string": "foo",
"bool": true,
"error": func() (int, error) {
return 0, fmt.Errorf("test error")
},
}
}
}
testEvaluate(evaluationTests, test)
}

View File

@ -0,0 +1,818 @@
package gval
import (
"context"
"fmt"
"testing"
"text/scanner"
)
func TestNoParameter(t *testing.T) {
testEvaluate(
[]evaluationTest{
{
name: "Number",
expression: "100",
want: 100.0,
},
{
name: "Single PLUS",
expression: "51 + 49",
want: 100.0,
},
{
name: "Single MINUS",
expression: "100 - 51",
want: 49.0,
},
{
name: "Single BITWISE AND",
expression: "100 & 50",
want: 32.0,
},
{
name: "Single BITWISE OR",
expression: "100 | 50",
want: 118.0,
},
{
name: "Single BITWISE XOR",
expression: "100 ^ 50",
want: 86.0,
},
{
name: "Single shift left",
expression: "2 << 1",
want: 4.0,
},
{
name: "Single shift right",
expression: "2 >> 1",
want: 1.0,
},
{
name: "Single BITWISE NOT",
expression: "~10",
want: -11.0,
},
{
name: "Single MULTIPLY",
expression: "5 * 20",
want: 100.0,
},
{
name: "Single DIVIDE",
expression: "100 / 20",
want: 5.0,
},
{
name: "Single even MODULUS",
expression: "100 % 2",
want: 0.0,
},
{
name: "Single odd MODULUS",
expression: "101 % 2",
want: 1.0,
},
{
name: "Single EXPONENT",
expression: "10 ** 2",
want: 100.0,
},
{
name: "Compound PLUS",
expression: "20 + 30 + 50",
want: 100.0,
},
{
name: "Compound BITWISE AND",
expression: "20 & 30 & 50",
want: 16.0,
},
{
name: "Mutiple operators",
expression: "20 * 5 - 49",
want: 51.0,
},
{
name: "Parenthesis usage",
expression: "100 - (5 * 10)",
want: 50.0,
},
{
name: "Nested parentheses",
expression: "50 + (5 * (15 - 5))",
want: 100.0,
},
{
name: "Nested parentheses with bitwise",
expression: "100 ^ (23 * (2 | 5))",
want: 197.0,
},
{
name: "Logical OR operation of two clauses",
expression: "(1 == 1) || (true == true)",
want: true,
},
{
name: "Logical AND operation of two clauses",
expression: "(1 == 1) && (true == true)",
want: true,
},
{
name: "Implicit boolean",
expression: "2 > 1",
want: true,
},
{
name: "Equal test minus numbers and no spaces",
expression: "-1==-1",
want: true,
},
{
name: "Compound boolean",
expression: "5 < 10 && 1 < 5",
want: true,
},
{
name: "Evaluated true && false operation (for issue #8)",
expression: "1 > 10 && 11 > 10",
want: false,
},
{
name: "Evaluated true && false operation (for issue #8)",
expression: "true == true && false == true",
want: false,
},
{
name: "Parenthesis boolean",
expression: "10 < 50 && (1 != 2 && 1 > 0)",
want: true,
},
{
name: "Comparison of string constants",
expression: `"foo" == "foo"`,
want: true,
},
{
name: "NEQ comparison of string constants",
expression: `"foo" != "bar"`,
want: true,
},
{
name: "REQ comparison of string constants",
expression: `"foobar" =~ "oba"`,
want: true,
},
{
name: "NREQ comparison of string constants",
expression: `"foo" !~ "bar"`,
want: true,
},
{
name: "Multiplicative/additive order",
expression: "5 + 10 * 2",
want: 25.0,
},
{
name: "Multiple constant multiplications",
expression: "10 * 10 * 10",
want: 1000.0,
},
{
name: "Multiple adds/multiplications",
expression: "10 * 10 * 10 + 1 * 10 * 10",
want: 1100.0,
},
{
name: "Modulus operatorPrecedence",
expression: "1 + 101 % 2 * 5",
want: 6.0,
},
{
name: "Exponent operatorPrecedence",
expression: "1 + 5 ** 3 % 2 * 5",
want: 6.0,
},
{
name: "Bit shift operatorPrecedence",
expression: "50 << 1 & 90",
want: 64.0,
},
{
name: "Bit shift operatorPrecedence",
expression: "90 & 50 << 1",
want: 64.0,
},
{
name: "Bit shift operatorPrecedence amongst non-bitwise",
expression: "90 + 50 << 1 * 5",
want: 4480.0,
},
{
name: "Order of non-commutative same-operatorPrecedence operators (additive)",
expression: "1 - 2 - 4 - 8",
want: -13.0,
},
{
name: "Order of non-commutative same-operatorPrecedence operators (multiplicative)",
expression: "1 * 4 / 2 * 8",
want: 16.0,
},
{
name: "Null coalesce operatorPrecedence",
expression: "true ?? true ? 100 + 200 : 400",
want: 300.0,
},
{
name: "Identical date equivalence",
expression: `"2014-01-02 14:12:22" == "2014-01-02 14:12:22"`,
want: true,
},
{
name: "Positive date GT",
expression: `"2014-01-02 14:12:22" > "2014-01-02 12:12:22"`,
want: true,
},
{
name: "Negative date GT",
expression: `"2014-01-02 14:12:22" > "2014-01-02 16:12:22"`,
want: false,
},
{
name: "Positive date GTE",
expression: `"2014-01-02 14:12:22" >= "2014-01-02 12:12:22"`,
want: true,
},
{
name: "Negative date GTE",
expression: `"2014-01-02 14:12:22" >= "2014-01-02 16:12:22"`,
want: false,
},
{
name: "Positive date LT",
expression: `"2014-01-02 14:12:22" < "2014-01-02 16:12:22"`,
want: true,
},
{
name: "Negative date LT",
expression: `"2014-01-02 14:12:22" < "2014-01-02 11:12:22"`,
want: false,
},
{
name: "Positive date LTE",
expression: `"2014-01-02 09:12:22" <= "2014-01-02 12:12:22"`,
want: true,
},
{
name: "Negative date LTE",
expression: `"2014-01-02 14:12:22" <= "2014-01-02 11:12:22"`,
want: false,
},
{
name: "Sign prefix comparison",
expression: "-1 < 0",
want: true,
},
{
name: "Lexicographic LT",
expression: `"ab" < "abc"`,
want: true,
},
{
name: "Lexicographic LTE",
expression: `"ab" <= "abc"`,
want: true,
},
{
name: "Lexicographic GT",
expression: `"aba" > "abc"`,
want: false,
},
{
name: "Lexicographic GTE",
expression: `"aba" >= "abc"`,
want: false,
},
{
name: "Boolean sign prefix comparison",
expression: "!true == false",
want: true,
},
{
name: "Inversion of clause",
expression: "!(10 < 0)",
want: true,
},
{
name: "Negation after modifier",
expression: "10 * -10",
want: -100.0,
},
{
name: "Ternary with single boolean",
expression: "true ? 10",
want: 10.0,
},
{
name: "Ternary nil with single boolean",
expression: "false ? 10",
want: nil,
},
{
name: "Ternary with comparator boolean",
expression: "10 > 5 ? 35.50",
want: 35.50,
},
{
name: "Ternary nil with comparator boolean",
expression: "1 > 5 ? 35.50",
want: nil,
},
{
name: "Ternary with parentheses",
expression: "(5 * (15 - 5)) > 5 ? 35.50",
want: 35.50,
},
{
name: "Ternary operatorPrecedence",
expression: "true ? 35.50 > 10",
want: true,
},
{
name: "Ternary-else",
expression: "false ? 35.50 : 50",
want: 50.0,
},
{
name: "Ternary-else inside clause",
expression: "(false ? 5 : 35.50) > 10",
want: true,
},
{
name: "Ternary-else (true-case) inside clause",
expression: "(true ? 1 : 5) < 10",
want: true,
},
{
name: "Ternary-else before comparator (negative case)",
expression: "true ? 1 : 5 > 10",
want: 1.0,
},
{
name: "Nested ternaries (#32)",
expression: "(2 == 2) ? 1 : (true ? 2 : 3)",
want: 1.0,
},
{
name: "Nested ternaries, right case (#32)",
expression: "false ? 1 : (true ? 2 : 3)",
want: 2.0,
},
{
name: "Doubly-nested ternaries (#32)",
expression: "true ? (false ? 1 : (false ? 2 : 3)) : (false ? 4 : 5)",
want: 3.0,
},
{
name: "String to string concat",
expression: `"foo" + "bar" == "foobar"`,
want: true,
},
{
name: "String to float64 concat",
expression: `"foo" + 123 == "foo123"`,
want: true,
},
{
name: "Float64 to string concat",
expression: `123 + "bar" == "123bar"`,
want: true,
},
{
name: "String to date concat",
expression: `"foo" + "02/05/1970" == "foobar"`,
want: false,
},
{
name: "String to bool concat",
expression: `"foo" + true == "footrue"`,
want: true,
},
{
name: "Bool to string concat",
expression: `true + "bar" == "truebar"`,
want: true,
},
{
name: "Null coalesce left",
expression: "1 ?? 2",
want: 1.0,
},
{
name: "Array membership literals",
expression: "1 in [1, 2, 3]",
want: true,
},
{
name: "Array membership literal with inversion",
expression: "!(1 in [1, 2, 3])",
want: false,
},
{
name: "Logical operator reordering (#30)",
expression: "(true && true) || (true && false)",
want: true,
},
{
name: "Logical operator reordering without parens (#30)",
expression: "true && true || true && false",
want: true,
},
{
name: "Logical operator reordering with multiple OR (#30)",
expression: "false || true && true || false",
want: true,
},
{
name: "Left-side multiple consecutive (should be reordered) operators",
expression: "(10 * 10 * 10) > 10",
want: true,
},
{
name: "Three-part non-paren logical op reordering (#44)",
expression: "false && true || true",
want: true,
},
{
name: "Three-part non-paren logical op reordering (#44), second one",
expression: "true || false && true",
want: true,
},
{
name: "Logical operator reordering without parens (#45)",
expression: "true && true || false && false",
want: true,
},
{
name: "Single function",
expression: "foo()",
extension: Function("foo", func(arguments ...interface{}) (interface{}, error) {
return true, nil
}),
want: true,
},
{
name: "Func with argument",
expression: "passthrough(1)",
extension: Function("passthrough", func(arguments ...interface{}) (interface{}, error) {
return arguments[0], nil
}),
want: 1.0,
},
{
name: "Func with arguments",
expression: "passthrough(1, 2)",
extension: Function("passthrough", func(arguments ...interface{}) (interface{}, error) {
return arguments[0].(float64) + arguments[1].(float64), nil
}),
want: 3.0,
},
{
name: "Nested function with operatorPrecedence",
expression: "sum(1, sum(2, 3), 2 + 2, true ? 4 : 5)",
extension: Function("sum", func(arguments ...interface{}) (interface{}, error) {
sum := 0.0
for _, v := range arguments {
sum += v.(float64)
}
return sum, nil
}),
want: 14.0,
},
{
name: "Empty function and modifier, compared",
expression: "numeric()-1 > 0",
extension: Function("numeric", func(arguments ...interface{}) (interface{}, error) {
return 2.0, nil
}),
want: true,
},
{
name: "Empty function comparator",
expression: "numeric() > 0",
extension: Function("numeric", func(arguments ...interface{}) (interface{}, error) {
return 2.0, nil
}),
want: true,
},
{
name: "Empty function logical operator",
expression: "success() && !false",
extension: Function("success", func(arguments ...interface{}) (interface{}, error) {
return true, nil
}),
want: true,
},
{
name: "Empty function ternary",
expression: "nope() ? 1 : 2.0",
extension: Function("nope", func(arguments ...interface{}) (interface{}, error) {
return false, nil
}),
want: 2.0,
},
{
name: "Empty function null coalesce",
expression: "null() ?? 2",
extension: Function("null", func(arguments ...interface{}) (interface{}, error) {
return nil, nil
}),
want: 2.0,
},
{
name: "Empty function with prefix",
expression: "-ten()",
extension: Function("ten", func(arguments ...interface{}) (interface{}, error) {
return 10.0, nil
}),
want: -10.0,
},
{
name: "Empty function as part of chain",
expression: "10 - numeric() - 2",
extension: Function("numeric", func(arguments ...interface{}) (interface{}, error) {
return 5.0, nil
}),
want: 3.0,
},
{
name: "Empty function near separator",
expression: "10 in [1, 2, 3, ten(), 8]",
extension: Function("ten", func(arguments ...interface{}) (interface{}, error) {
return 10.0, nil
}),
want: true,
},
{
name: "Enclosed empty function with modifier and comparator (#28)",
expression: "(ten() - 1) > 3",
extension: Function("ten", func(arguments ...interface{}) (interface{}, error) {
return 10.0, nil
}),
want: true,
},
{
name: "Array",
expression: `[(ten() - 1) > 3, (ten() - 1),"hey"]`,
extension: Function("ten", func(arguments ...interface{}) (interface{}, error) {
return 10.0, nil
}),
want: []interface{}{true, 9., "hey"},
},
{
name: "Object",
expression: `{1: (ten() - 1) > 3, 7 + ".X" : (ten() - 1),"hello" : "hey"}`,
extension: Function("ten", func(arguments ...interface{}) (interface{}, error) {
return 10.0, nil
}),
want: map[string]interface{}{"1": true, "7.X": 9., "hello": "hey"},
},
{
name: "Object negativ value",
expression: `{1: -1,"hello" : "hey"}`,
want: map[string]interface{}{"1": -1., "hello": "hey"},
},
{
name: "Empty Array",
expression: `[]`,
want: []interface{}{},
},
{
name: "Empty Object",
expression: `{}`,
want: map[string]interface{}{},
},
{
name: "Variadic",
expression: `sum(1,2,3,4)`,
extension: Function("sum", func(arguments ...float64) (interface{}, error) {
sum := 0.
for _, a := range arguments {
sum += a
}
return sum, nil
}),
want: 10.0,
},
{
name: "Ident Operator",
expression: `1 plus 1`,
extension: InfixNumberOperator("plus", func(a, b float64) (interface{}, error) {
return a + b, nil
}),
want: 2.0,
},
{
name: "Postfix Operator",
expression: ``,
extension: PostfixOperator("§", func(_ context.Context, _ *Parser, eval Evaluable) (Evaluable, error) {
return func(ctx context.Context, parameter interface{}) (interface{}, error) {
i, err := eval.EvalInt(ctx, parameter)
if err != nil {
return nil, err
}
return fmt.Sprintf("§%d", i), nil
}, nil
}),
want: "§4",
},
{
name: "Tabs as non-whitespace",
expression: "4\t5\t6",
extension: NewLanguage(
Init(func(ctx context.Context, p *Parser) (Evaluable, error) {
p.SetWhitespace('\n', '\r', ' ')
return p.ParseExpression(ctx)
}),
InfixNumberOperator("\t", func(a, b float64) (interface{}, error) {
return a * b, nil
}),
),
want: 120.0,
},
{
name: "Handle all other prefixes",
expression: "^foo + $bar + &baz",
extension: DefaultExtension(func(ctx context.Context, p *Parser) (Evaluable, error) {
var mul int
switch p.TokenText() {
case "^":
mul = 1
case "$":
mul = 2
case "&":
mul = 3
}
switch p.Scan() {
case scanner.Ident:
return p.Const(mul * len(p.TokenText())), nil
default:
return nil, p.Expected("length multiplier", scanner.Ident)
}
}),
want: 18.0,
},
{
name: "Embed languages",
expression: "left { 5 + 5 } right",
extension: func() Language {
step := func(ctx context.Context, p *Parser, cur Evaluable) (Evaluable, error) {
next, err := p.ParseExpression(ctx)
if err != nil {
return nil, err
}
return func(ctx context.Context, parameter interface{}) (interface{}, error) {
us, err := cur.EvalString(ctx, parameter)
if err != nil {
return nil, err
}
them, err := next.EvalString(ctx, parameter)
if err != nil {
return nil, err
}
return us + them, nil
}, nil
}
return NewLanguage(
Init(func(ctx context.Context, p *Parser) (Evaluable, error) {
p.SetWhitespace()
p.SetMode(0)
return p.ParseExpression(ctx)
}),
DefaultExtension(func(ctx context.Context, p *Parser) (Evaluable, error) {
return step(ctx, p, p.Const(p.TokenText()))
}),
PrefixExtension(scanner.EOF, func(ctx context.Context, p *Parser) (Evaluable, error) {
return p.Const(""), nil
}),
PrefixExtension('{', func(ctx context.Context, p *Parser) (Evaluable, error) {
eval, err := p.ParseSublanguage(ctx, Full())
if err != nil {
return nil, err
}
switch p.Scan() {
case '}':
default:
return nil, p.Expected("embedded", '}')
}
return step(ctx, p, eval)
}),
)
}(),
want: "left 10 right",
},
{
name: "Late binding",
expression: "5 * [ 10 * { 20 / [ 10 ] } ]",
extension: func() Language {
var inner, outer Language
parseCurly := func(ctx context.Context, p *Parser) (Evaluable, error) {
eval, err := p.ParseSublanguage(ctx, outer)
if err != nil {
return nil, err
}
if p.Scan() != '}' {
return nil, p.Expected("end", '}')
}
return eval, nil
}
parseSquare := func(ctx context.Context, p *Parser) (Evaluable, error) {
eval, err := p.ParseSublanguage(ctx, inner)
if err != nil {
return nil, err
}
if p.Scan() != ']' {
return nil, p.Expected("end", ']')
}
return eval, nil
}
inner = Full(PrefixExtension('{', parseCurly))
outer = Full(PrefixExtension('[', parseSquare))
return outer
}(),
want: 100.0,
},
},
t,
)
}

View File

@ -0,0 +1,719 @@
package gval
import (
"context"
"fmt"
"regexp"
"strings"
"testing"
"time"
"github.com/shopspring/decimal"
)
func TestParameterized(t *testing.T) {
testEvaluate(
[]evaluationTest{
{
name: "Single parameter modified by constant",
expression: "foo + 2",
parameter: map[string]interface{}{
"foo": 2.0,
},
want: 4.0,
},
{
name: "Single parameter modified by variable",
expression: "foo * bar",
parameter: map[string]interface{}{
"foo": 5.0,
"bar": 2.0,
},
want: 10.0,
},
{
name: "Single parameter modified by variable",
expression: `foo["hey"] * bar[1]`,
parameter: map[string]interface{}{
"foo": map[string]interface{}{"hey": 5.0},
"bar": []interface{}{7., 2.0},
},
want: 10.0,
},
{
name: "Multiple multiplications of the same parameter",
expression: "foo * foo * foo",
parameter: map[string]interface{}{
"foo": 10.0,
},
want: 1000.0,
},
{
name: "Multiple additions of the same parameter",
expression: "foo + foo + foo",
parameter: map[string]interface{}{
"foo": 10.0,
},
want: 30.0,
},
{
name: "NoSpaceOperator",
expression: "true&&name",
parameter: map[string]interface{}{
"name": true,
},
want: true,
},
{
name: "Parameter name sensitivity",
expression: "foo + FoO + FOO",
parameter: map[string]interface{}{
"foo": 8.0,
"FoO": 4.0,
"FOO": 2.0,
},
want: 14.0,
},
{
name: "Sign prefix comparison against prefixed variable",
expression: "-1 < -foo",
parameter: map[string]interface{}{"foo": -8.0},
want: true,
},
{
name: "Fixed-point parameter",
expression: "foo > 1",
parameter: map[string]interface{}{"foo": 2},
want: true,
},
{
name: "Modifier after closing clause",
expression: "(2 + 2) + 2 == 6",
want: true,
},
{
name: "Comparator after closing clause",
expression: "(2 + 2) >= 4",
want: true,
},
{
name: "Two-boolean logical operation (for issue #8)",
expression: "(foo == true) || (bar == true)",
parameter: map[string]interface{}{
"foo": true,
"bar": false,
},
want: true,
},
{
name: "Two-variable integer logical operation (for issue #8)",
expression: "foo > 10 && bar > 10",
parameter: map[string]interface{}{
"foo": 1,
"bar": 11,
},
want: false,
},
{
name: "Regex against right-hand parameter",
expression: `"foobar" =~ foo`,
parameter: map[string]interface{}{
"foo": "obar",
},
want: true,
},
{
name: "Not-regex against right-hand parameter",
expression: `"foobar" !~ foo`,
parameter: map[string]interface{}{
"foo": "baz",
},
want: true,
},
{
name: "Regex against two parameter",
expression: `foo =~ bar`,
parameter: map[string]interface{}{
"foo": "foobar",
"bar": "oba",
},
want: true,
},
{
name: "Not-regex against two parameter",
expression: "foo !~ bar",
parameter: map[string]interface{}{
"foo": "foobar",
"bar": "baz",
},
want: true,
},
{
name: "Pre-compiled regex",
expression: "foo =~ bar",
parameter: map[string]interface{}{
"foo": "foobar",
"bar": regexp.MustCompile("[fF][oO]+"),
},
want: true,
},
{
name: "Pre-compiled not-regex",
expression: "foo !~ bar",
parameter: map[string]interface{}{
"foo": "foobar",
"bar": regexp.MustCompile("[fF][oO]+"),
},
want: false,
},
{
name: "Single boolean parameter",
expression: "commission ? 10",
parameter: map[string]interface{}{
"commission": true},
want: 10.0,
},
{
name: "True comparator with a parameter",
expression: `partner == "amazon" ? 10`,
parameter: map[string]interface{}{
"partner": "amazon"},
want: 10.0,
},
{
name: "False comparator with a parameter",
expression: `partner == "amazon" ? 10`,
parameter: map[string]interface{}{
"partner": "ebay"},
want: nil,
},
{
name: "True comparator with multiple parameters",
expression: "theft && period == 24 ? 60",
parameter: map[string]interface{}{
"theft": true,
"period": 24,
},
want: 60.0,
},
{
name: "False comparator with multiple parameters",
expression: "theft && period == 24 ? 60",
parameter: map[string]interface{}{
"theft": false,
"period": 24,
},
want: nil,
},
{
name: "String concat with single string parameter",
expression: `foo + "bar"`,
parameter: map[string]interface{}{
"foo": "baz"},
want: "bazbar",
},
{
name: "String concat with multiple string parameter",
expression: "foo + bar",
parameter: map[string]interface{}{
"foo": "baz",
"bar": "quux",
},
want: "bazquux",
},
{
name: "String concat with float parameter",
expression: "foo + bar",
parameter: map[string]interface{}{
"foo": "baz",
"bar": 123.0,
},
want: "baz123",
},
{
name: "Mixed multiple string concat",
expression: `foo + 123 + "bar" + true`,
parameter: map[string]interface{}{"foo": "baz"},
want: "baz123bartrue",
},
{
name: "Integer width spectrum",
expression: "uint8 + uint16 + uint32 + uint64 + int8 + int16 + int32 + int64",
parameter: map[string]interface{}{
"uint8": uint8(0),
"uint16": uint16(0),
"uint32": uint32(0),
"uint64": uint64(0),
"int8": int8(0),
"int16": int16(0),
"int32": int32(0),
"int64": int64(0),
},
want: 0.0,
},
{
name: "Null coalesce right",
expression: "foo ?? 1.0",
parameter: map[string]interface{}{"foo": nil},
want: 1.0,
},
{
name: "Multiple comparator/logical operators (#30)",
expression: "(foo >= 2887057408 && foo <= 2887122943) || (foo >= 168100864 && foo <= 168118271)",
parameter: map[string]interface{}{"foo": 2887057409},
want: true,
},
{
name: "Multiple comparator/logical operators, opposite order (#30)",
expression: "(foo >= 168100864 && foo <= 168118271) || (foo >= 2887057408 && foo <= 2887122943)",
parameter: map[string]interface{}{"foo": 2887057409},
want: true,
},
{
name: "Multiple comparator/logical operators, small value (#30)",
expression: "(foo >= 2887057408 && foo <= 2887122943) || (foo >= 168100864 && foo <= 168118271)",
parameter: map[string]interface{}{"foo": 168100865},
want: true,
},
{
name: "Multiple comparator/logical operators, small value, opposite order (#30)",
expression: "(foo >= 168100864 && foo <= 168118271) || (foo >= 2887057408 && foo <= 2887122943)",
parameter: map[string]interface{}{"foo": 168100865},
want: true,
},
{
name: "Incomparable array equality comparison",
expression: "arr == arr",
parameter: map[string]interface{}{"arr": []int{0, 0, 0}},
want: true,
},
{
name: "Incomparable array not-equality comparison",
expression: "arr != arr",
parameter: map[string]interface{}{"arr": []int{0, 0, 0}},
want: false,
},
{
name: "Mixed function and parameters",
expression: "sum(1.2, amount) + name",
extension: Function("sum", func(arguments ...interface{}) (interface{}, error) {
sum := 0.0
for _, v := range arguments {
sum += v.(float64)
}
return sum, nil
},
),
parameter: map[string]interface{}{"amount": .8,
"name": "awesome",
},
want: "2awesome",
},
{
name: "Short-circuit OR",
expression: "true || fail()",
extension: Function("fail", func(arguments ...interface{}) (interface{}, error) {
return nil, fmt.Errorf("Did not short-circuit")
}),
want: true,
},
{
name: "Short-circuit AND",
expression: "false && fail()",
extension: Function("fail", func(arguments ...interface{}) (interface{}, error) {
return nil, fmt.Errorf("Did not short-circuit")
}),
want: false,
},
{
name: "Short-circuit ternary",
expression: "true ? 1 : fail()",
extension: Function("fail", func(arguments ...interface{}) (interface{}, error) {
return nil, fmt.Errorf("Did not short-circuit")
}),
want: 1.0,
},
{
name: "Short-circuit coalesce",
expression: `"foo" ?? fail()`,
extension: Function("fail", func(arguments ...interface{}) (interface{}, error) {
return nil, fmt.Errorf("Did not short-circuit")
}),
want: "foo",
},
{
name: "Simple parameter call",
expression: "foo.String",
parameter: map[string]interface{}{"foo": foo},
want: foo.String,
},
{
name: "Simple parameter function call",
expression: "foo.Func()",
parameter: map[string]interface{}{"foo": foo},
want: "funk",
},
{
name: "Simple parameter call from pointer",
expression: "fooptr.String",
parameter: map[string]interface{}{"fooptr": &foo},
want: foo.String,
},
{
name: "Simple parameter function call from pointer",
expression: "fooptr.Func()",
parameter: map[string]interface{}{"fooptr": &foo},
want: "funk",
},
{
name: "Simple parameter call",
expression: `foo.String == "hi"`,
parameter: map[string]interface{}{"foo": foo},
want: false,
},
{
name: "Simple parameter call with modifier",
expression: `foo.String + "hi"`,
parameter: map[string]interface{}{"foo": foo},
want: foo.String + "hi",
},
{
name: "Simple parameter function call, two-arg return",
expression: `foo.Func2()`,
parameter: map[string]interface{}{"foo": foo},
want: "frink",
},
{
name: "Simple parameter function call, one arg",
expression: `foo.FuncArgStr("boop")`,
parameter: map[string]interface{}{"foo": foo},
want: "boop",
},
{
name: "Simple parameter function call, one arg",
expression: `foo.FuncArgStr("boop") + "hi"`,
parameter: map[string]interface{}{"foo": foo},
want: "boophi",
},
{
name: "Nested parameter function call",
expression: `foo.Nested.Dunk("boop")`,
parameter: map[string]interface{}{"foo": foo},
want: "boopdunk",
},
{
name: "Nested parameter call",
expression: "foo.Nested.Funk",
parameter: map[string]interface{}{"foo": foo},
want: "funkalicious",
},
{
name: "Nested map call",
expression: `foo.Nested.Map["a"]`,
parameter: map[string]interface{}{"foo": foo},
want: 1,
},
{
name: "Nested slice call",
expression: `foo.Nested.Slice[1]`,
parameter: map[string]interface{}{"foo": foo},
want: 2,
},
{
name: "Parameter call with + modifier",
expression: "1 + foo.Int",
parameter: map[string]interface{}{"foo": foo},
want: 102.0,
},
{
name: "Parameter string call with + modifier",
expression: `"woop" + (foo.String)`,
parameter: map[string]interface{}{"foo": foo},
want: "woopstring!",
},
{
name: "Parameter call with && operator",
expression: "true && foo.BoolFalse",
parameter: map[string]interface{}{"foo": foo},
want: false,
},
{
name: "Null coalesce nested parameter",
expression: "foo.Nil ?? false",
parameter: map[string]interface{}{"foo": foo},
want: false,
},
{
name: "input functions",
expression: "func1() + func2()",
parameter: map[string]interface{}{
"func1": func() float64 { return 2000 },
"func2": func() float64 { return 2001 },
},
want: 4001.0,
},
{
name: "input functions",
expression: "func1(date1) + func2(date2)",
parameter: map[string]interface{}{
"date1": func() interface{} {
y2k, _ := time.Parse("2006", "2000")
return y2k
}(),
"date2": func() interface{} {
y2k1, _ := time.Parse("2006", "2001")
return y2k1
}(),
},
extension: NewLanguage(
Function("func1", func(arguments ...interface{}) (interface{}, error) {
return float64(arguments[0].(time.Time).Year()), nil
}),
Function("func2", func(arguments ...interface{}) (interface{}, error) {
return float64(arguments[0].(time.Time).Year()), nil
}),
),
want: 4001.0,
},
{
name: "complex64 number as parameter",
expression: "complex64",
parameter: map[string]interface{}{
"complex64": complex64(0),
"complex128": complex128(0),
},
want: complex64(0),
},
{
name: "complex128 number as parameter",
expression: "complex128",
parameter: map[string]interface{}{
"complex64": complex64(0),
"complex128": complex128(0),
},
want: complex128(0),
},
{
name: "coalesce with undefined",
expression: "fooz ?? foo",
parameter: map[string]interface{}{
"foo": "bar",
},
want: "bar",
},
{
name: "map[interface{}]interface{}",
expression: "foo",
parameter: map[interface{}]interface{}{
"foo": "bar",
},
want: "bar",
},
{
name: "method on pointer type",
expression: "foo.PointerFunc()",
parameter: map[string]interface{}{
"foo": &dummyParameter{},
},
want: "point",
},
{
name: "custom selector",
expression: "hello.world",
parameter: "!",
extension: NewLanguage(Base(), VariableSelector(func(path Evaluables) Evaluable {
return func(c context.Context, v interface{}) (interface{}, error) {
keys, err := path.EvalStrings(c, v)
if err != nil {
return nil, err
}
return fmt.Sprintf("%s%s", strings.Join(keys, " "), v), nil
}
})),
want: "hello world!",
},
{
name: "map[int]int",
expression: `a[0] + a[2]`,
parameter: map[string]interface{}{
"a": map[int]int{0: 1, 2: 1},
},
want: 2.,
},
{
name: "map[int]string",
expression: `a[0] * a[2]`,
parameter: map[string]interface{}{
"a": map[int]string{0: "1", 2: "1"},
},
want: 1.,
},
{
name: "coalesce typed nil 0",
expression: `ProjectID ?? 0`,
parameter: struct {
ProjectID *uint
}{},
want: 0.,
},
{
name: "coalesce typed nil 99",
expression: `ProjectID ?? 99`,
parameter: struct {
ProjectID *uint
}{},
want: 99.,
},
{
name: "operator with typed nil 99",
expression: `ProjectID + 99`,
parameter: struct {
ProjectID *uint
}{},
want: "<nil>99",
},
{
name: "operator with typed nil if",
expression: `Flag ? 1 : 2`,
parameter: struct {
Flag *uint
}{},
want: 2.,
},
{
name: "Decimal math doesn't experience rounding error",
expression: "(x * 12.146) - y",
extension: decimalArithmetic,
parameter: map[string]interface{}{
"x": 12.5,
"y": -5,
},
want: decimal.NewFromFloat(156.825),
equalityFunc: decimalEqualityFunc,
},
{
name: "Decimal logical operators fractional difference",
expression: "((x * 12.146) - y) > 156.824999999",
extension: decimalArithmetic,
parameter: map[string]interface{}{
"x": 12.5,
"y": -5,
},
want: true,
},
{
name: "Decimal logical operators whole number difference",
expression: "((x * 12.146) - y) > 156",
extension: decimalArithmetic,
parameter: map[string]interface{}{
"x": 12.5,
"y": -5,
},
want: true,
},
{
name: "Decimal logical operators exact decimal match against GT",
expression: "((x * 12.146) - y) > 156.825",
extension: decimalArithmetic,
parameter: map[string]interface{}{
"x": 12.5,
"y": -5,
},
want: false,
},
{
name: "Decimal logical operators exact equality",
expression: "((x * 12.146) - y) == 156.825",
extension: decimalArithmetic,
parameter: map[string]interface{}{
"x": 12.5,
"y": -5,
},
want: true,
},
{
name: "Decimal mixes with string logic with force fail",
expression: `(((x * 12.146) - y) == 156.825) && a == "test" && !b && b`,
extension: decimalArithmetic,
parameter: map[string]interface{}{
"x": 12.5,
"y": -5,
"a": "test",
"b": false,
},
want: false,
},
{
name: "Typed map with function call",
expression: `foo.MapWithFunc.Sum("a")`,
parameter: map[string]interface{}{
"foo": foo,
},
want: 3,
},
{
name: "Types slice with function call",
expression: `foo.SliceWithFunc.Sum("a")`,
parameter: map[string]interface{}{
"foo": foo,
},
want: 2,
},
},
t,
)
}

View File

@ -0,0 +1,179 @@
package gval
import (
"regexp/syntax"
"testing"
)
func TestParsingFailure(t *testing.T) {
testEvaluate(
[]evaluationTest{
{
name: "Invalid equality comparator",
expression: "1 = 1",
wantErr: unexpected(`"="`, "operator"),
},
{
name: "Invalid equality comparator",
expression: "1 === 1",
wantErr: unexpected(`"="`, "extension"),
},
{
name: "Too many characters for logical operator",
expression: "true &&& false",
wantErr: unexpected(`"&"`, "extension"),
},
{
name: "Too many characters for logical operator",
expression: "true ||| false",
wantErr: unexpected(`"|"`, "extension"),
},
{
name: "Premature end to expression, via modifier",
expression: "10 > 5 +",
wantErr: unexpected("EOF", "extensions"),
},
{
name: "Premature end to expression, via comparator",
expression: "10 + 5 >",
wantErr: unexpected("EOF", "extensions"),
},
{
name: "Premature end to expression, via logical operator",
expression: "10 > 5 &&",
wantErr: unexpected("EOF", "extensions"),
},
{
name: "Premature end to expression, via ternary operator",
expression: "true ?",
wantErr: unexpected("EOF", "extensions"),
},
{
name: "Hanging REQ",
expression: "`wat` =~",
wantErr: unexpected("EOF", "extensions"),
},
{
name: "Invalid operator change to REQ",
expression: " / =~",
wantErr: unexpected(`"/"`, "extensions"),
},
{
name: "Invalid starting token, comparator",
expression: "> 10",
wantErr: unexpected(`">"`, "extensions"),
},
{
name: "Invalid starting token, modifier",
expression: "+ 5",
wantErr: unexpected(`"+"`, "extensions"),
},
{
name: "Invalid starting token, logical operator",
expression: "&& 5 < 10",
wantErr: unexpected(`"&"`, "extensions"),
},
{
name: "Invalid NUMERIC transition",
expression: "10 10",
wantErr: unexpected(`Int`, "operator"),
},
{
name: "Invalid STRING transition",
expression: "`foo` `foo`",
wantErr: `String while scanning operator`, // can't use func unexpected because the token was changed from String to RawString in go 1.11
},
{
name: "Invalid operator transition",
expression: "10 > < 10",
wantErr: unexpected(`"<"`, "extensions"),
},
{
name: "Starting with unbalanced parens",
expression: " ) ( arg2",
wantErr: unexpected(`")"`, "extensions"),
},
{
name: "Unclosed bracket",
expression: "[foo bar",
wantErr: unexpected(`EOF`, "extensions"),
},
{
name: "Unclosed quote",
expression: "foo == `responseTime",
wantErr: "could not parse string",
},
{
name: "Constant regex pattern fail to compile",
expression: "foo =~ `[abc`",
wantErr: string(syntax.ErrMissingBracket),
},
{
name: "Constant unmatch regex pattern fail to compile",
expression: "foo !~ `[abc`",
wantErr: string(syntax.ErrMissingBracket),
},
{
name: "Unbalanced parentheses",
expression: "10 > (1 + 50",
wantErr: unexpected(`EOF`, "parentheses"),
},
{
name: "Multiple radix",
expression: "127.0.0.1",
wantErr: unexpected(`Float`, "operator"),
},
{
name: "Hanging accessor",
expression: "foo.Bar.",
wantErr: unexpected(`EOF`, "field"),
},
{
name: "Incomplete Hex",
expression: "0x",
wantErr: `strconv.ParseFloat: parsing "0x": invalid syntax`,
},
{
name: "Invalid Hex literal",
expression: "0x > 0",
wantErr: `strconv.ParseFloat: parsing "0x": invalid syntax`,
},
{
name: "Hex float (Unsupported)",
expression: "0x1.1",
wantErr: `strconv.ParseFloat: parsing "0x1.1": invalid syntax`,
},
{
name: "Hex invalid letter",
expression: "0x12g1",
wantErr: `strconv.ParseFloat: parsing "0x12": invalid syntax`,
},
{
name: "Error after camouflage",
expression: "0 + ,",
wantErr: `unexpected "," while scanning extensions`,
},
},
t,
)
}
func unknownOp(op string) string {
return "unknown operator " + op
}
func unexpected(token, unit string) string {
return "unexpected " + token + " while scanning " + unit
}

153
gval/gval_test.go Normal file
View File

@ -0,0 +1,153 @@
package gval
import (
"fmt"
"reflect"
"strings"
"testing"
"github.com/shopspring/decimal"
)
type evaluationTest struct {
name string
expression string
extension Language
parameter interface{}
want interface{}
equalityFunc func(x, y interface{}) bool
wantErr string
}
func testEvaluate(tests []evaluationTest, t *testing.T) {
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got, err := Evaluate(tt.expression, tt.parameter, tt.extension)
if tt.wantErr != "" {
if err == nil {
t.Fatalf("Evaluate(%s) expected error but got %v", tt.expression, got)
}
if !strings.Contains(err.Error(), tt.wantErr) {
t.Fatalf("Evaluate(%s) expected error %s but got error %v", tt.expression, tt.wantErr, err)
}
return
}
if err != nil {
t.Errorf("Evaluate() error = %v", err)
return
}
if ef := tt.equalityFunc; ef != nil {
if !ef(got, tt.want) {
t.Errorf("Evaluate(%s) = %v, want %v", tt.expression, got, tt.want)
}
} else if !reflect.DeepEqual(got, tt.want) {
t.Errorf("Evaluate(%s) = %v, want %v", tt.expression, got, tt.want)
}
})
}
}
// dummyParameter used to test "parameter calls".
type dummyParameter struct {
String string
Int int
BoolFalse bool
Nil interface{}
Nested dummyNestedParameter
MapWithFunc dummyMapWithFunc
SliceWithFunc dummySliceWithFunc
}
func (d dummyParameter) Func() string {
return "funk"
}
func (d dummyParameter) Func2() (string, error) {
return "frink", nil
}
func (d *dummyParameter) PointerFunc() (string, error) {
return "point", nil
}
func (d dummyParameter) FuncErr() (string, error) {
return "", fmt.Errorf("fumps")
}
func (d dummyParameter) FuncArgStr(arg1 string) string {
return arg1
}
func (d dummyParameter) AlwaysFail() (interface{}, error) {
return nil, fmt.Errorf("function should always fail")
}
type dummyNestedParameter struct {
Funk string
Map map[string]int
Slice []int
}
func (d dummyNestedParameter) Dunk(arg1 string) string {
return arg1 + "dunk"
}
var foo = dummyParameter{
String: "string!",
Int: 101,
BoolFalse: false,
Nil: nil,
Nested: dummyNestedParameter{
Funk: "funkalicious",
Map: map[string]int{"a": 1, "b": 2, "c": 3},
Slice: []int{1, 2, 3},
},
MapWithFunc: dummyMapWithFunc{"a": {1, 2}, "b": {3, 4}},
SliceWithFunc: dummySliceWithFunc{"a", "b", "c", "a"},
}
var fooFailureParameters = map[string]interface{}{
"foo": foo,
"fooptr": &foo,
}
var decimalEqualityFunc = func(x, y interface{}) bool {
v1, ok1 := x.(decimal.Decimal)
v2, ok2 := y.(decimal.Decimal)
if !ok1 || !ok2 {
return false
}
return v1.Equal(v2)
}
type dummyMapWithFunc map[string][]int
func (m dummyMapWithFunc) Sum(key string) int {
values, ok := m[key]
if !ok {
return -1
}
sum := 0
for _, v := range values {
sum += v
}
return sum
}
type dummySliceWithFunc []string
func (m dummySliceWithFunc) Sum(key string) int {
sum := 0
for _, v := range m {
if v == key {
sum += 1
}
}
return sum
}

281
gval/language.go Normal file
View File

@ -0,0 +1,281 @@
package gval
import (
"context"
"fmt"
"text/scanner"
"unicode"
"github.com/shopspring/decimal"
)
// Language is an expression language
type Language struct {
prefixes map[interface{}]extension
operators map[string]operator
operatorSymbols map[rune]struct{}
init extension
def extension
selector func(Evaluables) Evaluable
}
// NewLanguage returns the union of given Languages as new Language.
func NewLanguage(bases ...Language) Language {
l := newLanguage()
for _, base := range bases {
for i, e := range base.prefixes {
l.prefixes[i] = e
}
for i, e := range base.operators {
l.operators[i] = e.merge(l.operators[i])
l.operators[i].initiate(i)
}
for i := range base.operatorSymbols {
l.operatorSymbols[i] = struct{}{}
}
if base.init != nil {
l.init = base.init
}
if base.def != nil {
l.def = base.def
}
if base.selector != nil {
l.selector = base.selector
}
}
return l
}
func newLanguage() Language {
return Language{
prefixes: map[interface{}]extension{},
operators: map[string]operator{},
operatorSymbols: map[rune]struct{}{},
}
}
// NewEvaluable returns an Evaluable for given expression in the specified language
func (l Language) NewEvaluable(expression string) (Evaluable, error) {
return l.NewEvaluableWithContext(context.Background(), expression)
}
// NewEvaluableWithContext returns an Evaluable for given expression in the specified language using context
func (l Language) NewEvaluableWithContext(c context.Context, expression string) (Evaluable, error) {
p := newParser(expression, l)
eval, err := p.parse(c)
if err == nil && p.isCamouflaged() && p.lastScan != scanner.EOF {
err = p.camouflage
}
if err != nil {
pos := p.scanner.Pos()
return nil, fmt.Errorf("parsing error: %s - %d:%d %w", p.scanner.Position, pos.Line, pos.Column, err)
}
return eval, nil
}
// Evaluate given parameter with given expression
func (l Language) Evaluate(expression string, parameter interface{}) (interface{}, error) {
return l.EvaluateWithContext(context.Background(), expression, parameter)
}
// Evaluate given parameter with given expression using context
func (l Language) EvaluateWithContext(c context.Context, expression string, parameter interface{}) (interface{}, error) {
eval, err := l.NewEvaluableWithContext(c, expression)
if err != nil {
return nil, err
}
v, err := eval(c, parameter)
if err != nil {
return nil, fmt.Errorf("can not evaluate %s: %w", expression, err)
}
return v, nil
}
// Function returns a Language with given function.
// Function has no conversion for input types.
//
// If the function returns an error it must be the last return parameter.
//
// If the function has (without the error) more then one return parameter,
// it returns them as []interface{}.
func Function(name string, function interface{}) Language {
l := newLanguage()
l.prefixes[name] = func(c context.Context, p *Parser) (eval Evaluable, err error) {
args := []Evaluable{}
scan := p.Scan()
switch scan {
case '(':
args, err = p.parseArguments(c)
if err != nil {
return nil, err
}
default:
p.Camouflage("function call", '(')
}
return p.callFunc(toFunc(function), args...), nil
}
return l
}
// Constant returns a Language with given constant
func Constant(name string, value interface{}) Language {
l := newLanguage()
l.prefixes[l.makePrefixKey(name)] = func(c context.Context, p *Parser) (eval Evaluable, err error) {
return p.Const(value), nil
}
return l
}
// PrefixExtension extends a Language
func PrefixExtension(r rune, ext func(context.Context, *Parser) (Evaluable, error)) Language {
l := newLanguage()
l.prefixes[r] = ext
return l
}
// Init is a language that does no parsing, but invokes the given function when
// parsing starts. It is incumbent upon the function to call ParseExpression to
// continue parsing.
//
// This function can be used to customize the parser settings, such as
// whitespace or ident behavior.
func Init(ext func(context.Context, *Parser) (Evaluable, error)) Language {
l := newLanguage()
l.init = ext
return l
}
// DefaultExtension is a language that runs the given function if no other
// prefix matches.
func DefaultExtension(ext func(context.Context, *Parser) (Evaluable, error)) Language {
l := newLanguage()
l.def = ext
return l
}
// PrefixMetaPrefix chooses a Prefix to be executed
func PrefixMetaPrefix(r rune, ext func(context.Context, *Parser) (call string, alternative func() (Evaluable, error), err error)) Language {
l := newLanguage()
l.prefixes[r] = func(c context.Context, p *Parser) (Evaluable, error) {
call, alternative, err := ext(c, p)
if err != nil {
return nil, err
}
if prefix, ok := p.prefixes[l.makePrefixKey(call)]; ok {
return prefix(c, p)
}
return alternative()
}
return l
}
// PrefixOperator returns a Language with given prefix
func PrefixOperator(name string, e Evaluable) Language {
l := newLanguage()
l.prefixes[l.makePrefixKey(name)] = func(c context.Context, p *Parser) (Evaluable, error) {
eval, err := p.ParseNextExpression(c)
if err != nil {
return nil, err
}
prefix := func(c context.Context, v interface{}) (interface{}, error) {
a, err := eval(c, v)
if err != nil {
return nil, err
}
return e(c, a)
}
if eval.IsConst() {
v, err := prefix(c, nil)
if err != nil {
return nil, err
}
prefix = p.Const(v)
}
return prefix, nil
}
return l
}
// PostfixOperator extends a Language.
func PostfixOperator(name string, ext func(context.Context, *Parser, Evaluable) (Evaluable, error)) Language {
l := newLanguage()
l.operators[l.makeInfixKey(name)] = postfix{
f: func(c context.Context, p *Parser, eval Evaluable, pre operatorPrecedence) (Evaluable, error) {
return ext(c, p, eval)
},
}
return l
}
// InfixOperator for two arbitrary values.
func InfixOperator(name string, f func(a, b interface{}) (interface{}, error)) Language {
return newLanguageOperator(name, &infix{arbitrary: f})
}
// InfixShortCircuit operator is called after the left operand is evaluated.
func InfixShortCircuit(name string, f func(a interface{}) (interface{}, bool)) Language {
return newLanguageOperator(name, &infix{shortCircuit: f})
}
// InfixTextOperator for two text values.
func InfixTextOperator(name string, f func(a, b string) (interface{}, error)) Language {
return newLanguageOperator(name, &infix{text: f})
}
// InfixNumberOperator for two number values.
func InfixNumberOperator(name string, f func(a, b uint) (interface{}, error)) Language {
return newLanguageOperator(name, &infix{number: f})
}
// InfixDecimalOperator for two decimal values.
func InfixDecimalOperator(name string, f func(a, b decimal.Decimal) (interface{}, error)) Language {
return newLanguageOperator(name, &infix{decimal: f})
}
// InfixBoolOperator for two bool values.
func InfixBoolOperator(name string, f func(a, b bool) (interface{}, error)) Language {
return newLanguageOperator(name, &infix{boolean: f})
}
// Precedence of operator. The Operator with higher operatorPrecedence is evaluated first.
func Precedence(name string, operatorPrecendence uint8) Language {
return newLanguageOperator(name, operatorPrecedence(operatorPrecendence))
}
// InfixEvalOperator operates on the raw operands.
// Therefore it cannot be combined with operators for other operand types.
func InfixEvalOperator(name string, f func(a, b Evaluable) (Evaluable, error)) Language {
return newLanguageOperator(name, directInfix{infixBuilder: f})
}
func newLanguageOperator(name string, op operator) Language {
op.initiate(name)
l := newLanguage()
l.operators[l.makeInfixKey(name)] = op
return l
}
func (l *Language) makePrefixKey(key string) interface{} {
runes := []rune(key)
if len(runes) == 1 && !unicode.IsLetter(runes[0]) {
return runes[0]
}
return key
}
func (l *Language) makeInfixKey(key string) string {
for _, r := range key {
l.operatorSymbols[r] = struct{}{}
}
return key
}
// VariableSelector returns a Language which uses given variable selector.
// It must be combined with a Language that uses the vatiable selector. E.g. gval.Base().
func VariableSelector(selector func(path Evaluables) Evaluable) Language {
l := newLanguage()
l.selector = selector
return l
}

403
gval/operator.go Normal file
View File

@ -0,0 +1,403 @@
package gval
import (
"context"
"fmt"
"reflect"
"strconv"
"strings"
"github.com/shopspring/decimal"
)
type stage struct {
Evaluable
infixBuilder
operatorPrecedence
}
type stageStack []stage //operatorPrecedence in stacktStage is continuously, monotone ascending
func (s *stageStack) push(b stage) error {
for len(*s) > 0 && s.peek().operatorPrecedence >= b.operatorPrecedence {
a := s.pop()
eval, err := a.infixBuilder(a.Evaluable, b.Evaluable)
if err != nil {
return err
}
if a.IsConst() && b.IsConst() {
v, err := eval(nil, nil)
if err != nil {
return err
}
b.Evaluable = constant(v)
continue
}
b.Evaluable = eval
}
*s = append(*s, b)
return nil
}
func (s *stageStack) peek() stage {
return (*s)[len(*s)-1]
}
func (s *stageStack) pop() stage {
a := s.peek()
(*s) = (*s)[:len(*s)-1]
return a
}
type infixBuilder func(a, b Evaluable) (Evaluable, error)
func (l Language) isSymbolOperation(r rune) bool {
_, in := l.operatorSymbols[r]
return in
}
func (l Language) isOperatorPrefix(op string) bool {
for k := range l.operators {
if strings.HasPrefix(k, op) {
return true
}
}
return false
}
func (op *infix) initiate(name string) {
f := func(a, b interface{}) (interface{}, error) {
return nil, fmt.Errorf("invalid operation (%T) %s (%T)", a, name, b)
}
if op.arbitrary != nil {
f = op.arbitrary
}
for _, typeConvertion := range []bool{true, false} {
if op.text != nil && (!typeConvertion || op.arbitrary == nil) {
f = getStringOpFunc(op.text, f, typeConvertion)
}
if op.boolean != nil {
f = getBoolOpFunc(op.boolean, f, typeConvertion)
}
if op.number != nil {
f = getUintOpFunc(op.number, f, typeConvertion)
}
if op.decimal != nil {
f = getDecimalOpFunc(op.decimal, f, typeConvertion)
}
}
if op.shortCircuit == nil {
op.builder = func(a, b Evaluable) (Evaluable, error) {
return func(c context.Context, x interface{}) (interface{}, error) {
a, err := a(c, x)
if err != nil {
return nil, err
}
b, err := b(c, x)
if err != nil {
return nil, err
}
return f(a, b)
}, nil
}
return
}
shortF := op.shortCircuit
op.builder = func(a, b Evaluable) (Evaluable, error) {
return func(c context.Context, x interface{}) (interface{}, error) {
a, err := a(c, x)
if err != nil {
return nil, err
}
if r, ok := shortF(a); ok {
return r, nil
}
b, err := b(c, x)
if err != nil {
return nil, err
}
return f(a, b)
}, nil
}
}
type opFunc func(a, b interface{}) (interface{}, error)
func getStringOpFunc(s func(a, b string) (interface{}, error), f opFunc, typeConversion bool) opFunc {
if typeConversion {
return func(a, b interface{}) (interface{}, error) {
if a != nil && b != nil {
return s(fmt.Sprintf("%v", a), fmt.Sprintf("%v", b))
}
return f(a, b)
}
}
return func(a, b interface{}) (interface{}, error) {
s1, k := a.(string)
s2, l := b.(string)
if k && l {
return s(s1, s2)
}
return f(a, b)
}
}
func convertToBool(o interface{}) (bool, bool) {
if b, ok := o.(bool); ok {
return b, true
}
v := reflect.ValueOf(o)
if v.Kind() == reflect.Func {
if vt := v.Type(); vt.NumIn() == 0 && vt.NumOut() == 1 {
retType := vt.Out(0)
if retType.Kind() == reflect.Bool {
funcResults := v.Call([]reflect.Value{})
v = funcResults[0]
o = v.Interface()
}
}
}
for o != nil && v.Kind() == reflect.Ptr {
v = v.Elem()
if !v.IsValid() {
return false, false
}
o = v.Interface()
}
if o == false || o == nil || o == "false" || o == "FALSE" {
return false, true
}
if o == true || o == "true" || o == "TRUE" {
return true, true
}
if f, ok := convertToUint(o); ok {
return f != 0., true
}
return false, false
}
func getBoolOpFunc(o func(a, b bool) (interface{}, error), f opFunc, typeConversion bool) opFunc {
if typeConversion {
return func(a, b interface{}) (interface{}, error) {
x, k := convertToBool(a)
y, l := convertToBool(b)
if k && l {
return o(x, y)
}
return f(a, b)
}
}
return func(a, b interface{}) (interface{}, error) {
x, k := a.(bool)
y, l := b.(bool)
if k && l {
return o(x, y)
}
return f(a, b)
}
}
func convertToUint(o interface{}) (uint, bool) {
if i, ok := o.(uint); ok {
return i, true
}
v := reflect.ValueOf(o)
for o != nil && v.Kind() == reflect.Ptr {
v = v.Elem()
if !v.IsValid() {
return 0, false
}
o = v.Interface()
}
switch v.Kind() {
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
return uint(v.Int()), true
case reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
return uint(v.Uint()), true
}
if s, ok := o.(string); ok {
u, err := strconv.ParseUint(s, 0, 32)
if err == nil {
return uint(u), true
}
}
return 0, false
}
func getUintOpFunc(o func(a, b uint) (interface{}, error), f opFunc, typeConversion bool) opFunc {
if typeConversion {
return func(a, b interface{}) (interface{}, error) {
x, k := convertToUint(a)
y, l := convertToUint(b)
if k && l {
return o(x, y)
}
return f(a, b)
}
}
return func(a, b interface{}) (interface{}, error) {
x, k := a.(uint)
y, l := b.(uint)
if k && l {
return o(x, y)
}
return f(a, b)
}
}
func convertToDecimal(o interface{}) (decimal.Decimal, bool) {
if i, ok := o.(decimal.Decimal); ok {
return i, true
}
if i, ok := o.(float64); ok {
return decimal.NewFromFloat(i), true
}
v := reflect.ValueOf(o)
for o != nil && v.Kind() == reflect.Ptr {
v = v.Elem()
if !v.IsValid() {
return decimal.Zero, false
}
o = v.Interface()
}
switch v.Kind() {
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
return decimal.NewFromInt(v.Int()), true
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
return decimal.NewFromFloat(float64(v.Uint())), true
case reflect.Float32, reflect.Float64:
return decimal.NewFromFloat(v.Float()), true
}
if s, ok := o.(string); ok {
f, err := strconv.ParseFloat(s, 64)
if err == nil {
return decimal.NewFromFloat(f), true
}
}
return decimal.Zero, false
}
func getDecimalOpFunc(o func(a, b decimal.Decimal) (interface{}, error), f opFunc, typeConversion bool) opFunc {
if typeConversion {
return func(a, b interface{}) (interface{}, error) {
x, k := convertToDecimal(a)
y, l := convertToDecimal(b)
if k && l {
return o(x, y)
}
return f(a, b)
}
}
return func(a, b interface{}) (interface{}, error) {
x, k := a.(decimal.Decimal)
y, l := b.(decimal.Decimal)
if k && l {
return o(x, y)
}
return f(a, b)
}
}
type operator interface {
merge(operator) operator
precedence() operatorPrecedence
initiate(name string)
}
type operatorPrecedence uint8
func (pre operatorPrecedence) merge(op operator) operator {
if op, ok := op.(operatorPrecedence); ok {
if op > pre {
return op
}
return pre
}
if op == nil {
return pre
}
return op.merge(pre)
}
func (pre operatorPrecedence) precedence() operatorPrecedence {
return pre
}
func (pre operatorPrecedence) initiate(name string) {}
type infix struct {
operatorPrecedence
number func(a, b uint) (interface{}, error)
decimal func(a, b decimal.Decimal) (interface{}, error)
boolean func(a, b bool) (interface{}, error)
text func(a, b string) (interface{}, error)
arbitrary func(a, b interface{}) (interface{}, error)
shortCircuit func(a interface{}) (interface{}, bool)
builder infixBuilder
}
func (op infix) merge(op2 operator) operator {
switch op2 := op2.(type) {
case *infix:
if op.number == nil {
op.number = op2.number
}
if op.decimal == nil {
op.decimal = op2.decimal
}
if op.boolean == nil {
op.boolean = op2.boolean
}
if op.text == nil {
op.text = op2.text
}
if op.arbitrary == nil {
op.arbitrary = op2.arbitrary
}
if op.shortCircuit == nil {
op.shortCircuit = op2.shortCircuit
}
}
if op2 != nil && op2.precedence() > op.operatorPrecedence {
op.operatorPrecedence = op2.precedence()
}
return &op
}
type directInfix struct {
operatorPrecedence
infixBuilder
}
func (op directInfix) merge(op2 operator) operator {
switch op2 := op2.(type) {
case operatorPrecedence:
op.operatorPrecedence = op2
}
if op2 != nil && op2.precedence() > op.operatorPrecedence {
op.operatorPrecedence = op2.precedence()
}
return op
}
type extension func(context.Context, *Parser) (Evaluable, error)
type postfix struct {
operatorPrecedence
f func(context.Context, *Parser, Evaluable, operatorPrecedence) (Evaluable, error)
}
func (op postfix) merge(op2 operator) operator {
switch op2 := op2.(type) {
case postfix:
if op2.f != nil {
op.f = op2.f
}
}
if op2 != nil && op2.precedence() > op.operatorPrecedence {
op.operatorPrecedence = op2.precedence()
}
return op
}

166
gval/operator_test.go Normal file
View File

@ -0,0 +1,166 @@
package gval
import (
"context"
"fmt"
"reflect"
"testing"
)
func Test_Infix(t *testing.T) {
type subTest struct {
name string
a interface{}
b interface{}
wantRet interface{}
}
tests := []struct {
name string
infix
subTests []subTest
}{
{
"number operator",
infix{
number: func(a, b float64) (interface{}, error) { return a * b, nil },
},
[]subTest{
{"float64 arguments", 7., 3., 21.},
{"int arguments", 7, 3, 21.},
{"string arguments", "7", "3.", 21.},
},
},
{
"number and string operator",
infix{
number: func(a, b float64) (interface{}, error) { return a + b, nil },
text: func(a, b string) (interface{}, error) { return fmt.Sprintf("%v%v", a, b), nil },
},
[]subTest{
{"float64 arguments", 7., 3., 10.},
{"int arguments", 7, 3, 10.},
{"number string arguments", "7", "3.", "73."},
{"string arguments", "hello ", "world", "hello world"},
},
},
{
"bool operator",
infix{
shortCircuit: func(a interface{}) (interface{}, bool) { return false, a == false },
boolean: func(a, b bool) (interface{}, error) { return a && b, nil },
},
[]subTest{
{"bool arguments", false, true, false},
{"number arguments", 0, true, false},
{"lower string arguments", "false", "true", false},
{"upper string arguments", "TRUE", "FALSE", false},
{"shortCircuit", false, "not a boolean", false},
},
},
{
"bool, number, text and interface operator",
infix{
number: func(a, b float64) (interface{}, error) { return a == b, nil },
boolean: func(a, b bool) (interface{}, error) { return a == b, nil },
text: func(a, b string) (interface{}, error) { return a == b, nil },
arbitrary: func(a, b interface{}) (interface{}, error) { return a == b, nil },
},
[]subTest{
{"number string and int arguments", "7", 7, true},
{"bool string and bool arguments", "true", true, true},
{"string arguments", "hello", "hello", true},
{"upper string arguments", "TRUE", "FALSE", false},
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
tt.infix.initiate("<" + tt.name + ">")
builder := tt.infix.builder
for _, tt := range tt.subTests {
t.Run(tt.name, func(t *testing.T) {
eval, err := builder(constant(tt.a), constant(tt.b))
if err != nil {
t.Fatal(err)
}
got, err := eval(context.Background(), nil)
if err != nil {
t.Fatal(err)
}
if !reflect.DeepEqual(got, tt.wantRet) {
t.Fatalf("binaryOperator() eval() = %v, want %v", got, tt.wantRet)
}
})
}
})
}
}
func Test_stageStack_push(t *testing.T) {
p := (*Parser)(nil)
tests := []struct {
name string
pres []operatorPrecedence
expect string
}{
{
"flat",
[]operatorPrecedence{1, 1, 1, 1},
"((((AB)C)D)E)",
},
{
"asc",
[]operatorPrecedence{1, 2, 3, 4},
"(A(B(C(DE))))",
},
{
"desc",
[]operatorPrecedence{4, 3, 2, 1},
"((((AB)C)D)E)",
},
{
"mixed",
[]operatorPrecedence{1, 2, 1, 1},
"(((A(BC))D)E)",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
X := int('A')
op := func(a, b Evaluable) (Evaluable, error) {
return func(c context.Context, o interface{}) (interface{}, error) {
aa, _ := a.EvalString(c, nil)
bb, _ := b.EvalString(c, nil)
s := "(" + aa + bb + ")"
return s, nil
}, nil
}
stack := stageStack{}
for _, pre := range tt.pres {
if err := stack.push(stage{p.Const(string(rune(X))), op, pre}); err != nil {
t.Fatal(err)
}
X++
}
if err := stack.push(stage{p.Const(string(rune(X))), nil, 0}); err != nil {
t.Fatal(err)
}
if len(stack) != 1 {
t.Fatalf("stack must hold exactly one element")
}
got, _ := stack[0].EvalString(context.Background(), nil)
if got != tt.expect {
t.Fatalf("got %s but expected %s", got, tt.expect)
}
})
}
}

347
gval/parse.go Normal file
View File

@ -0,0 +1,347 @@
package gval
import (
"context"
"fmt"
"reflect"
"strconv"
"text/scanner"
"github.com/shopspring/decimal"
)
// ParseExpression scans an expression into an Evaluable.
func (p *Parser) ParseExpression(c context.Context) (eval Evaluable, err error) {
stack := stageStack{}
for {
eval, err = p.ParseNextExpression(c)
if err != nil {
return nil, err
}
if stage, err := p.parseOperator(c, &stack, eval); err != nil {
return nil, err
} else if err = stack.push(stage); err != nil {
return nil, err
}
if stack.peek().infixBuilder == nil {
return stack.pop().Evaluable, nil
}
}
}
// ParseNextExpression scans the expression ignoring following operators
func (p *Parser) ParseNextExpression(c context.Context) (eval Evaluable, err error) {
scan := p.Scan()
ex, ok := p.prefixes[scan]
if !ok {
if scan != scanner.EOF && p.def != nil {
return p.def(c, p)
}
return nil, p.Expected("extensions")
}
return ex(c, p)
}
// ParseSublanguage sets the next language for this parser to parse and calls
// its initialization function, usually ParseExpression.
func (p *Parser) ParseSublanguage(c context.Context, l Language) (Evaluable, error) {
if p.isCamouflaged() {
panic("can not ParseSublanguage() on camouflaged Parser")
}
curLang := p.Language
curWhitespace := p.scanner.Whitespace
curMode := p.scanner.Mode
curIsIdentRune := p.scanner.IsIdentRune
p.Language = l
p.resetScannerProperties()
defer func() {
p.Language = curLang
p.scanner.Whitespace = curWhitespace
p.scanner.Mode = curMode
p.scanner.IsIdentRune = curIsIdentRune
}()
return p.parse(c)
}
func (p *Parser) parse(c context.Context) (Evaluable, error) {
if p.init != nil {
return p.init(c, p)
}
return p.ParseExpression(c)
}
func parseString(c context.Context, p *Parser) (Evaluable, error) {
s, err := strconv.Unquote(p.TokenText())
if err != nil {
return nil, fmt.Errorf("could not parse string: %w", err)
}
return p.Const(s), nil
}
func parseNumber(c context.Context, p *Parser) (Evaluable, error) {
n, err := strconv.ParseUint(p.TokenText(), 0, 32)
if err != nil {
return nil, err
}
return p.Const(n), nil
}
func parseDecimal(c context.Context, p *Parser) (Evaluable, error) {
n, err := strconv.ParseFloat(p.TokenText(), 64)
if err != nil {
return nil, err
}
return p.Const(decimal.NewFromFloat(n)), nil
}
func parseParentheses(c context.Context, p *Parser) (Evaluable, error) {
eval, err := p.ParseExpression(c)
if err != nil {
return nil, err
}
switch p.Scan() {
case ')':
return eval, nil
default:
return nil, p.Expected("parentheses", ')')
}
}
func (p *Parser) parseOperator(c context.Context, stack *stageStack, eval Evaluable) (st stage, err error) {
for {
scan := p.Scan()
op := p.TokenText()
mustOp := false
if p.isSymbolOperation(scan) {
scan = p.Peek()
for p.isSymbolOperation(scan) && p.isOperatorPrefix(op+string(scan)) {
mustOp = true
op += string(scan)
p.Next()
scan = p.Peek()
}
} else if scan != scanner.Ident {
p.Camouflage("operator")
return stage{Evaluable: eval}, nil
}
switch operator := p.operators[op].(type) {
case *infix:
return stage{
Evaluable: eval,
infixBuilder: operator.builder,
operatorPrecedence: operator.operatorPrecedence,
}, nil
case directInfix:
return stage{
Evaluable: eval,
infixBuilder: operator.infixBuilder,
operatorPrecedence: operator.operatorPrecedence,
}, nil
case postfix:
if err = stack.push(stage{
operatorPrecedence: operator.operatorPrecedence,
Evaluable: eval,
}); err != nil {
return stage{}, err
}
eval, err = operator.f(c, p, stack.pop().Evaluable, operator.operatorPrecedence)
if err != nil {
return
}
continue
}
if !mustOp {
p.Camouflage("operator")
return stage{Evaluable: eval}, nil
}
return stage{}, fmt.Errorf("unknown operator %s", op)
}
}
func parseIdent(c context.Context, p *Parser) (call string, alternative func() (Evaluable, error), err error) {
token := p.TokenText()
return token,
func() (Evaluable, error) {
fullname := token
keys := []Evaluable{p.Const(token)}
for {
scan := p.Scan()
switch scan {
case '.':
scan = p.Scan()
switch scan {
case scanner.Ident:
token = p.TokenText()
keys = append(keys, p.Const(token))
default:
return nil, p.Expected("field", scanner.Ident)
}
case '(':
args, err := p.parseArguments(c)
if err != nil {
return nil, err
}
return p.callEvaluable(fullname, p.Var(keys...), args...), nil
case '[':
key, err := p.ParseExpression(c)
if err != nil {
return nil, err
}
switch p.Scan() {
case ']':
keys = append(keys, key)
default:
return nil, p.Expected("array key", ']')
}
default:
p.Camouflage("variable", '.', '(', '[')
return p.Var(keys...), nil
}
}
}, nil
}
func (p *Parser) parseArguments(c context.Context) (args []Evaluable, err error) {
if p.Scan() == ')' {
return
}
p.Camouflage("scan arguments", ')')
for {
arg, err := p.ParseExpression(c)
args = append(args, arg)
if err != nil {
return nil, err
}
switch p.Scan() {
case ')':
return args, nil
case ',':
default:
return nil, p.Expected("arguments", ')', ',')
}
}
}
func inArray(a, b interface{}) (interface{}, error) {
col, ok := b.([]interface{})
if !ok {
return nil, fmt.Errorf("expected type []interface{} for in operator but got %T", b)
}
for _, value := range col {
if reflect.DeepEqual(a, value) {
return true, nil
}
}
return false, nil
}
func parseIf(c context.Context, p *Parser, e Evaluable) (Evaluable, error) {
a, err := p.ParseExpression(c)
if err != nil {
return nil, err
}
b := p.Const(nil)
switch p.Scan() {
case ':':
b, err = p.ParseExpression(c)
if err != nil {
return nil, err
}
case scanner.EOF:
default:
return nil, p.Expected("<> ? <> : <>", ':', scanner.EOF)
}
return func(c context.Context, v interface{}) (interface{}, error) {
x, err := e(c, v)
if err != nil {
return nil, err
}
if valX := reflect.ValueOf(x); x == nil || valX.IsZero() {
return b(c, v)
}
return a(c, v)
}, nil
}
func parseJSONArray(c context.Context, p *Parser) (Evaluable, error) {
evals := []Evaluable{}
for {
switch p.Scan() {
default:
p.Camouflage("array", ',', ']')
eval, err := p.ParseExpression(c)
if err != nil {
return nil, err
}
evals = append(evals, eval)
case ',':
case ']':
return func(c context.Context, v interface{}) (interface{}, error) {
vs := make([]interface{}, len(evals))
for i, e := range evals {
eval, err := e(c, v)
if err != nil {
return nil, err
}
vs[i] = eval
}
return vs, nil
}, nil
}
}
}
func parseJSONObject(c context.Context, p *Parser) (Evaluable, error) {
type kv struct {
key Evaluable
value Evaluable
}
evals := []kv{}
for {
switch p.Scan() {
default:
p.Camouflage("object", ',', '}')
key, err := p.ParseExpression(c)
if err != nil {
return nil, err
}
if p.Scan() != ':' {
if err != nil {
return nil, p.Expected("object", ':')
}
}
value, err := p.ParseExpression(c)
if err != nil {
return nil, err
}
evals = append(evals, kv{key, value})
case ',':
case '}':
return func(c context.Context, v interface{}) (interface{}, error) {
vs := map[string]interface{}{}
for _, e := range evals {
value, err := e.value(c, v)
if err != nil {
return nil, err
}
key, err := e.key.EvalString(c, v)
if err != nil {
return nil, err
}
vs[key] = value
}
return vs, nil
}, nil
}
}
}

147
gval/parser.go Normal file
View File

@ -0,0 +1,147 @@
package gval
import (
"bytes"
"fmt"
"strings"
"text/scanner"
"unicode"
)
// Parser parses expressions in a Language into an Evaluable
type Parser struct {
scanner scanner.Scanner
Language
lastScan rune
camouflage error
}
func newParser(expression string, l Language) *Parser {
sc := scanner.Scanner{}
sc.Init(strings.NewReader(expression))
sc.Error = func(*scanner.Scanner, string) {}
sc.Filename = expression + "\t"
p := &Parser{scanner: sc, Language: l}
p.resetScannerProperties()
return p
}
func (p *Parser) resetScannerProperties() {
p.scanner.Whitespace = scanner.GoWhitespace
p.scanner.Mode = scanner.GoTokens
p.scanner.IsIdentRune = func(r rune, pos int) bool {
return unicode.IsLetter(r) || r == '_' || (pos > 0 && unicode.IsDigit(r))
}
}
// SetWhitespace sets the behavior of the whitespace matcher. The given
// characters must be less than or equal to 0x20 (' ').
func (p *Parser) SetWhitespace(chars ...rune) {
var mask uint64
for _, char := range chars {
mask |= 1 << uint(char)
}
p.scanner.Whitespace = mask
}
// SetMode sets the tokens that the underlying scanner will match.
func (p *Parser) SetMode(mode uint) {
p.scanner.Mode = mode
}
// SetIsIdentRuneFunc sets the function that matches ident characters in the
// underlying scanner.
func (p *Parser) SetIsIdentRuneFunc(fn func(ch rune, i int) bool) {
p.scanner.IsIdentRune = fn
}
// Scan reads the next token or Unicode character from source and returns it.
// It only recognizes tokens t for which the respective Mode bit (1<<-t) is set.
// It returns scanner.EOF at the end of the source.
func (p *Parser) Scan() rune {
if p.isCamouflaged() {
p.camouflage = nil
return p.lastScan
}
p.camouflage = nil
p.lastScan = p.scanner.Scan()
return p.lastScan
}
func (p *Parser) isCamouflaged() bool {
return p.camouflage != nil && p.camouflage != errCamouflageAfterNext
}
// Camouflage rewind the last Scan(). The Parser holds the camouflage error until
// the next Scan()
// Do not call Rewind() on a camouflaged Parser
func (p *Parser) Camouflage(unit string, expected ...rune) {
if p.isCamouflaged() {
panic(fmt.Errorf("can only Camouflage() after Scan(): %w", p.camouflage))
}
p.camouflage = p.Expected(unit, expected...)
}
// Peek returns the next Unicode character in the source without advancing
// the scanner. It returns EOF if the scanner's position is at the last
// character of the source.
// Do not call Peek() on a camouflaged Parser
func (p *Parser) Peek() rune {
if p.isCamouflaged() {
panic("can not Peek() on camouflaged Parser")
}
return p.scanner.Peek()
}
var errCamouflageAfterNext = fmt.Errorf("Camouflage() after Next()")
// Next reads and returns the next Unicode character.
// It returns EOF at the end of the source.
// Do not call Next() on a camouflaged Parser
func (p *Parser) Next() rune {
if p.isCamouflaged() {
panic("can not Next() on camouflaged Parser")
}
p.camouflage = errCamouflageAfterNext
return p.scanner.Next()
}
// TokenText returns the string corresponding to the most recently scanned token.
// Valid after calling Scan().
func (p *Parser) TokenText() string {
return p.scanner.TokenText()
}
// Expected returns an error signaling an unexpected Scan() result
func (p *Parser) Expected(unit string, expected ...rune) error {
return unexpectedRune{unit, expected, p.lastScan}
}
type unexpectedRune struct {
unit string
expected []rune
got rune
}
func (err unexpectedRune) Error() string {
exp := bytes.Buffer{}
runes := err.expected
switch len(runes) {
default:
for _, r := range runes[:len(runes)-2] {
exp.WriteString(scanner.TokenString(r))
exp.WriteString(", ")
}
fallthrough
case 2:
exp.WriteString(scanner.TokenString(runes[len(runes)-2]))
exp.WriteString(" or ")
fallthrough
case 1:
exp.WriteString(scanner.TokenString(runes[len(runes)-1]))
case 0:
return fmt.Sprintf("unexpected %s while scanning %s", scanner.TokenString(err.got), err.unit)
}
return fmt.Sprintf("unexpected %s while scanning %s expected %s", scanner.TokenString(err.got), err.unit, exp.String())
}

148
gval/parser_test.go Normal file
View File

@ -0,0 +1,148 @@
package gval
import (
"testing"
"text/scanner"
"unicode"
)
func TestParser_Scan(t *testing.T) {
tests := []struct {
name string
input string
Language
do func(p *Parser)
wantScan rune
wantToken string
wantPanic bool
}{
{
name: "camouflage",
input: "$abc",
do: func(p *Parser) {
p.Scan()
p.Camouflage("test")
},
wantScan: '$',
wantToken: "$",
},
{
name: "camouflage with next",
input: "$abc",
do: func(p *Parser) {
p.Scan()
p.Camouflage("test")
p.Next()
},
wantPanic: true,
},
{
name: "camouflage scan camouflage",
input: "$abc",
do: func(p *Parser) {
p.Scan()
p.Camouflage("test")
p.Scan()
p.Camouflage("test2")
},
wantScan: '$',
wantToken: "$",
},
{
name: "camouflage with peek",
input: "$abc",
do: func(p *Parser) {
p.Scan()
p.Camouflage("test")
p.Peek()
},
wantPanic: true,
},
{
name: "next and peek",
input: "$#abc",
do: func(p *Parser) {
p.Scan()
p.Next()
p.Peek()
},
wantScan: scanner.Ident,
wantToken: "abc",
},
{
name: "scan token camouflage token",
input: "abc",
do: func(p *Parser) {
p.Scan()
p.TokenText()
p.Camouflage("test")
},
wantScan: scanner.Ident,
wantToken: "abc",
},
{
name: "scan token peek camouflage token",
input: "abc",
do: func(p *Parser) {
p.Scan()
p.TokenText()
p.Peek()
p.Camouflage("test")
},
wantScan: scanner.Ident,
wantToken: "abc",
},
{
name: "tokenize all whitespace",
input: "foo\tbar\nbaz",
do: func(p *Parser) {
p.SetWhitespace()
p.Scan()
},
wantScan: '\t',
wantToken: "\t",
},
{
name: "custom ident",
input: "$#foo",
do: func(p *Parser) {
p.SetIsIdentRuneFunc(func(ch rune, i int) bool { return unicode.IsLetter(ch) || ch == '#' })
p.Scan()
},
wantScan: scanner.Ident,
wantToken: "#foo",
},
{
name: "do not scan idents",
input: "abc",
do: func(p *Parser) {
p.SetMode(scanner.GoTokens ^ scanner.ScanIdents)
p.Scan()
},
wantScan: 'b',
wantToken: "b",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
defer func() {
err := recover()
if err != nil && !tt.wantPanic {
t.Fatalf("unexpected panic: %v", err)
}
}()
p := newParser(tt.input, tt.Language)
tt.do(p)
if tt.wantPanic {
return
}
scan := p.Scan()
token := p.TokenText()
if scan != tt.wantScan || token != tt.wantToken {
t.Errorf("Parser.Scan() = %v (%v), want %v (%v)", scan, token, tt.wantScan, tt.wantToken)
}
})
}
}

165
main.go
View File

@ -3,19 +3,18 @@ package main
import (
_ "embed"
"fmt"
"image/color"
"okemu/config"
"okemu/debuger"
"okemu/debug"
"okemu/debug/listener"
"okemu/logger"
"okemu/okean240"
"okemu/okean240/fdc"
"okemu/z80/dis"
"sync/atomic"
"time"
"fyne.io/fyne/v2"
"fyne.io/fyne/v2/app"
"fyne.io/fyne/v2/canvas"
"fyne.io/fyne/v2/container"
"fyne.io/fyne/v2/driver/desktop"
"fyne.io/fyne/v2/widget"
)
@ -49,107 +48,38 @@ func main() {
// Reconfigure logging by config values
// logger.ReconfigureLogging(conf)
computer := okean240.New(conf)
debugger := debug.NewDebugger()
computer := okean240.NewComputer(conf, debugger)
computer.SetSerialBytes(serialBytes)
computer.LoadFloppy()
if conf.FDC.AutoLoadB {
err := computer.LoadFloppy(fdc.FloppyB)
if err != nil {
// show message
}
}
if conf.FDC.AutoLoadC {
err := computer.LoadFloppy(fdc.FloppyC)
if err != nil {
// show message
}
}
disasm := dis.NewDisassembler(computer)
w, raster, label := mainWindow(computer)
go emulator(computer)
go screen(computer, raster, label)
go debuger.SetupTcpHandler(conf, computer)
if conf.Debugger.Enabled {
go listener.SetupTcpHandler(conf, debugger, disasm, computer)
}
(*w).ShowAndRun()
}
func mainWindow(computer *okean240.ComputerType) (*fyne.Window, *canvas.Raster, *widget.Label) {
emulatorApp := app.New()
w := emulatorApp.NewWindow("Океан 240.2")
w.Canvas().SetOnTypedKey(
func(key *fyne.KeyEvent) {
computer.PutKey(key)
})
w.Canvas().SetOnTypedRune(
func(key rune) {
computer.PutRune(key)
})
addShortcuts(w.Canvas(), computer)
label := widget.NewLabel(fmt.Sprintf("Screen size: %dx%d", computer.ScreenWidth(), computer.ScreenHeight()))
raster := canvas.NewRasterWithPixels(
func(x, y, w, h int) color.Color {
var xx uint16
if computer.ScreenWidth() == 512 {
xx = uint16(x)
} else {
xx = uint16(x) / 2
}
return computer.GetPixel(xx, uint16(y/2))
})
raster.Resize(fyne.NewSize(512, 512))
raster.SetMinSize(fyne.NewSize(512, 512))
centerRaster := container.NewCenter(raster)
w.Resize(fyne.NewSize(600, 600))
hBox := container.NewHBox(
//widget.NewButton("++", func() {
// computer.IncOffset()
//}),
//widget.NewButton("--", func() {
// computer.DecOffset()
//}),
widget.NewButton("Ctrl+C", func() {
computer.PutCtrlKey(0x03)
}),
widget.NewButton("Load Floppy", func() {
computer.LoadFloppy()
}),
widget.NewButton("Save Floppy", func() {
computer.SaveFloppy()
}),
widget.NewButton("RUN1", func() {
computer.SetRamBytes(ramBytes1)
}),
widget.NewButton("RUN2", func() {
computer.SetRamBytes(ramBytes2)
}),
widget.NewButton("DUMP", func() {
computer.Dump(0x100, 15000)
}),
widget.NewCheck("Full speed", func(b bool) {
fullSpeed.Store(b)
if b {
computer.SetCPUFrequency(50_000_000)
} else {
computer.SetCPUFrequency(2_500_000)
}
}),
widget.NewSeparator(),
widget.NewButton("Reset", func() {
needReset = true
//computer.Reset(conf)
}),
widget.NewSeparator(),
widget.NewButton("Закрыть", func() {
emulatorApp.Quit()
}),
)
vBox := container.NewVBox(
centerRaster,
label,
hBox,
)
w.SetContent(vBox)
return &w, raster, label
}
func screen(computer *okean240.ComputerType, raster *canvas.Raster, label *widget.Label) {
ticker := time.NewTicker(20 * time.Millisecond)
frame := 0
@ -164,13 +94,17 @@ func screen(computer *okean240.ComputerType, raster *canvas.Raster, label *widge
if frame%50 == 0 {
freq = computer.Cycles() - pre
pre = computer.Cycles()
label.SetText(fmt.Sprintf("Screen size: %dx%d F: %d", computer.ScreenWidth(), computer.ScreenHeight(), freq))
label.SetText(formatLabel(computer, freq))
}
raster.Refresh()
})
}
}
func formatLabel(computer *okean240.ComputerType, freq uint64) string {
return fmt.Sprintf("Screen size: %dx%d | F: %d | Debugger: %s", computer.ScreenWidth(), computer.ScreenHeight(), freq, computer.DebuggerState())
}
const ticksPerTact uint64 = 4
func emulator(computer *okean240.ComputerType) {
@ -184,35 +118,24 @@ func emulator(computer *okean240.ComputerType) {
// 1.5 MHz
computer.TimerClk()
}
if !computer.IsStepMode() || computer.IsRunMode() {
var bp uint16 = 0
if fullSpeed.Load() {
_, bp = computer.Do()
} else {
if ticks >= nextClock {
var t uint32
t, bp = computer.Do()
nextClock = ticks + uint64(t)*ticksPerTact
}
}
// Breakpoint hit
if bp > 0 {
debuger.BreakpointHit(bp)
var bp uint16 = 0
var typ byte = 0
if fullSpeed.Load() {
_, bp, typ = computer.Do()
} else {
if ticks >= nextClock {
var t uint32
t, bp, typ = computer.Do()
nextClock = ticks + uint64(t)*ticksPerTact
}
}
// Breakpoint hit
if bp > 0 || typ != 0 {
listener.BreakpointHit(bp, typ)
}
if needReset {
computer.Reset()
needReset = false
}
}
}
// Add shortcuts for all Ctrl+<Letter>
func addShortcuts(c fyne.Canvas, computer *okean240.ComputerType) {
// Add shortcuts for Ctrl+A to Ctrl+Z
for kName := 'A'; kName <= 'Z'; kName++ {
kk := fyne.KeyName(kName)
sc := &desktop.CustomShortcut{KeyName: kk, Modifier: fyne.KeyModifierControl}
c.AddShortcut(sc, func(shortcut fyne.Shortcut) { computer.PutCtrlKey(byte(kName&0xff) - 0x40) })
}
}

View File

@ -3,8 +3,11 @@ package okean240
import (
_ "embed"
"encoding/binary"
"encoding/json"
"fmt"
"image/color"
"okemu/config"
"okemu/debug"
"okemu/okean240/fdc"
"okemu/okean240/pic"
"okemu/okean240/pit"
@ -22,39 +25,36 @@ import (
const DefaultCPUFrequency = 2_500_000
type Breakpoint struct {
addr uint16
enabled bool
}
type ComputerType struct {
cpu *c99.Z80
memory Memory
ioPorts [256]byte
cycles uint64
dd17EnableOut bool
colorMode bool
screenWidth int
screenHeight int
vRAM *RamBlock
palette byte
bgColor byte
pit *pit.I8253
usart *usart.I8251
pic *pic.I8259
fdc *fdc.FloppyDriveController
kbdBuffer []byte
vShift byte
hShift byte
stepMode bool
runMode bool
bpEnabled bool
breakpoints [MaxBreakpoints]Breakpoint
//aOffset uint16
cpuFrequency uint32
cpu *c99.Z80
memory Memory
ioPorts [256]byte
cycles uint64
tstatesPartial uint64
dd17EnableOut bool
colorMode bool
screenWidth int
screenHeight int
vRAM *RamBlock
palette byte
bgColor byte
pit *pit.I8253
usart *usart.I8251
pic *pic.I8259
fdc *fdc.FloppyDriveController
kbdBuffer []byte
vShift byte
hShift byte
cpuFrequency uint32
//
debugger *debug.Debugger
}
type Snapshot struct {
CPU *z80.CPU `json:"cpu,omitempty"`
Memory string `json:"memory,omitempty"`
}
const MaxBreakpoints = 256
const VRAMBlock0 = 3
const VRAMBlock1 = 7
const VidVsuBit = 0x80
@ -71,66 +71,15 @@ type ComputerInterface interface {
PutCtrlKey(shortcut fyne.Shortcut)
SaveFloppy()
LoadFloppy()
CPUState() *z80.Z80CPU
SetCPUState(state *z80.Z80CPU)
StepMode() bool
SetStepMode(step bool)
ClearMemBreakpoints()
SetBreakpointsEnabled(enabled bool)
IsBreakpoint() bool
//Dump(start uint16, length uint16)
CPUState() *z80.CPU
SetCPUState(state *z80.CPU)
}
func (c *ComputerType) SetBreakpointsEnabled(enabled bool) {
c.bpEnabled = enabled
}
func (c *ComputerType) IsBreakpointsEnabled() bool {
return c.bpEnabled
}
func (c *ComputerType) SetBreakpoint(no uint16, addr uint16) {
if no > 0 && no <= MaxBreakpoints {
c.breakpoints[no-1].addr = addr
log.Debugf("BP[%d] SET AT PC=%04X", no, addr)
} else {
log.Warnf("Breakpoint number %d out or range!", no)
}
}
func (c *ComputerType) SetBreakpointEnabled(no uint16, enabled bool) {
if no <= MaxBreakpoints && no > 0 {
c.breakpoints[no-1].enabled = enabled
} else {
log.Warnf("Breakpoint number %d out or range!", no)
}
}
func (c *ComputerType) IsBreakpointEnabled(no uint16) bool {
if no <= MaxBreakpoints && no > 0 {
return c.breakpoints[no-1].enabled
}
log.Warnf("Breakpoint number %d out or range!", no)
return false
}
func (c *ComputerType) ClearMemBreakpoints() {
log.Warnf("Clearing memory bpEnabled unimplemented!")
}
func (c *ComputerType) SetStepMode(step bool) {
c.stepMode = step
}
func (c *ComputerType) IsStepMode() bool {
return c.stepMode
}
func (c *ComputerType) GetCPUState() *z80.Z80CPU {
func (c *ComputerType) GetCPUState() *z80.CPU {
return c.cpu.GetState()
}
func (c *ComputerType) SetCPUState(state *z80.Z80CPU) {
func (c *ComputerType) SetCPUState(state *z80.CPU) {
c.cpu.SetState(state)
}
@ -146,8 +95,8 @@ func (c *ComputerType) MemWrite(addr uint16, val byte) {
c.memory.MemWrite(addr, val)
}
// New Builds new computer
func New(cfg *config.OkEmuConfig) *ComputerType {
// NewComputer Builds new computer
func NewComputer(cfg *config.OkEmuConfig, deb *debug.Debugger) *ComputerType {
c := ComputerType{}
c.memory = Memory{}
c.memory.Init(cfg.MonitorFile, cfg.CPMFile)
@ -155,6 +104,7 @@ func New(cfg *config.OkEmuConfig) *ComputerType {
c.cpu = c99.New(&c)
c.cycles = 0
c.tstatesPartial = 0
c.dd17EnableOut = false
c.screenWidth = 512
c.screenHeight = 256
@ -169,58 +119,80 @@ func New(cfg *config.OkEmuConfig) *ComputerType {
c.pit = pit.New()
c.usart = usart.New()
c.pic = pic.New()
c.fdc = fdc.New(cfg)
c.fdc = fdc.NewFDC(cfg)
c.cpuFrequency = DefaultCPUFrequency
c.bpEnabled = false
c.breakpoints = [256]Breakpoint{}
for i := range c.breakpoints {
c.breakpoints[i] = Breakpoint{}
c.breakpoints[i].enabled = false
c.breakpoints[i].addr = 0
}
c.debugger = deb
return &c
}
func (c *ComputerType) Reset() {
c.cpu.Reset()
c.cycles = 0
//c.vShift = 0
//c.hShift = 0
//c.memory = Memory{}
//c.memory.Init(cfg.MonitorFile, cfg.CPMFile)
//c.dd17EnableOut = false
//c.screenWidth = 256
//c.screenHeight = 256
//c.vRAM = c.memory.allMemory[3]
c.tstatesPartial = 0
}
func (c *ComputerType) SetRunMode(run bool) {
c.runMode = run
func (c *ComputerType) getContext() map[string]interface{} {
context := make(map[string]interface{})
s := c.cpu.GetState()
context["A"] = s.A
context["B"] = s.B
context["C"] = s.C
context["D"] = s.D
context["E"] = s.E
context["H"] = s.H
context["L"] = s.L
context["A'"] = s.AAlt
context["B'"] = s.BAlt
context["C'"] = s.CAlt
context["D'"] = s.DAlt
context["E'"] = s.EAlt
context["H'"] = s.HAlt
context["L'"] = s.LAlt
context["PC"] = s.PC
context["SP"] = s.SP
context["IX"] = s.IX
context["IY"] = s.IY
context["ZF"] = s.Flags.Z
context["SF"] = s.Flags.S
context["NF"] = s.Flags.N
context["PF"] = s.Flags.P
context["HF"] = s.Flags.H
context["YF"] = s.Flags.Y
context["XF"] = s.Flags.X
context["CF"] = s.Flags.C
context["BC"] = uint16(s.B)<<8 | uint16(s.C)
context["DE"] = uint16(s.D)<<8 | uint16(s.E)
context["HL"] = uint16(s.H)<<8 | uint16(s.L)
context["AF"] = uint16(s.A)<<8 | uint16(s.Flags.GetFlags())
return context
}
func (c *ComputerType) IsRunMode() bool {
return c.runMode
}
func (c *ComputerType) Do() (uint32, uint16) {
// check breakpoints
if c.bpEnabled && c.runMode {
for no, bp := range c.breakpoints {
if bp.enabled && bp.addr == c.cpu.GetState().PC {
c.runMode = false
return 0, uint16(no + 1)
func (c *ComputerType) Do() (uint32, uint16, byte) {
ticks := uint32(0)
var memAccess *map[uint16]byte
if c.debugger.StepMode() {
if c.debugger.RunMode() || c.debugger.DoStep() {
if c.debugger.RunInst() > 0 {
// skip first instruction after run-mode activated
bpHit, bp := c.debugger.CheckBreakpoints(c.getContext())
if bpHit {
//c.debugger.SetRunMode(false)
return 0, bp, 0
}
}
c.debugger.SaveHistory(c.cpu.GetState())
ticks, memAccess = c.cpu.RunInstruction()
mHit, mAddr, mTyp := c.debugger.CheckMemBreakpoints(memAccess)
if mHit {
return ticks, mAddr, mTyp
}
}
} else {
ticks, memAccess = c.cpu.RunInstruction()
}
ticks := c.cpu.RunInstruction()
c.cycles += uint64(ticks)
//pc := c.cpu.GetState().PC
//if pc >= 0xfea3 && pc <= 0xff25 {
// c.cpu.DebugOutput()
//}
return ticks, 0
c.tstatesPartial += uint64(ticks)
return ticks, 0, 0
}
func (c *ComputerType) GetPixel(x uint16, y uint16) color.RGBA {
@ -303,6 +275,14 @@ func (c *ComputerType) Cycles() uint64 {
return c.cycles
}
func (c *ComputerType) ResetTStatesPartial() {
c.tstatesPartial = 0
}
func (c *ComputerType) TStatesPartial() uint64 {
return c.tstatesPartial
}
func (c *ComputerType) TimerClk() {
// DD70 KR580VI53 CLK0, CKL1 @ 1.5MHz
c.pit.Tick(0)
@ -319,12 +299,12 @@ func (c *ComputerType) TimerClk() {
}
}
func (c *ComputerType) LoadFloppy() {
c.fdc.LoadFloppy()
func (c *ComputerType) LoadFloppy(drive byte) error {
return c.fdc.LoadFloppy(drive)
}
func (c *ComputerType) SaveFloppy() {
c.fdc.SaveFloppy()
func (c *ComputerType) SaveFloppy(drive byte) error {
return c.fdc.SaveFloppy(drive)
}
func (c *ComputerType) SetSerialBytes(bytes []byte) {
@ -379,3 +359,51 @@ func (c *ComputerType) CPUFrequency() uint32 {
func (c *ComputerType) SetCPUFrequency(frequency uint32) {
c.cpuFrequency = frequency
}
func (c *ComputerType) DebuggerState() string {
if c.debugger.StepMode() {
if c.debugger.RunMode() {
return "Run"
}
return "Step"
}
return "Off"
}
func (c *ComputerType) memoryAsHexStr() string {
res := ""
for addr := 0; addr <= 65535; addr++ {
res += fmt.Sprintf("%02X", c.memory.MemRead(uint16(addr)))
}
return res
}
func (c *ComputerType) SaveSnapshot(fn string) error {
// create snapshot file
file, err := os.Create(fn)
if err != nil {
return err
}
defer func() {
err := file.Close()
if err != nil {
log.Error(err)
}
}()
// take snapshot
s := Snapshot{
CPU: c.cpu.GetState(),
Memory: c.memoryAsHexStr(),
}
// convert to JSON
b, err := json.Marshal(s)
if err != nil {
return err
}
// and save
err = binary.Write(file, binary.LittleEndian, b)
if err != nil {
return err
}
return nil
}

View File

@ -10,6 +10,7 @@ package fdc
import (
"bytes"
"encoding/binary"
"errors"
"okemu/config"
"os"
"slices"
@ -21,6 +22,9 @@ import (
// Floppy parameters
const (
FloppyB = 0
FloppyC = 1
TotalDrives = 2
FloppySizeK = 720
SectorSize = 512
@ -84,8 +88,8 @@ type FloppyDriveController struct {
//curSector *SectorType
bytePtr uint16
trackBuffer []byte
floppyBFile string
floppyCFile string
floppyFile []string
config *config.OkEmuConfig
}
type FloppyDriveControllerInterface interface {
@ -98,7 +102,7 @@ type FloppyDriveControllerInterface interface {
SetData(value byte)
Data() byte
Drq() byte
SaveFloppy()
SaveFloppy(drive byte)
GetSectorNo() uint16
Track() byte
Sector() byte
@ -326,17 +330,21 @@ func (f *FloppyDriveController) Drq() byte {
return f.drq
}
func (f *FloppyDriveController) LoadFloppy() {
loadFloppy(&f.sectors[0], f.floppyBFile)
loadFloppy(&f.sectors[1], f.floppyCFile)
func (f *FloppyDriveController) LoadFloppy(drive byte) error {
if drive < TotalDrives {
return loadFloppy(&f.sectors[drive], f.floppyFile[drive])
}
return errors.New("DriveNo " + strconv.Itoa(int(drive)) + " out of range")
}
func (f *FloppyDriveController) SaveFloppy() {
saveFloppy(&f.sectors[0], f.floppyBFile)
saveFloppy(&f.sectors[1], f.floppyCFile)
func (f *FloppyDriveController) SaveFloppy(drive byte) error {
if drive < TotalDrives {
return saveFloppy(&f.sectors[drive], f.floppyFile[drive])
}
return errors.New("DriveNo " + strconv.Itoa(int(drive)) + " out of range")
}
func New(conf *config.OkEmuConfig) *FloppyDriveController {
func NewFDC(conf *config.OkEmuConfig) *FloppyDriveController {
sec := [2][SizeInSectors]SectorType{}
// for each drive
for d := 0; d < TotalDrives; d++ {
@ -346,20 +354,19 @@ func New(conf *config.OkEmuConfig) *FloppyDriveController {
}
}
return &FloppyDriveController{
sideNo: 0,
ddEn: 0,
init: 0,
drive: 0,
mot1: 0,
mot0: 0,
intRq: 0,
motSt: 0,
drq: 0,
lastCmd: 0xff,
sectors: sec,
bytePtr: 0xffff,
floppyBFile: conf.FloppyB,
floppyCFile: conf.FloppyC,
sideNo: 0,
ddEn: 0,
init: 0,
drive: 0,
mot1: 0,
mot0: 0,
intRq: 0,
motSt: 0,
drq: 0,
lastCmd: 0xff,
sectors: sec,
bytePtr: 0xffff,
floppyFile: []string{conf.FDC.FloppyB, conf.FDC.FloppyC},
}
}
@ -393,12 +400,12 @@ func (f *FloppyDriveController) Sector() byte {
}
// loadFloppy load floppy image to sector buffer from file
func loadFloppy(sectors *[SizeInSectors]SectorType, fileName string) {
func loadFloppy(sectors *[SizeInSectors]SectorType, fileName string) error {
log.Debugf("Load Floppy content from file %s.", fileName)
file, err := os.Open(fileName)
if err != nil {
log.Error(err)
return
return err
}
defer func(file *os.File) {
@ -416,17 +423,19 @@ func loadFloppy(sectors *[SizeInSectors]SectorType, fileName string) {
}
if err != nil {
log.Error("Load floppy content failed:", err)
break
return err
}
}
return nil
}
func saveFloppy(sectors *[SizeInSectors]SectorType, fileName string) {
// saveFloppy Save specified sectors to file with name fileName
func saveFloppy(sectors *[SizeInSectors]SectorType, fileName string) error {
log.Debugf("Save Floppy to file %s.", fileName)
file, err := os.Create(fileName)
if err != nil {
log.Error(err)
return
return err
}
defer func(file *os.File) {
err := file.Close()
@ -443,7 +452,8 @@ func saveFloppy(sectors *[SizeInSectors]SectorType, fileName string) {
}
if err != nil {
log.Error("Save floppy content failed:", err)
break
return err
}
}
return nil
}

View File

@ -1,8 +1,15 @@
logFile: "okemu.log"
logLevel: "info"
monitorFile: "rom/MON_r8_9c6c6546.bin"
# cpmFile: "rom/CPM_r7_b89a7e16.bin"
cpmFile: "rom/CPM_r8_bc0695e4.bin"
floppyB: "floppy/floppyB.okd"
floppyC: "floppy/floppyC.okd"
port: 10001
fdc:
autoLoadB: true
floppyB: "floppy/floppyB.okd"
autoLoadC: true
floppyC: "floppy/floppyC.okd"
debugger:
enabled: true
port: 10001

View File

@ -33,9 +33,15 @@ type Z80 struct {
intPending bool
nmiPending bool
core z80.MemIoRW
core z80.MemIoRW
memAccess map[uint16]byte
}
const (
MemAccessRead = 1
MemAccessWrite = 2
)
// New initializes a Z80 instance and return pointer to it
func New(core z80.MemIoRW) *Z80 {
z := Z80{}
@ -90,7 +96,9 @@ func New(core z80.MemIoRW) *Z80 {
}
// RunInstruction executes the next instruction in memory + handles interrupts
func (z *Z80) RunInstruction() uint32 {
func (z *Z80) RunInstruction() (uint32, *map[uint16]byte) {
z.memAccess = map[uint16]byte{}
pre := z.cycleCount
if z.isHalted {
z.execOpcode(0x00)
@ -99,10 +107,10 @@ func (z *Z80) RunInstruction() uint32 {
z.execOpcode(opcode)
}
z.processInterrupts()
return z.cycleCount - pre
return z.cycleCount - pre, &z.memAccess
}
func (z *Z80) SetState(state *z80.Z80CPU) {
func (z *Z80) SetState(state *z80.CPU) {
z.cycleCount = 0
z.a = state.A
z.b = state.B
@ -148,8 +156,8 @@ func (z *Z80) SetState(state *z80.Z80CPU) {
z.nmiPending = false
z.intData = 0
}
func (z *Z80) GetState() *z80.Z80CPU {
return &z80.Z80CPU{
func (z *Z80) GetState() *z80.CPU {
return &z80.CPU{
A: z.a,
B: z.b,
C: z.c,
@ -211,3 +219,7 @@ func (z *Z80) getAltFlags() z80.FlagsType {
C: z.f_&0x01 != 0,
}
}
func (z *Z80) PC() uint16 {
return z.pc
}

View File

@ -3,18 +3,24 @@ package c99
import log "github.com/sirupsen/logrus"
func (z *Z80) rb(addr uint16) byte {
z.memAccess[addr] = MemAccessRead
return z.core.MemRead(addr)
}
func (z *Z80) wb(addr uint16, val byte) {
z.memAccess[addr] = MemAccessWrite
z.core.MemWrite(addr, val)
}
func (z *Z80) rw(addr uint16) uint16 {
z.memAccess[addr] = MemAccessRead
z.memAccess[addr+1] = MemAccessRead
return (uint16(z.core.MemRead(addr+1)) << 8) | uint16(z.core.MemRead(addr))
}
func (z *Z80) ww(addr uint16, val uint16) {
z.memAccess[addr] = MemAccessWrite
z.memAccess[addr+1] = MemAccessWrite
z.core.MemWrite(addr, byte(val))
z.core.MemWrite(addr+1, byte(val>>8))
}
@ -30,13 +36,13 @@ func (z *Z80) popW() uint16 {
}
func (z *Z80) nextB() byte {
b := z.rb(z.pc)
b := z.core.MemRead(z.pc)
z.pc++
return b
}
func (z *Z80) nextW() uint16 {
w := z.rw(z.pc)
w := (uint16(z.core.MemRead(z.pc+1)) << 8) | uint16(z.core.MemRead(z.pc))
z.pc += 2
return w
}

View File

@ -19,58 +19,58 @@ type CPUInterface interface {
// RunInstruction Run single instruction, return number of CPU cycles
RunInstruction() uint32
// GetState Get current CPU state
GetState() *Z80CPU
GetState() *CPU
// SetState Set current CPU state
SetState(state *Z80CPU)
SetState(state *CPU)
// DebugOutput out current CPU state
DebugOutput()
}
// FlagsType - Processor flags
type FlagsType struct {
S bool
Z bool
Y bool
H bool
X bool
P bool
N bool
C bool
S bool `json:"s,omitempty"`
Z bool `json:"z,omitempty"`
Y bool `json:"y,omitempty"`
H bool `json:"h,omitempty"`
X bool `json:"x,omitempty"`
P bool `json:"p,omitempty"`
N bool `json:"n,omitempty"`
C bool `json:"c,omitempty"`
}
// Z80CPU - Processor state
type Z80CPU struct {
A byte
B byte
C byte
D byte
E byte
H byte
L byte
AAlt byte
BAlt byte
CAlt byte
DAlt byte
EAlt byte
HAlt byte
LAlt byte
IX uint16
IY uint16
I byte
R byte
SP uint16
PC uint16
Flags FlagsType
FlagsAlt FlagsType
IMode byte
Iff1 bool
Iff2 bool
Halted bool
DoDelayedDI bool
DoDelayedEI bool
CycleCount uint32
InterruptOccurred bool
MemPtr uint16
type CPU struct {
A byte `json:"a,omitempty"`
B byte `json:"b,omitempty"`
C byte `json:"c,omitempty"`
D byte `json:"d,omitempty"`
E byte `json:"e,omitempty"`
H byte `json:"h,omitempty"`
L byte `json:"l,omitempty"`
AAlt byte `json:"AAlt,omitempty"`
BAlt byte `json:"BAlt,omitempty"`
CAlt byte `json:"CAlt,omitempty"`
DAlt byte `json:"DAlt,omitempty"`
EAlt byte `json:"EAlt,omitempty"`
HAlt byte `json:"HAlt,omitempty"`
LAlt byte `json:"LAlt,omitempty"`
IX uint16 `json:"IX,omitempty"`
IY uint16 `json:"IY,omitempty"`
I byte `json:"i,omitempty"`
R byte `json:"r,omitempty"`
SP uint16 `json:"SP,omitempty"`
PC uint16 `json:"PC,omitempty"`
Flags FlagsType `json:"flags"`
FlagsAlt FlagsType `json:"flagsAlt"`
IMode byte `json:"IMode,omitempty"`
Iff1 bool `json:"iff1,omitempty"`
Iff2 bool `json:"iff2,omitempty"`
Halted bool `json:"halted,omitempty"`
DoDelayedDI bool `json:"doDelayedDI,omitempty"`
DoDelayedEI bool `json:"doDelayedEI,omitempty"`
CycleCount uint32 `json:"cycleCount,omitempty"`
InterruptOccurred bool `json:"interruptOccurred,omitempty"`
MemPtr uint16 `json:"memPtr,omitempty"`
//core MemIoRW
}
@ -132,7 +132,7 @@ func (f *FlagsType) GetFlagsStr() string {
return string(flags)
}
func (z *Z80CPU) IIFStr() string {
func (z *CPU) IIFStr() string {
flags := []byte{'-', '-'}
if z.Iff1 {
flags[0] = '1'
@ -166,3 +166,7 @@ func (f *FlagsType) SetFlags(flags byte) {
f.N = flags&0x02 != 0
f.C = flags&0x01 != 0
}
func (z *CPU) GetPC() uint16 {
return z.PC
}

View File

@ -413,7 +413,7 @@ func TestZ80Fuse(t *testing.T) {
}
func setComputerState(test Z80TestIn) {
state := z80.Z80CPU{
state := z80.CPU{
A: byte(test.registers.AF >> 8),
B: byte(test.registers.BC >> 8),
C: byte(test.registers.BC),

672
z80/dis/z80disasm.go Normal file
View File

@ -0,0 +1,672 @@
package dis
import (
"fmt"
"okemu/z80"
"strings"
)
type Disassembler struct {
pc uint16
core z80.MemIoRW
}
type Disassembly interface {
Disassm(pc uint16) string
}
func NewDisassembler(core z80.MemIoRW) *Disassembler {
d := Disassembler{
pc: 0,
core: core,
}
return &d
}
// opcode & 0x07
var operands = []string{"B", "C", "D", "E", "H", "L", "(HL)", "A"}
var aluOp = []string{"ADD A" + sep, "ADC A" + sep, "SUB ", "SBC A" + sep, "AND ", "XOR ", "OR ", "CP "}
const sep = ", "
func (d *Disassembler) jp(op, cond string) string {
addr := d.getW()
if cond != "" {
cond += sep
}
return fmt.Sprintf("%s %s%s", op, cond, addr)
}
func (d *Disassembler) jr(op, cond string) string {
addr := d.pc
offset := d.getByte()
if offset&0x80 != 0 {
addr += 0xFF00 | uint16(offset)
} else {
addr += d.pc + uint16(offset)
}
if cond != "" {
cond += sep
}
return fmt.Sprintf("%s %s0x%04X", op, cond, addr)
}
func (d *Disassembler) getByte() byte {
b := d.core.MemRead(d.pc)
d.pc++
return b
}
func (d *Disassembler) Disassm(pc uint16) string {
d.pc = pc
result := fmt.Sprintf(" %04X ", d.pc)
op := d.getByte()
switch {
// == 00:0F
case op == 0x00:
result += "NOP"
case op == 0x01:
result += "LD BC" + sep + d.getW()
case op == 0x02:
result += "LD (BC)" + sep + "A"
case op == 0x03:
result += "INC BC"
case op == 0x04:
result += "INC B"
case op == 0x05:
result += "DEC B"
case op == 0x06:
result += "LD B" + sep + d.getB()
case op == 0x07:
result += "RLCA"
case op == 0x08:
result += "EX AF, AF'"
case op == 0x09:
result += "ADD HL" + sep + "BC"
case op == 0x0A:
result += "LD A" + sep + "(BC)"
case op == 0x0B:
result += "DEC BC"
case op == 0x0C:
result += "INC C"
case op == 0x0D:
result += "DEC C"
case op == 0x0E:
result += "LD C" + sep + d.getB()
case op == 0x0F:
result += "RRCA"
// 10:1F
case op == 0x10:
// DJNZ rel
result += d.jr("DJNZ", "")
case op == 0x11:
result += "LD DE" + sep + d.getW()
case op == 0x12:
result += "LD (DE)" + sep + "A"
case op == 0x13:
result += "INC DE"
case op == 0x14:
result += "INC D"
case op == 0x15:
result += "DEC D"
case op == 0x16:
result += "LD D" + sep + d.getB()
case op == 0x17:
result += "RLA"
case op == 0x18:
result += d.jr("JR", "")
case op == 0x19:
result += "ADD HL" + sep + "DE"
case op == 0x1A:
result += "LD A" + sep + "(DE)"
case op == 0x1B:
result += "DEC DE"
case op == 0x1C:
result += "INC E"
case op == 0x1D:
result += "DEC E"
case op == 0x1E:
result += "LD E" + sep + d.getB()
case op == 0x1F:
result += "RRA"
// == 20:2F
case op == 0x20:
result += d.jr("JR", "NZ")
case op == 0x21:
result += "LD HL" + sep + d.getW()
case op == 0x22:
// LD (nn),HL
result += "LD (" + d.getW() + ")" + sep + "HL"
case op == 0x23:
result += "INC HL"
case op == 0x24:
result += "INC H"
case op == 0x25:
result += "DEC H"
case op == 0x26:
result += "LD H" + sep + d.getB()
case op == 0x27:
result += "DAA"
case op == 0x28:
result += d.jr("JR", "Z")
case op == 0x29:
result += "ADD HL" + sep + "HL"
case op == 0x2A:
result += "LD HL" + sep + "(" + d.getW() + ")"
case op == 0x2B:
result += "DEC HL"
case op == 0x2C:
result += "INC L"
case op == 0x2D:
result += "DEC L"
case op == 0x2E:
result += "LD L" + sep + d.getB()
case op == 0x2F:
result += "CPL"
// == 30:3F
case op == 0x30:
result += d.jr("JR", "NC")
case op == 0x31:
result += "LD SP" + sep + d.getW()
case op == 0x32:
result += "LD (" + d.getW() + ")" + sep + "A"
case op == 0x33:
result += "INC SP"
case op == 0x34:
result += "INC (HL)"
case op == 0x35:
result += "DEC (HL)"
case op == 0x36:
result += "LD (HL)" + sep + d.getB()
case op == 0x37:
result += "SCF"
case op == 0x38:
result += d.jr("JR", "C")
case op == 0x39:
result += "ADD HL" + sep + "SP"
case op == 0x3A:
result += "LD A" + sep + "(" + d.getW() + ")"
case op == 0x3B:
result += "DEC SP"
case op == 0x3C:
result += "INC A"
case op == 0x3D:
result += "DEC A"
case op == 0x3E:
result += "LD A" + sep + d.getB()
case op == 0x3F:
result += "CCF"
case op == 0x76:
result += "HALT"
case op >= 0x40 && op <= 0x7F:
// LD op8, op8
result += "LD " + operands[(op>>3)&0x07] + sep + operands[op&0x07]
case op >= 0x80 && op <= 0xBF:
// ALU op8
result += aluOp[(op>>3)&0x07] + operands[op&0x07]
case op == 0xc0:
result += "RET NZ"
case op == 0xc1:
result += "POP BC"
case op == 0xc2:
result += d.jp("JP", "NZ")
case op == 0xc3:
result += d.jp("JP", "")
case op == 0xc4:
result += d.jp("CALL", "NZ")
case op == 0xc5:
result += "PUSH BC"
case op == 0xc6:
result += "ADD A" + sep + d.getB()
case op == 0xc7 || op == 0xd7 || op == 0xe7 || op == 0xf7 || op == 0xcf || op == 0xdf || op == 0xef || op == 0xff:
// RST nnH
result += fmt.Sprintf("RST %d%dH", (op>>4)&3, (op&1)*8)
case op == 0xc8:
result += "RET Z"
case op == 0xc9:
result += "RET"
case op == 0xca:
result += d.jp("JP", "Z")
case op == 0xcb:
result += d.opocodeCB()
case op == 0xcc:
result += d.jp("CALL", "Z")
case op == 0xcd:
result += d.jp("CALL", "")
case op == 0xce:
result += "ADC A" + sep + d.getB()
case op == 0xd0:
result += "RET NC"
case op == 0xd1:
result += "POP DE"
case op == 0xd2:
result += d.jp("JP", "NC")
case op == 0xd3:
result += "OUT (" + d.getB() + ")" + sep + "A"
case op == 0xd4:
result += d.jp("CALL", "NC")
case op == 0xd5:
result += "PUSH DE"
case op == 0xd6:
result += "SUB " + d.getB()
case op == 0xd8:
result += "RET C"
case op == 0xd9:
result += "EXX"
case op == 0xda:
result += d.jp("JP", "C")
case op == 0xdb:
result += "IN A" + sep + " (" + d.getB() + ")"
case op == 0xdc:
result += d.jp("CALL", "C")
case op == 0xdd:
result += d.opocodeDD(op)
case op == 0xde:
result += "SBC A" + sep + d.getB()
case op == 0xe0:
result += "RET PO"
case op == 0xe1:
result += "POP HL"
case op == 0xe2:
result += d.jp("JP", "PO")
case op == 0xe3:
result += "EX (SP)" + sep + "HL"
case op == 0xe4:
result += d.jp("CALL", "PO")
case op == 0xe5:
result += "PUSH HL"
case op == 0xe6:
result += "AND " + d.getB()
case op == 0xe8:
result += "RET PE"
case op == 0xe9:
result += "JP (HL)"
case op == 0xea:
result += d.jp("JP", "PE")
case op == 0xeb:
result += "EX DE" + sep + "HL"
case op == 0xec:
result += d.jp("CALL", "PE")
case op == 0xed:
result += d.opocodeED()
case op == 0xee:
result += "XOR " + d.getB()
case op == 0xf0:
result += "RET P"
case op == 0xf1:
result += "POP AF"
case op == 0xf2:
result += d.jp("JP", "P")
case op == 0xf3:
result += "DI"
case op == 0xf4:
result += d.jp("CALL", "P")
case op == 0xf5:
result += "PUSH AF"
case op == 0xf6:
result += "OR " + d.getB()
case op == 0xf8:
result += "RET M"
case op == 0xf9:
result += "LD SP" + sep + "HL"
case op == 0xfa:
result += d.jp("JP", "M")
case op == 0xfb:
result += "EI"
case op == 0xfc:
result += d.jp("CALL", "M")
case op == 0xfd:
result += d.opocodeDD(op)
case op == 0xfe:
result += "CP " + d.getB()
default:
// All unknown as DB
result += fmt.Sprintf("DB 0x%02X", op)
}
return result
}
func (d *Disassembler) getW() string {
lo := d.core.MemRead(d.pc)
d.pc++
hi := d.core.MemRead(d.pc)
d.pc++
return fmt.Sprintf("0x%02X%02X", hi, lo)
}
func (d *Disassembler) getB() string {
lo := d.core.MemRead(d.pc)
d.pc++
return fmt.Sprintf("0x%02X", lo)
}
func (d *Disassembler) getRel() string {
offset := d.core.MemRead(d.pc)
var sign string
if int8(offset) < 0 {
sign = "-"
} else {
sign = "+"
}
return sign + fmt.Sprintf("0x%02X", offset&0x7F)
}
var shiftOps = []string{"RLC", "RRC", "RL", "RR", "SLA", "SRA", "SLL", "SRL"}
var bitOps = []string{"BIT", "RES", "SET"}
// opocodeCB disassemble Z80 Opcodes, with CB first byte
func (d *Disassembler) opocodeCB() string {
op := ""
opcode := d.getByte()
if opcode <= 0x3F {
op = shiftOps[opcode>>3&0x07] + operands[opcode&0x7]
} else {
op = shiftOps[(opcode>>6&0x03)-1] + operands[opcode&0x7]
}
return op
}
func (d *Disassembler) opocodeDD(op byte) string {
opcode := d.getByte()
result := ""
switch opcode {
case 0x09:
result = "ADD ii" + sep + "BC"
case 0x19:
result = "ADD ii" + sep + "DE"
case 0x21:
result = "LD ii" + sep + d.getW()
case 0x22:
result = "LD (" + d.getW() + ")" + sep + "ii"
case 0x23:
result = "INC ii"
case 0x24:
result = "INC IXH"
case 0x25:
result = "DEC IXH"
case 0x26:
result = "LD IXH" + sep + "n"
case 0x29:
result = "ADD ii" + sep + "ii"
case 0x2A:
result = "LD ii" + sep + "(" + d.getW() + ")"
case 0x2B:
result = "DEC ii"
case 0x34:
result = "INC (ii" + d.getRel() + ")"
case 0x35:
result = "DEC (ii" + d.getRel() + ")"
case 0x36:
result = "LD (ii" + d.getRel() + ")" + sep + "n"
case 0x39:
result = "ADD ii" + sep + "SP"
case 0x46:
result = "LD B" + sep + "(ii" + d.getRel() + ")"
case 0x4E:
result = "LD C" + sep + "(ii" + d.getRel() + ")"
case 0x56:
result = "LD D" + sep + "(ii" + d.getRel() + ")"
case 0x5E:
result = "LD E" + sep + "(ii" + d.getRel() + ")"
case 0x66:
result = "LD H" + sep + "(ii" + d.getRel() + ")"
case 0x6E:
result = "LD L" + sep + "(ii" + d.getRel() + ")"
case 0x70:
result = "LD (ii" + d.getRel() + ")" + sep + "B"
case 0x71:
result = "LD (ii" + d.getRel() + ")" + sep + "C"
case 0x72:
result = "LD (ii" + d.getRel() + ")" + sep + "D"
case 0x73:
result = "LD (ii" + d.getRel() + ")" + sep + "E"
case 0x74:
result = "LD (ii" + d.getRel() + ")" + sep + "H"
case 0x75:
result = "LD (ii" + d.getRel() + ")" + sep + "L"
case 0x77:
result = "LD (ii" + d.getRel() + ")" + sep + "A"
case 0x7E:
result = "LD A" + sep + "(ii" + d.getRel() + ")"
case 0x86:
result = "ADD A" + sep + "(ii" + d.getRel() + ")"
case 0x8E:
result = "ADC A" + sep + "(ii" + d.getRel() + ")"
case 0x96:
result = "SUB (ii" + d.getRel() + ")"
case 0x9E:
result = "SBC A" + sep + "(ii" + d.getRel() + ")"
case 0xA6:
result = "AND (ii" + d.getRel() + ")"
case 0xAE:
result = "XOR (ii" + d.getRel() + ")"
case 0xB6:
result = "OR (ii" + d.getRel() + ")"
case 0xBE:
result = "CP (ii" + d.getRel() + ")"
case 0xCB:
result = d.opocodeDDCB(op, opcode)
case 0xE1:
result = "POP ii"
case 0xE3:
result = "EX (SP)" + sep + "ii"
case 0xE5:
result = "PUSH ii"
case 0xE9:
result = "JP (ii)"
case 0xF9:
result = "LD SP" + sep + "ii"
default:
return fmt.Sprintf("DB 0x%02X, 0x%02X", op, opcode)
}
reg := "IX"
if op == 0xFD {
reg = "IY"
}
return strings.ReplaceAll(result, "ii", reg)
}
func (d *Disassembler) opocodeDDCB(op1 byte, op2 byte) string {
opcode := d.getByte()
result := ""
switch opcode {
case 0x06:
result = "RLC (ii" + d.getRel() + ")"
case 0x0E:
result = "RRC (ii" + d.getRel() + ")"
case 0x16:
result = "RL (ii" + d.getRel() + ")"
case 0x1E:
result = "RR (ii" + d.getRel() + ")"
case 0x26:
result = "SLA (ii" + d.getRel() + ")"
case 0x2E:
result = "SRA (ii" + d.getRel() + ")"
case 0x3E:
result = "SRL (ii" + d.getRel() + ")"
case 0x46:
result = "BIT 0" + sep + "(ii" + d.getRel() + ")"
case 0x4E:
result = "BIT 1" + sep + "(ii" + d.getRel() + ")"
case 0x56:
result = "BIT 2" + sep + "(ii" + d.getRel() + ")"
case 0x5E:
result = "BIT 3" + sep + "(ii" + d.getRel() + ")"
case 0x66:
result = "BIT 4" + sep + "(ii" + d.getRel() + ")"
case 0x6E:
result = "BIT 5" + sep + "(ii" + d.getRel() + ")"
case 0x76:
result = "BIT 6" + sep + "(ii" + d.getRel() + ")"
case 0x7E:
result = "BIT 7" + sep + "(ii" + d.getRel() + ")"
case 0x86:
result = "RES 0" + sep + "(ii" + d.getRel() + ")"
case 0x8E:
result = "RES 1" + sep + "(ii" + d.getRel() + ")"
case 0x96:
result = "RES 2" + sep + "(ii" + d.getRel() + ")"
case 0x9E:
result = "RES 3" + sep + "(ii" + d.getRel() + ")"
case 0xA6:
result = "RES 4" + sep + "(ii" + d.getRel() + ")"
case 0xAE:
result = "RES 5" + sep + "(ii" + d.getRel() + ")"
case 0xB6:
result = "RES 6" + sep + "(ii" + d.getRel() + ")"
case 0xBE:
result = "RES 7" + sep + "(ii" + d.getRel() + ")"
case 0xC6:
result = "SET 0" + sep + "(ii" + d.getRel() + ")"
case 0xCE:
result = "SET 1" + sep + "(ii" + d.getRel() + ")"
case 0xD6:
result = "SET 2" + sep + "(ii" + d.getRel() + ")"
case 0xDE:
result = "SET 3" + sep + "(ii" + d.getRel() + ")"
case 0xE6:
result = "SET 4" + sep + "(ii" + d.getRel() + ")"
case 0xEE:
result = "SET 5" + sep + "(ii" + d.getRel() + ")"
case 0xF6:
result = "SET 6" + sep + "(ii" + d.getRel() + ")"
case 0xFE:
result = "SET 7" + sep + "(ii" + d.getRel() + ")"
default:
result = fmt.Sprintf("DB 0x%02X, 0x%02X, 0x%02X", op1, op2, opcode)
}
return result
}
func (d *Disassembler) opocodeED() string {
opcode := d.getByte()
result := ""
switch opcode {
case 0x40:
result = "IN B" + sep + "(C)"
case 0x41:
result = "OUT (C)" + sep + "B"
case 0x42:
result = "SBC HL" + sep + "BC"
case 0x43:
result = "LD (" + d.getW() + ")" + sep + "BC"
case 0x44, 0x4C, 0x54, 0x5C, 0x64, 0x6C, 0x74, 0x7C:
result = "NEG"
case 0x45, 0x55, 0x5D, 0x65, 0x6D, 0x75, 0x7D:
result = "RETN"
case 0x46, 0x4E, 0x66, 0x6E:
result = "IM 0"
case 0x47:
result = "LD I" + sep + "A"
case 0x48:
result = "IN C" + sep + "(C)"
case 0x49:
result = "OUT (C)" + sep + "C"
case 0x4A:
result = "ADC HL" + sep + "BC"
case 0x4B:
result = "LD BC" + sep + "(" + d.getW() + ")"
case 0x4D:
result = "REТI"
case 0x4F:
result = "LD R" + sep + "A"
case 0x50:
result = "IN D" + sep + "(C)"
case 0x51:
result = "OUT (C)" + sep + "D"
case 0x52:
result = "SBC HL" + sep + "DE"
case 0x53:
result = "LD (nn)" + sep + "DE"
case 0x56, 0x76:
result = "IM 1"
case 0x57:
result = "LD A" + sep + "I"
case 0x58:
result = "IN E" + sep + "(C)"
case 0x59:
result = "OUT (C)" + sep + "E"
case 0x5A:
result = "ADC HL" + sep + "DE"
case 0x5B:
result = "LD DE" + sep + "(" + d.getW() + ")"
case 0x5E, 0x7E:
result = "IM 2"
case 0x5F:
result = "LD A" + sep + "R"
case 0x60:
result = "IN H" + sep + "(C)"
case 0x61:
result = "OUT (C)" + sep + "H"
case 0x62:
result = "SBC HL" + sep + "HL"
case 0x63:
result = "LD (nn)" + sep + "HL"
case 0x67:
result = "RRD"
case 0x68:
result = "IN L" + sep + " (C)"
case 0x69:
result = "OUT (C)" + sep + "L"
case 0x6A:
result = "ADC HL" + sep + " HL"
case 0x6B:
result = "LD HL" + sep + " (nn)"
case 0x6F:
result = "RLD"
case 0x70:
result = "INF"
case 0x71:
result = "OUT (C)" + sep + " 0"
case 0x72:
result = "SBC HL" + sep + "SP"
case 0x73:
result = "LD (nn)" + sep + "SP"
case 0x78:
result = "IN A" + sep + "(C)"
case 0x79:
result = "OUT (C)" + sep + "A"
case 0x7A:
result = "ADC HL" + sep + "SP"
case 0x7B:
result = "LD SP" + sep + "(" + d.getW() + ")"
case 0xA0:
result = "LDI"
case 0xA1:
result = "CPI"
case 0xA2:
result = "INI"
case 0xA3:
result = "OUTI"
case 0xA8:
result = "LDD"
case 0xA9:
result = "CPD"
case 0xAA:
result = "IND"
case 0xAB:
result = "OUTD"
case 0xB0:
result = "LDIR"
case 0xB1:
result = "CPIR"
case 0xB2:
result = "INIR"
case 0xB3:
result = "OTIR"
case 0xB8:
result = "LDDR"
case 0xB9:
result = "CPDR"
case 0xBA:
result = "INDR"
case 0xBB:
result = "OTDR"
default:
result = fmt.Sprintf("DB 0xED, 0x%02X", opcode)
}
return result
}

91
z80/dis/z80disasm_test.go Normal file
View File

@ -0,0 +1,91 @@
package dis
import "testing"
var disasm *Disassembler
type TestComp struct {
memory [65536]byte
}
func (t *TestComp) M1MemRead(addr uint16) byte {
return t.memory[addr]
}
func (t *TestComp) MemRead(addr uint16) byte {
return t.memory[addr]
}
func (t *TestComp) MemWrite(addr uint16, val byte) {
t.memory[addr] = val
}
func (t *TestComp) IORead(port uint16) byte {
return byte(port >> 8)
}
func (t *TestComp) IOWrite(port uint16, val byte) {
//
}
var testComp *TestComp
func init() {
testComp = &TestComp{}
for i := 0; i < 65536; i++ {
testComp.memory[i] = 0x3f
}
disasm = NewDisassembler(testComp)
}
func setMemory(addr uint16, value []byte) {
for i := 0; i < len(value); i++ {
testComp.memory[addr+uint16(i)] = value[i]
}
}
var test = []byte{0x31, 0x2c, 0x05, 0x11, 0x0e, 0x01, 0x0e, 0x09, 0xcd, 0x05, 0x00, 0xc3, 0x00, 0x00}
func Test_LD_SP_nn(t *testing.T) {
expected := " 0100 LD SP, 0x052C"
setMemory(0x100, test)
res := disasm.Disassm(0x100)
if res != expected {
t.Errorf("Error disasm LD SP, nn, result '%s', expected '%s'", res, expected)
}
}
func Test_LD_DE_nn(t *testing.T) {
expected := " 0103 LD DE, 0x010E"
setMemory(0x100, test)
res := disasm.Disassm(0x103)
if res != expected {
t.Errorf("Error disasm LD DE, nn, result '%s', expected '%s'", res, expected)
}
}
func Test_LD_C_n(t *testing.T) {
expected := " 0106 LD C, 0x09"
setMemory(0x100, test)
res := disasm.Disassm(0x106)
if res != expected {
t.Errorf("Error disasm LD C, n, result '%s', expected '%s'", res, expected)
}
}
func Test_CALL_nn(t *testing.T) {
expected := " 0108 CALL 0x0005"
setMemory(0x100, test)
res := disasm.Disassm(0x108)
if res != expected {
t.Errorf("Error disasm CALL nn, result '%s', expected '%s'", res, expected)
}
}
func Test_JP_nn(t *testing.T) {
expected := " 010B JP 0x0000"
setMemory(0x100, test)
res := disasm.Disassm(0x10b)
if res != expected {
t.Errorf("Error disasm JP nn, result '%s', expected '%s'", res, expected)
}
}