Rollup merge of #146735 - Qelxiros:const_mul_add, r=tgross35,RalfJung

unstably constify float mul_add methods

Tracking issue: rust-lang/rust#146724
r? `@tgross35`
This commit is contained in:
Stuart Cook 2025-09-25 20:31:54 +10:00 committed by GitHub
commit 8e62f95376
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
18 changed files with 118 additions and 135 deletions

View File

@ -636,6 +636,14 @@ impl<'tcx, M: Machine<'tcx>> InterpCx<'tcx, M> {
dest,
rustc_apfloat::Round::NearestTiesToEven,
)?,
sym::fmaf16 => self.fma_intrinsic::<Half>(args, dest)?,
sym::fmaf32 => self.fma_intrinsic::<Single>(args, dest)?,
sym::fmaf64 => self.fma_intrinsic::<Double>(args, dest)?,
sym::fmaf128 => self.fma_intrinsic::<Quad>(args, dest)?,
sym::fmuladdf16 => self.float_muladd_intrinsic::<Half>(args, dest)?,
sym::fmuladdf32 => self.float_muladd_intrinsic::<Single>(args, dest)?,
sym::fmuladdf64 => self.float_muladd_intrinsic::<Double>(args, dest)?,
sym::fmuladdf128 => self.float_muladd_intrinsic::<Quad>(args, dest)?,
// Unsupported intrinsic: skip the return_to_block below.
_ => return interp_ok(false),
@ -1035,4 +1043,42 @@ impl<'tcx, M: Machine<'tcx>> InterpCx<'tcx, M> {
self.write_scalar(res, dest)?;
interp_ok(())
}
fn fma_intrinsic<F>(
&mut self,
args: &[OpTy<'tcx, M::Provenance>],
dest: &PlaceTy<'tcx, M::Provenance>,
) -> InterpResult<'tcx, ()>
where
F: rustc_apfloat::Float + rustc_apfloat::FloatConvert<F> + Into<Scalar<M::Provenance>>,
{
let a: F = self.read_scalar(&args[0])?.to_float()?;
let b: F = self.read_scalar(&args[1])?.to_float()?;
let c: F = self.read_scalar(&args[2])?.to_float()?;
let res = a.mul_add(b, c).value;
let res = self.adjust_nan(res, &[a, b, c]);
self.write_scalar(res, dest)?;
interp_ok(())
}
fn float_muladd_intrinsic<F>(
&mut self,
args: &[OpTy<'tcx, M::Provenance>],
dest: &PlaceTy<'tcx, M::Provenance>,
) -> InterpResult<'tcx, ()>
where
F: rustc_apfloat::Float + rustc_apfloat::FloatConvert<F> + Into<Scalar<M::Provenance>>,
{
let a: F = self.read_scalar(&args[0])?.to_float()?;
let b: F = self.read_scalar(&args[1])?.to_float()?;
let c: F = self.read_scalar(&args[2])?.to_float()?;
let fuse = M::float_fuse_mul_add(self);
let res = if fuse { a.mul_add(b, c).value } else { ((a * b).value + c).value };
let res = self.adjust_nan(res, &[a, b, c]);
self.write_scalar(res, dest)?;
interp_ok(())
}
}

View File

@ -289,6 +289,9 @@ pub trait Machine<'tcx>: Sized {
a
}
/// Determines whether the `fmuladd` intrinsics fuse the multiply-add or use separate operations.
fn float_fuse_mul_add(_ecx: &mut InterpCx<'tcx, Self>) -> bool;
/// Called before a basic block terminator is executed.
#[inline]
fn before_terminator(_ecx: &mut InterpCx<'tcx, Self>) -> InterpResult<'tcx> {
@ -672,6 +675,11 @@ pub macro compile_time_machine(<$tcx: lifetime>) {
match fn_val {}
}
#[inline(always)]
fn float_fuse_mul_add(_ecx: &mut InterpCx<$tcx, Self>) -> bool {
true
}
#[inline(always)]
fn ub_checks(_ecx: &InterpCx<$tcx, Self>) -> InterpResult<$tcx, bool> {
// We can't look at `tcx.sess` here as that can differ across crates, which can lead to

View File

@ -1312,28 +1312,28 @@ pub fn log2f128(x: f128) -> f128;
/// [`f16::mul_add`](../../std/primitive.f16.html#method.mul_add)
#[rustc_intrinsic]
#[rustc_nounwind]
pub fn fmaf16(a: f16, b: f16, c: f16) -> f16;
pub const fn fmaf16(a: f16, b: f16, c: f16) -> f16;
/// Returns `a * b + c` for `f32` values.
///
/// The stabilized version of this intrinsic is
/// [`f32::mul_add`](../../std/primitive.f32.html#method.mul_add)
#[rustc_intrinsic]
#[rustc_nounwind]
pub fn fmaf32(a: f32, b: f32, c: f32) -> f32;
pub const fn fmaf32(a: f32, b: f32, c: f32) -> f32;
/// Returns `a * b + c` for `f64` values.
///
/// The stabilized version of this intrinsic is
/// [`f64::mul_add`](../../std/primitive.f64.html#method.mul_add)
#[rustc_intrinsic]
#[rustc_nounwind]
pub fn fmaf64(a: f64, b: f64, c: f64) -> f64;
pub const fn fmaf64(a: f64, b: f64, c: f64) -> f64;
/// Returns `a * b + c` for `f128` values.
///
/// The stabilized version of this intrinsic is
/// [`f128::mul_add`](../../std/primitive.f128.html#method.mul_add)
#[rustc_intrinsic]
#[rustc_nounwind]
pub fn fmaf128(a: f128, b: f128, c: f128) -> f128;
pub const fn fmaf128(a: f128, b: f128, c: f128) -> f128;
/// Returns `a * b + c` for `f16` values, non-deterministically executing
/// either a fused multiply-add or two operations with rounding of the
@ -1347,7 +1347,7 @@ pub fn fmaf128(a: f128, b: f128, c: f128) -> f128;
/// example.
#[rustc_intrinsic]
#[rustc_nounwind]
pub fn fmuladdf16(a: f16, b: f16, c: f16) -> f16;
pub const fn fmuladdf16(a: f16, b: f16, c: f16) -> f16;
/// Returns `a * b + c` for `f32` values, non-deterministically executing
/// either a fused multiply-add or two operations with rounding of the
/// intermediate result.
@ -1360,7 +1360,7 @@ pub fn fmuladdf16(a: f16, b: f16, c: f16) -> f16;
/// example.
#[rustc_intrinsic]
#[rustc_nounwind]
pub fn fmuladdf32(a: f32, b: f32, c: f32) -> f32;
pub const fn fmuladdf32(a: f32, b: f32, c: f32) -> f32;
/// Returns `a * b + c` for `f64` values, non-deterministically executing
/// either a fused multiply-add or two operations with rounding of the
/// intermediate result.
@ -1373,7 +1373,7 @@ pub fn fmuladdf32(a: f32, b: f32, c: f32) -> f32;
/// example.
#[rustc_intrinsic]
#[rustc_nounwind]
pub fn fmuladdf64(a: f64, b: f64, c: f64) -> f64;
pub const fn fmuladdf64(a: f64, b: f64, c: f64) -> f64;
/// Returns `a * b + c` for `f128` values, non-deterministically executing
/// either a fused multiply-add or two operations with rounding of the
/// intermediate result.
@ -1386,7 +1386,7 @@ pub fn fmuladdf64(a: f64, b: f64, c: f64) -> f64;
/// example.
#[rustc_intrinsic]
#[rustc_nounwind]
pub fn fmuladdf128(a: f128, b: f128, c: f128) -> f128;
pub const fn fmuladdf128(a: f128, b: f128, c: f128) -> f128;
/// Returns the largest integer less than or equal to an `f16`.
///

View File

@ -1659,7 +1659,8 @@ impl f128 {
#[doc(alias = "fmaf128", alias = "fusedMultiplyAdd")]
#[unstable(feature = "f128", issue = "116909")]
#[must_use = "method returns a new number and does not mutate the original value"]
pub fn mul_add(self, a: f128, b: f128) -> f128 {
#[rustc_const_unstable(feature = "const_mul_add", issue = "146724")]
pub const fn mul_add(self, a: f128, b: f128) -> f128 {
intrinsics::fmaf128(self, a, b)
}

View File

@ -1634,7 +1634,8 @@ impl f16 {
#[unstable(feature = "f16", issue = "116909")]
#[doc(alias = "fmaf16", alias = "fusedMultiplyAdd")]
#[must_use = "method returns a new number and does not mutate the original value"]
pub fn mul_add(self, a: f16, b: f16) -> f16 {
#[rustc_const_unstable(feature = "const_mul_add", issue = "146724")]
pub const fn mul_add(self, a: f16, b: f16) -> f16 {
intrinsics::fmaf16(self, a, b)
}

View File

@ -1798,7 +1798,8 @@ pub mod math {
#[doc(alias = "fmaf", alias = "fusedMultiplyAdd")]
#[must_use = "method returns a new number and does not mutate the original value"]
#[unstable(feature = "core_float_math", issue = "137578")]
pub fn mul_add(x: f32, y: f32, z: f32) -> f32 {
#[rustc_const_unstable(feature = "const_mul_add", issue = "146724")]
pub const fn mul_add(x: f32, y: f32, z: f32) -> f32 {
intrinsics::fmaf32(x, y, z)
}

View File

@ -1796,7 +1796,8 @@ pub mod math {
#[doc(alias = "fma", alias = "fusedMultiplyAdd")]
#[unstable(feature = "core_float_math", issue = "137578")]
#[must_use = "method returns a new number and does not mutate the original value"]
pub fn mul_add(x: f64, a: f64, b: f64) -> f64 {
#[rustc_const_unstable(feature = "const_mul_add", issue = "146724")]
pub const fn mul_add(x: f64, a: f64, b: f64) -> f64 {
intrinsics::fmaf64(x, a, b)
}

View File

@ -20,24 +20,6 @@ const TOL_PRECISE: f128 = 1e-28;
// FIXME(f16_f128,miri): many of these have to be disabled since miri does not yet support
// the intrinsics.
#[test]
#[cfg(not(miri))]
#[cfg(target_has_reliable_f128_math)]
fn test_mul_add() {
let nan: f128 = f128::NAN;
let inf: f128 = f128::INFINITY;
let neg_inf: f128 = f128::NEG_INFINITY;
assert_biteq!(12.3f128.mul_add(4.5, 6.7), 62.0500000000000000000000000000000037);
assert_biteq!((-12.3f128).mul_add(-4.5, -6.7), 48.6500000000000000000000000000000049);
assert_biteq!(0.0f128.mul_add(8.9, 1.2), 1.2);
assert_biteq!(3.4f128.mul_add(-0.0, 5.6), 5.6);
assert!(nan.mul_add(7.8, 9.0).is_nan());
assert_biteq!(inf.mul_add(7.8, 9.0), inf);
assert_biteq!(neg_inf.mul_add(7.8, 9.0), neg_inf);
assert_biteq!(8.9f128.mul_add(inf, 3.2), inf);
assert_biteq!((-3.2f128).mul_add(2.4, neg_inf), neg_inf);
}
#[test]
#[cfg(any(miri, target_has_reliable_f128_math))]
fn test_max_recip() {

View File

@ -22,24 +22,6 @@ const TOL_P4: f16 = 10.0;
// FIXME(f16_f128,miri): many of these have to be disabled since miri does not yet support
// the intrinsics.
#[test]
#[cfg(not(miri))]
#[cfg(target_has_reliable_f16_math)]
fn test_mul_add() {
let nan: f16 = f16::NAN;
let inf: f16 = f16::INFINITY;
let neg_inf: f16 = f16::NEG_INFINITY;
assert_biteq!(12.3f16.mul_add(4.5, 6.7), 62.031);
assert_biteq!((-12.3f16).mul_add(-4.5, -6.7), 48.625);
assert_biteq!(0.0f16.mul_add(8.9, 1.2), 1.2);
assert_biteq!(3.4f16.mul_add(-0.0, 5.6), 5.6);
assert!(nan.mul_add(7.8, 9.0).is_nan());
assert_biteq!(inf.mul_add(7.8, 9.0), inf);
assert_biteq!(neg_inf.mul_add(7.8, 9.0), neg_inf);
assert_biteq!(8.9f16.mul_add(inf, 3.2), inf);
assert_biteq!((-3.2f16).mul_add(2.4, neg_inf), neg_inf);
}
#[test]
#[cfg(any(miri, target_has_reliable_f16_math))]
fn test_max_recip() {

View File

@ -1,21 +0,0 @@
use core::f32;
use super::assert_biteq;
// FIXME(#140515): mingw has an incorrect fma https://sourceforge.net/p/mingw-w64/bugs/848/
#[cfg_attr(all(target_os = "windows", target_env = "gnu", not(target_abi = "llvm")), ignore)]
#[test]
fn test_mul_add() {
let nan: f32 = f32::NAN;
let inf: f32 = f32::INFINITY;
let neg_inf: f32 = f32::NEG_INFINITY;
assert_biteq!(f32::math::mul_add(12.3f32, 4.5, 6.7), 62.05);
assert_biteq!(f32::math::mul_add(-12.3f32, -4.5, -6.7), 48.65);
assert_biteq!(f32::math::mul_add(0.0f32, 8.9, 1.2), 1.2);
assert_biteq!(f32::math::mul_add(3.4f32, -0.0, 5.6), 5.6);
assert!(f32::math::mul_add(nan, 7.8, 9.0).is_nan());
assert_biteq!(f32::math::mul_add(inf, 7.8, 9.0), inf);
assert_biteq!(f32::math::mul_add(neg_inf, 7.8, 9.0), neg_inf);
assert_biteq!(f32::math::mul_add(8.9f32, inf, 3.2), inf);
assert_biteq!(f32::math::mul_add(-3.2f32, 2.4, neg_inf), neg_inf);
}

View File

@ -1,21 +0,0 @@
use core::f64;
use super::assert_biteq;
// FIXME(#140515): mingw has an incorrect fma https://sourceforge.net/p/mingw-w64/bugs/848/
#[cfg_attr(all(target_os = "windows", target_env = "gnu", not(target_abi = "llvm")), ignore)]
#[test]
fn test_mul_add() {
let nan: f64 = f64::NAN;
let inf: f64 = f64::INFINITY;
let neg_inf: f64 = f64::NEG_INFINITY;
assert_biteq!(12.3f64.mul_add(4.5, 6.7), 62.050000000000004);
assert_biteq!((-12.3f64).mul_add(-4.5, -6.7), 48.650000000000006);
assert_biteq!(0.0f64.mul_add(8.9, 1.2), 1.2);
assert_biteq!(3.4f64.mul_add(-0.0, 5.6), 5.6);
assert!(nan.mul_add(7.8, 9.0).is_nan());
assert_biteq!(inf.mul_add(7.8, 9.0), inf);
assert_biteq!(neg_inf.mul_add(7.8, 9.0), neg_inf);
assert_biteq!(8.9f64.mul_add(inf, 3.2), inf);
assert_biteq!((-3.2f64).mul_add(2.4, neg_inf), neg_inf);
}

View File

@ -34,6 +34,10 @@ trait TestableFloat: Sized {
const RAW_12_DOT_5: Self;
const RAW_1337: Self;
const RAW_MINUS_14_DOT_25: Self;
/// The result of 12.3.mul_add(4.5, 6.7)
const MUL_ADD_RESULT: Self;
/// The result of (-12.3).mul_add(-4.5, -6.7)
const NEG_MUL_ADD_RESULT: Self;
}
impl TestableFloat for f16 {
@ -58,6 +62,8 @@ impl TestableFloat for f16 {
const RAW_12_DOT_5: Self = Self::from_bits(0x4a40);
const RAW_1337: Self = Self::from_bits(0x6539);
const RAW_MINUS_14_DOT_25: Self = Self::from_bits(0xcb20);
const MUL_ADD_RESULT: Self = 62.031;
const NEG_MUL_ADD_RESULT: Self = 48.625;
}
impl TestableFloat for f32 {
@ -84,6 +90,8 @@ impl TestableFloat for f32 {
const RAW_12_DOT_5: Self = Self::from_bits(0x41480000);
const RAW_1337: Self = Self::from_bits(0x44a72000);
const RAW_MINUS_14_DOT_25: Self = Self::from_bits(0xc1640000);
const MUL_ADD_RESULT: Self = 62.05;
const NEG_MUL_ADD_RESULT: Self = 48.65;
}
impl TestableFloat for f64 {
@ -106,6 +114,8 @@ impl TestableFloat for f64 {
const RAW_12_DOT_5: Self = Self::from_bits(0x4029000000000000);
const RAW_1337: Self = Self::from_bits(0x4094e40000000000);
const RAW_MINUS_14_DOT_25: Self = Self::from_bits(0xc02c800000000000);
const MUL_ADD_RESULT: Self = 62.050000000000004;
const NEG_MUL_ADD_RESULT: Self = 48.650000000000006;
}
impl TestableFloat for f128 {
@ -128,6 +138,8 @@ impl TestableFloat for f128 {
const RAW_12_DOT_5: Self = Self::from_bits(0x40029000000000000000000000000000);
const RAW_1337: Self = Self::from_bits(0x40094e40000000000000000000000000);
const RAW_MINUS_14_DOT_25: Self = Self::from_bits(0xc002c800000000000000000000000000);
const MUL_ADD_RESULT: Self = 62.0500000000000000000000000000000037;
const NEG_MUL_ADD_RESULT: Self = 48.6500000000000000000000000000000049;
}
/// Determine the tolerance for values of the argument type.
@ -359,8 +371,6 @@ macro_rules! float_test {
mod f128;
mod f16;
mod f32;
mod f64;
float_test! {
name: num,
@ -1541,3 +1551,28 @@ float_test! {
assert_biteq!(Float::from_bits(masked_nan2), Float::from_bits(masked_nan2));
}
}
float_test! {
name: mul_add,
attrs: {
f16: #[cfg(any(miri, target_has_reliable_f16))],
// FIXME(#140515): mingw has an incorrect fma https://sourceforge.net/p/mingw-w64/bugs/848/
f32: #[cfg_attr(all(target_os = "windows", target_env = "gnu", not(target_abi = "llvm")), ignore)],
f64: #[cfg_attr(all(target_os = "windows", target_env = "gnu", not(target_abi = "llvm")), ignore)],
f128: #[cfg(any(miri, target_has_reliable_f128))],
},
test<Float> {
let nan: Float = Float::NAN;
let inf: Float = Float::INFINITY;
let neg_inf: Float = Float::NEG_INFINITY;
assert_biteq!(flt(12.3).mul_add(4.5, 6.7), Float::MUL_ADD_RESULT);
assert_biteq!((flt(-12.3)).mul_add(-4.5, -6.7), Float::NEG_MUL_ADD_RESULT);
assert_biteq!(flt(0.0).mul_add(8.9, 1.2), 1.2);
assert_biteq!(flt(3.4).mul_add(-0.0, 5.6), 5.6);
assert!(nan.mul_add(7.8, 9.0).is_nan());
assert_biteq!(inf.mul_add(7.8, 9.0), inf);
assert_biteq!(neg_inf.mul_add(7.8, 9.0), neg_inf);
assert_biteq!(flt(8.9).mul_add(inf, 3.2), inf);
assert_biteq!((flt(-3.2)).mul_add(2.4, neg_inf), neg_inf);
}
}

View File

@ -20,6 +20,7 @@
#![feature(const_convert)]
#![feature(const_destruct)]
#![feature(const_eval_select)]
#![feature(const_mul_add)]
#![feature(const_ops)]
#![feature(const_option_ops)]
#![feature(const_ref_cell)]

View File

@ -332,6 +332,7 @@
#![feature(char_internals)]
#![feature(clone_to_uninit)]
#![feature(const_convert)]
#![feature(const_mul_add)]
#![feature(core_intrinsics)]
#![feature(core_io_borrowed_buf)]
#![feature(drop_guard)]

View File

@ -217,7 +217,8 @@ impl f32 {
#[must_use = "method returns a new number and does not mutate the original value"]
#[stable(feature = "rust1", since = "1.0.0")]
#[inline]
pub fn mul_add(self, a: f32, b: f32) -> f32 {
#[rustc_const_unstable(feature = "const_mul_add", issue = "146724")]
pub const fn mul_add(self, a: f32, b: f32) -> f32 {
core::f32::math::mul_add(self, a, b)
}

View File

@ -217,7 +217,8 @@ impl f64 {
#[must_use = "method returns a new number and does not mutate the original value"]
#[stable(feature = "rust1", since = "1.0.0")]
#[inline]
pub fn mul_add(self, a: f64, b: f64) -> f64 {
#[rustc_const_unstable(feature = "const_mul_add", issue = "146724")]
pub const fn mul_add(self, a: f64, b: f64) -> f64 {
core::f64::math::mul_add(self, a, b)
}

View File

@ -1,4 +1,3 @@
use rand::Rng;
use rustc_apfloat::{self, Float, FloatConvert, Round};
use rustc_middle::mir;
use rustc_middle::ty::{self, FloatTy};
@ -39,46 +38,6 @@ pub trait EvalContextExt<'tcx>: crate::MiriInterpCxExt<'tcx> {
"sqrtf64" => sqrt::<rustc_apfloat::ieee::Double>(this, args, dest)?,
"sqrtf128" => sqrt::<rustc_apfloat::ieee::Quad>(this, args, dest)?,
"fmaf32" => {
let [a, b, c] = check_intrinsic_arg_count(args)?;
let a = this.read_scalar(a)?.to_f32()?;
let b = this.read_scalar(b)?.to_f32()?;
let c = this.read_scalar(c)?.to_f32()?;
let res = a.mul_add(b, c).value;
let res = this.adjust_nan(res, &[a, b, c]);
this.write_scalar(res, dest)?;
}
"fmaf64" => {
let [a, b, c] = check_intrinsic_arg_count(args)?;
let a = this.read_scalar(a)?.to_f64()?;
let b = this.read_scalar(b)?.to_f64()?;
let c = this.read_scalar(c)?.to_f64()?;
let res = a.mul_add(b, c).value;
let res = this.adjust_nan(res, &[a, b, c]);
this.write_scalar(res, dest)?;
}
"fmuladdf32" => {
let [a, b, c] = check_intrinsic_arg_count(args)?;
let a = this.read_scalar(a)?.to_f32()?;
let b = this.read_scalar(b)?.to_f32()?;
let c = this.read_scalar(c)?.to_f32()?;
let fuse: bool = this.machine.float_nondet && this.machine.rng.get_mut().random();
let res = if fuse { a.mul_add(b, c).value } else { ((a * b).value + c).value };
let res = this.adjust_nan(res, &[a, b, c]);
this.write_scalar(res, dest)?;
}
"fmuladdf64" => {
let [a, b, c] = check_intrinsic_arg_count(args)?;
let a = this.read_scalar(a)?.to_f64()?;
let b = this.read_scalar(b)?.to_f64()?;
let c = this.read_scalar(c)?.to_f64()?;
let fuse: bool = this.machine.float_nondet && this.machine.rng.get_mut().random();
let res = if fuse { a.mul_add(b, c).value } else { ((a * b).value + c).value };
let res = this.adjust_nan(res, &[a, b, c]);
this.write_scalar(res, dest)?;
}
#[rustfmt::skip]
| "fadd_fast"
| "fsub_fast"

View File

@ -1293,6 +1293,11 @@ impl<'tcx> Machine<'tcx> for MiriMachine<'tcx> {
ecx.equal_float_min_max(a, b)
}
#[inline(always)]
fn float_fuse_mul_add(ecx: &mut InterpCx<'tcx, Self>) -> bool {
ecx.machine.float_nondet && ecx.machine.rng.get_mut().random()
}
#[inline(always)]
fn ub_checks(ecx: &InterpCx<'tcx, Self>) -> InterpResult<'tcx, bool> {
interp_ok(ecx.tcx.sess.ub_checks())