[Mlir-commits] [mlir] e02d414 - [MemRefToLLVM] Add a method in MemRefDescriptor to get the buffer addr
Quentin Colombet
llvmlistbot at llvm.org
Tue Apr 25 21:14:48 PDT 2023
Author: Quentin Colombet
Date: 2023-04-26T06:12:18+02:00
New Revision: e02d4142dca48d9359ad304fec629ea3d7e6924c
URL: https://github.com/llvm/llvm-project/commit/e02d4142dca48d9359ad304fec629ea3d7e6924c
DIFF: https://github.com/llvm/llvm-project/commit/e02d4142dca48d9359ad304fec629ea3d7e6924c.diff
LOG: [MemRefToLLVM] Add a method in MemRefDescriptor to get the buffer addr
This patch pushes the computation of the start address of a memref in one
place (a method in MemRefDescriptor.)
This allows all the (indirect) users of this method to produce the start
address in the same way.
Thanks to this change, we expose more CSEs opportunities and thanks to
that, the backend is able to properly find the `llvm.assume` expression
related to the base address as demonstrated in the added test.
Differential Revision: https://reviews.llvm.org/D148947
Added:
Modified:
mlir/include/mlir/Conversion/LLVMCommon/MemRefBuilder.h
mlir/lib/Conversion/LLVMCommon/MemRefBuilder.cpp
mlir/lib/Conversion/LLVMCommon/Pattern.cpp
mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp
mlir/test/Conversion/MemRefToLLVM/expand-then-convert-to-llvm.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Conversion/LLVMCommon/MemRefBuilder.h b/mlir/include/mlir/Conversion/LLVMCommon/MemRefBuilder.h
index 76ed89563fd7b..28d37a91edb80 100644
--- a/mlir/include/mlir/Conversion/LLVMCommon/MemRefBuilder.h
+++ b/mlir/include/mlir/Conversion/LLVMCommon/MemRefBuilder.h
@@ -89,6 +89,14 @@ class MemRefDescriptor : public StructBuilder {
/// Returns the (LLVM) pointer type this descriptor contains.
LLVM::LLVMPointerType getElementPtrType();
+ /// Builds IR for getting the start address of the buffer represented
+ /// by this memref:
+ /// `memref.alignedPtr + memref.offset * sizeof(type.getElementType())`.
+ /// \note there is no setter for this one since it is derived from alignedPtr
+ /// and offset.
+ Value bufferPtr(OpBuilder &builder, Location loc,
+ LLVMTypeConverter &converter, MemRefType type);
+
/// Builds IR populating a MemRef descriptor structure from a list of
/// individual values composing that descriptor, in the following order:
/// - allocated pointer;
diff --git a/mlir/lib/Conversion/LLVMCommon/MemRefBuilder.cpp b/mlir/lib/Conversion/LLVMCommon/MemRefBuilder.cpp
index 248aab3d6838b..2373765dae007 100644
--- a/mlir/lib/Conversion/LLVMCommon/MemRefBuilder.cpp
+++ b/mlir/lib/Conversion/LLVMCommon/MemRefBuilder.cpp
@@ -199,6 +199,28 @@ LLVM::LLVMPointerType MemRefDescriptor::getElementPtrType() {
.cast<LLVM::LLVMPointerType>();
}
+Value MemRefDescriptor::bufferPtr(OpBuilder &builder, Location loc,
+ LLVMTypeConverter &converter,
+ MemRefType type) {
+ // When we convert to LLVM, the input memref must have been normalized
+ // beforehand. Hence, this call is guaranteed to work.
+ auto [strides, offsetCst] = getStridesAndOffset(type);
+
+ Value ptr = alignedPtr(builder, loc);
+ // Skip if offset is zero.
+ if (offsetCst != 0) {
+ Type indexType = converter.getIndexType();
+ Value offsetVal =
+ ShapedType::isDynamic(offsetCst)
+ ? offset(builder, loc)
+ : createIndexAttrConstant(builder, loc, indexType, offsetCst);
+ Type elementType = converter.convertType(type.getElementType());
+ ptr = builder.create<LLVM::GEPOp>(loc, ptr.getType(), elementType, ptr,
+ offsetVal);
+ }
+ return ptr;
+}
+
/// Creates a MemRef descriptor structure from a list of individual values
/// composing that descriptor, in the following order:
/// - allocated pointer;
diff --git a/mlir/lib/Conversion/LLVMCommon/Pattern.cpp b/mlir/lib/Conversion/LLVMCommon/Pattern.cpp
index e2dae4044504b..67a2898c02050 100644
--- a/mlir/lib/Conversion/LLVMCommon/Pattern.cpp
+++ b/mlir/lib/Conversion/LLVMCommon/Pattern.cpp
@@ -72,14 +72,14 @@ Value ConvertToLLVMPattern::getStridedElementPtr(
auto [strides, offset] = getStridesAndOffset(type);
MemRefDescriptor memRefDescriptor(memRefDesc);
- Value base = memRefDescriptor.alignedPtr(rewriter, loc);
+ // Use a canonical representation of the start address so that later
+ // optimizations have a longer sequence of instructions to CSE.
+ // If we don't do that we would sprinkle the memref.offset in various
+ // position of the
diff erent address computations.
+ Value base =
+ memRefDescriptor.bufferPtr(rewriter, loc, *getTypeConverter(), type);
Value index;
- if (offset != 0) // Skip if offset is zero.
- index = ShapedType::isDynamic(offset)
- ? memRefDescriptor.offset(rewriter, loc)
- : createIndexConstant(rewriter, loc, offset);
-
for (int i = 0, e = indices.size(); i < e; ++i) {
Value increment = indices[i];
if (strides[i] != 1) { // Skip if stride is 1.
diff --git a/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp b/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp
index c0a3a78434de9..e9fbad30783c3 100644
--- a/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp
+++ b/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp
@@ -332,6 +332,8 @@ struct AssumeAlignmentOpLowering
: public ConvertOpToLLVMPattern<memref::AssumeAlignmentOp> {
using ConvertOpToLLVMPattern<
memref::AssumeAlignmentOp>::ConvertOpToLLVMPattern;
+ explicit AssumeAlignmentOpLowering(LLVMTypeConverter &converter)
+ : ConvertOpToLLVMPattern<memref::AssumeAlignmentOp>(converter) {}
LogicalResult
matchAndRewrite(memref::AssumeAlignmentOp op, OpAdaptor adaptor,
@@ -341,28 +343,15 @@ struct AssumeAlignmentOpLowering
auto loc = op.getLoc();
auto srcMemRefType = op.getMemref().getType().cast<MemRefType>();
- // When we convert to LLVM, the input memref must have been normalized
- // beforehand. Hence, this call is guaranteed to work.
- auto [strides, offset] = getStridesAndOffset(srcMemRefType);
-
- MemRefDescriptor memRefDescriptor(memref);
- Value ptr = memRefDescriptor.alignedPtr(rewriter, memref.getLoc());
- // Skip if offset is zero.
- if (offset != 0) {
- Value offsetVal = ShapedType::isDynamic(offset)
- ? memRefDescriptor.offset(rewriter, loc)
- : createIndexConstant(rewriter, loc, offset);
- Type elementType =
- typeConverter->convertType(srcMemRefType.getElementType());
- ptr = rewriter.create<LLVM::GEPOp>(loc, ptr.getType(), elementType, ptr,
- offsetVal);
- }
+ Value ptr = getStridedElementPtr(loc, srcMemRefType, memref, /*indices=*/{},
+ rewriter);
// Emit llvm.assume(memref & (alignment - 1) == 0).
//
// This relies on LLVM's CSE optimization (potentially after SROA), since
// after CSE all memref instances should get de-duplicated into the same
// pointer SSA value.
+ MemRefDescriptor memRefDescriptor(memref);
auto intPtrType =
getIntPtrType(memRefDescriptor.getElementPtrType().getAddressSpace());
Value zero = createIndexAttrConstant(rewriter, loc, intPtrType, 0);
diff --git a/mlir/test/Conversion/MemRefToLLVM/expand-then-convert-to-llvm.mlir b/mlir/test/Conversion/MemRefToLLVM/expand-then-convert-to-llvm.mlir
index 7a21bf36b8537..0e655b5464d96 100644
--- a/mlir/test/Conversion/MemRefToLLVM/expand-then-convert-to-llvm.mlir
+++ b/mlir/test/Conversion/MemRefToLLVM/expand-then-convert-to-llvm.mlir
@@ -668,3 +668,32 @@ func.func @collapse_static_shape_with_non_identity_layout(%arg: memref<1x1x8x8xf
%1 = memref.collapse_shape %arg [[0, 1, 2, 3]] : memref<1x1x8x8xf32, strided<[64, 64, 8, 1], offset: ?>> into memref<64xf32, strided<[1], offset: ?>>
return %1 : memref<64xf32, strided<[1], offset: ?>>
}
+
+// -----
+
+// Check that the address of %arg0 appears with the same value
+// in both the llvm.assume and as base of the load.
+// This is to make sure that later CSEs and alignment propagation
+// will be able to do their job easily.
+
+// CHECK-LABEL: func @load_and_assume(
+// CHECK-SAME: %[[ARG0:.*]]: memref<?x?xf32, strided<[?, ?], offset: ?>>,
+// CHECK: %[[DESC:.*]] = builtin.unrealized_conversion_cast %[[ARG0]] : memref<?x?xf32, strided<[?, ?], offset: ?>> to !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
+// CHECK: %[[ALIGNED_PTR:.*]] = llvm.extractvalue %[[DESC]][1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
+// CHECK: %[[OFFSET:.*]] = llvm.extractvalue %[[DESC]][2] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
+// CHECK: %[[BUFF_ADDR:.*]] = llvm.getelementptr %[[ALIGNED_PTR]][%[[OFFSET]]] : (!llvm.ptr, i64) -> !llvm.ptr, f32
+// CHECK: %[[INT_TO_PTR:.*]] = llvm.ptrtoint %[[BUFF_ADDR]] : !llvm.ptr to i64
+// CHECK: %[[AND:.*]] = llvm.and %[[INT_TO_PTR]], {{.*}} : i64
+// CHECK: %[[CMP:.*]] = llvm.icmp "eq" %[[AND]], {{.*}} : i64
+// CHECK: "llvm.intr.assume"(%[[CMP]]) : (i1) -> ()
+// CHECK: %[[LD_ADDR:.*]] = llvm.getelementptr %[[BUFF_ADDR]][%{{.*}}] : (!llvm.ptr, i64) -> !llvm.ptr, f32
+// CHECK: %[[VAL:.*]] = llvm.load %[[LD_ADDR]] : !llvm.ptr -> f32
+// CHECK: return %[[VAL]] : f32
+func.func @load_and_assume(
+ %arg0: memref<?x?xf32, strided<[?, ?], offset: ?>>,
+ %i0: index, %i1: index)
+ -> f32 {
+ memref.assume_alignment %arg0, 16 : memref<?x?xf32, strided<[?, ?], offset: ?>>
+ %2 = memref.load %arg0[%i0, %i1] : memref<?x?xf32, strided<[?, ?], offset: ?>>
+ func.return %2 : f32
+}
More information about the Mlir-commits
mailing list