[Mlir-commits] [mlir] 3359806 - [mlir][LLVM][MemRef] Lower assume_alignment with operand bundles (#117800)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Tue Nov 26 15:20:42 PST 2024


Author: Krzysztof Drewniak
Date: 2024-11-26T17:20:39-06:00
New Revision: 3359806817d5d3a600e3f0bdae60ac3df1c85e7f

URL: https://github.com/llvm/llvm-project/commit/3359806817d5d3a600e3f0bdae60ac3df1c85e7f
DIFF: https://github.com/llvm/llvm-project/commit/3359806817d5d3a600e3f0bdae60ac3df1c85e7f.diff

LOG: [mlir][LLVM][MemRef] Lower assume_alignment with operand bundles (#117800)

Now that LLVM allows a operand bundle on assume calls to directly
specify alignment assumptions, change the lowering of
memref.assume_alignment to use that feature instead of the ptrtoint
method.

This makes LLVM's job easier and prevents issues when dealing with
cases where ptrtoint isn't a desired operation (like those with poiner
provenance)

Added: 
    

Modified: 
    mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp
    mlir/test/Conversion/MemRefToLLVM/expand-then-convert-to-llvm.mlir
    mlir/test/Conversion/MemRefToLLVM/memref-to-llvm.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp b/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp
index 4bfa536cc8a44a..86f687d7f2636e 100644
--- a/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp
+++ b/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp
@@ -192,22 +192,15 @@ struct AssumeAlignmentOpLowering
     Value ptr = getStridedElementPtr(loc, srcMemRefType, memref, /*indices=*/{},
                                      rewriter);
 
-    // Emit llvm.assume(memref & (alignment - 1) == 0).
-    //
-    // This relies on LLVM's CSE optimization (potentially after SROA), since
-    // after CSE all memref instances should get de-duplicated into the same
-    // pointer SSA value.
-    MemRefDescriptor memRefDescriptor(memref);
-    auto intPtrType =
-        getIntPtrType(memRefDescriptor.getElementPtrType().getAddressSpace());
-    Value zero = createIndexAttrConstant(rewriter, loc, intPtrType, 0);
-    Value mask =
-        createIndexAttrConstant(rewriter, loc, intPtrType, alignment - 1);
-    Value ptrValue = rewriter.create<LLVM::PtrToIntOp>(loc, intPtrType, ptr);
-    rewriter.create<LLVM::AssumeOp>(
-        loc, rewriter.create<LLVM::ICmpOp>(
-                 loc, LLVM::ICmpPredicate::eq,
-                 rewriter.create<LLVM::AndOp>(loc, ptrValue, mask), zero));
+    // Emit llvm.assume(true) ["align"(memref, alignment)].
+    // This is more direct than ptrtoint-based checks, is explicitly supported,
+    // and works with non-integral address spaces.
+    Value trueCond =
+        rewriter.create<LLVM::ConstantOp>(loc, rewriter.getBoolAttr(true));
+    Value alignmentConst =
+        createIndexAttrConstant(rewriter, loc, getIndexType(), alignment);
+    rewriter.create<LLVM::AssumeOp>(loc, trueCond, LLVM::AssumeAlignTag(), ptr,
+                                    alignmentConst);
 
     rewriter.eraseOp(op);
     return success();

diff  --git a/mlir/test/Conversion/MemRefToLLVM/expand-then-convert-to-llvm.mlir b/mlir/test/Conversion/MemRefToLLVM/expand-then-convert-to-llvm.mlir
index ec5ceae57ccb33..a78db9733b7eef 100644
--- a/mlir/test/Conversion/MemRefToLLVM/expand-then-convert-to-llvm.mlir
+++ b/mlir/test/Conversion/MemRefToLLVM/expand-then-convert-to-llvm.mlir
@@ -675,10 +675,7 @@ func.func @collapse_static_shape_with_non_identity_layout(%arg: memref<1x1x8x8xf
 // CHECK: %[[ALIGNED_PTR:.*]] = llvm.extractvalue %[[DESC]][1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
 // CHECK: %[[OFFSET:.*]] = llvm.extractvalue %[[DESC]][2] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
 // CHECK: %[[BUFF_ADDR:.*]] = llvm.getelementptr %[[ALIGNED_PTR]][%[[OFFSET]]] : (!llvm.ptr, i64) -> !llvm.ptr, f32
-// CHECK: %[[INT_TO_PTR:.*]] = llvm.ptrtoint %[[BUFF_ADDR]] : !llvm.ptr to i64
-// CHECK: %[[AND:.*]] = llvm.and %[[INT_TO_PTR]], {{.*}}  : i64
-// CHECK: %[[CMP:.*]] = llvm.icmp "eq" %[[AND]], {{.*}} : i64
-// CHECK: llvm.intr.assume %[[CMP]] : i1
+// CHECK: llvm.intr.assume %{{.*}} ["align"(%[[BUFF_ADDR]], %{{.*}} : !llvm.ptr, i64)] : i1
 // CHECK: %[[LD_ADDR:.*]] = llvm.getelementptr %[[BUFF_ADDR]][%{{.*}}] : (!llvm.ptr, i64) -> !llvm.ptr, f32
 // CHECK: %[[VAL:.*]] = llvm.load %[[LD_ADDR]] : !llvm.ptr -> f32
 // CHECK: return %[[VAL]] : f32

diff  --git a/mlir/test/Conversion/MemRefToLLVM/memref-to-llvm.mlir b/mlir/test/Conversion/MemRefToLLVM/memref-to-llvm.mlir
index 48dc9079333d4f..67b68b7a1c0444 100644
--- a/mlir/test/Conversion/MemRefToLLVM/memref-to-llvm.mlir
+++ b/mlir/test/Conversion/MemRefToLLVM/memref-to-llvm.mlir
@@ -155,12 +155,9 @@ func.func @subview(%0 : memref<64x4xf32, strided<[4, 1], offset: 0>>, %arg0 : in
 // CHECK-LABEL: func @assume_alignment(
 func.func @assume_alignment(%0 : memref<4x4xf16>) {
   // CHECK: %[[PTR:.*]] = llvm.extractvalue %[[MEMREF:.*]][1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
-  // CHECK-NEXT: %[[ZERO:.*]] = llvm.mlir.constant(0 : index) : i64
-  // CHECK-NEXT: %[[MASK:.*]] = llvm.mlir.constant(15 : index) : i64
-  // CHECK-NEXT: %[[INT:.*]] = llvm.ptrtoint %[[PTR]] : !llvm.ptr to i64
-  // CHECK-NEXT: %[[MASKED_PTR:.*]] = llvm.and %[[INT]], %[[MASK:.*]] : i64
-  // CHECK-NEXT: %[[CONDITION:.*]] = llvm.icmp "eq" %[[MASKED_PTR]], %[[ZERO]] : i64
-  // CHECK-NEXT: llvm.intr.assume %[[CONDITION]] : i1
+  // CHECK-NEXT: %[[TRUE:.*]] = llvm.mlir.constant(true) : i1
+  // CHECK-NEXT: %[[ALIGN:.*]] = llvm.mlir.constant(16 : index) : i64
+    // CHECK-NEXT: llvm.intr.assume %[[TRUE]] ["align"(%[[PTR]], %[[ALIGN]] : !llvm.ptr, i64)] : i1
   memref.assume_alignment %0, 16 : memref<4x4xf16>
   return
 }
@@ -172,12 +169,9 @@ func.func @assume_alignment_w_offset(%0 : memref<4x4xf16, strided<[?, ?], offset
   // CHECK-DAG: %[[PTR:.*]] = llvm.extractvalue %[[MEMREF:.*]][1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
   // CHECK-DAG: %[[OFFSET:.*]] = llvm.extractvalue %[[MEMREF]][2] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
   // CHECK-DAG: %[[BUFF_ADDR:.*]] =  llvm.getelementptr %[[PTR]][%[[OFFSET]]] : (!llvm.ptr, i64) -> !llvm.ptr, f16
-  // CHECK-DAG: %[[ZERO:.*]] = llvm.mlir.constant(0 : index) : i64
-  // CHECK-DAG: %[[MASK:.*]] = llvm.mlir.constant(15 : index) : i64
-  // CHECK-NEXT: %[[INT:.*]] = llvm.ptrtoint %[[BUFF_ADDR]] : !llvm.ptr to i64
-  // CHECK-NEXT: %[[MASKED_PTR:.*]] = llvm.and %[[INT]], %[[MASK:.*]] : i64
-  // CHECK-NEXT: %[[CONDITION:.*]] = llvm.icmp "eq" %[[MASKED_PTR]], %[[ZERO]] : i64
-  // CHECK-NEXT: llvm.intr.assume %[[CONDITION]] : i1
+  // CHECK-DAG: %[[TRUE:.*]] = llvm.mlir.constant(true) : i1
+  // CHECK-DAG: %[[ALIGN:.*]] = llvm.mlir.constant(16 : index) : i64
+    // CHECK-NEXT: llvm.intr.assume %[[TRUE]] ["align"(%[[BUFF_ADDR]], %[[ALIGN]] : !llvm.ptr, i64)] : i1
   memref.assume_alignment %0, 16 : memref<4x4xf16, strided<[?, ?], offset: ?>>
   return
 }
@@ -410,7 +404,7 @@ func.func @atomic_rmw_with_offset(%I : memref<10xi32, strided<[1], offset: 5>>,
 // CHECK-SAME:   %[[ARG2:.+]]: index
 // CHECK-DAG:    %[[MEMREF_STRUCT:.+]] = builtin.unrealized_conversion_cast %[[ARG0]] : memref<10xi32, strided<[1], offset: 5>> to !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
 // CHECK-DAG:    %[[INDEX:.+]] = builtin.unrealized_conversion_cast %[[ARG2]] : index to i64
-// CHECK:        %[[BASE_PTR:.+]] = llvm.extractvalue %[[MEMREF_STRUCT]][1] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> 
+// CHECK:        %[[BASE_PTR:.+]] = llvm.extractvalue %[[MEMREF_STRUCT]][1] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
 // CHECK:        %[[OFFSET:.+]] = llvm.mlir.constant(5 : index) : i64
 // CHECK:        %[[OFFSET_PTR:.+]] = llvm.getelementptr %[[BASE_PTR]][%[[OFFSET]]] : (!llvm.ptr, i64) -> !llvm.ptr, i32
 // CHECK:        %[[PTR:.+]] = llvm.getelementptr %[[OFFSET_PTR]][%[[INDEX]]] : (!llvm.ptr, i64) -> !llvm.ptr, i32
@@ -601,7 +595,7 @@ func.func @extract_aligned_pointer_as_index(%m: memref<?xf32>) -> index {
 // CHECK-LABEL: func @extract_aligned_pointer_as_index_unranked
 func.func @extract_aligned_pointer_as_index_unranked(%m: memref<*xf32>) -> index {
   %0 = memref.extract_aligned_pointer_as_index %m: memref<*xf32> -> index
-  // CHECK: %[[PTR:.*]] = llvm.extractvalue %{{.*}}[1] : !llvm.struct<(i64, ptr)> 
+  // CHECK: %[[PTR:.*]] = llvm.extractvalue %{{.*}}[1] : !llvm.struct<(i64, ptr)>
   // CHECK: %[[ALIGNED_FIELD:.*]] = llvm.getelementptr %[[PTR]][1] : (!llvm.ptr) -> !llvm.ptr, !llvm.ptr
   // CHECK: %[[ALIGNED_PTR:.*]] = llvm.load %[[ALIGNED_FIELD]] : !llvm.ptr -> !llvm.ptr
   // CHECK: %[[I64:.*]] = llvm.ptrtoint %[[ALIGNED_PTR]] : !llvm.ptr to i64


        


More information about the Mlir-commits mailing list