[Mlir-commits] [mlir] 200266a - [mlir][MemRef] Fix the lowering of extract_strided_metadata

Quentin Colombet llvmlistbot at llvm.org
Fri Nov 4 18:13:28 PDT 2022


Author: Quentin Colombet
Date: 2022-11-05T01:06:38Z
New Revision: 200266a0a12b0af77967c8770e1e1bdd634bfc4a

URL: https://github.com/llvm/llvm-project/commit/200266a0a12b0af77967c8770e1e1bdd634bfc4a
DIFF: https://github.com/llvm/llvm-project/commit/200266a0a12b0af77967c8770e1e1bdd634bfc4a.diff

LOG: [mlir][MemRef] Fix the lowering of extract_strided_metadata

The first result of the extract_strided_metadata operation is a MemRef,
not a naked pointer.
This patch fixes the lowering of this operation in MemRefToLLVM so that
we properly materialize the full MemRef structure and not just the base,
naked, pointer.

Differential Revision: https://reviews.llvm.org/D137364

Added: 
    

Modified: 
    mlir/include/mlir/Conversion/LLVMCommon/MemRefBuilder.h
    mlir/lib/Conversion/LLVMCommon/MemRefBuilder.cpp
    mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp
    mlir/test/Conversion/MemRefToLLVM/memref-to-llvm.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Conversion/LLVMCommon/MemRefBuilder.h b/mlir/include/mlir/Conversion/LLVMCommon/MemRefBuilder.h
index ac54ee6888136..2b7735da84666 100644
--- a/mlir/include/mlir/Conversion/LLVMCommon/MemRefBuilder.h
+++ b/mlir/include/mlir/Conversion/LLVMCommon/MemRefBuilder.h
@@ -43,6 +43,10 @@ class MemRefDescriptor : public StructBuilder {
   static MemRefDescriptor fromStaticShape(OpBuilder &builder, Location loc,
                                           LLVMTypeConverter &typeConverter,
                                           MemRefType type, Value memory);
+  static MemRefDescriptor fromStaticShape(OpBuilder &builder, Location loc,
+                                          LLVMTypeConverter &typeConverter,
+                                          MemRefType type, Value memory,
+                                          Value alignedMemory);
 
   /// Builds IR extracting the allocated pointer from the descriptor.
   Value allocatedPtr(OpBuilder &builder, Location loc);

diff  --git a/mlir/lib/Conversion/LLVMCommon/MemRefBuilder.cpp b/mlir/lib/Conversion/LLVMCommon/MemRefBuilder.cpp
index a65ac51c31c63..4f72cd1081f0e 100644
--- a/mlir/lib/Conversion/LLVMCommon/MemRefBuilder.cpp
+++ b/mlir/lib/Conversion/LLVMCommon/MemRefBuilder.cpp
@@ -43,6 +43,12 @@ MemRefDescriptor
 MemRefDescriptor::fromStaticShape(OpBuilder &builder, Location loc,
                                   LLVMTypeConverter &typeConverter,
                                   MemRefType type, Value memory) {
+  return fromStaticShape(builder, loc, typeConverter, type, memory, memory);
+}
+
+MemRefDescriptor MemRefDescriptor::fromStaticShape(
+    OpBuilder &builder, Location loc, LLVMTypeConverter &typeConverter,
+    MemRefType type, Value memory, Value alignedMemory) {
   assert(type.hasStaticShape() && "unexpected dynamic shape");
 
   // Extract all strides and offsets and verify they are static.
@@ -61,7 +67,7 @@ MemRefDescriptor::fromStaticShape(OpBuilder &builder, Location loc,
 
   auto descr = MemRefDescriptor::undef(builder, loc, convertedType);
   descr.setAllocatedPtr(builder, loc, memory);
-  descr.setAlignedPtr(builder, loc, memory);
+  descr.setAlignedPtr(builder, loc, alignedMemory);
   descr.setConstantOffset(builder, loc, offset);
 
   // Fill in sizes and strides

diff  --git a/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp b/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp
index d63b84ccdf856..4685590fa1d32 100644
--- a/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp
+++ b/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp
@@ -2115,7 +2115,7 @@ class ExtractStridedMetadataOpLowering
       return failure();
 
     // Create the descriptor.
-    MemRefDescriptor sourceMemRef(adaptor.getOperands().front());
+    MemRefDescriptor sourceMemRef(adaptor.getSource());
     Location loc = extractStridedMetadataOp.getLoc();
     Value source = extractStridedMetadataOp.getSource();
 
@@ -2125,7 +2125,13 @@ class ExtractStridedMetadataOpLowering
     results.reserve(2 + rank * 2);
 
     // Base buffer.
-    results.push_back(sourceMemRef.allocatedPtr(rewriter, loc));
+    Value baseBuffer = sourceMemRef.allocatedPtr(rewriter, loc);
+    Value alignedBuffer = sourceMemRef.alignedPtr(rewriter, loc);
+    MemRefDescriptor dstMemRef = MemRefDescriptor::fromStaticShape(
+        rewriter, loc, *getTypeConverter(),
+        extractStridedMetadataOp.getBaseBuffer().getType().cast<MemRefType>(),
+        baseBuffer, alignedBuffer);
+    results.push_back((Value)dstMemRef);
 
     // Offset.
     results.push_back(sourceMemRef.offset(rewriter, loc));

diff  --git a/mlir/test/Conversion/MemRefToLLVM/memref-to-llvm.mlir b/mlir/test/Conversion/MemRefToLLVM/memref-to-llvm.mlir
index 12fd3f21e3da0..344e06db09b62 100644
--- a/mlir/test/Conversion/MemRefToLLVM/memref-to-llvm.mlir
+++ b/mlir/test/Conversion/MemRefToLLVM/memref-to-llvm.mlir
@@ -1169,6 +1169,12 @@ func.func @extract_aligned_pointer_as_index(%m: memref<?xf32>) -> index {
 // CHECK-SAME: %[[ARG:.*]]: memref
 // CHECK: %[[MEM_DESC:.*]] = builtin.unrealized_conversion_cast %[[ARG]] : memref<?x?xf32, strided<[?, ?], offset: ?>> to !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<2 x i64>, array<2 x i64>)>
 // CHECK: %[[BASE:.*]] = llvm.extractvalue %[[MEM_DESC]][0] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<2 x i64>, array<2 x i64>)>
+// CHECK: %[[ALIGNED_BASE:.*]] = llvm.extractvalue %[[MEM_DESC]][1] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<2 x i64>, array<2 x i64>)>
+// CHECK: %[[DESC:.*]] = llvm.mlir.undef : !llvm.struct<(ptr<f32>, ptr<f32>, i64)>
+// CHECK: %[[DESC0:.*]] = llvm.insertvalue %[[BASE]], %[[DESC]][0] : !llvm.struct<(ptr<f32>, ptr<f32>, i64)>
+// CHECK: %[[DESC1:.*]] = llvm.insertvalue %[[ALIGNED_BASE]], %[[DESC0]][1] : !llvm.struct<(ptr<f32>, ptr<f32>, i64)>
+// CHECK: %[[OFF0:.*]] = llvm.mlir.constant(0 : index) : i64
+// CHECK: %[[BASE_BUFFER_DESC:.*]] = llvm.insertvalue %[[OFF0]], %[[DESC1]][2] : !llvm.struct<(ptr<f32>, ptr<f32>, i64)>
 // CHECK: %[[OFFSET:.*]] = llvm.extractvalue %[[MEM_DESC]][2] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<2 x i64>, array<2 x i64>)>
 // CHECK: %[[SIZE0:.*]] = llvm.extractvalue %[[MEM_DESC]][3, 0] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<2 x i64>, array<2 x i64>)>
 // CHECK: %[[SIZE1:.*]] = llvm.extractvalue %[[MEM_DESC]][3, 1] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<2 x i64>, array<2 x i64>)>


        


More information about the Mlir-commits mailing list