[Mlir-commits] [mlir] e131807 - Support non identity layout map for reshape ops in MemRefToLLVM lowering
Yi Zhang
llvmlistbot at llvm.org
Tue Apr 26 10:06:22 PDT 2022
Author: Yi Zhang
Date: 2022-04-26T13:03:53-04:00
New Revision: e1318078a4e160eb723bcbcfcdcc9a1b618f7067
URL: https://github.com/llvm/llvm-project/commit/e1318078a4e160eb723bcbcfcdcc9a1b618f7067
DIFF: https://github.com/llvm/llvm-project/commit/e1318078a4e160eb723bcbcfcdcc9a1b618f7067.diff
LOG: Support non identity layout map for reshape ops in MemRefToLLVM lowering
This change borrows the ideas from `computeExpanded/CollapsedLayoutMap`
and computes the dynamic strides at runtime for the memref descriptors.
Differential Revision: https://reviews.llvm.org/D124001
Added:
Modified:
mlir/include/mlir/Conversion/LLVMCommon/MemRefBuilder.h
mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp
mlir/test/Conversion/MemRefToLLVM/memref-to-llvm.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Conversion/LLVMCommon/MemRefBuilder.h b/mlir/include/mlir/Conversion/LLVMCommon/MemRefBuilder.h
index 28d4a0926007e..ac54ee6888136 100644
--- a/mlir/include/mlir/Conversion/LLVMCommon/MemRefBuilder.h
+++ b/mlir/include/mlir/Conversion/LLVMCommon/MemRefBuilder.h
@@ -79,6 +79,9 @@ class MemRefDescriptor : public StructBuilder {
void setConstantStride(OpBuilder &builder, Location loc, unsigned pos,
uint64_t stride);
+ /// Returns the type of array element in this descriptor.
+ Type getIndexType() { return indexType; };
+
/// Returns the (LLVM) pointer type this descriptor contains.
LLVM::LLVMPointerType getElementPtrType();
diff --git a/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp b/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp
index b0c881077efb9..110b40adf777f 100644
--- a/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp
+++ b/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp
@@ -1301,7 +1301,7 @@ static OpFoldResult getCollapsedOutputDimSize(
static SmallVector<OpFoldResult, 4>
getCollapsedOutputShape(OpBuilder &b, Location loc, Type &llvmIndexType,
- ArrayRef<ReassociationIndices> reassocation,
+ ArrayRef<ReassociationIndices> reassociation,
ArrayRef<int64_t> inStaticShape,
MemRefDescriptor &inDesc,
ArrayRef<int64_t> outStaticShape) {
@@ -1309,42 +1309,155 @@ getCollapsedOutputShape(OpBuilder &b, Location loc, Type &llvmIndexType,
llvm::seq<int64_t>(0, outStaticShape.size()), [&](int64_t outDimIndex) {
return getCollapsedOutputDimSize(b, loc, llvmIndexType, outDimIndex,
outStaticShape[outDimIndex],
- inStaticShape, inDesc, reassocation);
+ inStaticShape, inDesc, reassociation);
}));
}
static SmallVector<OpFoldResult, 4>
getExpandedOutputShape(OpBuilder &b, Location loc, Type &llvmIndexType,
- ArrayRef<ReassociationIndices> reassocation,
+ ArrayRef<ReassociationIndices> reassociation,
ArrayRef<int64_t> inStaticShape,
MemRefDescriptor &inDesc,
ArrayRef<int64_t> outStaticShape) {
DenseMap<int64_t, int64_t> outDimToInDimMap =
- getExpandedDimToCollapsedDimMap(reassocation);
+ getExpandedDimToCollapsedDimMap(reassociation);
return llvm::to_vector<4>(llvm::map_range(
llvm::seq<int64_t>(0, outStaticShape.size()), [&](int64_t outDimIndex) {
return getExpandedOutputDimSize(b, loc, llvmIndexType, outDimIndex,
outStaticShape, inDesc, inStaticShape,
- reassocation, outDimToInDimMap);
+ reassociation, outDimToInDimMap);
}));
}
static SmallVector<Value>
getDynamicOutputShape(OpBuilder &b, Location loc, Type &llvmIndexType,
- ArrayRef<ReassociationIndices> reassocation,
+ ArrayRef<ReassociationIndices> reassociation,
ArrayRef<int64_t> inStaticShape, MemRefDescriptor &inDesc,
ArrayRef<int64_t> outStaticShape) {
return outStaticShape.size() < inStaticShape.size()
? getAsValues(b, loc, llvmIndexType,
getCollapsedOutputShape(b, loc, llvmIndexType,
- reassocation, inStaticShape,
+ reassociation, inStaticShape,
inDesc, outStaticShape))
: getAsValues(b, loc, llvmIndexType,
getExpandedOutputShape(b, loc, llvmIndexType,
- reassocation, inStaticShape,
+ reassociation, inStaticShape,
inDesc, outStaticShape));
}
+static void fillInStridesForExpandedMemDescriptor(
+ OpBuilder &b, Location loc, MemRefType srcType, MemRefDescriptor &srcDesc,
+ MemRefDescriptor &dstDesc, ArrayRef<ReassociationIndices> reassociation) {
+ // See comments for computeExpandedLayoutMap for details on how the strides
+ // are calculated.
+ for (auto &en : llvm::enumerate(reassociation)) {
+ auto currentStrideToExpand = srcDesc.stride(b, loc, en.index());
+ for (auto dstIndex : llvm::reverse(en.value())) {
+ dstDesc.setStride(b, loc, dstIndex, currentStrideToExpand);
+ Value size = dstDesc.size(b, loc, dstIndex);
+ currentStrideToExpand =
+ b.create<LLVM::MulOp>(loc, size, currentStrideToExpand);
+ }
+ }
+}
+
+static void fillInStridesForCollapsedMemDescriptor(
+ ConversionPatternRewriter &rewriter, Location loc, Operation *op,
+ TypeConverter *typeConverter, MemRefType srcType, MemRefDescriptor &srcDesc,
+ MemRefDescriptor &dstDesc, ArrayRef<ReassociationIndices> reassociation) {
+ // See comments for computeCollapsedLayoutMap for details on how the strides
+ // are calculated.
+ auto srcShape = srcType.getShape();
+ for (auto &en : llvm::enumerate(reassociation)) {
+ rewriter.setInsertionPoint(op);
+ auto dstIndex = en.index();
+ ArrayRef<int64_t> ref = llvm::makeArrayRef(en.value());
+ while (srcShape[ref.back()] == 1 && ref.size() > 1)
+ ref = ref.drop_back();
+ if (!ShapedType::isDynamic(srcShape[ref.back()]) || ref.size() == 1) {
+ dstDesc.setStride(rewriter, loc, dstIndex,
+ srcDesc.stride(rewriter, loc, ref.back()));
+ } else {
+ // Iterate over the source strides in reverse order. Skip over the
+ // dimensions whose size is 1.
+ // TODO: we should take the minimum stride in the reassociation group
+ // instead of just the first where the dimension is not 1.
+ //
+ // +------------------------------------------------------+
+ // | curEntry: |
+ // | %srcStride = strides[srcIndex] |
+ // | %neOne = cmp sizes[srcIndex],1 +--+
+ // | cf.cond_br %neOne, continue(%srcStride), nextEntry | |
+ // +-------------------------+----------------------------+ |
+ // | |
+ // v |
+ // +-----------------------------+ |
+ // | nextEntry: | |
+ // | ... +---+ |
+ // +--------------+--------------+ | |
+ // | | |
+ // v | |
+ // +-----------------------------+ | |
+ // | nextEntry: | | |
+ // | ... | | |
+ // +--------------+--------------+ | +--------+
+ // | | |
+ // v v v
+ // +--------------------------------------------------+
+ // | continue(%newStride): |
+ // | %newMemRefDes = setStride(%newStride,dstIndex) |
+ // +--------------------------------------------------+
+ OpBuilder::InsertionGuard guard(rewriter);
+ Block *initBlock = rewriter.getInsertionBlock();
+ Block *continueBlock =
+ rewriter.splitBlock(initBlock, rewriter.getInsertionPoint());
+ continueBlock->insertArgument(unsigned(0), srcDesc.getIndexType(), loc);
+ rewriter.setInsertionPointToStart(continueBlock);
+ dstDesc.setStride(rewriter, loc, dstIndex, continueBlock->getArgument(0));
+
+ Block *curEntryBlock = initBlock;
+ Block *nextEntryBlock;
+ for (auto srcIndex : llvm::reverse(ref)) {
+ if (srcShape[srcIndex] == 1 && srcIndex != ref.front())
+ continue;
+ rewriter.setInsertionPointToEnd(curEntryBlock);
+ Value srcStride = srcDesc.stride(rewriter, loc, srcIndex);
+ if (srcIndex == ref.front()) {
+ rewriter.create<LLVM::BrOp>(loc, srcStride, continueBlock);
+ break;
+ }
+ Value one = rewriter.create<LLVM::ConstantOp>(
+ loc, typeConverter->convertType(rewriter.getI64Type()),
+ rewriter.getI32IntegerAttr(1));
+ Value predNeOne = rewriter.create<LLVM::ICmpOp>(
+ loc, LLVM::ICmpPredicate::ne, srcDesc.size(rewriter, loc, srcIndex),
+ one);
+ {
+ OpBuilder::InsertionGuard guard(rewriter);
+ nextEntryBlock = rewriter.createBlock(
+ initBlock->getParent(), Region::iterator(continueBlock), {});
+ }
+ rewriter.create<LLVM::CondBrOp>(loc, predNeOne, continueBlock,
+ srcStride, nextEntryBlock, llvm::None);
+ curEntryBlock = nextEntryBlock;
+ }
+ }
+ }
+}
+
+static void fillInDynamicStridesForMemDescriptor(
+ ConversionPatternRewriter &b, Location loc, Operation *op,
+ TypeConverter *typeConverter, MemRefType srcType, MemRefType dstType,
+ MemRefDescriptor &srcDesc, MemRefDescriptor &dstDesc,
+ ArrayRef<ReassociationIndices> reassociation) {
+ if (srcType.getRank() > dstType.getRank())
+ fillInStridesForCollapsedMemDescriptor(b, loc, op, typeConverter, srcType,
+ srcDesc, dstDesc, reassociation);
+ else
+ fillInStridesForExpandedMemDescriptor(b, loc, srcType, srcDesc, dstDesc,
+ reassociation);
+}
+
// ReshapeOp creates a new view descriptor of the proper rank.
// For now, the only conversion supported is for target MemRef with static sizes
// and strides.
@@ -1361,15 +1474,6 @@ class ReassociatingReshapeOpConversion
MemRefType dstType = reshapeOp.getResultType();
MemRefType srcType = reshapeOp.getSrcType();
- // The condition on the layouts can be ignored when all shapes are static.
- if (!srcType.hasStaticShape() || !dstType.hasStaticShape()) {
- if (!srcType.getLayout().isIdentity() ||
- !dstType.getLayout().isIdentity()) {
- return rewriter.notifyMatchFailure(
- reshapeOp, "only empty layout map is supported");
- }
- }
-
int64_t offset;
SmallVector<int64_t, 4> strides;
if (failed(getStridesAndOffset(dstType, strides, offset))) {
@@ -1401,7 +1505,8 @@ class ReassociatingReshapeOpConversion
if (llvm::all_of(strides, isStaticStride)) {
for (auto &en : llvm::enumerate(strides))
dstDesc.setConstantStride(rewriter, loc, en.index(), en.value());
- } else {
+ } else if (srcType.getLayout().isIdentity() &&
+ dstType.getLayout().isIdentity()) {
Value c1 = rewriter.create<LLVM::ConstantOp>(loc, llvmIndexType,
rewriter.getIndexAttr(1));
Value stride = c1;
@@ -1410,6 +1515,12 @@ class ReassociatingReshapeOpConversion
dstDesc.setStride(rewriter, loc, dimIndex, stride);
stride = rewriter.create<LLVM::MulOp>(loc, dstShape[dimIndex], stride);
}
+ } else {
+ // There could be mixed static/dynamic strides. For simplicity, we
+ // recompute all strides if there is at least one dynamic stride.
+ fillInDynamicStridesForMemDescriptor(
+ rewriter, loc, reshapeOp, this->typeConverter, srcType, dstType,
+ srcDesc, dstDesc, reshapeOp.getReassociationIndices());
}
rewriter.replaceOp(reshapeOp, {dstDesc});
return success();
diff --git a/mlir/test/Conversion/MemRefToLLVM/memref-to-llvm.mlir b/mlir/test/Conversion/MemRefToLLVM/memref-to-llvm.mlir
index 9b03c18a58802..8b215769e0751 100644
--- a/mlir/test/Conversion/MemRefToLLVM/memref-to-llvm.mlir
+++ b/mlir/test/Conversion/MemRefToLLVM/memref-to-llvm.mlir
@@ -706,6 +706,45 @@ func.func @collapse_shape_static(%arg0: memref<1x3x4x1x5xf32>) -> memref<3x4x5xf
// -----
+func.func @collapse_shape_dynamic_with_non_identity_layout(
+ %arg0 : memref<4x?x?xf32, affine_map<(d0, d1, d2)[s0, s1] -> (d0 * s1 + s0 + d1 * 4 + d2)>>) ->
+ memref<4x?xf32, affine_map<(d0, d1)[s0, s1, s2] -> (d0 * s1 + s0 + d1 * s2)>> {
+ %0 = memref.collapse_shape %arg0 [[0], [1, 2]]:
+ memref<4x?x?xf32, affine_map<(d0, d1, d2)[s0, s1] -> (d0 * s1 + s0 + d1 * 4 + d2)>> into
+ memref<4x?xf32, affine_map<(d0, d1)[s0, s1, s2] -> (d0 * s1 + s0 + d1 * s2)>>
+ return %0 : memref<4x?xf32, affine_map<(d0, d1)[s0, s1, s2] -> (d0 * s1 + s0 + d1 * s2)>>
+}
+// CHECK-LABEL: func @collapse_shape_dynamic_with_non_identity_layout(
+// CHECK: llvm.mlir.undef : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<2 x i64>, array<2 x i64>)>
+// CHECK: llvm.extractvalue %{{.*}}[0] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<3 x i64>, array<3 x i64>)>
+// CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[0] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<2 x i64>, array<2 x i64>)>
+// CHECK: llvm.extractvalue %{{.*}}[1] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<3 x i64>, array<3 x i64>)>
+// CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[1] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<2 x i64>, array<2 x i64>)>
+// CHECK: llvm.extractvalue %{{.*}}[2] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<3 x i64>, array<3 x i64>)>
+// CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[2] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<2 x i64>, array<2 x i64>)>
+// CHECK: llvm.mlir.constant(1 : index) : i64
+// CHECK: llvm.extractvalue %{{.*}}[3, 1] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<3 x i64>, array<3 x i64>)>
+// CHECK: llvm.mul %{{.*}}, %{{.*}} : i64
+// CHECK: llvm.extractvalue %{{.*}}[3, 2] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<3 x i64>, array<3 x i64>)>
+// CHECK: llvm.mul %{{.*}}, %{{.*}} : i64
+// CHECK: llvm.mlir.constant(4 : index) : i64
+// CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[3, 0] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<2 x i64>, array<2 x i64>)>
+// CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[3, 1] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<2 x i64>, array<2 x i64>)>
+// CHECK: llvm.extractvalue %{{.*}}[4, 0] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<3 x i64>, array<3 x i64>)>
+// CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[4, 0] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<2 x i64>, array<2 x i64>)>
+// CHECK: llvm.extractvalue %{{.*}}[4, 2] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<3 x i64>, array<3 x i64>)>
+// CHECK: llvm.mlir.constant(1 : i32) : i64
+// CHECK: llvm.extractvalue %{{.*}}[3, 2] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<3 x i64>, array<3 x i64>)>
+// CHECK: llvm.icmp "ne" %{{.*}}, %{{.*}} : i64
+// CHECK: llvm.cond_br %{{.*}}, ^bb2(%{{.*}} : i64), ^bb1
+// CHECK: ^bb1:
+// CHECK: llvm.extractvalue %{{.*}}[4, 1] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<3 x i64>, array<3 x i64>)>
+// CHECK: llvm.br ^bb2(%{{.*}} : i64)
+// CHECK: ^bb2(%[[STRIDE:.*]]: i64):
+// CHECK: llvm.insertvalue %[[STRIDE]], %{{.*}}[4, 1] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<2 x i64>, array<2 x i64>)>
+
+// -----
+
func.func @expand_shape_static(%arg0: memref<3x4x5xf32>) -> memref<1x3x4x1x5xf32> {
// Reshapes that expand a contiguous tensor with some 1's.
%0 = memref.expand_shape %arg0 [[0, 1], [2], [3, 4]]
@@ -840,6 +879,44 @@ func.func @expand_shape_dynamic(%arg0 : memref<1x?xf32>) -> memref<1x2x?xf32> {
// -----
+func.func @expand_shape_dynamic_with_non_identity_layout(
+ %arg0 : memref<1x?xf32, affine_map<(d0, d1)[s0, s1, s2] -> (d0 * s1 + s0 + d1 * s2)>>) ->
+ memref<1x2x?xf32, affine_map<(d0, d1, d2)[s0, s1, s2, s3] -> (d0 * s1 + s0 + d1 * s2 + d2 * s3)>> {
+ %0 = memref.expand_shape %arg0 [[0], [1, 2]]:
+ memref<1x?xf32, affine_map<(d0, d1)[s0, s1, s2] -> (d0 * s1 + s0 + d1 * s2)>> into
+ memref<1x2x?xf32, affine_map<(d0, d1, d2)[s0, s1, s2, s3] -> (d0 * s1 + s0 + d1 * s2 + d2 * s3)>>
+ return %0 : memref<1x2x?xf32, affine_map<(d0, d1, d2)[s0, s1, s2, s3] -> (d0 * s1 + s0 + d1 * s2 + d2 * s3)>>
+}
+// CHECK-LABEL: func @expand_shape_dynamic_with_non_identity_layout(
+// CHECK: llvm.mlir.undef : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<3 x i64>, array<3 x i64>)>
+// CHECK: llvm.extractvalue %{{.*}}[0] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<2 x i64>, array<2 x i64>)>
+// CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[0] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<3 x i64>, array<3 x i64>)>
+// CHECK: llvm.extractvalue %{{.*}}[1] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<2 x i64>, array<2 x i64>)>
+// CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[1] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<3 x i64>, array<3 x i64>)>
+// CHECK: llvm.extractvalue %{{.*}}[2] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<2 x i64>, array<2 x i64>)>
+// CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[2] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<3 x i64>, array<3 x i64>)>
+// CHECK: llvm.extractvalue %{{.*}}[3, 1] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<2 x i64>, array<2 x i64>)>
+// CHECK: llvm.mlir.constant(2 : index) : i64
+// CHECK: llvm.sdiv %{{.*}}, %{{.*}} : i64
+// CHECK: llvm.mlir.constant(1 : index) : i64
+// CHECK: llvm.mlir.constant(2 : index) : i64
+// CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[3, 0] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<3 x i64>, array<3 x i64>)>
+// CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[3, 1] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<3 x i64>, array<3 x i64>)>
+// CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[3, 2] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<3 x i64>, array<3 x i64>)>
+// CHECK: llvm.extractvalue %{{.*}}[4, 0] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<2 x i64>, array<2 x i64>)>
+// CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[4, 0] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<3 x i64>, array<3 x i64>)>
+// CHECK: llvm.extractvalue %{{.*}}[3, 0] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<3 x i64>, array<3 x i64>)>
+// CHECK: llvm.mul %{{.*}}, %{{.*}} : i64
+// CHECK: llvm.extractvalue %{{.*}}[4, 1] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<2 x i64>, array<2 x i64>)>
+// CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[4, 2] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<3 x i64>, array<3 x i64>)>
+// CHECK: llvm.extractvalue %{{.*}}[3, 2] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<3 x i64>, array<3 x i64>)>
+// CHECK: llvm.mul %{{.*}}, %{{.*}} : i64
+// CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[4, 1] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<3 x i64>, array<3 x i64>)>
+// CHECK: llvm.extractvalue %{{.*}}[3, 1] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<3 x i64>, array<3 x i64>)>
+// CHECK: llvm.mul %{{.*}}, %{{.*}} : i64
+
+// -----
+
// CHECK-LABEL: func @rank_of_unranked
// CHECK32-LABEL: func @rank_of_unranked
func.func @rank_of_unranked(%unranked: memref<*xi32>) {
More information about the Mlir-commits
mailing list