[PATCH] D137364: [mlir][MemRefToLLVM] Fix the lowering of extract_strided_metadata

Quentin Colombet via Phabricator via llvm-commits llvm-commits at lists.llvm.org
Fri Nov 4 11:48:30 PDT 2022


qcolombet updated this revision to Diff 473303.
qcolombet added a comment.

- The order of evaluation of the arguments of a function is platform dependent, as a result the code that extracts the base and aligned pointers from the source memref would happen in different order on different platforms, hence the failure on debian. To fix that put the `basePtr` and `alignedPtr` in temporary variables before making the call.


Repository:
  rG LLVM Github Monorepo

CHANGES SINCE LAST ACTION
  https://reviews.llvm.org/D137364/new/

https://reviews.llvm.org/D137364

Files:
  mlir/include/mlir/Conversion/LLVMCommon/MemRefBuilder.h
  mlir/lib/Conversion/LLVMCommon/MemRefBuilder.cpp
  mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp
  mlir/test/Conversion/MemRefToLLVM/memref-to-llvm.mlir


Index: mlir/test/Conversion/MemRefToLLVM/memref-to-llvm.mlir
===================================================================
--- mlir/test/Conversion/MemRefToLLVM/memref-to-llvm.mlir
+++ mlir/test/Conversion/MemRefToLLVM/memref-to-llvm.mlir
@@ -1169,6 +1169,12 @@
 // CHECK-SAME: %[[ARG:.*]]: memref
 // CHECK: %[[MEM_DESC:.*]] = builtin.unrealized_conversion_cast %[[ARG]] : memref<?x?xf32, strided<[?, ?], offset: ?>> to !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<2 x i64>, array<2 x i64>)>
 // CHECK: %[[BASE:.*]] = llvm.extractvalue %[[MEM_DESC]][0] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<2 x i64>, array<2 x i64>)>
+// CHECK: %[[ALIGNED_BASE:.*]] = llvm.extractvalue %[[MEM_DESC]][1] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<2 x i64>, array<2 x i64>)>
+// CHECK: %[[DESC:.*]] = llvm.mlir.undef : !llvm.struct<(ptr<f32>, ptr<f32>, i64)>
+// CHECK: %[[DESC0:.*]] = llvm.insertvalue %[[BASE]], %[[DESC]][0] : !llvm.struct<(ptr<f32>, ptr<f32>, i64)>
+// CHECK: %[[DESC1:.*]] = llvm.insertvalue %[[ALIGNED_BASE]], %[[DESC0]][1] : !llvm.struct<(ptr<f32>, ptr<f32>, i64)>
+// CHECK: %[[OFF0:.*]] = llvm.mlir.constant(0 : index) : i64
+// CHECK: %[[BASE_BUFFER_DESC:.*]] = llvm.insertvalue %[[OFF0]], %[[DESC1]][2] : !llvm.struct<(ptr<f32>, ptr<f32>, i64)>
 // CHECK: %[[OFFSET:.*]] = llvm.extractvalue %[[MEM_DESC]][2] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<2 x i64>, array<2 x i64>)>
 // CHECK: %[[SIZE0:.*]] = llvm.extractvalue %[[MEM_DESC]][3, 0] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<2 x i64>, array<2 x i64>)>
 // CHECK: %[[SIZE1:.*]] = llvm.extractvalue %[[MEM_DESC]][3, 1] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<2 x i64>, array<2 x i64>)>
Index: mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp
===================================================================
--- mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp
+++ mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp
@@ -2115,7 +2115,7 @@
       return failure();
 
     // Create the descriptor.
-    MemRefDescriptor sourceMemRef(adaptor.getOperands().front());
+    MemRefDescriptor sourceMemRef(adaptor.getSource());
     Location loc = extractStridedMetadataOp.getLoc();
     Value source = extractStridedMetadataOp.getSource();
 
@@ -2125,7 +2125,13 @@
     results.reserve(2 + rank * 2);
 
     // Base buffer.
-    results.push_back(sourceMemRef.allocatedPtr(rewriter, loc));
+    Value baseBuffer = sourceMemRef.allocatedPtr(rewriter, loc);
+    Value alignedBuffer = sourceMemRef.alignedPtr(rewriter, loc);
+    MemRefDescriptor dstMemRef = MemRefDescriptor::fromStaticShape(
+        rewriter, loc, *getTypeConverter(),
+        extractStridedMetadataOp.getBaseBuffer().getType().cast<MemRefType>(),
+        baseBuffer, alignedBuffer);
+    results.push_back((Value)dstMemRef);
 
     // Offset.
     results.push_back(sourceMemRef.offset(rewriter, loc));
Index: mlir/lib/Conversion/LLVMCommon/MemRefBuilder.cpp
===================================================================
--- mlir/lib/Conversion/LLVMCommon/MemRefBuilder.cpp
+++ mlir/lib/Conversion/LLVMCommon/MemRefBuilder.cpp
@@ -43,6 +43,12 @@
 MemRefDescriptor::fromStaticShape(OpBuilder &builder, Location loc,
                                   LLVMTypeConverter &typeConverter,
                                   MemRefType type, Value memory) {
+  return fromStaticShape(builder, loc, typeConverter, type, memory, memory);
+}
+
+MemRefDescriptor MemRefDescriptor::fromStaticShape(
+    OpBuilder &builder, Location loc, LLVMTypeConverter &typeConverter,
+    MemRefType type, Value memory, Value alignedMemory) {
   assert(type.hasStaticShape() && "unexpected dynamic shape");
 
   // Extract all strides and offsets and verify they are static.
@@ -61,7 +67,7 @@
 
   auto descr = MemRefDescriptor::undef(builder, loc, convertedType);
   descr.setAllocatedPtr(builder, loc, memory);
-  descr.setAlignedPtr(builder, loc, memory);
+  descr.setAlignedPtr(builder, loc, alignedMemory);
   descr.setConstantOffset(builder, loc, offset);
 
   // Fill in sizes and strides
Index: mlir/include/mlir/Conversion/LLVMCommon/MemRefBuilder.h
===================================================================
--- mlir/include/mlir/Conversion/LLVMCommon/MemRefBuilder.h
+++ mlir/include/mlir/Conversion/LLVMCommon/MemRefBuilder.h
@@ -43,6 +43,10 @@
   static MemRefDescriptor fromStaticShape(OpBuilder &builder, Location loc,
                                           LLVMTypeConverter &typeConverter,
                                           MemRefType type, Value memory);
+  static MemRefDescriptor fromStaticShape(OpBuilder &builder, Location loc,
+                                          LLVMTypeConverter &typeConverter,
+                                          MemRefType type, Value memory,
+                                          Value alignedMemory);
 
   /// Builds IR extracting the allocated pointer from the descriptor.
   Value allocatedPtr(OpBuilder &builder, Location loc);


-------------- next part --------------
A non-text attachment was scrubbed...
Name: D137364.473303.patch
Type: text/x-patch
Size: 4971 bytes
Desc: not available
URL: <http://lists.llvm.org/pipermail/llvm-commits/attachments/20221104/a94a0ef5/attachment.bin>


More information about the llvm-commits mailing list