[Mlir-commits] [mlir] 68e1aef - [MemRefToLLVM] Fix the lowering of memref.assume_alignment
Quentin Colombet
llvmlistbot at llvm.org
Tue Apr 25 21:09:59 PDT 2023
Author: Quentin Colombet
Date: 2023-04-26T06:07:00+02:00
New Revision: 68e1aef68e40452b6c176b25e67c13a0359c96ca
URL: https://github.com/llvm/llvm-project/commit/68e1aef68e40452b6c176b25e67c13a0359c96ca
DIFF: https://github.com/llvm/llvm-project/commit/68e1aef68e40452b6c176b25e67c13a0359c96ca.diff
LOG: [MemRefToLLVM] Fix the lowering of memref.assume_alignment
`memref.assume_alignment` annotates the alignment of the source buffer
not the base pointer.
Put diffrently, prior to this patch `memref.assume_alignment` would lower
to `llvm.assume %buffer.base.isAligned(X)` whereas what we want is
`llvm.assume (%buffer.base + %buffer.offset).isAligned(X)`.
In other words, we were missing to include the offset in the expression
checked by the `llvm.assume`.
Differential Revision: https://reviews.llvm.org/D148930
Added:
Modified:
mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp
mlir/test/Conversion/MemRefToLLVM/memref-to-llvm.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp b/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp
index 07d406bbb1c1a..c0a3a78434de9 100644
--- a/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp
+++ b/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp
@@ -340,16 +340,28 @@ struct AssumeAlignmentOpLowering
unsigned alignment = op.getAlignment();
auto loc = op.getLoc();
+ auto srcMemRefType = op.getMemref().getType().cast<MemRefType>();
+ // When we convert to LLVM, the input memref must have been normalized
+ // beforehand. Hence, this call is guaranteed to work.
+ auto [strides, offset] = getStridesAndOffset(srcMemRefType);
+
MemRefDescriptor memRefDescriptor(memref);
Value ptr = memRefDescriptor.alignedPtr(rewriter, memref.getLoc());
+ // Skip if offset is zero.
+ if (offset != 0) {
+ Value offsetVal = ShapedType::isDynamic(offset)
+ ? memRefDescriptor.offset(rewriter, loc)
+ : createIndexConstant(rewriter, loc, offset);
+ Type elementType =
+ typeConverter->convertType(srcMemRefType.getElementType());
+ ptr = rewriter.create<LLVM::GEPOp>(loc, ptr.getType(), elementType, ptr,
+ offsetVal);
+ }
- // Emit llvm.assume(memref.alignedPtr & (alignment - 1) == 0). Notice that
- // the asserted memref.alignedPtr isn't used anywhere else, as the real
- // users like load/store/views always re-extract memref.alignedPtr as they
- // get lowered.
+ // Emit llvm.assume(memref & (alignment - 1) == 0).
//
// This relies on LLVM's CSE optimization (potentially after SROA), since
- // after CSE all memref.alignedPtr instances get de-duplicated into the same
+ // after CSE all memref instances should get de-duplicated into the same
// pointer SSA value.
auto intPtrType =
getIntPtrType(memRefDescriptor.getElementPtrType().getAddressSpace());
diff --git a/mlir/test/Conversion/MemRefToLLVM/memref-to-llvm.mlir b/mlir/test/Conversion/MemRefToLLVM/memref-to-llvm.mlir
index 4b4b3836a0075..5b58198bcfd82 100644
--- a/mlir/test/Conversion/MemRefToLLVM/memref-to-llvm.mlir
+++ b/mlir/test/Conversion/MemRefToLLVM/memref-to-llvm.mlir
@@ -130,7 +130,7 @@ func.func @subview(%0 : memref<64x4xf32, strided<[4, 1], offset: 0>>, %arg0 : in
// -----
-// CHECK-LABEL: func @assume_alignment
+// CHECK-LABEL: func @assume_alignment(
func.func @assume_alignment(%0 : memref<4x4xf16>) {
// CHECK: %[[PTR:.*]] = llvm.extractvalue %[[MEMREF:.*]][1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
// CHECK-NEXT: %[[ZERO:.*]] = llvm.mlir.constant(0 : index) : i64
@@ -145,6 +145,22 @@ func.func @assume_alignment(%0 : memref<4x4xf16>) {
// -----
+// CHECK-LABEL: func @assume_alignment_w_offset
+func.func @assume_alignment_w_offset(%0 : memref<4x4xf16, strided<[?, ?], offset: ?>>) {
+ // CHECK-DAG: %[[PTR:.*]] = llvm.extractvalue %[[MEMREF:.*]][1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
+ // CHECK-DAG: %[[OFFSET:.*]] = llvm.extractvalue %[[MEMREF]][2] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
+ // CHECK-DAG: %[[BUFF_ADDR:.*]] = llvm.getelementptr %[[PTR]][%[[OFFSET]]] : (!llvm.ptr, i64) -> !llvm.ptr, f16
+ // CHECK-DAG: %[[ZERO:.*]] = llvm.mlir.constant(0 : index) : i64
+ // CHECK-DAG: %[[MASK:.*]] = llvm.mlir.constant(15 : index) : i64
+ // CHECK-NEXT: %[[INT:.*]] = llvm.ptrtoint %[[BUFF_ADDR]] : !llvm.ptr to i64
+ // CHECK-NEXT: %[[MASKED_PTR:.*]] = llvm.and %[[INT]], %[[MASK:.*]] : i64
+ // CHECK-NEXT: %[[CONDITION:.*]] = llvm.icmp "eq" %[[MASKED_PTR]], %[[ZERO]] : i64
+ // CHECK-NEXT: "llvm.intr.assume"(%[[CONDITION]]) : (i1) -> ()
+ memref.assume_alignment %0, 16 : memref<4x4xf16, strided<[?, ?], offset: ?>>
+ return
+}
+// -----
+
// CHECK-LABEL: func @dim_of_unranked
// CHECK32-LABEL: func @dim_of_unranked
func.func @dim_of_unranked(%unranked: memref<*xi32>) -> index {
More information about the Mlir-commits
mailing list