[Mlir-commits] [mlir] 787bf5e - [mlir] Add canonicalization for the `subtensor` op

Nicolas Vasilache llvmlistbot at llvm.org
Fri Oct 2 03:06:46 PDT 2020


Author: Nicolas Vasilache
Date: 2020-10-02T06:05:52-04:00
New Revision: 787bf5e383a32b3ebc87332ff9e868db8f937056

URL: https://github.com/llvm/llvm-project/commit/787bf5e383a32b3ebc87332ff9e868db8f937056
DIFF: https://github.com/llvm/llvm-project/commit/787bf5e383a32b3ebc87332ff9e868db8f937056.diff

LOG: [mlir] Add canonicalization for the `subtensor` op

Differential revision: https://reviews.llvm.org/D88656

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/StandardOps/IR/Ops.td
    mlir/lib/Dialect/StandardOps/IR/Ops.cpp
    mlir/test/Transforms/canonicalize.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td
index dbc3e4ca521b..3d9daee964b6 100644
--- a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td
+++ b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td
@@ -3164,7 +3164,7 @@ def SubTensorOp : BaseOpWithOffsetSizesAndStrides<"subtensor"> {
                                 ArrayRef<int64_t> staticStrides);
   }];
 
-  // let hasCanonicalizer = 1;
+  let hasCanonicalizer = 1;
 }
 
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp
index d684a4b98e55..5548274eee18 100644
--- a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp
+++ b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp
@@ -2989,50 +2989,59 @@ void canonicalizeSubViewPart(SmallVectorImpl<Value> &values,
   }
 }
 
+static void replaceWithNewOp(PatternRewriter &rewriter, SubViewOp op,
+                             SubViewOp newOp) {
+  rewriter.replaceOpWithNewOp<MemRefCastOp>(op, newOp, op.getType());
+}
+
+static void replaceWithNewOp(PatternRewriter &rewriter, SubTensorOp op,
+                             SubTensorOp newOp) {
+  rewriter.replaceOpWithNewOp<TensorCastOp>(op, newOp, op.getType());
+}
+
 /// Pattern to rewrite a subview op with constant arguments.
-class SubViewOpConstantArgumentFolder final
-    : public OpRewritePattern<SubViewOp> {
+template <typename OpType>
+class OpWithOffsetSizesAndStridesConstantArgumentFolder final
+    : public OpRewritePattern<OpType> {
 public:
-  using OpRewritePattern<SubViewOp>::OpRewritePattern;
+  using OpRewritePattern<OpType>::OpRewritePattern;
 
-  LogicalResult matchAndRewrite(SubViewOp subViewOp,
+  LogicalResult matchAndRewrite(OpType op,
                                 PatternRewriter &rewriter) const override {
     // No constant operand, just return;
-    if (llvm::none_of(subViewOp.getOperands(), [](Value operand) {
+    if (llvm::none_of(op.getOperands(), [](Value operand) {
           return matchPattern(operand, m_ConstantIndex());
         }))
       return failure();
 
     // At least one of offsets/sizes/strides is a new constant.
     // Form the new list of operands and constant attributes from the existing.
-    SmallVector<Value, 8> newOffsets(subViewOp.offsets());
+    SmallVector<Value, 8> newOffsets(op.offsets());
     SmallVector<int64_t, 8> newStaticOffsets =
-        extractFromI64ArrayAttr(subViewOp.static_offsets());
-    assert(newStaticOffsets.size() == subViewOp.getSourceRank());
+        extractFromI64ArrayAttr(op.static_offsets());
+    assert(newStaticOffsets.size() == op.getSourceRank());
     canonicalizeSubViewPart(newOffsets, newStaticOffsets,
                             ShapedType::isDynamicStrideOrOffset);
 
-    SmallVector<Value, 8> newSizes(subViewOp.sizes());
+    SmallVector<Value, 8> newSizes(op.sizes());
     SmallVector<int64_t, 8> newStaticSizes =
-        extractFromI64ArrayAttr(subViewOp.static_sizes());
-    assert(newStaticOffsets.size() == subViewOp.getSourceRank());
+        extractFromI64ArrayAttr(op.static_sizes());
+    assert(newStaticOffsets.size() == op.getSourceRank());
     canonicalizeSubViewPart(newSizes, newStaticSizes, ShapedType::isDynamic);
 
-    SmallVector<Value, 8> newStrides(subViewOp.strides());
+    SmallVector<Value, 8> newStrides(op.strides());
     SmallVector<int64_t, 8> newStaticStrides =
-        extractFromI64ArrayAttr(subViewOp.static_strides());
-    assert(newStaticOffsets.size() == subViewOp.getSourceRank());
+        extractFromI64ArrayAttr(op.static_strides());
+    assert(newStaticOffsets.size() == op.getSourceRank());
     canonicalizeSubViewPart(newStrides, newStaticStrides,
                             ShapedType::isDynamicStrideOrOffset);
 
     // Create the new op in canonical form.
-    auto newSubViewOp = rewriter.create<SubViewOp>(
-        subViewOp.getLoc(), subViewOp.source(), newStaticOffsets,
-        newStaticSizes, newStaticStrides, newOffsets, newSizes, newStrides);
+    auto newOp = rewriter.create<OpType>(
+        op.getLoc(), op.source(), newStaticOffsets, newStaticSizes,
+        newStaticStrides, newOffsets, newSizes, newStrides);
 
-    // Insert a memref_cast for compatibility of the uses of the op.
-    rewriter.replaceOpWithNewOp<MemRefCastOp>(subViewOp, newSubViewOp,
-                                              subViewOp.getType());
+    replaceWithNewOp(rewriter, op, newOp);
 
     return success();
   }
@@ -3183,8 +3192,8 @@ class SubViewOpMemRefCastFolder final : public OpRewritePattern<SubViewOp> {
 
 void SubViewOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
                                             MLIRContext *context) {
-  results.insert<SubViewOpConstantArgumentFolder, SubViewOpMemRefCastFolder>(
-      context);
+  results.insert<OpWithOffsetSizesAndStridesConstantArgumentFolder<SubViewOp>,
+                 SubViewOpMemRefCastFolder>(context);
 }
 
 //===----------------------------------------------------------------------===//
@@ -3275,6 +3284,13 @@ static LogicalResult verify(SubTensorOp op) {
   return success();
 }
 
+void SubTensorOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
+                                              MLIRContext *context) {
+  results
+      .insert<OpWithOffsetSizesAndStridesConstantArgumentFolder<SubTensorOp>>(
+          context);
+}
+
 //===----------------------------------------------------------------------===//
 // TensorCastOp
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/test/Transforms/canonicalize.mlir b/mlir/test/Transforms/canonicalize.mlir
index 3603c473a1fd..dc7be097b0c0 100644
--- a/mlir/test/Transforms/canonicalize.mlir
+++ b/mlir/test/Transforms/canonicalize.mlir
@@ -1110,3 +1110,32 @@ func @tensor_cast_chain_invalid(%input: tensor<4x8xi32>) -> tensor<8x4xi32> {
   // CHECK-NEXT: return %[[C2]]
   return %1 : tensor<8x4xi32>
 }
+
+// -----
+
+// CHECK-LABEL: func @subtensor
+// CHECK-SAME: %[[ARG0:[0-9a-z]*]]: index, %[[ARG1:[0-9a-z]*]]: index
+func @subtensor(%t: tensor<8x16x4xf32>, %arg0 : index, %arg1 : index) 
+  -> tensor<?x?x?xf32> 
+{
+  %c0 = constant 0 : index
+  %c1 = constant 1 : index
+  %c2 = constant 2 : index
+  %c7 = constant 7 : index
+  %c11 = constant 11 : index
+
+  // CHECK: subtensor %{{.*}}[0, 0, 0] [7, 11, 2] [1, 1, 1] :
+  // CHECK-SAME: tensor<8x16x4xf32> to tensor<7x11x2xf32>
+  // CHECK: tensor_cast %{{.*}} : tensor<7x11x2xf32> to tensor<?x?x?xf32>
+  %1 = subtensor %t[%c0, %c0, %c0] [%c7, %c11, %c2] [%c1, %c1, %c1]
+    : tensor<8x16x4xf32> to tensor<?x?x?xf32>
+
+  // Test: subtensor with one dynamic operand can also be folded.
+  // CHECK: subtensor %{{.*}}[0, 0, 0] [2, %[[ARG0]], 2] [1, 1, 1] :
+  // CHECK-SAME: tensor<?x?x?xf32> to tensor<2x?x2xf32>
+  // CHECK: tensor_cast %{{.*}} : tensor<2x?x2xf32> to tensor<?x?x?xf32>
+  %2 = subtensor %1[%c0, %c0, %c0] [%c2, %arg0, %c2] [%c1, %c1, %c1]
+    : tensor<?x?x?xf32> to tensor<?x?x?xf32>
+
+  return %2 : tensor<?x?x?xf32>
+}


        


More information about the Mlir-commits mailing list