[Mlir-commits] [mlir] b3c227a - [mlir] Better support for rank-reducing subview / subtensor type inference.
Nicolas Vasilache
llvmlistbot at llvm.org
Fri Feb 19 00:35:10 PST 2021
Author: Nicolas Vasilache
Date: 2021-02-19T08:30:50Z
New Revision: b3c227a25a231248e3752918c2cac1a7b9414ef1
URL: https://github.com/llvm/llvm-project/commit/b3c227a25a231248e3752918c2cac1a7b9414ef1
DIFF: https://github.com/llvm/llvm-project/commit/b3c227a25a231248e3752918c2cac1a7b9414ef1.diff
LOG: [mlir] Better support for rank-reducing subview / subtensor type inference.
Differential Revision: https://reviews.llvm.org/D96995
Added:
Modified:
mlir/include/mlir/Dialect/StandardOps/IR/Ops.td
mlir/lib/Dialect/StandardOps/IR/Ops.cpp
mlir/test/Dialect/Standard/canonicalize.mlir
mlir/test/Transforms/canonicalize.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td
index 64279c8fce3c..82b4717b6bc1 100644
--- a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td
+++ b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td
@@ -2788,6 +2788,16 @@ def SubViewOp : BaseOpWithOffsetSizesAndStrides<
ArrayRef<OpFoldResult> staticOffsets,
ArrayRef<OpFoldResult> staticSizes,
ArrayRef<OpFoldResult> staticStrides);
+ static Type inferRankReducedResultType(unsigned resultRank,
+ MemRefType sourceMemRefType,
+ ArrayRef<int64_t> staticOffsets,
+ ArrayRef<int64_t> staticSizes,
+ ArrayRef<int64_t> staticStrides);
+ static Type inferRankReducedResultType(unsigned resultRank,
+ MemRefType sourceMemRefType,
+ ArrayRef<OpFoldResult> staticOffsets,
+ ArrayRef<OpFoldResult> staticSizes,
+ ArrayRef<OpFoldResult> staticStrides);
/// Return the expected rank of each of the`static_offsets`, `static_sizes`
/// and `static_strides` attributes.
@@ -2914,6 +2924,16 @@ def SubTensorOp : BaseOpWithOffsetSizesAndStrides<
ArrayRef<OpFoldResult> staticOffsets,
ArrayRef<OpFoldResult> staticSizes,
ArrayRef<OpFoldResult> staticStrides);
+ static Type inferRankReducedResultType(unsigned resultRank,
+ RankedTensorType sourceRankedTensorType,
+ ArrayRef<int64_t> staticOffsets,
+ ArrayRef<int64_t> staticSizes,
+ ArrayRef<int64_t> staticStrides);
+ static Type inferRankReducedResultType(unsigned resultRank,
+ RankedTensorType sourceRankedTensorType,
+ ArrayRef<OpFoldResult> staticOffsets,
+ ArrayRef<OpFoldResult> staticSizes,
+ ArrayRef<OpFoldResult> staticStrides);
/// Return the expected rank of each of the`static_offsets`, `static_sizes`
/// and `static_strides` attributes.
@@ -3027,6 +3047,7 @@ def SubTensorInsertOp : BaseOpWithOffsetSizesAndStrides<
/// and `strides` operands.
static unsigned getOffsetSizeAndStrideStartOperandIndex() { return 2; }
}];
+
let hasFolder = 1;
}
diff --git a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp
index e916de4c0658..084d3fdfb2bf 100644
--- a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp
+++ b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp
@@ -2891,8 +2891,68 @@ Type SubViewOp::inferResultType(MemRefType sourceMemRefType,
dispatchIndexOpFoldResults(leadingStaticStrides, dynamicStrides,
staticStrides, ShapedType::kDynamicStrideOrOffset);
return SubViewOp::inferResultType(sourceMemRefType, staticOffsets,
- staticSizes, staticStrides)
- .cast<MemRefType>();
+ staticSizes, staticStrides);
+}
+
+static void
+getPositionsOfShapeOne(unsigned rank, ArrayRef<int64_t> shape,
+ llvm::SmallDenseSet<unsigned> &dimsToProject) {
+ dimsToProject.reserve(rank);
+ for (unsigned pos = 0, e = shape.size(); pos < e && rank > 0; ++pos) {
+ if (shape[pos] == 1) {
+ dimsToProject.insert(pos);
+ --rank;
+ }
+ }
+}
+
+Type SubViewOp::inferRankReducedResultType(
+ unsigned resultRank, MemRefType sourceRankedTensorType,
+ ArrayRef<int64_t> leadingStaticOffsets,
+ ArrayRef<int64_t> leadingStaticSizes,
+ ArrayRef<int64_t> leadingStaticStrides) {
+ auto inferredType =
+ inferResultType(sourceRankedTensorType, leadingStaticOffsets,
+ leadingStaticSizes, leadingStaticStrides)
+ .cast<MemRefType>();
+ assert(inferredType.getRank() >= resultRank && "expected ");
+ int rankDiff = inferredType.getRank() - resultRank;
+ if (rankDiff > 0) {
+ auto shape = inferredType.getShape();
+ llvm::SmallDenseSet<unsigned> dimsToProject;
+ getPositionsOfShapeOne(rankDiff, shape, dimsToProject);
+ SmallVector<int64_t> projectedShape;
+ for (unsigned pos = 0, e = shape.size(); pos < e; ++pos)
+ if (!dimsToProject.contains(pos))
+ projectedShape.push_back(shape[pos]);
+
+ AffineMap map;
+ auto maps = inferredType.getAffineMaps();
+ if (!maps.empty() && maps.front())
+ map = getProjectedMap(maps.front(), dimsToProject);
+ inferredType =
+ MemRefType::get(projectedShape, inferredType.getElementType(), map,
+ inferredType.getMemorySpace());
+ }
+ return inferredType;
+}
+
+Type SubViewOp::inferRankReducedResultType(
+ unsigned resultRank, MemRefType sourceRankedTensorType,
+ ArrayRef<OpFoldResult> leadingStaticOffsets,
+ ArrayRef<OpFoldResult> leadingStaticSizes,
+ ArrayRef<OpFoldResult> leadingStaticStrides) {
+ SmallVector<int64_t> staticOffsets, staticSizes, staticStrides;
+ SmallVector<Value> dynamicOffsets, dynamicSizes, dynamicStrides;
+ dispatchIndexOpFoldResults(leadingStaticOffsets, dynamicOffsets,
+ staticOffsets, ShapedType::kDynamicStrideOrOffset);
+ dispatchIndexOpFoldResults(leadingStaticSizes, dynamicSizes, staticSizes,
+ ShapedType::kDynamicSize);
+ dispatchIndexOpFoldResults(leadingStaticStrides, dynamicStrides,
+ staticStrides, ShapedType::kDynamicStrideOrOffset);
+ return SubViewOp::inferRankReducedResultType(
+ resultRank, sourceRankedTensorType, staticOffsets, staticSizes,
+ staticStrides);
}
// Build a SubViewOp with mixed static and dynamic entries and custom result
@@ -3407,29 +3467,11 @@ class SubViewOpMemRefCastFolder final : public OpRewritePattern<SubViewOp> {
/// Deduce the resultType of the SubViewOp using `inferSubViewResultType` on
/// the cast source operand type and the SubViewOp static information. This
/// is the resulting type if the MemRefCastOp were folded.
- auto resultType = SubViewOp::inferResultType(
- castOp.source().getType().cast<MemRefType>(),
- extractFromI64ArrayAttr(subViewOp.static_offsets()),
- extractFromI64ArrayAttr(subViewOp.static_sizes()),
- extractFromI64ArrayAttr(subViewOp.static_strides()))
- .cast<MemRefType>();
- uint32_t rankDiff =
- subViewOp.getSourceType().getRank() - subViewOp.getType().getRank();
- if (rankDiff > 0) {
- auto shape = resultType.getShape();
- auto projectedShape = shape.drop_front(rankDiff);
- AffineMap map;
- auto maps = resultType.getAffineMaps();
- if (!maps.empty() && maps.front()) {
- auto optionalUnusedDimsMask =
- computeRankReductionMask(shape, projectedShape);
- llvm::SmallDenseSet<unsigned> dimsToProject =
- optionalUnusedDimsMask.getValue();
- map = getProjectedMap(maps.front(), dimsToProject);
- }
- resultType = MemRefType::get(projectedShape, resultType.getElementType(),
- map, resultType.getMemorySpace());
- }
+ auto resultType = SubViewOp::inferRankReducedResultType(
+ subViewOp.getType().getRank(),
+ castOp.source().getType().cast<MemRefType>(),
+ subViewOp.getMixedOffsets(), subViewOp.getMixedSizes(),
+ subViewOp.getMixedStrides());
Value newSubView = rewriter.create<SubViewOp>(
subViewOp.getLoc(), resultType, castOp.source(), subViewOp.offsets(),
subViewOp.sizes(), subViewOp.strides(), subViewOp.static_offsets(),
@@ -3492,8 +3534,52 @@ Type SubTensorOp::inferResultType(RankedTensorType sourceRankedTensorType,
dispatchIndexOpFoldResults(leadingStaticStrides, dynamicStrides,
staticStrides, ShapedType::kDynamicStrideOrOffset);
return SubTensorOp::inferResultType(sourceRankedTensorType, staticOffsets,
- staticSizes, staticStrides)
- .cast<RankedTensorType>();
+ staticSizes, staticStrides);
+}
+
+/// A subtensor result type can be fully inferred from the source type and the
+/// static representation of offsets, sizes and strides. Special sentinels
+/// encode the dynamic case.
+Type SubTensorOp::inferRankReducedResultType(
+ unsigned resultRank, RankedTensorType sourceRankedTensorType,
+ ArrayRef<int64_t> leadingStaticOffsets,
+ ArrayRef<int64_t> leadingStaticSizes,
+ ArrayRef<int64_t> leadingStaticStrides) {
+ auto inferredType =
+ inferResultType(sourceRankedTensorType, leadingStaticOffsets,
+ leadingStaticSizes, leadingStaticStrides)
+ .cast<RankedTensorType>();
+ int rankDiff = inferredType.getRank() - resultRank;
+ if (rankDiff > 0) {
+ auto shape = inferredType.getShape();
+ llvm::SmallDenseSet<unsigned> dimsToProject;
+ getPositionsOfShapeOne(rankDiff, shape, dimsToProject);
+ SmallVector<int64_t> projectedShape;
+ for (unsigned pos = 0, e = shape.size(); pos < e; ++pos)
+ if (!dimsToProject.contains(pos))
+ projectedShape.push_back(shape[pos]);
+ inferredType =
+ RankedTensorType::get(projectedShape, inferredType.getElementType());
+ }
+ return inferredType;
+}
+
+Type SubTensorOp::inferRankReducedResultType(
+ unsigned resultRank, RankedTensorType sourceRankedTensorType,
+ ArrayRef<OpFoldResult> leadingStaticOffsets,
+ ArrayRef<OpFoldResult> leadingStaticSizes,
+ ArrayRef<OpFoldResult> leadingStaticStrides) {
+ SmallVector<int64_t> staticOffsets, staticSizes, staticStrides;
+ SmallVector<Value> dynamicOffsets, dynamicSizes, dynamicStrides;
+ dispatchIndexOpFoldResults(leadingStaticOffsets, dynamicOffsets,
+ staticOffsets, ShapedType::kDynamicStrideOrOffset);
+ dispatchIndexOpFoldResults(leadingStaticSizes, dynamicSizes, staticSizes,
+ ShapedType::kDynamicSize);
+ dispatchIndexOpFoldResults(leadingStaticStrides, dynamicStrides,
+ staticStrides, ShapedType::kDynamicStrideOrOffset);
+ return SubTensorOp::inferRankReducedResultType(
+ resultRank, sourceRankedTensorType, staticOffsets, staticSizes,
+ staticStrides);
}
// Build a SubTensorOp with mixed static and dynamic entries and custom result
@@ -3571,11 +3657,65 @@ static LogicalResult verify(SubTensorOp op) {
return produceSubViewErrorMsg(result, op, expectedType);
}
+namespace {
+/// Pattern to rewrite a subtensor op with tensor::Cast arguments.
+/// This essentially pushes memref_cast past its consuming subtensor when
+/// `canFoldIntoConsumerOp` is true.
+///
+/// Example:
+/// ```
+/// %0 = tensorcast %V : tensor<16x16xf32> to tensor<?x?xf32>
+/// %1 = subtensor %0[0, 0][3, 4][1, 1] : tensor<?x?xf32> to tensor<3x4xf32>
+/// ```
+/// is rewritten into:
+/// ```
+/// %0 = subtensor %V[0, 0][3, 4][1, 1] : tensor<16x16xf32> to tensor<3x4xf32>
+/// %1 = tensor.cast %0: tensor<3x4xf32> to tensor<3x4xf32>
+/// ```
+class SubTensorOpCastFolder final : public OpRewritePattern<SubTensorOp> {
+public:
+ using OpRewritePattern<SubTensorOp>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(SubTensorOp subTensorOp,
+ PatternRewriter &rewriter) const override {
+ // Any constant operand, just return to let SubViewOpConstantFolder kick in.
+ if (llvm::any_of(subTensorOp.getOperands(), [](Value operand) {
+ return matchPattern(operand, m_ConstantIndex());
+ }))
+ return failure();
+
+ auto castOp = subTensorOp.source().getDefiningOp<tensor::CastOp>();
+ if (!castOp)
+ return failure();
+
+ if (!canFoldIntoConsumerOp(castOp))
+ return failure();
+
+ /// Deduce the resultType of SubTensorOp with `inferRankReducedResultType`
+ /// on the cast source operand type and the SubTensorOp static information.
+ /// This is the resulting type if the tensor::CastOp were folded and
+ /// rank-reduced to the desired result rank.
+ auto resultType = SubTensorOp::inferRankReducedResultType(
+ subTensorOp.getType().getRank(),
+ castOp.source().getType().cast<RankedTensorType>(),
+ subTensorOp.getMixedOffsets(), subTensorOp.getMixedSizes(),
+ subTensorOp.getMixedStrides());
+ Value newSubTensor = rewriter.create<SubTensorOp>(
+ subTensorOp.getLoc(), resultType, castOp.source(),
+ subTensorOp.offsets(), subTensorOp.sizes(), subTensorOp.strides(),
+ subTensorOp.static_offsets(), subTensorOp.static_sizes(),
+ subTensorOp.static_strides());
+ rewriter.replaceOpWithNewOp<tensor::CastOp>(
+ subTensorOp, subTensorOp.getType(), newSubTensor);
+ return success();
+ }
+};
+} // namespace
+
void SubTensorOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
MLIRContext *context) {
- results
- .insert<OpWithOffsetSizesAndStridesConstantArgumentFolder<SubTensorOp>>(
- context);
+ results.insert<OpWithOffsetSizesAndStridesConstantArgumentFolder<SubTensorOp>,
+ SubTensorOpCastFolder>(context);
}
//
diff --git a/mlir/test/Dialect/Standard/canonicalize.mlir b/mlir/test/Dialect/Standard/canonicalize.mlir
index 9247152e8677..5c437ae3dda4 100644
--- a/mlir/test/Dialect/Standard/canonicalize.mlir
+++ b/mlir/test/Dialect/Standard/canonicalize.mlir
@@ -146,13 +146,13 @@ func @tensor_cast_to_memref(%arg0 : tensor<4x6x16x32xi8>) ->
// CHECK-LABEL: func @subview_of_memcast
// CHECK-SAME: %[[ARG0:.[a-z0-9A-Z_]+]]: memref<4x6x16x32xi8>
-// CHECK: %[[S:.+]] = subview %arg0[0, 1, 0, 0] [1, 1, 16, 32] [1, 1, 1, 1] : memref<4x6x16x32xi8> to memref<16x32xi8, #{{.*}}>
+// CHECK: %[[S:.+]] = subview %arg0[0, 1, 0] [1, 1, 16] [1, 1, 1] : memref<4x6x16x32xi8> to memref<16x32xi8, #{{.*}}>
// CHECK: %[[M:.+]] = memref_cast %[[S]] : memref<16x32xi8, #{{.*}}> to memref<16x32xi8, #{{.*}}>
// CHECK: return %[[M]] : memref<16x32xi8, #{{.*}}>
func @subview_of_memcast(%arg : memref<4x6x16x32xi8>) ->
memref<16x32xi8, affine_map<(d0, d1)[s0] -> (d0 * 32 + d1 + s0)>>{
%0 = memref_cast %arg : memref<4x6x16x32xi8> to memref<?x?x16x32xi8>
- %1 = subview %0[0, 1, 0, 0] [1, 1, 16, 32] [1, 1, 1, 1] :
+ %1 = subview %0[0, 1, 0] [1, 1, 16] [1, 1, 1] :
memref<?x?x16x32xi8> to
memref<16x32xi8, affine_map<(d0, d1)[s0] -> (d0 * 32 + d1 + s0)>>
return %1 : memref<16x32xi8, affine_map<(d0, d1)[s0] -> (d0 * 32 + d1 + s0)>>
@@ -176,3 +176,14 @@ func @trivial_subtensor_insert(%arg0 : tensor<4x6x16x32xi8>, %arg1 : tensor<4x6x
return %0 : tensor<4x6x16x32xi8>
}
+// CHECK-LABEL: func @rank_reducing_tensor_of_cast
+// CHECK-SAME: %[[ARG0:.[a-z0-9A-Z_]+]]: tensor<4x6x16x32xi8>
+// CHECK: %[[S:.+]] = subtensor %arg0[0, 1, 0] [1, 1, 16] [1, 1, 1] : tensor<4x6x16x32xi8> to tensor<16x32xi8>
+// Tensor cast is moved after subtensor and then gets canonicalized away.
+// CHECK-NOT: tensor.cast
+// CHECK: return %[[S]] : tensor<16x32xi8>
+func @rank_reducing_tensor_of_cast(%arg : tensor<4x6x16x32xi8>) -> tensor<16x32xi8> {
+ %0 = tensor.cast %arg : tensor<4x6x16x32xi8> to tensor<?x?x16x32xi8>
+ %1 = subtensor %0[0, 1, 0] [1, 1, 16] [1, 1, 1] : tensor<?x?x16x32xi8> to tensor<16x32xi8>
+ return %1 : tensor<16x32xi8>
+}
diff --git a/mlir/test/Transforms/canonicalize.mlir b/mlir/test/Transforms/canonicalize.mlir
index 62c07dd8a063..3bc3eeee8354 100644
--- a/mlir/test/Transforms/canonicalize.mlir
+++ b/mlir/test/Transforms/canonicalize.mlir
@@ -1034,8 +1034,8 @@ func @memref_cast_folding_subview_static(%V: memref<16x16xf32>, %a: index, %b: i
// CHECK-LABEL: func @subtensor
// CHECK-SAME: %[[ARG0:[0-9a-z]*]]: index, %[[ARG1:[0-9a-z]*]]: index
-func @subtensor(%t: tensor<8x16x4xf32>, %arg0 : index, %arg1 : index)
- -> tensor<?x?x?xf32>
+func @subtensor(%t: tensor<8x16x4xf32>, %arg0 : index, %arg1 : index)
+ -> tensor<?x?x?xf32>
{
%c0 = constant 0 : index
%c1 = constant 1 : index
@@ -1045,16 +1045,18 @@ func @subtensor(%t: tensor<8x16x4xf32>, %arg0 : index, %arg1 : index)
// CHECK: subtensor %{{.*}}[0, 0, 0] [7, 11, 2] [1, 1, 1] :
// CHECK-SAME: tensor<8x16x4xf32> to tensor<7x11x2xf32>
- // CHECK: tensor.cast %{{.*}} : tensor<7x11x2xf32> to tensor<?x?x?xf32>
+ // tensor.cast gets folded away in consumer.
+ // CHECK-NOT: tensor.cast
%1 = subtensor %t[%c0, %c0, %c0] [%c7, %c11, %c2] [%c1, %c1, %c1]
: tensor<8x16x4xf32> to tensor<?x?x?xf32>
// Test: subtensor with one dynamic operand can also be folded.
// CHECK: subtensor %{{.*}}[0, 0, 0] [2, %[[ARG0]], 2] [1, 1, 1] :
- // CHECK-SAME: tensor<?x?x?xf32> to tensor<2x?x2xf32>
+ // CHECK-SAME: tensor<7x11x2xf32> to tensor<2x?x2xf32>
// CHECK: tensor.cast %{{.*}} : tensor<2x?x2xf32> to tensor<?x?x?xf32>
%2 = subtensor %1[%c0, %c0, %c0] [%c2, %arg0, %c2] [%c1, %c1, %c1]
: tensor<?x?x?xf32> to tensor<?x?x?xf32>
return %2 : tensor<?x?x?xf32>
}
+
More information about the Mlir-commits
mailing list