[Mlir-commits] [mlir] 68e1aef - [MemRefToLLVM] Fix the lowering of memref.assume_alignment

Quentin Colombet llvmlistbot at llvm.org
Tue Apr 25 21:09:59 PDT 2023


Author: Quentin Colombet
Date: 2023-04-26T06:07:00+02:00
New Revision: 68e1aef68e40452b6c176b25e67c13a0359c96ca

URL: https://github.com/llvm/llvm-project/commit/68e1aef68e40452b6c176b25e67c13a0359c96ca
DIFF: https://github.com/llvm/llvm-project/commit/68e1aef68e40452b6c176b25e67c13a0359c96ca.diff

LOG: [MemRefToLLVM] Fix the lowering of memref.assume_alignment

`memref.assume_alignment` annotates the alignment of the source buffer
not the base pointer.
Put diffrently, prior to this patch `memref.assume_alignment` would lower
to `llvm.assume %buffer.base.isAligned(X)` whereas what we want is
`llvm.assume (%buffer.base + %buffer.offset).isAligned(X)`.
In other words, we were missing to include the offset in the expression
checked by the `llvm.assume`.

Differential Revision: https://reviews.llvm.org/D148930

Added: 
    

Modified: 
    mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp
    mlir/test/Conversion/MemRefToLLVM/memref-to-llvm.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp b/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp
index 07d406bbb1c1a..c0a3a78434de9 100644
--- a/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp
+++ b/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp
@@ -340,16 +340,28 @@ struct AssumeAlignmentOpLowering
     unsigned alignment = op.getAlignment();
     auto loc = op.getLoc();
 
+    auto srcMemRefType = op.getMemref().getType().cast<MemRefType>();
+    // When we convert to LLVM, the input memref must have been normalized
+    // beforehand. Hence, this call is guaranteed to work.
+    auto [strides, offset] = getStridesAndOffset(srcMemRefType);
+
     MemRefDescriptor memRefDescriptor(memref);
     Value ptr = memRefDescriptor.alignedPtr(rewriter, memref.getLoc());
+    // Skip if offset is zero.
+    if (offset != 0) {
+      Value offsetVal = ShapedType::isDynamic(offset)
+                            ? memRefDescriptor.offset(rewriter, loc)
+                            : createIndexConstant(rewriter, loc, offset);
+      Type elementType =
+          typeConverter->convertType(srcMemRefType.getElementType());
+      ptr = rewriter.create<LLVM::GEPOp>(loc, ptr.getType(), elementType, ptr,
+                                         offsetVal);
+    }
 
-    // Emit llvm.assume(memref.alignedPtr & (alignment - 1) == 0). Notice that
-    // the asserted memref.alignedPtr isn't used anywhere else, as the real
-    // users like load/store/views always re-extract memref.alignedPtr as they
-    // get lowered.
+    // Emit llvm.assume(memref & (alignment - 1) == 0).
     //
     // This relies on LLVM's CSE optimization (potentially after SROA), since
-    // after CSE all memref.alignedPtr instances get de-duplicated into the same
+    // after CSE all memref instances should get de-duplicated into the same
     // pointer SSA value.
     auto intPtrType =
         getIntPtrType(memRefDescriptor.getElementPtrType().getAddressSpace());

diff  --git a/mlir/test/Conversion/MemRefToLLVM/memref-to-llvm.mlir b/mlir/test/Conversion/MemRefToLLVM/memref-to-llvm.mlir
index 4b4b3836a0075..5b58198bcfd82 100644
--- a/mlir/test/Conversion/MemRefToLLVM/memref-to-llvm.mlir
+++ b/mlir/test/Conversion/MemRefToLLVM/memref-to-llvm.mlir
@@ -130,7 +130,7 @@ func.func @subview(%0 : memref<64x4xf32, strided<[4, 1], offset: 0>>, %arg0 : in
 
 // -----
 
-// CHECK-LABEL: func @assume_alignment
+// CHECK-LABEL: func @assume_alignment(
 func.func @assume_alignment(%0 : memref<4x4xf16>) {
   // CHECK: %[[PTR:.*]] = llvm.extractvalue %[[MEMREF:.*]][1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
   // CHECK-NEXT: %[[ZERO:.*]] = llvm.mlir.constant(0 : index) : i64
@@ -145,6 +145,22 @@ func.func @assume_alignment(%0 : memref<4x4xf16>) {
 
 // -----
 
+// CHECK-LABEL: func @assume_alignment_w_offset
+func.func @assume_alignment_w_offset(%0 : memref<4x4xf16, strided<[?, ?], offset: ?>>) {
+  // CHECK-DAG: %[[PTR:.*]] = llvm.extractvalue %[[MEMREF:.*]][1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
+  // CHECK-DAG: %[[OFFSET:.*]] = llvm.extractvalue %[[MEMREF]][2] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
+  // CHECK-DAG: %[[BUFF_ADDR:.*]] =  llvm.getelementptr %[[PTR]][%[[OFFSET]]] : (!llvm.ptr, i64) -> !llvm.ptr, f16
+  // CHECK-DAG: %[[ZERO:.*]] = llvm.mlir.constant(0 : index) : i64
+  // CHECK-DAG: %[[MASK:.*]] = llvm.mlir.constant(15 : index) : i64
+  // CHECK-NEXT: %[[INT:.*]] = llvm.ptrtoint %[[BUFF_ADDR]] : !llvm.ptr to i64
+  // CHECK-NEXT: %[[MASKED_PTR:.*]] = llvm.and %[[INT]], %[[MASK:.*]] : i64
+  // CHECK-NEXT: %[[CONDITION:.*]] = llvm.icmp "eq" %[[MASKED_PTR]], %[[ZERO]] : i64
+  // CHECK-NEXT: "llvm.intr.assume"(%[[CONDITION]]) : (i1) -> ()
+  memref.assume_alignment %0, 16 : memref<4x4xf16, strided<[?, ?], offset: ?>>
+  return
+}
+// -----
+
 // CHECK-LABEL: func @dim_of_unranked
 // CHECK32-LABEL: func @dim_of_unranked
 func.func @dim_of_unranked(%unranked: memref<*xi32>) -> index {


        


More information about the Mlir-commits mailing list