[Mlir-commits] [mlir] a9e68db - [mlir] Add canonicaliations for subtensor_insert operation.

llvmlistbot at llvm.org llvmlistbot at llvm.org
Mon Mar 1 14:59:36 PST 2021


Author: MaheshRavishankar
Date: 2021-03-01T14:59:18-08:00
New Revision: a9e68db9736080373d73606d89a270e38d7f1273

URL: https://github.com/llvm/llvm-project/commit/a9e68db9736080373d73606d89a270e38d7f1273
DIFF: https://github.com/llvm/llvm-project/commit/a9e68db9736080373d73606d89a270e38d7f1273.diff

LOG: [mlir] Add canonicaliations for subtensor_insert operation.

Add canonicalizers to subtensor_insert operations need canonicalizers
that propagate the constant arguments within offsets, sizes and
strides. Also add pattern to propogate tensor_cast operations.

Differential Revision: https://reviews.llvm.org/D97704

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/StandardOps/IR/Ops.td
    mlir/lib/Dialect/StandardOps/IR/Ops.cpp
    mlir/test/Dialect/Standard/canonicalize.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td
index 2bcae1cb8f04..fe054c59ae6e 100644
--- a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td
+++ b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td
@@ -3048,6 +3048,7 @@ def SubTensorInsertOp : BaseOpWithOffsetSizesAndStrides<
     static unsigned getOffsetSizeAndStrideStartOperandIndex() { return 2; }
   }];
 
+  let hasCanonicalizer = 1;
   let hasFolder = 1;
 }
 

diff  --git a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp
index 501b3d8e2b18..539252af5cf9 100644
--- a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp
+++ b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp
@@ -3795,6 +3795,95 @@ OpFoldResult SubTensorInsertOp::fold(ArrayRef<Attribute>) {
   return OpFoldResult();
 }
 
+namespace {
+/// Pattern to rewrite a subtensor_insert op with constant arguments.
+class SubTensorInsertOpConstantArgumentFolder final
+    : public OpRewritePattern<SubTensorInsertOp> {
+public:
+  using OpRewritePattern<SubTensorInsertOp>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(SubTensorInsertOp subTensorInsertOp,
+                                PatternRewriter &rewriter) const override {
+    // No constant operand, just return.
+    if (llvm::none_of(subTensorInsertOp.getOperands(), [](Value operand) {
+          return matchPattern(operand, m_ConstantIndex());
+        }))
+      return failure();
+
+    // At least one of offsets/sizes/strides is a new constant.
+    // Form the new list of operands and constant attributes from the existing.
+    SmallVector<OpFoldResult> mixedOffsets(subTensorInsertOp.getMixedOffsets());
+    SmallVector<OpFoldResult> mixedSizes(subTensorInsertOp.getMixedSizes());
+    SmallVector<OpFoldResult> mixedStrides(subTensorInsertOp.getMixedStrides());
+    canonicalizeSubViewPart(mixedOffsets, ShapedType::isDynamicStrideOrOffset);
+    canonicalizeSubViewPart(mixedSizes, ShapedType::isDynamic);
+    canonicalizeSubViewPart(mixedStrides, ShapedType::isDynamicStrideOrOffset);
+
+    // Create the new op in canonical form.
+    Value source = subTensorInsertOp.source();
+    RankedTensorType sourceType = source.getType().cast<RankedTensorType>();
+    SmallVector<int64_t, 4> shape = llvm::to_vector<4>(
+        llvm::map_range(mixedSizes, [](OpFoldResult valueOrAttr) -> int64_t {
+          if (auto attr = valueOrAttr.dyn_cast<Attribute>())
+            return attr.cast<IntegerAttr>().getInt();
+          return ShapedType::kDynamicSize;
+        }));
+    RankedTensorType newSourceType =
+        RankedTensorType::get(shape, sourceType.getElementType());
+    Location loc = subTensorInsertOp.getLoc();
+    if (sourceType != newSourceType)
+      source = rewriter.create<tensor::CastOp>(loc, newSourceType, source);
+    rewriter.replaceOpWithNewOp<SubTensorInsertOp>(
+        subTensorInsertOp, source, subTensorInsertOp.dest(), mixedOffsets,
+        mixedSizes, mixedStrides);
+    return success();
+  }
+};
+
+/// Fold tensor_casts with subtensor_insert operations.
+struct SubTensorInsertOpCastFolder final
+    : public OpRewritePattern<SubTensorInsertOp> {
+  using OpRewritePattern<SubTensorInsertOp>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(SubTensorInsertOp subTensorOp,
+                                PatternRewriter &rewriter) const override {
+    if (llvm::any_of(subTensorOp.getOperands(), [](Value operand) {
+          return matchPattern(operand, m_ConstantIndex());
+        }))
+      return failure();
+
+    auto getSourceOfCastOp = [](Value v) -> Optional<Value> {
+      auto castOp = v.getDefiningOp<tensor::CastOp>();
+      if (!castOp || !canFoldIntoConsumerOp(castOp))
+        return llvm::None;
+      return castOp.source();
+    };
+    Optional<Value> sourceCastSource = getSourceOfCastOp(subTensorOp.source());
+    Optional<Value> destCastSource = getSourceOfCastOp(subTensorOp.dest());
+    if (!sourceCastSource && !destCastSource &&
+        subTensorOp.dest().getType() == subTensorOp.getResult().getType())
+      return failure();
+
+    auto newOp = rewriter.create<SubTensorInsertOp>(
+        subTensorOp.getLoc(),
+        (sourceCastSource ? *sourceCastSource : subTensorOp.source()),
+        (destCastSource ? *destCastSource : subTensorOp.dest()),
+        subTensorOp.getMixedOffsets(), subTensorOp.getMixedSizes(),
+        subTensorOp.getMixedStrides());
+
+    rewriter.replaceOpWithNewOp<tensor::CastOp>(subTensorOp,
+                                                subTensorOp.getType(), newOp);
+    return success();
+  }
+};
+} // namespace
+
+void SubTensorInsertOp::getCanonicalizationPatterns(
+    OwningRewritePatternList &results, MLIRContext *context) {
+  results.insert<SubTensorInsertOpConstantArgumentFolder,
+                 SubTensorInsertOpCastFolder>(context);
+}
+
 //===----------------------------------------------------------------------===//
 // TensorLoadOp
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/test/Dialect/Standard/canonicalize.mlir b/mlir/test/Dialect/Standard/canonicalize.mlir
index b887e90e931b..72b886a238ff 100644
--- a/mlir/test/Dialect/Standard/canonicalize.mlir
+++ b/mlir/test/Dialect/Standard/canonicalize.mlir
@@ -252,3 +252,51 @@ func @rank_reducing_subtensor_insert_of_cast(%a : tensor<16x32xi8>, %b : tensor<
   %res = subtensor_insert %cast into %b[0, 1, 0] [1, 1, 16] [1, 1, 1] : tensor<?x32xi8> into tensor<4x6x16x32xi8>
   return %res : tensor<4x6x16x32xi8>
 }
+
+// -----
+
+func @subtensor_canonicalize(%arg0 : tensor<2x?xi32>, %arg1 : tensor<i32>,
+    %arg2 : index, %arg3 : index) -> tensor<?x?xi32> {
+  %c0 = constant 0 : index
+  %c1 = constant 1 : index
+  %c2 = constant 2 : index
+  %c8 = constant 8 : index
+  %0 = dim %arg0, %c1 : tensor<2x?xi32>
+  %1 = tensor.extract %arg1[] : tensor<i32>
+  %2 = tensor.generate %arg2, %c8 {
+  ^bb0(%arg4: index, %arg5: index):
+    tensor.yield %1 : i32
+  } : tensor<?x?xi32>
+  %3 = subtensor_insert %arg0 into %2[%c0, %arg3] [%c2, %0] [%c1, %c1] : tensor<2x?xi32> into tensor<?x?xi32>
+  return %3 : tensor<?x?xi32>
+}
+// CHECK-LABEL: func @subtensor_canonicalize
+//       CHECK:   %[[UPDATED:.+]] = subtensor_insert %{{.+}} into %{{.+}}[0, %{{.+}}] [2, %{{.+}}] [1, 1]
+//  CHECK-SAME:     tensor<2x?xi32> into tensor<?x8xi32>
+//       CHECK:   %[[CAST:.+]] = tensor.cast %[[UPDATED]]
+//       CHECK:   return %[[CAST]]
+
+// -----
+
+func @subtensor_insert_output_dest_canonicalize(%arg0 : tensor<2x3xi32>, %arg1 : tensor<i32>) -> tensor<3x9xi32> {
+  %c0 = constant 0 : index
+  %c1 = constant 1 : index
+  %c2 = constant 2 : index
+  %c9 = constant 9 : index
+  %c3 = constant 3 : index
+  %2 = tensor.extract %arg1[] : tensor<i32>
+  %4 = tensor.generate %c3, %c9 {
+  ^bb0(%arg2: index, %arg3: index):
+    tensor.yield %2 : i32
+  } : tensor<?x?xi32>
+  %5 = subtensor_insert %arg0 into %4[%c0, %c1] [%c2, %c3] [1, 1] : tensor<2x3xi32> into tensor<?x?xi32>
+  %6 = tensor.cast %5 : tensor<?x?xi32> to tensor<3x9xi32>
+  return %6 : tensor<3x9xi32>
+}
+// CHECK-LABEL: func @subtensor_insert_output_dest_canonicalize
+//  CHECK-SAME:   %[[ARG0:[a-zA-z0-9_]+]]: tensor<2x3xi32>
+//  CHECK-SAME:   %[[ARG1:[a-zA-Z0-9_]+]]: tensor<i32>
+//       CHECK:   %[[PAD:.+]] = tensor.extract %[[ARG1]]
+//       CHECK:   %[[GENERATE:.+]] = tensor.generate
+//       CHECK:   %[[RESULT:.+]] = subtensor_insert %[[ARG0]] into %[[GENERATE]]
+//       CHECK:   return %[[RESULT]]


        


More information about the Mlir-commits mailing list