[Mlir-commits] [mlir] 787bf5e - [mlir] Add canonicalization for the `subtensor` op
Nicolas Vasilache
llvmlistbot at llvm.org
Fri Oct 2 03:06:46 PDT 2020
Author: Nicolas Vasilache
Date: 2020-10-02T06:05:52-04:00
New Revision: 787bf5e383a32b3ebc87332ff9e868db8f937056
URL: https://github.com/llvm/llvm-project/commit/787bf5e383a32b3ebc87332ff9e868db8f937056
DIFF: https://github.com/llvm/llvm-project/commit/787bf5e383a32b3ebc87332ff9e868db8f937056.diff
LOG: [mlir] Add canonicalization for the `subtensor` op
Differential revision: https://reviews.llvm.org/D88656
Added:
Modified:
mlir/include/mlir/Dialect/StandardOps/IR/Ops.td
mlir/lib/Dialect/StandardOps/IR/Ops.cpp
mlir/test/Transforms/canonicalize.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td
index dbc3e4ca521b..3d9daee964b6 100644
--- a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td
+++ b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td
@@ -3164,7 +3164,7 @@ def SubTensorOp : BaseOpWithOffsetSizesAndStrides<"subtensor"> {
ArrayRef<int64_t> staticStrides);
}];
- // let hasCanonicalizer = 1;
+ let hasCanonicalizer = 1;
}
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp
index d684a4b98e55..5548274eee18 100644
--- a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp
+++ b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp
@@ -2989,50 +2989,59 @@ void canonicalizeSubViewPart(SmallVectorImpl<Value> &values,
}
}
+static void replaceWithNewOp(PatternRewriter &rewriter, SubViewOp op,
+ SubViewOp newOp) {
+ rewriter.replaceOpWithNewOp<MemRefCastOp>(op, newOp, op.getType());
+}
+
+static void replaceWithNewOp(PatternRewriter &rewriter, SubTensorOp op,
+ SubTensorOp newOp) {
+ rewriter.replaceOpWithNewOp<TensorCastOp>(op, newOp, op.getType());
+}
+
/// Pattern to rewrite a subview op with constant arguments.
-class SubViewOpConstantArgumentFolder final
- : public OpRewritePattern<SubViewOp> {
+template <typename OpType>
+class OpWithOffsetSizesAndStridesConstantArgumentFolder final
+ : public OpRewritePattern<OpType> {
public:
- using OpRewritePattern<SubViewOp>::OpRewritePattern;
+ using OpRewritePattern<OpType>::OpRewritePattern;
- LogicalResult matchAndRewrite(SubViewOp subViewOp,
+ LogicalResult matchAndRewrite(OpType op,
PatternRewriter &rewriter) const override {
// No constant operand, just return;
- if (llvm::none_of(subViewOp.getOperands(), [](Value operand) {
+ if (llvm::none_of(op.getOperands(), [](Value operand) {
return matchPattern(operand, m_ConstantIndex());
}))
return failure();
// At least one of offsets/sizes/strides is a new constant.
// Form the new list of operands and constant attributes from the existing.
- SmallVector<Value, 8> newOffsets(subViewOp.offsets());
+ SmallVector<Value, 8> newOffsets(op.offsets());
SmallVector<int64_t, 8> newStaticOffsets =
- extractFromI64ArrayAttr(subViewOp.static_offsets());
- assert(newStaticOffsets.size() == subViewOp.getSourceRank());
+ extractFromI64ArrayAttr(op.static_offsets());
+ assert(newStaticOffsets.size() == op.getSourceRank());
canonicalizeSubViewPart(newOffsets, newStaticOffsets,
ShapedType::isDynamicStrideOrOffset);
- SmallVector<Value, 8> newSizes(subViewOp.sizes());
+ SmallVector<Value, 8> newSizes(op.sizes());
SmallVector<int64_t, 8> newStaticSizes =
- extractFromI64ArrayAttr(subViewOp.static_sizes());
- assert(newStaticOffsets.size() == subViewOp.getSourceRank());
+ extractFromI64ArrayAttr(op.static_sizes());
+ assert(newStaticOffsets.size() == op.getSourceRank());
canonicalizeSubViewPart(newSizes, newStaticSizes, ShapedType::isDynamic);
- SmallVector<Value, 8> newStrides(subViewOp.strides());
+ SmallVector<Value, 8> newStrides(op.strides());
SmallVector<int64_t, 8> newStaticStrides =
- extractFromI64ArrayAttr(subViewOp.static_strides());
- assert(newStaticOffsets.size() == subViewOp.getSourceRank());
+ extractFromI64ArrayAttr(op.static_strides());
+ assert(newStaticOffsets.size() == op.getSourceRank());
canonicalizeSubViewPart(newStrides, newStaticStrides,
ShapedType::isDynamicStrideOrOffset);
// Create the new op in canonical form.
- auto newSubViewOp = rewriter.create<SubViewOp>(
- subViewOp.getLoc(), subViewOp.source(), newStaticOffsets,
- newStaticSizes, newStaticStrides, newOffsets, newSizes, newStrides);
+ auto newOp = rewriter.create<OpType>(
+ op.getLoc(), op.source(), newStaticOffsets, newStaticSizes,
+ newStaticStrides, newOffsets, newSizes, newStrides);
- // Insert a memref_cast for compatibility of the uses of the op.
- rewriter.replaceOpWithNewOp<MemRefCastOp>(subViewOp, newSubViewOp,
- subViewOp.getType());
+ replaceWithNewOp(rewriter, op, newOp);
return success();
}
@@ -3183,8 +3192,8 @@ class SubViewOpMemRefCastFolder final : public OpRewritePattern<SubViewOp> {
void SubViewOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
MLIRContext *context) {
- results.insert<SubViewOpConstantArgumentFolder, SubViewOpMemRefCastFolder>(
- context);
+ results.insert<OpWithOffsetSizesAndStridesConstantArgumentFolder<SubViewOp>,
+ SubViewOpMemRefCastFolder>(context);
}
//===----------------------------------------------------------------------===//
@@ -3275,6 +3284,13 @@ static LogicalResult verify(SubTensorOp op) {
return success();
}
+void SubTensorOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
+ MLIRContext *context) {
+ results
+ .insert<OpWithOffsetSizesAndStridesConstantArgumentFolder<SubTensorOp>>(
+ context);
+}
+
//===----------------------------------------------------------------------===//
// TensorCastOp
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/Transforms/canonicalize.mlir b/mlir/test/Transforms/canonicalize.mlir
index 3603c473a1fd..dc7be097b0c0 100644
--- a/mlir/test/Transforms/canonicalize.mlir
+++ b/mlir/test/Transforms/canonicalize.mlir
@@ -1110,3 +1110,32 @@ func @tensor_cast_chain_invalid(%input: tensor<4x8xi32>) -> tensor<8x4xi32> {
// CHECK-NEXT: return %[[C2]]
return %1 : tensor<8x4xi32>
}
+
+// -----
+
+// CHECK-LABEL: func @subtensor
+// CHECK-SAME: %[[ARG0:[0-9a-z]*]]: index, %[[ARG1:[0-9a-z]*]]: index
+func @subtensor(%t: tensor<8x16x4xf32>, %arg0 : index, %arg1 : index)
+ -> tensor<?x?x?xf32>
+{
+ %c0 = constant 0 : index
+ %c1 = constant 1 : index
+ %c2 = constant 2 : index
+ %c7 = constant 7 : index
+ %c11 = constant 11 : index
+
+ // CHECK: subtensor %{{.*}}[0, 0, 0] [7, 11, 2] [1, 1, 1] :
+ // CHECK-SAME: tensor<8x16x4xf32> to tensor<7x11x2xf32>
+ // CHECK: tensor_cast %{{.*}} : tensor<7x11x2xf32> to tensor<?x?x?xf32>
+ %1 = subtensor %t[%c0, %c0, %c0] [%c7, %c11, %c2] [%c1, %c1, %c1]
+ : tensor<8x16x4xf32> to tensor<?x?x?xf32>
+
+ // Test: subtensor with one dynamic operand can also be folded.
+ // CHECK: subtensor %{{.*}}[0, 0, 0] [2, %[[ARG0]], 2] [1, 1, 1] :
+ // CHECK-SAME: tensor<?x?x?xf32> to tensor<2x?x2xf32>
+ // CHECK: tensor_cast %{{.*}} : tensor<2x?x2xf32> to tensor<?x?x?xf32>
+ %2 = subtensor %1[%c0, %c0, %c0] [%c2, %arg0, %c2] [%c1, %c1, %c1]
+ : tensor<?x?x?xf32> to tensor<?x?x?xf32>
+
+ return %2 : tensor<?x?x?xf32>
+}
More information about the Mlir-commits
mailing list