[PATCH] D137364: [mlir][MemRefToLLVM] Fix the lowering of extract_strided_metadata
Quentin Colombet via Phabricator via llvm-commits
llvm-commits at lists.llvm.org
Thu Nov 3 15:10:34 PDT 2022
qcolombet created this revision.
qcolombet added reviewers: nicolasvasilache, ftynse.
qcolombet added a project: MLIR.
Herald added subscribers: Moerafaat, zero9178, bzcheeseman, awarzynski, sdasgup3, wenzhicui, wrengr, cota, teijeong, rdzhabarov, tatianashp, msifontes, jurahul, Kayjukh, grosul1, Joonsoo, liufengdb, aartbik, mgester, arpith-jacob, antiagainst, shauheen, rriddle, mehdi_amini.
Herald added a project: All.
qcolombet requested review of this revision.
Herald added a subscriber: stephenneuendorffer.
Herald added a reviewer: dcaballe.
The first result of the extract_strided_metadata operation is a MemRef, not a naked pointer.
This patch fixes the lowering of this operation in MemRefToLLVM so that we properly materialize the full MemRef structure and not just the base naked pointer.
Repository:
rG LLVM Github Monorepo
https://reviews.llvm.org/D137364
Files:
mlir/include/mlir/Conversion/LLVMCommon/MemRefBuilder.h
mlir/lib/Conversion/LLVMCommon/MemRefBuilder.cpp
mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp
mlir/test/Conversion/MemRefToLLVM/memref-to-llvm.mlir
Index: mlir/test/Conversion/MemRefToLLVM/memref-to-llvm.mlir
===================================================================
--- mlir/test/Conversion/MemRefToLLVM/memref-to-llvm.mlir
+++ mlir/test/Conversion/MemRefToLLVM/memref-to-llvm.mlir
@@ -1168,7 +1168,13 @@
// CHECK-LABEL: func @extract_strided_metadata(
// CHECK-SAME: %[[ARG:.*]]: memref
// CHECK: %[[MEM_DESC:.*]] = builtin.unrealized_conversion_cast %[[ARG]] : memref<?x?xf32, strided<[?, ?], offset: ?>> to !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<2 x i64>, array<2 x i64>)>
+// CHECK: %[[ALIGNED_BASE:.*]] = llvm.extractvalue %[[MEM_DESC]][1] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<2 x i64>, array<2 x i64>)>
// CHECK: %[[BASE:.*]] = llvm.extractvalue %[[MEM_DESC]][0] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<2 x i64>, array<2 x i64>)>
+// CHECK: %[[DESC:.*]] = llvm.mlir.undef : !llvm.struct<(ptr<f32>, ptr<f32>, i64)>
+// CHECK: %[[DESC0:.*]] = llvm.insertvalue %[[BASE]], %[[DESC]][0] : !llvm.struct<(ptr<f32>, ptr<f32>, i64)>
+// CHECK: %[[DESC1:.*]] = llvm.insertvalue %[[ALIGNED_BASE]], %[[DESC0]][1] : !llvm.struct<(ptr<f32>, ptr<f32>, i64)>
+// CHECK: %[[OFF0:.*]] = llvm.mlir.constant(0 : index) : i64
+// CHECK: %[[BASE_BUFFER_DESC:.*]] = llvm.insertvalue %[[OFF0]], %[[DESC1]][2] : !llvm.struct<(ptr<f32>, ptr<f32>, i64)>
// CHECK: %[[OFFSET:.*]] = llvm.extractvalue %[[MEM_DESC]][2] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<2 x i64>, array<2 x i64>)>
// CHECK: %[[SIZE0:.*]] = llvm.extractvalue %[[MEM_DESC]][3, 0] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<2 x i64>, array<2 x i64>)>
// CHECK: %[[SIZE1:.*]] = llvm.extractvalue %[[MEM_DESC]][3, 1] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<2 x i64>, array<2 x i64>)>
Index: mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp
===================================================================
--- mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp
+++ mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp
@@ -2115,7 +2115,7 @@
return failure();
// Create the descriptor.
- MemRefDescriptor sourceMemRef(adaptor.getOperands().front());
+ MemRefDescriptor sourceMemRef(adaptor.getSource());
Location loc = extractStridedMetadataOp.getLoc();
Value source = extractStridedMetadataOp.getSource();
@@ -2125,7 +2125,12 @@
results.reserve(2 + rank * 2);
// Base buffer.
- results.push_back(sourceMemRef.allocatedPtr(rewriter, loc));
+ MemRefDescriptor dstMemRef = MemRefDescriptor::fromStaticShape(
+ rewriter, loc, *getTypeConverter(),
+ extractStridedMetadataOp.getBaseBuffer().getType().cast<MemRefType>(),
+ sourceMemRef.allocatedPtr(rewriter, loc),
+ sourceMemRef.alignedPtr(rewriter, loc));
+ results.push_back((Value)dstMemRef);
// Offset.
results.push_back(sourceMemRef.offset(rewriter, loc));
Index: mlir/lib/Conversion/LLVMCommon/MemRefBuilder.cpp
===================================================================
--- mlir/lib/Conversion/LLVMCommon/MemRefBuilder.cpp
+++ mlir/lib/Conversion/LLVMCommon/MemRefBuilder.cpp
@@ -43,6 +43,12 @@
MemRefDescriptor::fromStaticShape(OpBuilder &builder, Location loc,
LLVMTypeConverter &typeConverter,
MemRefType type, Value memory) {
+ return fromStaticShape(builder, loc, typeConverter, type, memory, memory);
+}
+
+MemRefDescriptor MemRefDescriptor::fromStaticShape(
+ OpBuilder &builder, Location loc, LLVMTypeConverter &typeConverter,
+ MemRefType type, Value memory, Value alignedMemory) {
assert(type.hasStaticShape() && "unexpected dynamic shape");
// Extract all strides and offsets and verify they are static.
@@ -61,7 +67,7 @@
auto descr = MemRefDescriptor::undef(builder, loc, convertedType);
descr.setAllocatedPtr(builder, loc, memory);
- descr.setAlignedPtr(builder, loc, memory);
+ descr.setAlignedPtr(builder, loc, alignedMemory);
descr.setConstantOffset(builder, loc, offset);
// Fill in sizes and strides
Index: mlir/include/mlir/Conversion/LLVMCommon/MemRefBuilder.h
===================================================================
--- mlir/include/mlir/Conversion/LLVMCommon/MemRefBuilder.h
+++ mlir/include/mlir/Conversion/LLVMCommon/MemRefBuilder.h
@@ -43,6 +43,10 @@
static MemRefDescriptor fromStaticShape(OpBuilder &builder, Location loc,
LLVMTypeConverter &typeConverter,
MemRefType type, Value memory);
+ static MemRefDescriptor fromStaticShape(OpBuilder &builder, Location loc,
+ LLVMTypeConverter &typeConverter,
+ MemRefType type, Value memory,
+ Value alignedMemory);
/// Builds IR extracting the allocated pointer from the descriptor.
Value allocatedPtr(OpBuilder &builder, Location loc);
-------------- next part --------------
A non-text attachment was scrubbed...
Name: D137364.473049.patch
Type: text/x-patch
Size: 4951 bytes
Desc: not available
URL: <http://lists.llvm.org/pipermail/llvm-commits/attachments/20221103/4c83b4b9/attachment.bin>
More information about the llvm-commits
mailing list