[PATCH] D148930: [MemRefToLLVM] Fix the lowering of memref.assume_alignment

Quentin Colombet via Phabricator via llvm-commits llvm-commits at lists.llvm.org
Tue Apr 25 21:10:15 PDT 2023


This revision was automatically updated to reflect the committed changes.
qcolombet marked an inline comment as done.
Closed by commit rG68e1aef68e40: [MemRefToLLVM] Fix the lowering of memref.assume_alignment (authored by qcolombet).

Repository:
  rG LLVM Github Monorepo

CHANGES SINCE LAST ACTION
  https://reviews.llvm.org/D148930/new/

https://reviews.llvm.org/D148930

Files:
  mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp
  mlir/test/Conversion/MemRefToLLVM/memref-to-llvm.mlir


Index: mlir/test/Conversion/MemRefToLLVM/memref-to-llvm.mlir
===================================================================
--- mlir/test/Conversion/MemRefToLLVM/memref-to-llvm.mlir
+++ mlir/test/Conversion/MemRefToLLVM/memref-to-llvm.mlir
@@ -130,7 +130,7 @@
 
 // -----
 
-// CHECK-LABEL: func @assume_alignment
+// CHECK-LABEL: func @assume_alignment(
 func.func @assume_alignment(%0 : memref<4x4xf16>) {
   // CHECK: %[[PTR:.*]] = llvm.extractvalue %[[MEMREF:.*]][1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
   // CHECK-NEXT: %[[ZERO:.*]] = llvm.mlir.constant(0 : index) : i64
@@ -145,6 +145,22 @@
 
 // -----
 
+// CHECK-LABEL: func @assume_alignment_w_offset
+func.func @assume_alignment_w_offset(%0 : memref<4x4xf16, strided<[?, ?], offset: ?>>) {
+  // CHECK-DAG: %[[PTR:.*]] = llvm.extractvalue %[[MEMREF:.*]][1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
+  // CHECK-DAG: %[[OFFSET:.*]] = llvm.extractvalue %[[MEMREF]][2] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
+  // CHECK-DAG: %[[BUFF_ADDR:.*]] =  llvm.getelementptr %[[PTR]][%[[OFFSET]]] : (!llvm.ptr, i64) -> !llvm.ptr, f16
+  // CHECK-DAG: %[[ZERO:.*]] = llvm.mlir.constant(0 : index) : i64
+  // CHECK-DAG: %[[MASK:.*]] = llvm.mlir.constant(15 : index) : i64
+  // CHECK-NEXT: %[[INT:.*]] = llvm.ptrtoint %[[BUFF_ADDR]] : !llvm.ptr to i64
+  // CHECK-NEXT: %[[MASKED_PTR:.*]] = llvm.and %[[INT]], %[[MASK:.*]] : i64
+  // CHECK-NEXT: %[[CONDITION:.*]] = llvm.icmp "eq" %[[MASKED_PTR]], %[[ZERO]] : i64
+  // CHECK-NEXT: "llvm.intr.assume"(%[[CONDITION]]) : (i1) -> ()
+  memref.assume_alignment %0, 16 : memref<4x4xf16, strided<[?, ?], offset: ?>>
+  return
+}
+// -----
+
 // CHECK-LABEL: func @dim_of_unranked
 // CHECK32-LABEL: func @dim_of_unranked
 func.func @dim_of_unranked(%unranked: memref<*xi32>) -> index {
Index: mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp
===================================================================
--- mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp
+++ mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp
@@ -340,16 +340,28 @@
     unsigned alignment = op.getAlignment();
     auto loc = op.getLoc();
 
+    auto srcMemRefType = op.getMemref().getType().cast<MemRefType>();
+    // When we convert to LLVM, the input memref must have been normalized
+    // beforehand. Hence, this call is guaranteed to work.
+    auto [strides, offset] = getStridesAndOffset(srcMemRefType);
+
     MemRefDescriptor memRefDescriptor(memref);
     Value ptr = memRefDescriptor.alignedPtr(rewriter, memref.getLoc());
+    // Skip if offset is zero.
+    if (offset != 0) {
+      Value offsetVal = ShapedType::isDynamic(offset)
+                            ? memRefDescriptor.offset(rewriter, loc)
+                            : createIndexConstant(rewriter, loc, offset);
+      Type elementType =
+          typeConverter->convertType(srcMemRefType.getElementType());
+      ptr = rewriter.create<LLVM::GEPOp>(loc, ptr.getType(), elementType, ptr,
+                                         offsetVal);
+    }
 
-    // Emit llvm.assume(memref.alignedPtr & (alignment - 1) == 0). Notice that
-    // the asserted memref.alignedPtr isn't used anywhere else, as the real
-    // users like load/store/views always re-extract memref.alignedPtr as they
-    // get lowered.
+    // Emit llvm.assume(memref & (alignment - 1) == 0).
     //
     // This relies on LLVM's CSE optimization (potentially after SROA), since
-    // after CSE all memref.alignedPtr instances get de-duplicated into the same
+    // after CSE all memref instances should get de-duplicated into the same
     // pointer SSA value.
     auto intPtrType =
         getIntPtrType(memRefDescriptor.getElementPtrType().getAddressSpace());


-------------- next part --------------
A non-text attachment was scrubbed...
Name: D148930.517038.patch
Type: text/x-patch
Size: 3772 bytes
Desc: not available
URL: <http://lists.llvm.org/pipermail/llvm-commits/attachments/20230426/5901c958/attachment.bin>


More information about the llvm-commits mailing list