gigatron/rom/Contrib/psr/multiply/src/quarter-square.asm.py
2025-01-28 19:17:01 +03:00

382 lines
10 KiB
Python

import math
import sys
from asm import (
AC,
C,
X,
Y,
adda,
align,
anda,
bge,
blt,
bne,
bra,
disableListing,
enableListing,
end,
fillers,
hi,
jmp,
label,
ld,
lo,
nop,
ora,
pc,
st,
suba,
writeRomFiles,
xora,
zpByte,
)
if __name__ == "__main__":
enableListing()
# Inputs
a, b = zpByte(), zpByte()
# Output
result = zpByte(2)
# Used for storage
tmp = zpByte()
# Used for the continuation address after lookup.
continuation = zpByte()
# Used for the action to take with the high byte.
high_byte_action = zpByte()
# The following code implements a lookup table of floored quarter squares,
# for values up to 255.
# This is supposed to enable a fast multiplication for 7-bit numbers.
# First the high-bytes.
# The table is shifted down by 32 places, as the first 32 high-bytes are all zero
# This allows us to have code later in the page which we can branch back to.
align(0x100, size=0x100)
label("Quarter-squares lookup table")
for i in range(32, 256):
val = math.floor(i**2 / 4)
ld(hi(val))
C(f"${val:04x} = {val} = floor({i} ** 2 / 4); ${val:04x} >> 8 = ${val >> 8:02x}")
# We jump back here after looking up the low-byte of the result.
label("low-byte return point")
ld(hi("multiply 7x7"), Y)
jmp(Y, [continuation])
ld(hi(pc()), Y) # Make it easy to get back here!
cost_of_low_byte_return = 3
label("table entry.possibly-negative")
# AC is negative, if b > a. Find absolute value
blt(pc() + 3) # 1
bra(pc() + 3) # 2
suba(1) # 3; if >= 0
xora(0xFF) # 3; if < 0
adda(1) # 4
cost_of_absolute = 4
label("table entry")
# Calculate an index into the high-byte table.
# This is basically a matter of subtracting 32, and jumping in if the result >= 0.
# But values greater than 160 have the sign-bit set after subtraction,
# despite being >32.
# We test for the sign bit and jump after subtraction even if 'negative' in these cases.
st([tmp]) # 1
blt(pc() + 5) # 2
suba(32) # 3
bge(AC) # 4
bra([high_byte_action]) # 5
ld(0) # 6
bra(AC) # 4
bra([high_byte_action]) # 5
cost_of_high_byte_table_entry = 6
# Some space here for other code?
fillers(until=251)
# We jump back here after looking up the high-byte of the result.
# Counting is in reverse
label("high-byte action.invert-and-add")
xora(0xFF) # 4
label("high-byte action.add")
adda([result + 1]) # 3
label("high-byte action.store")
st([result + 1]) # 2
ld([tmp]) # 1
assert pc() & 0xFF == 0xFF, pc()
cost_of_high_byte_invert_and_add = 4
cost_of_high_byte_add = 3
cost_of_high_byte_store = 2
label("low-byte table entry")
# With the index in the accumulator, and the current page in the Y register,
# We jump to the right entry in the low-byte table, replace AC with the result,
# and jump back out to 'low-byte return point' defined above,
# using the double-jump trick.
# This exploits the fact the following is the last instruction on the page,
# and the high byte of the PC has already incremented,
# so the bne will take us to an address in the next page.
# We use the Y register to return to the current page.
# The table has two zeros at the start, meaning that we can replace the first
# with the jump back.
bne(AC) # 1
align(0x100, size=0x100)
jmp(Y, "low-byte return point") # 2
ld(0) # 3
cost_of_low_byte_table_entry = 3
C("0 = floor(0 ** 2 / 4) and floor(1 ** 2 / 4)")
for i in range(2, 256):
val = math.floor(i**2 / 4)
ld(val)
C(f"${val:04x} = {val} = floor({i} ** 2 / 4)")
# Code copied from the main ROM. This provides a lookup table for right-shifts
align(0x100, size=0x100)
label("shiftTable")
shiftTable = pc()
for ix in range(255):
for n in range(1, 9): # Find first zero
if ~ix & (1 << (n - 1)):
break
pattern = ["x" if i < n else "1" if ix & (1 << i) else "0" for i in range(8)]
ld(ix >> n)
C("0b%s >> %d" % ("".join(reversed(pattern)), n))
assert pc() & 255 == 255
bra([continuation]) # Jumps back into next page
align(0x100, size=0x100)
nop() #
label("multiply 7x7")
# The formula is floor(((a + b) ** 2) / 4) - floor(((a - b) ** 2) / 4)
ld(".after-first-lookup") # 1
st([continuation])
ld(hi("Quarter-squares lookup table"), Y)
ld("high-byte action.store")
st([high_byte_action]) # 5
ld([a])
jmp(Y, "table entry") # 7
adda([b]) # 8
cost_to_first_lookup = 8
cost_after_first_lookup = (
cost_to_first_lookup
+ cost_of_high_byte_table_entry
+ cost_of_high_byte_store
+ cost_of_low_byte_table_entry
+ cost_of_low_byte_return
)
label(".after-first-lookup")
# On return we have the low-byte in the accumulator
# We can safely add one to it without causing an overflow,
# because 255 does not appear in the low-byte table.
# This is part of the following subtraction.
adda(1) # 1
st([result])
ld(".after-second-lookup")
st([continuation])
ld("high-byte action.invert-and-add") # 5
st([high_byte_action])
ld([a])
jmp(Y, "table entry.possibly-negative") # 8
suba([b]) # 9
cost_between_lookups = 9
cost_after_second_lookup = (
cost_after_first_lookup
+ cost_between_lookups
+ cost_of_absolute
+ cost_of_high_byte_table_entry
+ cost_of_high_byte_invert_and_add
+ cost_of_low_byte_table_entry
+ cost_of_low_byte_return
)
label(".after-second-lookup")
xora(0xFF) # 1
# We need to add this to the result
# But we may have a carry
adda([result])
st([tmp])
blt(pc() + 4) # 5
suba([result]) # 6
bra(pc() + 4) # 7
ora([result]) # 8
bra(pc() + 2) # 7
anda([result]) # 8
anda(0b1000_0000, X) # 9
ld([tmp])
st([result]) # 10
ld([X])
adda([result + 1])
st([result + 1]) # 13
cost_of_final_add = 13
cost_of_7bit_multiply = cost_after_second_lookup + cost_of_final_add
C(f"Total cost: {cost_of_7bit_multiply} cycles")
label("done")
nop()
label("multiply 8x8")
# Extend the 7bit x 7bit multiplication to 8bit x 8bit
#
# The logic goes as follows.
# Let A and B be the low seven bits of a and b, e.g.
# A = a₆2⁶ + a₅2⁵ + a₄2⁴ + a₃2³ + a₂2² + a₁2¹ + a₀2⁰
# B = b₆2⁶ + b₅2⁵ + b₄2⁴ + b₃2³ + b₂2² + b₁2¹ + b₀2⁰
# Then we could think of 8bit x 8bit multiply as
# (a₇2⁷ + A)(b₇2⁷ + B)
# Multiplying out the brackets gives
# a₇2⁷b₇2⁷ + a₇2⁷B + b₇2⁷A + AB
# (and AB is the result we already know how to calculate)
# Simplifying a bit, we get
# a₇b₇2¹⁴ + 2⁷(a₇B + b₇A) + AB
# Since a₇ and b₇ are one or zero, multiplying by them is like an if.
# We can consider four cases,
# if a₇ is 1 and b₇ is 1:
# 2¹⁴ + 2⁷(B + A) + AB
# if a₇ is 1 and b₇ is 0:
# 2⁷B + AB
# if a₇ is 0 and b₇ is 1:
# 2⁷A + AB
# if a₇ is 0 and b₇ is 0:
# AB
# Multiplication by 2⁷(=128) is a left-shift by seven, which
# can be broken down as moving the LSB of the low-byte to the
# MSB of the low-byte, and setting all of the other bits to zero,
# The other bits of the low byte can be right-shifted by one,
# and moved to the high byte.
# Test which of the branches we are on.
ld([a]) # 1
xora([b])
blt(".one MSB set") # 3
ld([a]) # 4
# Both MSBs equal
bge("multiply 7x7") # 5
anda(0b0111_1111) # 6
cost_of_8bit_multiply__both_msbs_low = cost_of_7bit_multiply + 6
# Both MSBs set
st([a]) # 7; a = A
ld(".after right-shift")
st([continuation])
ld(2**14 >> 8) # 10; Write the high-byte for later addition.
st([result + 1])
ld([b])
anda(0b0111_1111)
st([b]) # b = B
adda([a]) # 15
cost_of_both_msbs_set = 15
label(". << 7")
st([tmp]) # 1
anda(0b0000_0001) # Clear all but the bottom bit
adda(0b0111_1111) # Carries bottom bit to top bit
anda(0b1000_0000) # Clears all but the top bit
st([result]) # 5
ld([tmp])
anda(0b1111_1110) # Calculate index to right-shift-table
ld(hi("shiftTable"), Y)
jmp(Y, AC) # 9
bra(0xFF) # 10
# 11 ld (a + b) >> 1
# 12 bra [continuation]
# 12 nop
cost_of_right_shift = 13
cost_after_right_shift = cost_of_both_msbs_set + cost_of_right_shift
label(".one MSB set")
bge(".b has msb set") # 5
ld(0b0111_1111) # 6
# a has msb set
anda([a]) # 7
st([a])
ld(lo(".after right-shift") + 1)
st([continuation]) # 10
bra(". << 7") # 11
ld([b]) # 12
label(".b has msb set")
anda([b]) # 7
st([b])
ld(lo(".after right-shift") + 1)
st([continuation]) # 10
bra(". << 7") # 11
ld([a]) # 12
cost_of_one_msb_set = 12
one_msb_cost_saving = cost_of_both_msbs_set - cost_of_one_msb_set + 1
label(".after right-shift")
adda([result + 1]) # 1
st([result + 1])
ld(".after-first-lookup-8bit")
st([continuation])
ld(hi("Quarter-squares lookup table"), Y) # 5
ld("high-byte action.add")
st([high_byte_action])
ld([a])
jmp(Y, "table entry") # 9
adda([b]) # 10
cost_of_after_right_shift = 10
cost_after_first_lookup__8bit = (
cost_after_right_shift
+ cost_of_after_right_shift
+ cost_of_high_byte_table_entry
+ cost_of_high_byte_add
+ cost_of_low_byte_table_entry
+ cost_of_low_byte_return
)
label(".after-first-lookup-8bit")
# On return we have the low-byte in the accumulator
# We already either have 0 or 128 in the low byte of the result.
# Adding will cause an overflow into the high byte iff both
# have the MSB set, in which case the result will definitely
# not have the MSB set.
adda([result]) # 1
blt(pc() + 3) # 2
bra(pc() + 3) # 3
ld([result], X) # 4; May be a carry
ld(0, X) # 4; Definitely no carry
# We can safely add one to the result without causing a further overflow,
# because 255 does not appear in the low-byte table.
# This is part of the subtraction.
adda(1) # 5
st([result])
# Do the carry
ld([X])
adda([result + 1])
st([result + 1])
ld(".after-second-lookup") # 10
st([continuation])
ld("high-byte action.invert-and-add")
st([high_byte_action])
ld([a])
jmp(Y, "table entry.possibly-negative") # 15
suba([b]) # 16
cost_between_lookups__8bit = 16
cost_after_second_lookup__8bit = (
cost_after_first_lookup__8bit
+ cost_between_lookups__8bit
+ cost_of_absolute
+ cost_of_high_byte_table_entry
+ cost_of_high_byte_invert_and_add
+ cost_of_low_byte_table_entry
+ cost_of_low_byte_return
)
cost_of_8bit_multiply = cost_after_second_lookup__8bit + cost_of_final_add
no_msb_cost_saving = cost_of_8bit_multiply - cost_of_8bit_multiply__both_msbs_low
C(f"Worst case cost: {cost_of_8bit_multiply}")
C(f"Saving when both values < 128: {no_msb_cost_saving}")
C(f"Saving when only one value < 128: {one_msb_cost_saving}")
end()
if __name__ == "__main__":
disableListing()
writeRomFiles(sys.argv[0])