Skip to content

Pyrtl floating point library#475

Open
gaborszita wants to merge 3 commits intoUCSBarchlab:developmentfrom
gaborszita:pyrtlfloat
Open

Pyrtl floating point library#475
gaborszita wants to merge 3 commits intoUCSBarchlab:developmentfrom
gaborszita:pyrtlfloat

Conversation

@gaborszita
Copy link
Contributor

No description provided.

@codecov
Copy link

codecov bot commented Oct 30, 2025

Codecov Report

❌ Patch coverage is 98.89503% with 4 lines in your changes missing coverage. Please review.
✅ Project coverage is 91.9%. Comparing base (8c706f6) to head (e601040).
⚠️ Report is 5 commits behind head on development.

Files with missing lines Patch % Lines
pyrtl/rtllib/pyrtlfloat/floatoperations.py 91.9% 4 Missing ⚠️
Additional details and impacted files
@@              Coverage Diff              @@
##           development    #475     +/-   ##
=============================================
+ Coverage         91.0%   91.9%   +0.9%     
=============================================
  Files               25      31      +6     
  Lines             7091    7452    +361     
=============================================
+ Hits              6450    6845    +395     
+ Misses             641     607     -34     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

Copy link
Member

@fdxmw fdxmw left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for this contribution! Here are some initial comments

from ._types import PyrtlFloatConfig, RoundingMode


class AddSubHelper:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Generally, classes with only @staticmethods probably shouldn't be classes :) It looks like this class can be removed, and all the static methods can be ordinary functions? Same comment for FloatUtils and MultiplicationHelper

raw_result_grs: pyrtl.WireVector,
) -> tuple[pyrtl.WireVector, pyrtl.WireVector]:
last = raw_result_mantissa[0]
guard = raw_result_grs[2]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Consider using wire_struct to simplify these kinds of concat/slice patterns. With a wire_struct, you wouldn't need this line, and later lines would instead refer to raw_result.guard.

Also see the wire_struct example.

fp_prop: FPTypeProperties, wire: pyrtl.WireVector
) -> pyrtl.WireVector:
return wire[
fp_prop.num_mantissa_bits : fp_prop.num_mantissa_bits
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Similarly, wire_struct can remove the need for these tricky bit offset calculations

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was trying to do this, but the problem is that the number of mantissa and exponent bits differs based on fp_prop, so we would need to set the bitwidths of the slices in the wire_struct dynamically at runtime. Is it possible to do this with wire_struct?

Copy link
Member

@fdxmw fdxmw Nov 14, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is possible with some trickery, but it's probably overkill in this case, since you only need to select between a few well-known bit layouts (the user can't define an arbitrary custom floating point format). I think you could define a set of wire_structs that share the same interface, something like:

@pyrtl.wire_struct
class Float16
    sign: 1
    exponent: 5
    fraction: 10

@pyrtl.wire_struct
class Float32
    sign: 1
    exponent: 8
    fraction: 23

So the idea is that these objects share the same interface, so it's easier to write code that works with any of these types. For example, if you want to compare sign bits, you'd write a.sign == b.sign, and that would work regardless of whether a or b are Float16 or Float32. If you care about types (I encourage you to! :), these are Unions like Float16 | Float32.

In case it's useful, here's the trickery, you can define a function that returns a dynamically-defined class:

import pyrtl

def define_my_struct(a_bits: int, b_bits: int):
    @pyrtl.wire_struct
    class MyInternalClass:
        a: a_bits
        b: b_bits

    return MyInternalClass

MyStruct = define_my_struct(a_bits=4, b_bits=8)

my_struct_instance = MyStruct(a=1, b=2)

print("a bitwidth", my_struct_instance.a.bitwidth)
print("b bitwidth", my_struct_instance.b.bitwidth)
print("total bitwidth", my_struct_instance.bitwidth)
$ uv run ...
a bitwidth 4
b bitwidth 8
total bitwidth 12


class AddSubHelper:
@staticmethod
def add(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We'll need docstrings for all user-facing classes, methods, and functions. Docstrings for internal stuff would also be great; for example as someone new to this code it would really help to have an idea of what make_denormals_zero does conceptually.

)

@staticmethod
def is_inf(fp_prop: FPTypeProperties, wire: pyrtl.WireVector) -> pyrtl.WireVector:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It would really help to have some comments with brief definitions and references for important floating point concepts (NaN, inf, denormalized (how are these different?), guard, round, sticky, ...).

I'm not expecting you to teach me everything in comments, but it would really help to know what I should read if I want to understand how this works :)

exponent_min_value = sum_exponent_min_value - need_to_normalize

if rounding_mode == RoundingMode.RNE:
raw_result_exponent = rounded_product_exponent[0:num_exp_bits]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think you can omit the 0 here, also on line 119 below. Doing so would be more consistent with your code above, line 63 for example.

result_exponent = pyrtl.WireVector(bitwidth=num_exp_bits)
result_mantissa = pyrtl.WireVector(bitwidth=num_mant_bits)

operand_a_nan = FloatUtils.is_NaN(fp_type_props, operand_a_daz)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Capitalization of NaN is inconsistent here, it's nan on the left hand side and NaN on the right hand side. I'd rename is_NaN to is_nan since we pretty consistently use lowercase with underscores for method and function names.

self.sim = pyrtl.Simulation()

def test_multiplication_simple(self):
self.sim.step({"a": 0b0100001000000000, "b": 0b0100010100000000})
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We'll need a better story for how this is tested :) A few things to consider:

Every PyRTL developer will need to run these tests for the indefinite future, so we really don't want the tests to run for too long. I'd say all continuously-run floating point tests should complete in a couple seconds or less. These tests can't be comprehensive, but we should still try to get as much value as we can out of them.

Try to think about the most interesting inputs (zeroes, inf, nan, ...), and which combinations of inputs and operations are worth testing. I would avoid randomly generated tests because they are unlikely to cover all the interesting combinations with a reasonable number of test cases.

Think about how someone besides yourself would debug one of these test failures. It would help a lot to know what cases a particular test exercises (# Test addition with inf, etc). It also helps to know if the unexpected output is slightly wrong, or complete nonsense, which is pretty hard to tell from a raw bit pattern (see next point :)

Consider adding some helper functions that can convert float to and from these bit patterns. That would make it easier for someone to understand that the test checks that 1.0 + 2.0 == 3.0, for example.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yep, testing needs more work, I'll add more rigorous tests and test edge cases. That's why I marked it as a draft PR because it is not finished yet.


@staticmethod
def multiply(
def mul(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This doesn't seem very useful, since it just passes all its arguments through to MultiplicationHelper.multiply? Seems better to point users directly to MultiplicationHelper.multiply. Same comments for the other methods below.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The reason I have this FloatOperations class and separate MultiplicationHelper and AddSubHelper classes is because the user is supposed to use FloatOperations, MultiplicationHelper and AddSubHelper are internal helpers and the user is not supposed to use them. I created these helpers so I can separate the multiplication and addition/subtraction logic into separate files.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think that makes sense, but I'd still try to refactor, it would be nice to remove these functions that are just pass-through wrappers for another function. Maybe:

  1. AddSubHelper.add could be an ordinary function instead of a @staticmethod
  2. _add_sub.py could be renamed to add_sub.py
  3. Non-public functions in add_sub.py could have an underscore prefix
  4. __init__.py could import add from add_sub

I think that would still make it clear which parts of the interface are meant for public use, while removing a layer of indirection?

This kind of indirection tends to be annoying when debugging problems, because you have to jump through another hoop to find the code you're looking for. "Our princess is in another castle!"

operand_a: pyrtl.WireVector,
operand_b: pyrtl.WireVector,
) -> pyrtl.WireVector:
fp_type_props = config.fp_type_properties
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

All user-facing methods/functions should check (somewhere, not necessarily here) that the operand bitwidths are consistent with config

@gaborszita gaborszita marked this pull request as ready for review January 27, 2026 04:12
Copy link
Member

@fdxmw fdxmw left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry for the slow review! This code is making a lot more sense to me now, and I think it looks pretty good! I've only reviewed the main addition and subtraction code for now, I'll take a look at multiplication tomorrow. Here are some comments in the meantime.

Also I still stand by my earlier comments about trying to remove layers of indirection that don't do much work (FloatOperations, _BaseTypedFloatOperations) and trying to use WireStruct instead of calculating bit offsets :)


# RNE rounding uses the guard, round, and sticky bits.
# When shifting the smaller mantissa to the right, some bits are shifted out.
# The first bit shifted out becomes the guard bit, the second becomes the round bit,
Copy link
Member

@fdxmw fdxmw Feb 26, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The words first and second here are a bit confusing because I can interpret it the other way, for example if I shift the bits one at a time, the first bit that gets shifted out is the least significant bit (smaller_operand_mantissa[0]), regardless of the shift_amount.

Maybe the first bit could be something like "the most significant bit to the right of the mantissa, after shifting"?

# shifted out are the guard and round bits, and the sticky bit is
# the OR of all remaining bits.
with smaller_mantissa_shift_amount >= 2:
guard_and_round = pyrtl.shift_right_logical(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is fine for now, but keep in mind that shift_right_logical creates a barrel shifter, which is pretty expensive, and we just made one on L75 above (and we'll make another on L95 below :). It should be possible to use one barrel shifter instead of three, by appending zero-valued GRS bits to the smaller_operand_mantissa before we shift it. That should work for the guard and round bits, but properly tracking the sticky bit would likely require changes to the barrel shifter.

pyrtl.Const(1), get_mantissa(fp_type_props, operand_larger)
)

# Align mantissas by shifting the smaller one to match the larger's exponent.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think an important insight here is:

  1. Shifting the mantissa right by one divides the value by two.
  2. Adding one to the exponent multiplies the value by two.
  3. If we simultaneously do (1) and (2), we don't change the value, but it helps match the operand's exponents for easier math.

)

# Perform subtraction of operands.
sub_exponent, sub_mantissa, sub_grs, num_leading_zeros = _sub_operands(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Naming seems a little inconsistent here, above _add_operands returns sum_exponent, while here _sub_operands returns sub_exponent. Seems like this should be called difference_exponent (addition results in a sum, subtraction results in a difference)?

larger_operand_exponent > larger_exponent_max_value
):
final_result_sign |= larger_operand_sign
if rounding_mode == RoundingMode.RNE:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's pretty counterintuitive that rounding_mode determines whether you get inf or largest_finite_number on overflow. A reference would be really useful here

# subtract the rounding increment from the absolute minimum exponent,
# which is one greater than the all-0s exponent (reserved for
# zero and denormals).
initial_larger_exponent_min_value = pyrtl.Const(1)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Seems like this initialization should be = None, since it always gets overwritten in the if statement below? (calling pyrtl.Const(1) creates a Const and adds it to the Block, but that Const is never used)

# Round the result if using RNE rounding mode.
if rounding_mode == RoundingMode.RNE:
(
raw_result_rounded_exponent,
Copy link
Member

@fdxmw fdxmw Feb 26, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It doesn't seem useful to give the rounded results new names (raw_result_rounded_exponent, etc), since we don't seem to use the not-rounded results (raw_result_exponent, etc). If this instead updated raw_result_exponent and friends, I think we could remove the conditional on L287 below, since we can always assign final_result_exponent |= raw_result_exponent ?

Also this function is quite long and there are a lot of variables, so it would really help readers to clarify variable lifetimes (which variables are still in use), and how groups of variables are related. Some ways to do this include:

  1. Avoid creating new variables when it makes sense to reuse an existing one (this suggestion)
  2. Try to define fewer variables with more internal structure, rather than defining a flat layer of top-level variables (raw_result.exponent rather than raw_result_exponent) (an earlier suggestion)
  3. Factor out more functions, which clearly defines the inputs and outputs for each chunk of code, and creates more local variable scopes
  4. Explicitly del variables that aren't needed anymore

):
final_result_sign |= larger_operand_sign
make_zero(final_result_exponent, final_result_mantissa)
with pyrtl.otherwise:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could use a comment like "Otherwise no special cases apply: this is the common case" :)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should be called _float_utils.py? (only one L in utilities :)

fp_prop: FPTypeProperties, wire: pyrtl.WireVector
) -> pyrtl.WireVector:
"""
Returns whether the floating point number is denormalized.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A reference that defines denormalized numbers would be very helpful. https://en.wikipedia.org/wiki/Subnormal_number seems good, but maybe you've found something better?

return out


def make_inf(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's a little confusing that make_denormals_zero returns a WireVector while make_inf conditionally assigns its argument WireVectors. Consistency is clearer, so try to be consistent :)

I'd recommend just returning WireVectors in all these methods because the implementation is much easier to understand. Very few people know what happens when you do a |= assignment outside of a conditional_assignment block.


# Denormalized numbers are not supported, so we flush them to zero.
operands = (operand_a, operand_b)
operands_daz = tuple(make_denormals_zero(fp_type_props, op) for op in operands)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I left a similar comment on _add_sub.py, but I'd redefine operands here rather than defining a new operands_daz since we don't seem to use operands anymore after this point.

Similarly, this is a big function, so think about how someone new to this code will read it, and look for ways to help them navigate and understand what's important.

Don't forget that you have the Curse of Knowledge: you wrote all this code, so it already makes perfect sense to you :) To really test your code's readability, you have to show it to other people.

total_bits = num_exp_bits + num_mant_bits + 1

# Denormalized numbers are not supported, so we flush them to zero.
operand_a_daz = make_denormals_zero(fp_type_props, operand_a)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Consider grouping the operands into a tuple, for consistency with the multiplication implementation, and removing redundant code. Maybe it should be a sorted tuple, so the smaller operand is operands[0] and we wouldn't need more names to track smaller/larger?


def decode_float16(bits):
"""Decode Float16 bits to sign, exponent, and mantissa."""
return (bits >> 15) & 1, (bits >> 10) & 0x1F, bits & 0x3FF
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Similarly, assert that bits fits in 16 bits

"""Tests for rounding in addition (RNE and RTZ)."""

def setUp(self):
pyrtl.reset_working_block()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

All these tests seem to use the same hardware, so it would help to factor this logic out to make it obvious that we're still using the same hardware, or to clearly point out any differences. Trying to determine this as a reader is not easy :) Some ideas:

  1. Just define all the tests that use the same hardware in the same class. You can add big section comments like
    ############################
    # Additional rounding tests.
    to group the different kinds of tests for people reading the code. I recommend this approach.
  2. Factor the hardware definition out into a function.
  3. Move the hardware definition to a base class.

self.assertTrue(is_nan(self.sim.inspect("result_rtz")))

def test_add_denormal_flushed_to_zero(self):
"""Test that denormal inputs are flushed to zero."""
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same nit about consistency, I've seen at least denormal, denorms, and denormalized in this PR. Try to pick one and use it consistently :)

def test_add_one_plus_two(self):
"""Test 1.0 + 2.0 = 3.0"""
self.sim.step({"a": FLOAT16_ONE, "b": FLOAT16_TWO})
self.assertEqual(self.sim.inspect("result_rne"), FLOAT16_THREE)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm sure you've run into this, but think about how someone would debug one of these assertEqual failures. I think you'd get a message that says

{big decimal number} is not equal to {other big decimal number}"

which is not very helpful.

Most likely, the first thing you'd do to debug a failure is to call decode_float16 to split out the sign, exponent, and mantissa on the unexpected value. Since this is the most likely first debug step, it'd really help the next person that works on this (which could be you! :) to integrate decode_float16 into these checks.

Maybe we can define a assertFloat16Equal helper that internally calls decode_float16 and separately calls assertEqual on the sign, exponent, and mantissa?

FLOAT16_NEG_ZERO = 0x8000
FLOAT16_POS_INF = 0x7C00
FLOAT16_NEG_INF = 0xFC00
FLOAT16_NAN = 0x7E00 # Quiet NaN with mantissa bit set
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants