[Mlir-commits] [mlir] 5380e30 - [mlir] translate memref.reshape ops that have static shapes
Ashay Rane
llvmlistbot at llvm.org
Thu May 12 11:57:41 PDT 2022
Author: Ashay Rane
Date: 2022-05-12T11:57:20-07:00
New Revision: 5380e30e047bbac9b2cceb69162eb8db1e1a7abf
URL: https://github.com/llvm/llvm-project/commit/5380e30e047bbac9b2cceb69162eb8db1e1a7abf
DIFF: https://github.com/llvm/llvm-project/commit/5380e30e047bbac9b2cceb69162eb8db1e1a7abf.diff
LOG: [mlir] translate memref.reshape ops that have static shapes
This patch references code for translating memref.reinterpret_cast ops
to add translation rules for memref.reshape ops that have a static shape
argument. Since reshape ops don't have offsets, sizes, or strides, this
patch simply sets the allocated and aligned pointers of the MemRef
descriptor.
Reviewed By: ftynse, cathyzhyi
Differential Revision: https://reviews.llvm.org/D125039
Added:
Modified:
mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp
mlir/test/Conversion/MemRefToLLVM/convert-static-memref-ops.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp b/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp
index 110b40adf777f..579f2da92fa41 100644
--- a/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp
+++ b/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp
@@ -26,6 +26,10 @@ using namespace mlir;
namespace {
+bool isStaticStrideOrOffset(int64_t strideOrOffset) {
+ return !ShapedType::isDynamicStrideOrOffset(strideOrOffset);
+}
+
struct AllocOpLowering : public AllocLikeOpLLVMLowering {
AllocOpLowering(LLVMTypeConverter &converter)
: AllocLikeOpLLVMLowering(memref::AllocOp::getOperationName(),
@@ -1091,11 +1095,52 @@ struct MemRefReshapeOpLowering
Type srcType, memref::ReshapeOp reshapeOp,
memref::ReshapeOp::Adaptor adaptor,
Value *descriptor) const {
- // Conversion for statically-known shape args is performed via
- // `memref_reinterpret_cast`.
auto shapeMemRefType = reshapeOp.shape().getType().cast<MemRefType>();
- if (shapeMemRefType.hasStaticShape())
- return failure();
+ if (shapeMemRefType.hasStaticShape()) {
+ MemRefType targetMemRefType =
+ reshapeOp.getResult().getType().cast<MemRefType>();
+ auto llvmTargetDescriptorTy =
+ typeConverter->convertType(targetMemRefType)
+ .dyn_cast_or_null<LLVM::LLVMStructType>();
+ if (!llvmTargetDescriptorTy)
+ return failure();
+
+ // Create descriptor.
+ Location loc = reshapeOp.getLoc();
+ auto desc =
+ MemRefDescriptor::undef(rewriter, loc, llvmTargetDescriptorTy);
+
+ // Set allocated and aligned pointers.
+ Value allocatedPtr, alignedPtr;
+ extractPointersAndOffset(loc, rewriter, *getTypeConverter(),
+ reshapeOp.source(), adaptor.source(),
+ &allocatedPtr, &alignedPtr);
+ desc.setAllocatedPtr(rewriter, loc, allocatedPtr);
+ desc.setAlignedPtr(rewriter, loc, alignedPtr);
+
+ // Extract the offset and strides from the type.
+ int64_t offset;
+ SmallVector<int64_t> strides;
+ if (failed(getStridesAndOffset(targetMemRefType, strides, offset)))
+ return rewriter.notifyMatchFailure(
+ reshapeOp, "failed to get stride and offset exprs");
+
+ if (!isStaticStrideOrOffset(offset))
+ return rewriter.notifyMatchFailure(reshapeOp,
+ "dynamic offset is unsupported");
+ if (!llvm::all_of(strides, isStaticStrideOrOffset))
+ return rewriter.notifyMatchFailure(reshapeOp,
+ "dynamic strides are unsupported");
+
+ desc.setConstantOffset(rewriter, loc, offset);
+ for (unsigned i = 0, e = targetMemRefType.getRank(); i < e; ++i) {
+ desc.setConstantSize(rewriter, loc, i, targetMemRefType.getDimSize(i));
+ desc.setConstantStride(rewriter, loc, i, strides[i]);
+ }
+
+ *descriptor = desc;
+ return success();
+ }
// The shape is a rank-1 tensor with unknown length.
Location loc = reshapeOp.getLoc();
@@ -1499,10 +1544,7 @@ class ReassociatingReshapeOpConversion
for (auto &en : llvm::enumerate(dstShape))
dstDesc.setSize(rewriter, loc, en.index(), en.value());
- auto isStaticStride = [](int64_t stride) {
- return !ShapedType::isDynamicStrideOrOffset(stride);
- };
- if (llvm::all_of(strides, isStaticStride)) {
+ if (llvm::all_of(strides, isStaticStrideOrOffset)) {
for (auto &en : llvm::enumerate(strides))
dstDesc.setConstantStride(rewriter, loc, en.index(), en.value());
} else if (srcType.getLayout().isIdentity() &&
diff --git a/mlir/test/Conversion/MemRefToLLVM/convert-static-memref-ops.mlir b/mlir/test/Conversion/MemRefToLLVM/convert-static-memref-ops.mlir
index 93f1449002902..b8f3717682a05 100644
--- a/mlir/test/Conversion/MemRefToLLVM/convert-static-memref-ops.mlir
+++ b/mlir/test/Conversion/MemRefToLLVM/convert-static-memref-ops.mlir
@@ -217,3 +217,38 @@ module attributes { dlti.dl_spec = #dlti.dl_spec<#dlti.dl_entry<index, 32>> } {
}
}
+// -----
+
+memref.global "private" constant @__constant_3xi64 : memref<3xi64> = dense<[2, 6, 20]>
+
+// CHECK-LABEL: func @memref.reshape
+// CHECK-SAME: %[[arg0:.*]]: memref<4x5x6xf32>) -> memref<2x6x20xf32>
+func.func @memref.reshape(%arg0: memref<4x5x6xf32>) -> memref<2x6x20xf32> {
+ // CHECK: %[[cast0:.*]] = builtin.unrealized_conversion_cast %arg0 : memref<4x5x6xf32> to !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<3 x i64>, array<3 x i64>)>
+ %0 = memref.get_global @__constant_3xi64 : memref<3xi64>
+
+ // CHECK: %[[undef:.*]] = llvm.mlir.undef : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<3 x i64>, array<3 x i64>)>
+ // CHECK: %[[elem0:.*]] = llvm.extractvalue %[[cast0]][0] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<3 x i64>, array<3 x i64>)>
+ // CHECK: %[[elem1:.*]] = llvm.extractvalue %[[cast0]][1] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<3 x i64>, array<3 x i64>)>
+ // CHECK: %[[insert0:.*]] = llvm.insertvalue %[[elem0]], %[[undef]][0] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<3 x i64>, array<3 x i64>)>
+ // CHECK: %[[insert1:.*]] = llvm.insertvalue %[[elem1]], %[[insert0:.*]][1] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<3 x i64>, array<3 x i64>)>
+ // CHECK: %[[zero:.*]] = llvm.mlir.constant(0 : index) : i64
+ // CHECK: %[[insert2:.*]] = llvm.insertvalue %[[zero]], %[[insert1]][2] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<3 x i64>, array<3 x i64>)>
+ // CHECK: %[[two:.*]] = llvm.mlir.constant(2 : index) : i64
+ // CHECK: %[[insert3:.*]] = llvm.insertvalue %[[two]], %[[insert2]][3, 0] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<3 x i64>, array<3 x i64>)>
+ // CHECK: %[[hundred_and_twenty:.*]] = llvm.mlir.constant(120 : index) : i64
+ // CHECK: %[[insert4:.*]] = llvm.insertvalue %[[hundred_and_twenty]], %[[insert3]][4, 0] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<3 x i64>, array<3 x i64>)>
+ // CHECK: %[[six:.*]] = llvm.mlir.constant(6 : index) : i64
+ // CHECK: %[[insert5:.*]] = llvm.insertvalue %[[six]], %[[insert4]][3, 1] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<3 x i64>, array<3 x i64>)>
+ // CHECK: %[[twenty0:.*]] = llvm.mlir.constant(20 : index) : i64
+ // CHECK: %[[insert6:.*]] = llvm.insertvalue %[[twenty0]], %[[insert5]][4, 1] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<3 x i64>, array<3 x i64>)>
+ // CHECK: %[[twenty1:.*]] = llvm.mlir.constant(20 : index) : i64
+ // CHECK: %[[insert7:.*]] = llvm.insertvalue %[[twenty1]], %[[insert6]][3, 2] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<3 x i64>, array<3 x i64>)>
+ // CHECK: %[[one:.*]] = llvm.mlir.constant(1 : index) : i64
+ // CHECK: %[[insert8:.*]] = llvm.insertvalue %[[one]], %[[insert7]][4, 2] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<3 x i64>, array<3 x i64>)>
+ // CHECK: %[[cast1:.*]] = builtin.unrealized_conversion_cast %[[insert8]] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<3 x i64>, array<3 x i64>)> to memref<2x6x20xf32>
+ %1 = memref.reshape %arg0(%0) : (memref<4x5x6xf32>, memref<3xi64>) -> memref<2x6x20xf32>
+
+ // CHECK: return %[[cast1]] : memref<2x6x20xf32>
+ return %1 : memref<2x6x20xf32>
+}
More information about the Mlir-commits
mailing list