[PATCH] D137364: [mlir][MemRefToLLVM] Fix the lowering of extract_strided_metadata
Quentin Colombet via Phabricator via llvm-commits
llvm-commits at lists.llvm.org
Fri Nov 4 11:48:30 PDT 2022
qcolombet updated this revision to Diff 473303.
qcolombet added a comment.
- The order of evaluation of the arguments of a function is platform dependent, as a result the code that extracts the base and aligned pointers from the source memref would happen in different order on different platforms, hence the failure on debian. To fix that put the `basePtr` and `alignedPtr` in temporary variables before making the call.
Repository:
rG LLVM Github Monorepo
CHANGES SINCE LAST ACTION
https://reviews.llvm.org/D137364/new/
https://reviews.llvm.org/D137364
Files:
mlir/include/mlir/Conversion/LLVMCommon/MemRefBuilder.h
mlir/lib/Conversion/LLVMCommon/MemRefBuilder.cpp
mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp
mlir/test/Conversion/MemRefToLLVM/memref-to-llvm.mlir
Index: mlir/test/Conversion/MemRefToLLVM/memref-to-llvm.mlir
===================================================================
--- mlir/test/Conversion/MemRefToLLVM/memref-to-llvm.mlir
+++ mlir/test/Conversion/MemRefToLLVM/memref-to-llvm.mlir
@@ -1169,6 +1169,12 @@
// CHECK-SAME: %[[ARG:.*]]: memref
// CHECK: %[[MEM_DESC:.*]] = builtin.unrealized_conversion_cast %[[ARG]] : memref<?x?xf32, strided<[?, ?], offset: ?>> to !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<2 x i64>, array<2 x i64>)>
// CHECK: %[[BASE:.*]] = llvm.extractvalue %[[MEM_DESC]][0] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<2 x i64>, array<2 x i64>)>
+// CHECK: %[[ALIGNED_BASE:.*]] = llvm.extractvalue %[[MEM_DESC]][1] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<2 x i64>, array<2 x i64>)>
+// CHECK: %[[DESC:.*]] = llvm.mlir.undef : !llvm.struct<(ptr<f32>, ptr<f32>, i64)>
+// CHECK: %[[DESC0:.*]] = llvm.insertvalue %[[BASE]], %[[DESC]][0] : !llvm.struct<(ptr<f32>, ptr<f32>, i64)>
+// CHECK: %[[DESC1:.*]] = llvm.insertvalue %[[ALIGNED_BASE]], %[[DESC0]][1] : !llvm.struct<(ptr<f32>, ptr<f32>, i64)>
+// CHECK: %[[OFF0:.*]] = llvm.mlir.constant(0 : index) : i64
+// CHECK: %[[BASE_BUFFER_DESC:.*]] = llvm.insertvalue %[[OFF0]], %[[DESC1]][2] : !llvm.struct<(ptr<f32>, ptr<f32>, i64)>
// CHECK: %[[OFFSET:.*]] = llvm.extractvalue %[[MEM_DESC]][2] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<2 x i64>, array<2 x i64>)>
// CHECK: %[[SIZE0:.*]] = llvm.extractvalue %[[MEM_DESC]][3, 0] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<2 x i64>, array<2 x i64>)>
// CHECK: %[[SIZE1:.*]] = llvm.extractvalue %[[MEM_DESC]][3, 1] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<2 x i64>, array<2 x i64>)>
Index: mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp
===================================================================
--- mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp
+++ mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp
@@ -2115,7 +2115,7 @@
return failure();
// Create the descriptor.
- MemRefDescriptor sourceMemRef(adaptor.getOperands().front());
+ MemRefDescriptor sourceMemRef(adaptor.getSource());
Location loc = extractStridedMetadataOp.getLoc();
Value source = extractStridedMetadataOp.getSource();
@@ -2125,7 +2125,13 @@
results.reserve(2 + rank * 2);
// Base buffer.
- results.push_back(sourceMemRef.allocatedPtr(rewriter, loc));
+ Value baseBuffer = sourceMemRef.allocatedPtr(rewriter, loc);
+ Value alignedBuffer = sourceMemRef.alignedPtr(rewriter, loc);
+ MemRefDescriptor dstMemRef = MemRefDescriptor::fromStaticShape(
+ rewriter, loc, *getTypeConverter(),
+ extractStridedMetadataOp.getBaseBuffer().getType().cast<MemRefType>(),
+ baseBuffer, alignedBuffer);
+ results.push_back((Value)dstMemRef);
// Offset.
results.push_back(sourceMemRef.offset(rewriter, loc));
Index: mlir/lib/Conversion/LLVMCommon/MemRefBuilder.cpp
===================================================================
--- mlir/lib/Conversion/LLVMCommon/MemRefBuilder.cpp
+++ mlir/lib/Conversion/LLVMCommon/MemRefBuilder.cpp
@@ -43,6 +43,12 @@
MemRefDescriptor::fromStaticShape(OpBuilder &builder, Location loc,
LLVMTypeConverter &typeConverter,
MemRefType type, Value memory) {
+ return fromStaticShape(builder, loc, typeConverter, type, memory, memory);
+}
+
+MemRefDescriptor MemRefDescriptor::fromStaticShape(
+ OpBuilder &builder, Location loc, LLVMTypeConverter &typeConverter,
+ MemRefType type, Value memory, Value alignedMemory) {
assert(type.hasStaticShape() && "unexpected dynamic shape");
// Extract all strides and offsets and verify they are static.
@@ -61,7 +67,7 @@
auto descr = MemRefDescriptor::undef(builder, loc, convertedType);
descr.setAllocatedPtr(builder, loc, memory);
- descr.setAlignedPtr(builder, loc, memory);
+ descr.setAlignedPtr(builder, loc, alignedMemory);
descr.setConstantOffset(builder, loc, offset);
// Fill in sizes and strides
Index: mlir/include/mlir/Conversion/LLVMCommon/MemRefBuilder.h
===================================================================
--- mlir/include/mlir/Conversion/LLVMCommon/MemRefBuilder.h
+++ mlir/include/mlir/Conversion/LLVMCommon/MemRefBuilder.h
@@ -43,6 +43,10 @@
static MemRefDescriptor fromStaticShape(OpBuilder &builder, Location loc,
LLVMTypeConverter &typeConverter,
MemRefType type, Value memory);
+ static MemRefDescriptor fromStaticShape(OpBuilder &builder, Location loc,
+ LLVMTypeConverter &typeConverter,
+ MemRefType type, Value memory,
+ Value alignedMemory);
/// Builds IR extracting the allocated pointer from the descriptor.
Value allocatedPtr(OpBuilder &builder, Location loc);
-------------- next part --------------
A non-text attachment was scrubbed...
Name: D137364.473303.patch
Type: text/x-patch
Size: 4971 bytes
Desc: not available
URL: <http://lists.llvm.org/pipermail/llvm-commits/attachments/20221104/a94a0ef5/attachment.bin>
More information about the llvm-commits
mailing list