autodiff: fixed test to be more precise for type tree checking

This commit is contained in:
Karan Janthe 2025-09-04 11:17:34 +00:00
parent 574f0b97d6
commit 4f3f0f48e7
19 changed files with 120 additions and 87 deletions

View File

@ -118,6 +118,13 @@ pub(crate) mod Enzyme_AD {
max_size: i64,
add_offset: u64,
);
pub(crate) fn EnzymeTypeTreeInsertEq(
CTT: CTypeTreeRef,
indices: *const i64,
len: usize,
ct: CConcreteType,
ctx: &Context,
);
pub(crate) fn EnzymeTypeTreeToString(arg1: CTypeTreeRef) -> *const c_char;
pub(crate) fn EnzymeTypeTreeToStringFree(arg1: *const c_char);
}
@ -234,6 +241,16 @@ pub(crate) mod Fallback_AD {
unimplemented!()
}
pub(crate) unsafe fn EnzymeTypeTreeInsertEq(
CTT: CTypeTreeRef,
indices: *const i64,
len: usize,
ct: CConcreteType,
ctx: &Context,
) {
unimplemented!()
}
pub(crate) unsafe fn EnzymeTypeTreeToString(arg1: CTypeTreeRef) -> *const c_char {
unimplemented!()
}
@ -312,6 +329,12 @@ impl TypeTree {
self
}
pub(crate) fn insert(&mut self, indices: &[i64], ct: CConcreteType, ctx: &Context) {
unsafe {
EnzymeTypeTreeInsertEq(self.inner, indices.as_ptr(), indices.len(), ct, ctx);
}
}
}
impl Clone for TypeTree {

View File

@ -8,22 +8,24 @@ use {
use crate::llvm::{self, Value};
/// Converts a Rust TypeTree to Enzyme's internal TypeTree format
///
/// This function takes a Rust-side TypeTree (from rustc_ast::expand::typetree)
/// and converts it to Enzyme's internal C++ TypeTree representation that
/// Enzyme can understand during differentiation analysis.
#[cfg(llvm_enzyme)]
fn to_enzyme_typetree(
rust_typetree: RustTypeTree,
data_layout: &str,
_data_layout: &str,
llcx: &llvm::Context,
) -> llvm::TypeTree {
// Start with an empty TypeTree
let mut enzyme_tt = llvm::TypeTree::new();
// Convert each Type in the Rust TypeTree to Enzyme format
for rust_type in rust_typetree.0 {
process_typetree_recursive(&mut enzyme_tt, &rust_typetree, &[], llcx);
enzyme_tt
}
#[cfg(llvm_enzyme)]
fn process_typetree_recursive(
enzyme_tt: &mut llvm::TypeTree,
rust_typetree: &RustTypeTree,
parent_indices: &[i64],
llcx: &llvm::Context,
) {
for rust_type in &rust_typetree.0 {
let concrete_type = match rust_type.kind {
rustc_ast::expand::typetree::Kind::Anything => llvm::CConcreteType::DT_Anything,
rustc_ast::expand::typetree::Kind::Integer => llvm::CConcreteType::DT_Integer,
@ -35,25 +37,27 @@ fn to_enzyme_typetree(
rustc_ast::expand::typetree::Kind::Unknown => llvm::CConcreteType::DT_Unknown,
};
// Create a TypeTree for this specific type
let type_tt = llvm::TypeTree::from_type(concrete_type, llcx);
// Apply offset if specified
let type_tt = if rust_type.offset == -1 {
type_tt // -1 means everywhere/no specific offset
let mut indices = parent_indices.to_vec();
if !parent_indices.is_empty() {
if rust_type.offset == -1 {
indices.push(-1);
} else {
indices.push(rust_type.offset as i64);
}
} else if rust_type.offset == -1 {
indices.push(-1);
} else {
// Apply specific offset positioning
type_tt.shift(data_layout, rust_type.offset, rust_type.size as isize, 0)
};
indices.push(rust_type.offset as i64);
}
// Merge this type into the main TypeTree
enzyme_tt = enzyme_tt.merge(type_tt);
enzyme_tt.insert(&indices, concrete_type, llcx);
if rust_type.kind == rustc_ast::expand::typetree::Kind::Pointer && !rust_type.child.0.is_empty() {
process_typetree_recursive(enzyme_tt, &rust_type.child, &indices, llcx);
}
}
enzyme_tt
}
// Attaches TypeTree information to LLVM function as enzyme_type attributes.
#[cfg(llvm_enzyme)]
pub(crate) fn add_tt<'ll>(
llmod: &'ll llvm::Module,
@ -64,28 +68,20 @@ pub(crate) fn add_tt<'ll>(
let inputs = tt.args;
let ret_tt: RustTypeTree = tt.ret;
// Get LLVM data layout string for TypeTree conversion
let llvm_data_layout: *const c_char = unsafe { llvm::LLVMGetDataLayoutStr(&*llmod) };
let llvm_data_layout =
std::str::from_utf8(unsafe { std::ffi::CStr::from_ptr(llvm_data_layout) }.to_bytes())
.expect("got a non-UTF8 data-layout from LLVM");
// Attribute name that Enzyme recognizes for TypeTree information
let attr_name = "enzyme_type";
let c_attr_name = CString::new(attr_name).unwrap();
// Attach TypeTree attributes to each input parameter
// Enzyme uses these to understand parameter memory layouts during differentiation
for (i, input) in inputs.iter().enumerate() {
unsafe {
// Convert Rust TypeTree to Enzyme's internal format
let enzyme_tt = to_enzyme_typetree(input.clone(), llvm_data_layout, llcx);
// Serialize TypeTree to string format that Enzyme can parse
let c_str = llvm::EnzymeTypeTreeToString(enzyme_tt.inner);
let c_str = std::ffi::CStr::from_ptr(c_str);
// Create LLVM string attribute with TypeTree information
let attr = llvm::LLVMCreateStringAttribute(
llcx,
c_attr_name.as_ptr(),
@ -94,17 +90,11 @@ pub(crate) fn add_tt<'ll>(
c_str.to_bytes().len() as c_uint,
);
// Attach attribute to the specific function parameter
// Note: ArgumentPlace uses 0-based indexing, but LLVM uses 1-based for arguments
attributes::apply_to_llfn(fn_def, llvm::AttributePlace::Argument(i as u32), &[attr]);
// Free the C string to prevent memory leaks
llvm::EnzymeTypeTreeToStringFree(c_str.as_ptr());
}
}
// Attach TypeTree attribute to the return type
// Enzyme needs this to understand how to handle return value derivatives
unsafe {
let enzyme_tt = to_enzyme_typetree(ret_tt, llvm_data_layout, llcx);
let c_str = llvm::EnzymeTypeTreeToString(enzyme_tt.inner);
@ -118,15 +108,11 @@ pub(crate) fn add_tt<'ll>(
c_str.to_bytes().len() as c_uint,
);
// Attach to function return type
attributes::apply_to_llfn(fn_def, llvm::AttributePlace::ReturnValue, &[ret_attr]);
// Free the C string
llvm::EnzymeTypeTreeToStringFree(c_str.as_ptr());
}
}
// Fallback implementation when Enzyme is not available
#[cfg(not(llvm_enzyme))]
pub(crate) fn add_tt<'ll>(
_llmod: &'ll llvm::Module,

View File

@ -2261,10 +2261,10 @@ pub fn typetree_from_ty<'tcx>(tcx: TyCtxt<'tcx>, ty: Ty<'tcx>) -> TypeTree {
x if x == tcx.types.f32 => (Kind::Float, 4),
x if x == tcx.types.f64 => (Kind::Double, 8),
x if x == tcx.types.f128 => (Kind::F128, 16),
_ => return TypeTree::new(),
_ => (Kind::Integer, 0),
}
} else {
return TypeTree::new();
(Kind::Integer, 0)
};
return TypeTree(vec![Type { offset: -1, size, kind, child: TypeTree::new() }]);
@ -2295,32 +2295,14 @@ pub fn typetree_from_ty<'tcx>(tcx: TyCtxt<'tcx>, ty: Ty<'tcx>) -> TypeTree {
let element_tree = typetree_from_ty(tcx, *element_ty);
let element_layout = tcx
.layout_of(ty::TypingEnv::fully_monomorphized().as_query_input(*element_ty))
.ok()
.map(|layout| layout.size.bytes_usize())
.unwrap_or(0);
if element_layout == 0 {
return TypeTree::new();
}
let mut types = Vec::new();
for i in 0..len {
let base_offset = (i as usize * element_layout) as isize;
for elem_type in &element_tree.0 {
types.push(Type {
offset: if elem_type.offset == -1 {
base_offset
} else {
base_offset + elem_type.offset
},
size: elem_type.size,
kind: elem_type.kind,
child: elem_type.child.clone(),
});
}
for elem_type in &element_tree.0 {
types.push(Type {
offset: -1,
size: elem_type.size,
kind: elem_type.kind,
child: elem_type.child.clone(),
});
}
return TypeTree(types);

View File

@ -1,4 +1,4 @@
; Check that array TypeTree metadata is correctly generated
; Should show Float@double at each array element offset (0, 8, 16, 24, 32 bytes)
CHECK: define{{.*}}"enzyme_type"="{[]:Float@double}"{{.*}}@test_array{{.*}}"enzyme_type"="{[]:Pointer}"
CHECK: define{{.*}}"enzyme_type"="{[-1]:Float@double}"{{.*}}@test_array{{.*}}"enzyme_type"="{[-1]:Pointer, [-1,-1]:Float@double}"

View File

@ -1,7 +1,7 @@
; Check that enzyme_type attributes are present in the LLVM IR function definition
; This verifies our TypeTree system correctly attaches metadata for Enzyme
CHECK: define{{.*}}"enzyme_type"="{[]:Float@double}"{{.*}}@test_memcpy({{.*}}"enzyme_type"="{[]:Pointer}"
CHECK: define{{.*}}"enzyme_type"="{[-1]:Float@double}"{{.*}}@test_memcpy({{.*}}"enzyme_type"="{[-1]:Pointer, [-1,-1]:Float@double}"
; Check that llvm.memcpy exists (either call or declare)
CHECK: {{(call|declare).*}}@llvm.memcpy

View File

@ -0,0 +1,2 @@
; Check that mixed struct with large array generates correct detailed type tree
CHECK: define{{.*}}"enzyme_type"="{[-1]:Float@float}"{{.*}}@test_mixed_struct{{.*}}"enzyme_type"="{[-1]:Pointer, [-1,0]:Integer, [-1,8]:Float@float}"

View File

@ -0,0 +1,16 @@
//@ needs-enzyme
//@ ignore-cross-compile
use run_make_support::{llvm_filecheck, rfs, rustc};
fn main() {
rustc()
.input("test.rs")
.arg("-Zautodiff=Enable")
.arg("-Zautodiff=NoPostopt")
.opt_level("0")
.emit("llvm-ir")
.run();
llvm_filecheck().patterns("mixed.check").stdin_buf(rfs::read("test.ll")).run();
}

View File

@ -0,0 +1,23 @@
#![feature(autodiff)]
use std::autodiff::autodiff_reverse;
#[repr(C)]
struct Container {
header: i64,
data: [f32; 1000],
}
#[autodiff_reverse(d_test, Duplicated, Active)]
#[no_mangle]
#[inline(never)]
fn test_mixed_struct(container: &Container) -> f32 {
container.data[0] + container.data[999]
}
fn main() {
let container = Container { header: 42, data: [1.0; 1000] };
let mut d_container = Container { header: 0, data: [0.0; 1000] };
let result = d_test(&container, &mut d_container, 1.0);
std::hint::black_box(result);
}

View File

@ -1,4 +1,4 @@
// Check that enzyme_type attributes are present when TypeTree is enabled
// This verifies our TypeTree metadata attachment is working
CHECK: define{{.*}}"enzyme_type"="{[]:Float@double}"{{.*}}@square{{.*}}"enzyme_type"="{[]:Pointer}"
CHECK: define{{.*}}"enzyme_type"="{[-1]:Float@double}"{{.*}}@square{{.*}}"enzyme_type"="{[-1]:Pointer, [-1,-1]:Float@double}"

View File

@ -1,4 +1,4 @@
; Check that f128 TypeTree metadata is correctly generated
; f128 maps to Unknown in our current implementation since CConcreteType doesn't have DT_F128
; Should show Float@fp128 for f128 values and Pointer for references
CHECK: define{{.*}}"enzyme_type"={{.*}}@test_f128{{.*}}"enzyme_type"="{[]:Pointer}"
CHECK: define{{.*}}"enzyme_type"="{[-1]:Float@fp128}"{{.*}}@test_f128{{.*}}"enzyme_type"="{[-1]:Pointer, [-1,-1]:Float@fp128}"

View File

@ -1,4 +1,4 @@
; Check that f16 TypeTree metadata is correctly generated
; Should show Half for f16 values and Pointer for references
; Should show Float@half for f16 values and Pointer for references
CHECK: define{{.*}}"enzyme_type"={{.*}}@test_f16{{.*}}"enzyme_type"="{[]:Pointer}"
CHECK: define{{.*}}"enzyme_type"="{[-1]:Float@half}"{{.*}}@test_f16{{.*}}"enzyme_type"="{[-1]:Pointer, [-1,-1]:Float@half}"

View File

@ -1,4 +1,4 @@
; Check that f32 TypeTree metadata is correctly generated
; Should show Float@float for f32 values and Pointer for references
CHECK: define{{.*}}"enzyme_type"="{[]:Float@float}"{{.*}}@test_f32{{.*}}"enzyme_type"="{[]:Pointer}"
CHECK: define{{.*}}"enzyme_type"="{[-1]:Float@float}"{{.*}}@test_f32{{.*}}"enzyme_type"="{[-1]:Pointer, [-1,-1]:Float@float}"

View File

@ -1,4 +1,4 @@
; Check that f64 TypeTree metadata is correctly generated
; Should show Float@double for f64 values and Pointer for references
CHECK: define{{.*}}"enzyme_type"="{[]:Float@double}"{{.*}}@test_f64{{.*}}"enzyme_type"="{[]:Pointer}"
CHECK: define{{.*}}"enzyme_type"="{[-1]:Float@double}"{{.*}}@test_f64{{.*}}"enzyme_type"="{[-1]:Pointer, [-1,-1]:Float@double}"

View File

@ -1,4 +1,4 @@
; Check that i32 TypeTree metadata is correctly generated
; Should show Integer for i32 values (integers are typically Const in autodiff)
; Should show Integer for i32 values and Pointer for references
CHECK: define{{.*}}"enzyme_type"="{[]:Integer}"{{.*}}@test_i32{{.*}}"enzyme_type"="{[]:Integer}"
CHECK: define{{.*}}"enzyme_type"="{[-1]:Integer}"{{.*}}@test_i32{{.*}}"enzyme_type"="{[-1]:Pointer, [-1,-1]:Integer}"

View File

@ -2,13 +2,14 @@
use std::autodiff::autodiff_reverse;
#[autodiff_reverse(d_test, Const, Active)]
#[autodiff_reverse(d_test, Duplicated, Active)]
#[no_mangle]
fn test_i32(x: i32) -> i32 {
fn test_i32(x: &i32) -> i32 {
x * x
}
fn main() {
let x = 5_i32;
let _result = d_test(x, 1);
let mut dx = 0_i32;
let _result = d_test(&x, &mut dx, 1);
}

View File

@ -1,4 +1,4 @@
; Check that slice TypeTree metadata is correctly generated
; Should show Float@double for slice elements
CHECK: define{{.*}}"enzyme_type"="{[]:Float@double}"{{.*}}@test_slice{{.*}}"enzyme_type"="{[]:Pointer}"
CHECK: define{{.*}}"enzyme_type"="{[-1]:Float@double}"{{.*}}@test_slice{{.*}}"enzyme_type"="{[-1]:Pointer, [-1,-1]:Float@double}"

View File

@ -1,4 +1,4 @@
; Check that struct TypeTree metadata is correctly generated
; Should show Float@double at offsets 0, 8, 16 for Point struct fields
CHECK: define{{.*}}"enzyme_type"="{[]:Float@double}"{{.*}}@test_struct{{.*}}"enzyme_type"="{[]:Pointer}"
CHECK: define{{.*}}"enzyme_type"="{[-1]:Float@double}"{{.*}}@test_struct{{.*}}"enzyme_type"="{[-1]:Pointer, [-1,0]:Float@double, [-1,8]:Float@double, [-1,16]:Float@double}"

View File

@ -1,4 +1,4 @@
; Check that tuple TypeTree metadata is correctly generated
; Should show Float@double at offsets 0, 8, 16 for (f64, f64, f64)
CHECK: define{{.*}}"enzyme_type"="{[]:Float@double}"{{.*}}@test_tuple{{.*}}"enzyme_type"="{[]:Pointer}"
CHECK: define{{.*}}"enzyme_type"="{[-1]:Float@double}"{{.*}}@test_tuple{{.*}}"enzyme_type"="{[-1]:Pointer, [-1,0]:Float@double, [-1,8]:Float@double, [-1,16]:Float@double}"

View File

@ -1,7 +1,7 @@
// CHECK: callee - {[-1]:Float@float} |{[-1]:Pointer}:{}
// CHECK: ptr %{{[0-9]+}}: {[-1]:Pointer}
// CHECK-DAG: %{{[0-9]+}} = getelementptr inbounds nuw i8, ptr %{{[0-9]+}}, i64 8, !dbg !{{[0-9]+}}: {[-1]:Pointer}
// CHECK-DAG: %{{[0-9]+}} = load ptr, ptr %{{[0-9]+}}, align 8, !dbg !{{[0-9]+}}, !nonnull !102, !noundef !{{[0-9]+}}: {}
// CHECK-DAG: %{{[0-9]+}} = load ptr, ptr %{{[0-9]+}}, align 8, !dbg !{{[0-9]+}}, !nonnull !{{[0-9]+}}, !noundef !{{[0-9]+}}: {}
// CHECK-DAG: %{{[0-9]+}} = getelementptr inbounds nuw i8, ptr %{{[0-9]+}}, i64 16, !dbg !{{[0-9]+}}: {[-1]:Pointer}
// CHECK-DAG: %{{[0-9]+}} = load i64, ptr %{{[0-9]+}}, align 8, !dbg !{{[0-9]+}}, !noundef !{{[0-9]+}}: {}
// CHECK-DAG: %{{[0-9]+}} = icmp eq i64 %{{[0-9]+}}, 0, !dbg !{{[0-9]+}}: {[-1]:Integer}