[Mlir-commits] [mlir] 3359806 - [mlir][LLVM][MemRef] Lower assume_alignment with operand bundles (#117800)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Tue Nov 26 15:20:42 PST 2024
Author: Krzysztof Drewniak
Date: 2024-11-26T17:20:39-06:00
New Revision: 3359806817d5d3a600e3f0bdae60ac3df1c85e7f
URL: https://github.com/llvm/llvm-project/commit/3359806817d5d3a600e3f0bdae60ac3df1c85e7f
DIFF: https://github.com/llvm/llvm-project/commit/3359806817d5d3a600e3f0bdae60ac3df1c85e7f.diff
LOG: [mlir][LLVM][MemRef] Lower assume_alignment with operand bundles (#117800)
Now that LLVM allows a operand bundle on assume calls to directly
specify alignment assumptions, change the lowering of
memref.assume_alignment to use that feature instead of the ptrtoint
method.
This makes LLVM's job easier and prevents issues when dealing with
cases where ptrtoint isn't a desired operation (like those with poiner
provenance)
Added:
Modified:
mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp
mlir/test/Conversion/MemRefToLLVM/expand-then-convert-to-llvm.mlir
mlir/test/Conversion/MemRefToLLVM/memref-to-llvm.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp b/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp
index 4bfa536cc8a44a..86f687d7f2636e 100644
--- a/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp
+++ b/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp
@@ -192,22 +192,15 @@ struct AssumeAlignmentOpLowering
Value ptr = getStridedElementPtr(loc, srcMemRefType, memref, /*indices=*/{},
rewriter);
- // Emit llvm.assume(memref & (alignment - 1) == 0).
- //
- // This relies on LLVM's CSE optimization (potentially after SROA), since
- // after CSE all memref instances should get de-duplicated into the same
- // pointer SSA value.
- MemRefDescriptor memRefDescriptor(memref);
- auto intPtrType =
- getIntPtrType(memRefDescriptor.getElementPtrType().getAddressSpace());
- Value zero = createIndexAttrConstant(rewriter, loc, intPtrType, 0);
- Value mask =
- createIndexAttrConstant(rewriter, loc, intPtrType, alignment - 1);
- Value ptrValue = rewriter.create<LLVM::PtrToIntOp>(loc, intPtrType, ptr);
- rewriter.create<LLVM::AssumeOp>(
- loc, rewriter.create<LLVM::ICmpOp>(
- loc, LLVM::ICmpPredicate::eq,
- rewriter.create<LLVM::AndOp>(loc, ptrValue, mask), zero));
+ // Emit llvm.assume(true) ["align"(memref, alignment)].
+ // This is more direct than ptrtoint-based checks, is explicitly supported,
+ // and works with non-integral address spaces.
+ Value trueCond =
+ rewriter.create<LLVM::ConstantOp>(loc, rewriter.getBoolAttr(true));
+ Value alignmentConst =
+ createIndexAttrConstant(rewriter, loc, getIndexType(), alignment);
+ rewriter.create<LLVM::AssumeOp>(loc, trueCond, LLVM::AssumeAlignTag(), ptr,
+ alignmentConst);
rewriter.eraseOp(op);
return success();
diff --git a/mlir/test/Conversion/MemRefToLLVM/expand-then-convert-to-llvm.mlir b/mlir/test/Conversion/MemRefToLLVM/expand-then-convert-to-llvm.mlir
index ec5ceae57ccb33..a78db9733b7eef 100644
--- a/mlir/test/Conversion/MemRefToLLVM/expand-then-convert-to-llvm.mlir
+++ b/mlir/test/Conversion/MemRefToLLVM/expand-then-convert-to-llvm.mlir
@@ -675,10 +675,7 @@ func.func @collapse_static_shape_with_non_identity_layout(%arg: memref<1x1x8x8xf
// CHECK: %[[ALIGNED_PTR:.*]] = llvm.extractvalue %[[DESC]][1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
// CHECK: %[[OFFSET:.*]] = llvm.extractvalue %[[DESC]][2] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
// CHECK: %[[BUFF_ADDR:.*]] = llvm.getelementptr %[[ALIGNED_PTR]][%[[OFFSET]]] : (!llvm.ptr, i64) -> !llvm.ptr, f32
-// CHECK: %[[INT_TO_PTR:.*]] = llvm.ptrtoint %[[BUFF_ADDR]] : !llvm.ptr to i64
-// CHECK: %[[AND:.*]] = llvm.and %[[INT_TO_PTR]], {{.*}} : i64
-// CHECK: %[[CMP:.*]] = llvm.icmp "eq" %[[AND]], {{.*}} : i64
-// CHECK: llvm.intr.assume %[[CMP]] : i1
+// CHECK: llvm.intr.assume %{{.*}} ["align"(%[[BUFF_ADDR]], %{{.*}} : !llvm.ptr, i64)] : i1
// CHECK: %[[LD_ADDR:.*]] = llvm.getelementptr %[[BUFF_ADDR]][%{{.*}}] : (!llvm.ptr, i64) -> !llvm.ptr, f32
// CHECK: %[[VAL:.*]] = llvm.load %[[LD_ADDR]] : !llvm.ptr -> f32
// CHECK: return %[[VAL]] : f32
diff --git a/mlir/test/Conversion/MemRefToLLVM/memref-to-llvm.mlir b/mlir/test/Conversion/MemRefToLLVM/memref-to-llvm.mlir
index 48dc9079333d4f..67b68b7a1c0444 100644
--- a/mlir/test/Conversion/MemRefToLLVM/memref-to-llvm.mlir
+++ b/mlir/test/Conversion/MemRefToLLVM/memref-to-llvm.mlir
@@ -155,12 +155,9 @@ func.func @subview(%0 : memref<64x4xf32, strided<[4, 1], offset: 0>>, %arg0 : in
// CHECK-LABEL: func @assume_alignment(
func.func @assume_alignment(%0 : memref<4x4xf16>) {
// CHECK: %[[PTR:.*]] = llvm.extractvalue %[[MEMREF:.*]][1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
- // CHECK-NEXT: %[[ZERO:.*]] = llvm.mlir.constant(0 : index) : i64
- // CHECK-NEXT: %[[MASK:.*]] = llvm.mlir.constant(15 : index) : i64
- // CHECK-NEXT: %[[INT:.*]] = llvm.ptrtoint %[[PTR]] : !llvm.ptr to i64
- // CHECK-NEXT: %[[MASKED_PTR:.*]] = llvm.and %[[INT]], %[[MASK:.*]] : i64
- // CHECK-NEXT: %[[CONDITION:.*]] = llvm.icmp "eq" %[[MASKED_PTR]], %[[ZERO]] : i64
- // CHECK-NEXT: llvm.intr.assume %[[CONDITION]] : i1
+ // CHECK-NEXT: %[[TRUE:.*]] = llvm.mlir.constant(true) : i1
+ // CHECK-NEXT: %[[ALIGN:.*]] = llvm.mlir.constant(16 : index) : i64
+ // CHECK-NEXT: llvm.intr.assume %[[TRUE]] ["align"(%[[PTR]], %[[ALIGN]] : !llvm.ptr, i64)] : i1
memref.assume_alignment %0, 16 : memref<4x4xf16>
return
}
@@ -172,12 +169,9 @@ func.func @assume_alignment_w_offset(%0 : memref<4x4xf16, strided<[?, ?], offset
// CHECK-DAG: %[[PTR:.*]] = llvm.extractvalue %[[MEMREF:.*]][1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
// CHECK-DAG: %[[OFFSET:.*]] = llvm.extractvalue %[[MEMREF]][2] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
// CHECK-DAG: %[[BUFF_ADDR:.*]] = llvm.getelementptr %[[PTR]][%[[OFFSET]]] : (!llvm.ptr, i64) -> !llvm.ptr, f16
- // CHECK-DAG: %[[ZERO:.*]] = llvm.mlir.constant(0 : index) : i64
- // CHECK-DAG: %[[MASK:.*]] = llvm.mlir.constant(15 : index) : i64
- // CHECK-NEXT: %[[INT:.*]] = llvm.ptrtoint %[[BUFF_ADDR]] : !llvm.ptr to i64
- // CHECK-NEXT: %[[MASKED_PTR:.*]] = llvm.and %[[INT]], %[[MASK:.*]] : i64
- // CHECK-NEXT: %[[CONDITION:.*]] = llvm.icmp "eq" %[[MASKED_PTR]], %[[ZERO]] : i64
- // CHECK-NEXT: llvm.intr.assume %[[CONDITION]] : i1
+ // CHECK-DAG: %[[TRUE:.*]] = llvm.mlir.constant(true) : i1
+ // CHECK-DAG: %[[ALIGN:.*]] = llvm.mlir.constant(16 : index) : i64
+ // CHECK-NEXT: llvm.intr.assume %[[TRUE]] ["align"(%[[BUFF_ADDR]], %[[ALIGN]] : !llvm.ptr, i64)] : i1
memref.assume_alignment %0, 16 : memref<4x4xf16, strided<[?, ?], offset: ?>>
return
}
@@ -410,7 +404,7 @@ func.func @atomic_rmw_with_offset(%I : memref<10xi32, strided<[1], offset: 5>>,
// CHECK-SAME: %[[ARG2:.+]]: index
// CHECK-DAG: %[[MEMREF_STRUCT:.+]] = builtin.unrealized_conversion_cast %[[ARG0]] : memref<10xi32, strided<[1], offset: 5>> to !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
// CHECK-DAG: %[[INDEX:.+]] = builtin.unrealized_conversion_cast %[[ARG2]] : index to i64
-// CHECK: %[[BASE_PTR:.+]] = llvm.extractvalue %[[MEMREF_STRUCT]][1] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
+// CHECK: %[[BASE_PTR:.+]] = llvm.extractvalue %[[MEMREF_STRUCT]][1] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
// CHECK: %[[OFFSET:.+]] = llvm.mlir.constant(5 : index) : i64
// CHECK: %[[OFFSET_PTR:.+]] = llvm.getelementptr %[[BASE_PTR]][%[[OFFSET]]] : (!llvm.ptr, i64) -> !llvm.ptr, i32
// CHECK: %[[PTR:.+]] = llvm.getelementptr %[[OFFSET_PTR]][%[[INDEX]]] : (!llvm.ptr, i64) -> !llvm.ptr, i32
@@ -601,7 +595,7 @@ func.func @extract_aligned_pointer_as_index(%m: memref<?xf32>) -> index {
// CHECK-LABEL: func @extract_aligned_pointer_as_index_unranked
func.func @extract_aligned_pointer_as_index_unranked(%m: memref<*xf32>) -> index {
%0 = memref.extract_aligned_pointer_as_index %m: memref<*xf32> -> index
- // CHECK: %[[PTR:.*]] = llvm.extractvalue %{{.*}}[1] : !llvm.struct<(i64, ptr)>
+ // CHECK: %[[PTR:.*]] = llvm.extractvalue %{{.*}}[1] : !llvm.struct<(i64, ptr)>
// CHECK: %[[ALIGNED_FIELD:.*]] = llvm.getelementptr %[[PTR]][1] : (!llvm.ptr) -> !llvm.ptr, !llvm.ptr
// CHECK: %[[ALIGNED_PTR:.*]] = llvm.load %[[ALIGNED_FIELD]] : !llvm.ptr -> !llvm.ptr
// CHECK: %[[I64:.*]] = llvm.ptrtoint %[[ALIGNED_PTR]] : !llvm.ptr to i64
More information about the Mlir-commits
mailing list