Pyrtl floating point library#475
Conversation
Codecov Report❌ Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## development #475 +/- ##
=============================================
+ Coverage 91.0% 91.9% +0.9%
=============================================
Files 25 31 +6
Lines 7091 7452 +361
=============================================
+ Hits 6450 6845 +395
+ Misses 641 607 -34 ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
fdxmw
left a comment
There was a problem hiding this comment.
Thank you for this contribution! Here are some initial comments
pyrtl/rtllib/pyrtlfloat/_add_sub.py
Outdated
| from ._types import PyrtlFloatConfig, RoundingMode | ||
|
|
||
|
|
||
| class AddSubHelper: |
There was a problem hiding this comment.
Generally, classes with only @staticmethods probably shouldn't be classes :) It looks like this class can be removed, and all the static methods can be ordinary functions? Same comment for FloatUtils and MultiplicationHelper
pyrtl/rtllib/pyrtlfloat/_add_sub.py
Outdated
| raw_result_grs: pyrtl.WireVector, | ||
| ) -> tuple[pyrtl.WireVector, pyrtl.WireVector]: | ||
| last = raw_result_mantissa[0] | ||
| guard = raw_result_grs[2] |
There was a problem hiding this comment.
Consider using wire_struct to simplify these kinds of concat/slice patterns. With a wire_struct, you wouldn't need this line, and later lines would instead refer to raw_result.guard.
Also see the wire_struct example.
| fp_prop: FPTypeProperties, wire: pyrtl.WireVector | ||
| ) -> pyrtl.WireVector: | ||
| return wire[ | ||
| fp_prop.num_mantissa_bits : fp_prop.num_mantissa_bits |
There was a problem hiding this comment.
Similarly, wire_struct can remove the need for these tricky bit offset calculations
There was a problem hiding this comment.
I was trying to do this, but the problem is that the number of mantissa and exponent bits differs based on fp_prop, so we would need to set the bitwidths of the slices in the wire_struct dynamically at runtime. Is it possible to do this with wire_struct?
There was a problem hiding this comment.
It is possible with some trickery, but it's probably overkill in this case, since you only need to select between a few well-known bit layouts (the user can't define an arbitrary custom floating point format). I think you could define a set of wire_structs that share the same interface, something like:
@pyrtl.wire_struct
class Float16
sign: 1
exponent: 5
fraction: 10
@pyrtl.wire_struct
class Float32
sign: 1
exponent: 8
fraction: 23So the idea is that these objects share the same interface, so it's easier to write code that works with any of these types. For example, if you want to compare sign bits, you'd write a.sign == b.sign, and that would work regardless of whether a or b are Float16 or Float32. If you care about types (I encourage you to! :), these are Unions like Float16 | Float32.
In case it's useful, here's the trickery, you can define a function that returns a dynamically-defined class:
import pyrtl
def define_my_struct(a_bits: int, b_bits: int):
@pyrtl.wire_struct
class MyInternalClass:
a: a_bits
b: b_bits
return MyInternalClass
MyStruct = define_my_struct(a_bits=4, b_bits=8)
my_struct_instance = MyStruct(a=1, b=2)
print("a bitwidth", my_struct_instance.a.bitwidth)
print("b bitwidth", my_struct_instance.b.bitwidth)
print("total bitwidth", my_struct_instance.bitwidth)$ uv run ...
a bitwidth 4
b bitwidth 8
total bitwidth 12
pyrtl/rtllib/pyrtlfloat/_add_sub.py
Outdated
|
|
||
| class AddSubHelper: | ||
| @staticmethod | ||
| def add( |
There was a problem hiding this comment.
We'll need docstrings for all user-facing classes, methods, and functions. Docstrings for internal stuff would also be great; for example as someone new to this code it would really help to have an idea of what make_denormals_zero does conceptually.
| ) | ||
|
|
||
| @staticmethod | ||
| def is_inf(fp_prop: FPTypeProperties, wire: pyrtl.WireVector) -> pyrtl.WireVector: |
There was a problem hiding this comment.
It would really help to have some comments with brief definitions and references for important floating point concepts (NaN, inf, denormalized (how are these different?), guard, round, sticky, ...).
I'm not expecting you to teach me everything in comments, but it would really help to know what I should read if I want to understand how this works :)
| exponent_min_value = sum_exponent_min_value - need_to_normalize | ||
|
|
||
| if rounding_mode == RoundingMode.RNE: | ||
| raw_result_exponent = rounded_product_exponent[0:num_exp_bits] |
There was a problem hiding this comment.
I think you can omit the 0 here, also on line 119 below. Doing so would be more consistent with your code above, line 63 for example.
| result_exponent = pyrtl.WireVector(bitwidth=num_exp_bits) | ||
| result_mantissa = pyrtl.WireVector(bitwidth=num_mant_bits) | ||
|
|
||
| operand_a_nan = FloatUtils.is_NaN(fp_type_props, operand_a_daz) |
There was a problem hiding this comment.
Capitalization of NaN is inconsistent here, it's nan on the left hand side and NaN on the right hand side. I'd rename is_NaN to is_nan since we pretty consistently use lowercase with underscores for method and function names.
| self.sim = pyrtl.Simulation() | ||
|
|
||
| def test_multiplication_simple(self): | ||
| self.sim.step({"a": 0b0100001000000000, "b": 0b0100010100000000}) |
There was a problem hiding this comment.
We'll need a better story for how this is tested :) A few things to consider:
Every PyRTL developer will need to run these tests for the indefinite future, so we really don't want the tests to run for too long. I'd say all continuously-run floating point tests should complete in a couple seconds or less. These tests can't be comprehensive, but we should still try to get as much value as we can out of them.
Try to think about the most interesting inputs (zeroes, inf, nan, ...), and which combinations of inputs and operations are worth testing. I would avoid randomly generated tests because they are unlikely to cover all the interesting combinations with a reasonable number of test cases.
Think about how someone besides yourself would debug one of these test failures. It would help a lot to know what cases a particular test exercises (# Test addition with inf, etc). It also helps to know if the unexpected output is slightly wrong, or complete nonsense, which is pretty hard to tell from a raw bit pattern (see next point :)
Consider adding some helper functions that can convert float to and from these bit patterns. That would make it easier for someone to understand that the test checks that 1.0 + 2.0 == 3.0, for example.
There was a problem hiding this comment.
Yep, testing needs more work, I'll add more rigorous tests and test edge cases. That's why I marked it as a draft PR because it is not finished yet.
|
|
||
| @staticmethod | ||
| def multiply( | ||
| def mul( |
There was a problem hiding this comment.
This doesn't seem very useful, since it just passes all its arguments through to MultiplicationHelper.multiply? Seems better to point users directly to MultiplicationHelper.multiply. Same comments for the other methods below.
There was a problem hiding this comment.
The reason I have this FloatOperations class and separate MultiplicationHelper and AddSubHelper classes is because the user is supposed to use FloatOperations, MultiplicationHelper and AddSubHelper are internal helpers and the user is not supposed to use them. I created these helpers so I can separate the multiplication and addition/subtraction logic into separate files.
There was a problem hiding this comment.
I think that makes sense, but I'd still try to refactor, it would be nice to remove these functions that are just pass-through wrappers for another function. Maybe:
AddSubHelper.addcould be an ordinary function instead of a@staticmethod_add_sub.pycould be renamed toadd_sub.py- Non-public functions in
add_sub.pycould have an underscore prefix __init__.pycould importaddfromadd_sub
I think that would still make it clear which parts of the interface are meant for public use, while removing a layer of indirection?
This kind of indirection tends to be annoying when debugging problems, because you have to jump through another hoop to find the code you're looking for. "Our princess is in another castle!"
pyrtl/rtllib/pyrtlfloat/_add_sub.py
Outdated
| operand_a: pyrtl.WireVector, | ||
| operand_b: pyrtl.WireVector, | ||
| ) -> pyrtl.WireVector: | ||
| fp_type_props = config.fp_type_properties |
There was a problem hiding this comment.
All user-facing methods/functions should check (somewhere, not necessarily here) that the operand bitwidths are consistent with config
fdxmw
left a comment
There was a problem hiding this comment.
Sorry for the slow review! This code is making a lot more sense to me now, and I think it looks pretty good! I've only reviewed the main addition and subtraction code for now, I'll take a look at multiplication tomorrow. Here are some comments in the meantime.
Also I still stand by my earlier comments about trying to remove layers of indirection that don't do much work (FloatOperations, _BaseTypedFloatOperations) and trying to use WireStruct instead of calculating bit offsets :)
|
|
||
| # RNE rounding uses the guard, round, and sticky bits. | ||
| # When shifting the smaller mantissa to the right, some bits are shifted out. | ||
| # The first bit shifted out becomes the guard bit, the second becomes the round bit, |
There was a problem hiding this comment.
The words first and second here are a bit confusing because I can interpret it the other way, for example if I shift the bits one at a time, the first bit that gets shifted out is the least significant bit (smaller_operand_mantissa[0]), regardless of the shift_amount.
Maybe the first bit could be something like "the most significant bit to the right of the mantissa, after shifting"?
| # shifted out are the guard and round bits, and the sticky bit is | ||
| # the OR of all remaining bits. | ||
| with smaller_mantissa_shift_amount >= 2: | ||
| guard_and_round = pyrtl.shift_right_logical( |
There was a problem hiding this comment.
This is fine for now, but keep in mind that shift_right_logical creates a barrel shifter, which is pretty expensive, and we just made one on L75 above (and we'll make another on L95 below :). It should be possible to use one barrel shifter instead of three, by appending zero-valued GRS bits to the smaller_operand_mantissa before we shift it. That should work for the guard and round bits, but properly tracking the sticky bit would likely require changes to the barrel shifter.
| pyrtl.Const(1), get_mantissa(fp_type_props, operand_larger) | ||
| ) | ||
|
|
||
| # Align mantissas by shifting the smaller one to match the larger's exponent. |
There was a problem hiding this comment.
I think an important insight here is:
- Shifting the mantissa right by one divides the value by two.
- Adding one to the exponent multiplies the value by two.
- If we simultaneously do (1) and (2), we don't change the value, but it helps match the operand's exponents for easier math.
| ) | ||
|
|
||
| # Perform subtraction of operands. | ||
| sub_exponent, sub_mantissa, sub_grs, num_leading_zeros = _sub_operands( |
There was a problem hiding this comment.
Naming seems a little inconsistent here, above _add_operands returns sum_exponent, while here _sub_operands returns sub_exponent. Seems like this should be called difference_exponent (addition results in a sum, subtraction results in a difference)?
| larger_operand_exponent > larger_exponent_max_value | ||
| ): | ||
| final_result_sign |= larger_operand_sign | ||
| if rounding_mode == RoundingMode.RNE: |
There was a problem hiding this comment.
It's pretty counterintuitive that rounding_mode determines whether you get inf or largest_finite_number on overflow. A reference would be really useful here
| # subtract the rounding increment from the absolute minimum exponent, | ||
| # which is one greater than the all-0s exponent (reserved for | ||
| # zero and denormals). | ||
| initial_larger_exponent_min_value = pyrtl.Const(1) |
There was a problem hiding this comment.
Seems like this initialization should be = None, since it always gets overwritten in the if statement below? (calling pyrtl.Const(1) creates a Const and adds it to the Block, but that Const is never used)
| # Round the result if using RNE rounding mode. | ||
| if rounding_mode == RoundingMode.RNE: | ||
| ( | ||
| raw_result_rounded_exponent, |
There was a problem hiding this comment.
It doesn't seem useful to give the rounded results new names (raw_result_rounded_exponent, etc), since we don't seem to use the not-rounded results (raw_result_exponent, etc). If this instead updated raw_result_exponent and friends, I think we could remove the conditional on L287 below, since we can always assign final_result_exponent |= raw_result_exponent ?
Also this function is quite long and there are a lot of variables, so it would really help readers to clarify variable lifetimes (which variables are still in use), and how groups of variables are related. Some ways to do this include:
- Avoid creating new variables when it makes sense to reuse an existing one (this suggestion)
- Try to define fewer variables with more internal structure, rather than defining a flat layer of top-level variables (
raw_result.exponentrather thanraw_result_exponent) (an earlier suggestion) - Factor out more functions, which clearly defines the inputs and outputs for each chunk of code, and creates more local variable scopes
- Explicitly
delvariables that aren't needed anymore
| ): | ||
| final_result_sign |= larger_operand_sign | ||
| make_zero(final_result_exponent, final_result_mantissa) | ||
| with pyrtl.otherwise: |
There was a problem hiding this comment.
Could use a comment like "Otherwise no special cases apply: this is the common case" :)
There was a problem hiding this comment.
This should be called _float_utils.py? (only one L in utilities :)
| fp_prop: FPTypeProperties, wire: pyrtl.WireVector | ||
| ) -> pyrtl.WireVector: | ||
| """ | ||
| Returns whether the floating point number is denormalized. |
There was a problem hiding this comment.
A reference that defines denormalized numbers would be very helpful. https://en.wikipedia.org/wiki/Subnormal_number seems good, but maybe you've found something better?
| return out | ||
|
|
||
|
|
||
| def make_inf( |
There was a problem hiding this comment.
It's a little confusing that make_denormals_zero returns a WireVector while make_inf conditionally assigns its argument WireVectors. Consistency is clearer, so try to be consistent :)
I'd recommend just returning WireVectors in all these methods because the implementation is much easier to understand. Very few people know what happens when you do a |= assignment outside of a conditional_assignment block.
|
|
||
| # Denormalized numbers are not supported, so we flush them to zero. | ||
| operands = (operand_a, operand_b) | ||
| operands_daz = tuple(make_denormals_zero(fp_type_props, op) for op in operands) |
There was a problem hiding this comment.
I left a similar comment on _add_sub.py, but I'd redefine operands here rather than defining a new operands_daz since we don't seem to use operands anymore after this point.
Similarly, this is a big function, so think about how someone new to this code will read it, and look for ways to help them navigate and understand what's important.
Don't forget that you have the Curse of Knowledge: you wrote all this code, so it already makes perfect sense to you :) To really test your code's readability, you have to show it to other people.
| total_bits = num_exp_bits + num_mant_bits + 1 | ||
|
|
||
| # Denormalized numbers are not supported, so we flush them to zero. | ||
| operand_a_daz = make_denormals_zero(fp_type_props, operand_a) |
There was a problem hiding this comment.
Consider grouping the operands into a tuple, for consistency with the multiplication implementation, and removing redundant code. Maybe it should be a sorted tuple, so the smaller operand is operands[0] and we wouldn't need more names to track smaller/larger?
|
|
||
| def decode_float16(bits): | ||
| """Decode Float16 bits to sign, exponent, and mantissa.""" | ||
| return (bits >> 15) & 1, (bits >> 10) & 0x1F, bits & 0x3FF |
There was a problem hiding this comment.
Similarly, assert that bits fits in 16 bits
| """Tests for rounding in addition (RNE and RTZ).""" | ||
|
|
||
| def setUp(self): | ||
| pyrtl.reset_working_block() |
There was a problem hiding this comment.
All these tests seem to use the same hardware, so it would help to factor this logic out to make it obvious that we're still using the same hardware, or to clearly point out any differences. Trying to determine this as a reader is not easy :) Some ideas:
- Just define all the tests that use the same hardware in the same class. You can add big section comments like
to group the different kinds of tests for people reading the code. I recommend this approach.
############################ # Additional rounding tests.
- Factor the hardware definition out into a function.
- Move the hardware definition to a base class.
| self.assertTrue(is_nan(self.sim.inspect("result_rtz"))) | ||
|
|
||
| def test_add_denormal_flushed_to_zero(self): | ||
| """Test that denormal inputs are flushed to zero.""" |
There was a problem hiding this comment.
Same nit about consistency, I've seen at least denormal, denorms, and denormalized in this PR. Try to pick one and use it consistently :)
| def test_add_one_plus_two(self): | ||
| """Test 1.0 + 2.0 = 3.0""" | ||
| self.sim.step({"a": FLOAT16_ONE, "b": FLOAT16_TWO}) | ||
| self.assertEqual(self.sim.inspect("result_rne"), FLOAT16_THREE) |
There was a problem hiding this comment.
I'm sure you've run into this, but think about how someone would debug one of these assertEqual failures. I think you'd get a message that says
{big decimal number} is not equal to {other big decimal number}"
which is not very helpful.
Most likely, the first thing you'd do to debug a failure is to call decode_float16 to split out the sign, exponent, and mantissa on the unexpected value. Since this is the most likely first debug step, it'd really help the next person that works on this (which could be you! :) to integrate decode_float16 into these checks.
Maybe we can define a assertFloat16Equal helper that internally calls decode_float16 and separately calls assertEqual on the sign, exponent, and mantissa?
| FLOAT16_NEG_ZERO = 0x8000 | ||
| FLOAT16_POS_INF = 0x7C00 | ||
| FLOAT16_NEG_INF = 0xFC00 | ||
| FLOAT16_NAN = 0x7E00 # Quiet NaN with mantissa bit set |
There was a problem hiding this comment.
No description provided.