[Mlir-commits] [mlir] a9e68db - [mlir] Add canonicaliations for subtensor_insert operation.
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Mon Mar 1 14:59:36 PST 2021
Author: MaheshRavishankar
Date: 2021-03-01T14:59:18-08:00
New Revision: a9e68db9736080373d73606d89a270e38d7f1273
URL: https://github.com/llvm/llvm-project/commit/a9e68db9736080373d73606d89a270e38d7f1273
DIFF: https://github.com/llvm/llvm-project/commit/a9e68db9736080373d73606d89a270e38d7f1273.diff
LOG: [mlir] Add canonicaliations for subtensor_insert operation.
Add canonicalizers to subtensor_insert operations need canonicalizers
that propagate the constant arguments within offsets, sizes and
strides. Also add pattern to propogate tensor_cast operations.
Differential Revision: https://reviews.llvm.org/D97704
Added:
Modified:
mlir/include/mlir/Dialect/StandardOps/IR/Ops.td
mlir/lib/Dialect/StandardOps/IR/Ops.cpp
mlir/test/Dialect/Standard/canonicalize.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td
index 2bcae1cb8f04..fe054c59ae6e 100644
--- a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td
+++ b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td
@@ -3048,6 +3048,7 @@ def SubTensorInsertOp : BaseOpWithOffsetSizesAndStrides<
static unsigned getOffsetSizeAndStrideStartOperandIndex() { return 2; }
}];
+ let hasCanonicalizer = 1;
let hasFolder = 1;
}
diff --git a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp
index 501b3d8e2b18..539252af5cf9 100644
--- a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp
+++ b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp
@@ -3795,6 +3795,95 @@ OpFoldResult SubTensorInsertOp::fold(ArrayRef<Attribute>) {
return OpFoldResult();
}
+namespace {
+/// Pattern to rewrite a subtensor_insert op with constant arguments.
+class SubTensorInsertOpConstantArgumentFolder final
+ : public OpRewritePattern<SubTensorInsertOp> {
+public:
+ using OpRewritePattern<SubTensorInsertOp>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(SubTensorInsertOp subTensorInsertOp,
+ PatternRewriter &rewriter) const override {
+ // No constant operand, just return.
+ if (llvm::none_of(subTensorInsertOp.getOperands(), [](Value operand) {
+ return matchPattern(operand, m_ConstantIndex());
+ }))
+ return failure();
+
+ // At least one of offsets/sizes/strides is a new constant.
+ // Form the new list of operands and constant attributes from the existing.
+ SmallVector<OpFoldResult> mixedOffsets(subTensorInsertOp.getMixedOffsets());
+ SmallVector<OpFoldResult> mixedSizes(subTensorInsertOp.getMixedSizes());
+ SmallVector<OpFoldResult> mixedStrides(subTensorInsertOp.getMixedStrides());
+ canonicalizeSubViewPart(mixedOffsets, ShapedType::isDynamicStrideOrOffset);
+ canonicalizeSubViewPart(mixedSizes, ShapedType::isDynamic);
+ canonicalizeSubViewPart(mixedStrides, ShapedType::isDynamicStrideOrOffset);
+
+ // Create the new op in canonical form.
+ Value source = subTensorInsertOp.source();
+ RankedTensorType sourceType = source.getType().cast<RankedTensorType>();
+ SmallVector<int64_t, 4> shape = llvm::to_vector<4>(
+ llvm::map_range(mixedSizes, [](OpFoldResult valueOrAttr) -> int64_t {
+ if (auto attr = valueOrAttr.dyn_cast<Attribute>())
+ return attr.cast<IntegerAttr>().getInt();
+ return ShapedType::kDynamicSize;
+ }));
+ RankedTensorType newSourceType =
+ RankedTensorType::get(shape, sourceType.getElementType());
+ Location loc = subTensorInsertOp.getLoc();
+ if (sourceType != newSourceType)
+ source = rewriter.create<tensor::CastOp>(loc, newSourceType, source);
+ rewriter.replaceOpWithNewOp<SubTensorInsertOp>(
+ subTensorInsertOp, source, subTensorInsertOp.dest(), mixedOffsets,
+ mixedSizes, mixedStrides);
+ return success();
+ }
+};
+
+/// Fold tensor_casts with subtensor_insert operations.
+struct SubTensorInsertOpCastFolder final
+ : public OpRewritePattern<SubTensorInsertOp> {
+ using OpRewritePattern<SubTensorInsertOp>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(SubTensorInsertOp subTensorOp,
+ PatternRewriter &rewriter) const override {
+ if (llvm::any_of(subTensorOp.getOperands(), [](Value operand) {
+ return matchPattern(operand, m_ConstantIndex());
+ }))
+ return failure();
+
+ auto getSourceOfCastOp = [](Value v) -> Optional<Value> {
+ auto castOp = v.getDefiningOp<tensor::CastOp>();
+ if (!castOp || !canFoldIntoConsumerOp(castOp))
+ return llvm::None;
+ return castOp.source();
+ };
+ Optional<Value> sourceCastSource = getSourceOfCastOp(subTensorOp.source());
+ Optional<Value> destCastSource = getSourceOfCastOp(subTensorOp.dest());
+ if (!sourceCastSource && !destCastSource &&
+ subTensorOp.dest().getType() == subTensorOp.getResult().getType())
+ return failure();
+
+ auto newOp = rewriter.create<SubTensorInsertOp>(
+ subTensorOp.getLoc(),
+ (sourceCastSource ? *sourceCastSource : subTensorOp.source()),
+ (destCastSource ? *destCastSource : subTensorOp.dest()),
+ subTensorOp.getMixedOffsets(), subTensorOp.getMixedSizes(),
+ subTensorOp.getMixedStrides());
+
+ rewriter.replaceOpWithNewOp<tensor::CastOp>(subTensorOp,
+ subTensorOp.getType(), newOp);
+ return success();
+ }
+};
+} // namespace
+
+void SubTensorInsertOp::getCanonicalizationPatterns(
+ OwningRewritePatternList &results, MLIRContext *context) {
+ results.insert<SubTensorInsertOpConstantArgumentFolder,
+ SubTensorInsertOpCastFolder>(context);
+}
+
//===----------------------------------------------------------------------===//
// TensorLoadOp
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/Standard/canonicalize.mlir b/mlir/test/Dialect/Standard/canonicalize.mlir
index b887e90e931b..72b886a238ff 100644
--- a/mlir/test/Dialect/Standard/canonicalize.mlir
+++ b/mlir/test/Dialect/Standard/canonicalize.mlir
@@ -252,3 +252,51 @@ func @rank_reducing_subtensor_insert_of_cast(%a : tensor<16x32xi8>, %b : tensor<
%res = subtensor_insert %cast into %b[0, 1, 0] [1, 1, 16] [1, 1, 1] : tensor<?x32xi8> into tensor<4x6x16x32xi8>
return %res : tensor<4x6x16x32xi8>
}
+
+// -----
+
+func @subtensor_canonicalize(%arg0 : tensor<2x?xi32>, %arg1 : tensor<i32>,
+ %arg2 : index, %arg3 : index) -> tensor<?x?xi32> {
+ %c0 = constant 0 : index
+ %c1 = constant 1 : index
+ %c2 = constant 2 : index
+ %c8 = constant 8 : index
+ %0 = dim %arg0, %c1 : tensor<2x?xi32>
+ %1 = tensor.extract %arg1[] : tensor<i32>
+ %2 = tensor.generate %arg2, %c8 {
+ ^bb0(%arg4: index, %arg5: index):
+ tensor.yield %1 : i32
+ } : tensor<?x?xi32>
+ %3 = subtensor_insert %arg0 into %2[%c0, %arg3] [%c2, %0] [%c1, %c1] : tensor<2x?xi32> into tensor<?x?xi32>
+ return %3 : tensor<?x?xi32>
+}
+// CHECK-LABEL: func @subtensor_canonicalize
+// CHECK: %[[UPDATED:.+]] = subtensor_insert %{{.+}} into %{{.+}}[0, %{{.+}}] [2, %{{.+}}] [1, 1]
+// CHECK-SAME: tensor<2x?xi32> into tensor<?x8xi32>
+// CHECK: %[[CAST:.+]] = tensor.cast %[[UPDATED]]
+// CHECK: return %[[CAST]]
+
+// -----
+
+func @subtensor_insert_output_dest_canonicalize(%arg0 : tensor<2x3xi32>, %arg1 : tensor<i32>) -> tensor<3x9xi32> {
+ %c0 = constant 0 : index
+ %c1 = constant 1 : index
+ %c2 = constant 2 : index
+ %c9 = constant 9 : index
+ %c3 = constant 3 : index
+ %2 = tensor.extract %arg1[] : tensor<i32>
+ %4 = tensor.generate %c3, %c9 {
+ ^bb0(%arg2: index, %arg3: index):
+ tensor.yield %2 : i32
+ } : tensor<?x?xi32>
+ %5 = subtensor_insert %arg0 into %4[%c0, %c1] [%c2, %c3] [1, 1] : tensor<2x3xi32> into tensor<?x?xi32>
+ %6 = tensor.cast %5 : tensor<?x?xi32> to tensor<3x9xi32>
+ return %6 : tensor<3x9xi32>
+}
+// CHECK-LABEL: func @subtensor_insert_output_dest_canonicalize
+// CHECK-SAME: %[[ARG0:[a-zA-z0-9_]+]]: tensor<2x3xi32>
+// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: tensor<i32>
+// CHECK: %[[PAD:.+]] = tensor.extract %[[ARG1]]
+// CHECK: %[[GENERATE:.+]] = tensor.generate
+// CHECK: %[[RESULT:.+]] = subtensor_insert %[[ARG0]] into %[[GENERATE]]
+// CHECK: return %[[RESULT]]
More information about the Mlir-commits
mailing list