[Mlir-commits] [mlir] ebf3537 - [mlir][tensor] Insert explicit tensor.cast ops for insert_slice src
Matthias Springer
llvmlistbot at llvm.org
Tue Aug 24 03:51:51 PDT 2021
Author: Matthias Springer
Date: 2021-08-24T19:45:04+09:00
New Revision: ebf35370ff596dcbd8a4a74b865cd066440510a2
URL: https://github.com/llvm/llvm-project/commit/ebf35370ff596dcbd8a4a74b865cd066440510a2
DIFF: https://github.com/llvm/llvm-project/commit/ebf35370ff596dcbd8a4a74b865cd066440510a2.diff
LOG: [mlir][tensor] Insert explicit tensor.cast ops for insert_slice src
If additional static type information can be deduced from a insert_slice's size operands, insert an explicit cast of the op's source operand.
This enables other canonicalization patterns that are matching for tensor_cast ops such as `ForOpTensorCastFolder` in SCF.
Differential Revision: https://reviews.llvm.org/D108617
Added:
Modified:
mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
mlir/test/Dialect/SCF/canonicalize.mlir
mlir/test/Dialect/Tensor/canonicalize.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
index 5dd3127a508bf..14ce6c104d44f 100644
--- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
+++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
@@ -1085,7 +1085,24 @@ class InsertSliceOpConstantArgumentFolder final
}
};
-/// Fold tensor_casts with insert_slice operations.
+/// Fold tensor_casts with insert_slice operations. If the source or destination
+/// tensor is a tensor_cast that removes static type information, the cast is
+/// folded into the insert_slice operation. E.g.:
+///
+/// ```mlir
+/// %1 = tensor.cast %0 : tensor<8x16xf32> to tensor<?x?xf32>
+/// %2 = tensor.insert_slice %1 into ... : tensor<?x?xf32> into ...
+/// ```
+///
+/// folds into:
+///
+/// ```mlir
+/// %2 = tensor.insert_slice %0 into ... : tensor<8x16xf32> into ...
+/// ```
+///
+/// Note: When folding a cast on the destination tensor, the result of the
+/// insert_slice operation is casted to ensure that the type of the result did
+/// not change.
struct InsertSliceOpCastFolder final : public OpRewritePattern<InsertSliceOp> {
using OpRewritePattern<InsertSliceOp>::OpRewritePattern;
@@ -1123,12 +1140,63 @@ struct InsertSliceOpCastFolder final : public OpRewritePattern<InsertSliceOp> {
return success();
}
};
+
+/// If additional static type information can be deduced from a insert_slice's
+/// size operands, insert an explicit cast of the op's source operand. This
+/// enables other canonicalization patterns that are matching for tensor_cast
+/// ops such as `ForOpTensorCastFolder` in SCF.
+///
+/// Example:
+///
+/// ```mlir
+/// %r = tensor.insert_slice %0 into %1[...] [64, 64] [1, 1]
+/// : tensor<?x?xf32> into ...
+/// ```
+///
+/// folds into:
+///
+/// ```mlir
+/// %tmp = tensor.cast %0 : tensor<?x?xf32> to tensor<64x64xf32>
+/// %r = tensor.insert_slice %tmp into %1[...] [64, 64] [1, 1]
+/// : tensor<64x64xf32> into ...
+/// ```
+struct InsertSliceOpSourceCastInserter final
+ : public OpRewritePattern<InsertSliceOp> {
+ using OpRewritePattern<InsertSliceOp>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(InsertSliceOp insertSliceOp,
+ PatternRewriter &rewriter) const override {
+ RankedTensorType srcType = insertSliceOp.getSourceType();
+ if (srcType.getRank() != insertSliceOp.getType().getRank())
+ return failure();
+ SmallVector<int64_t> newSrcShape(srcType.getShape().begin(),
+ srcType.getShape().end());
+ for (int64_t i = 0; i < srcType.getRank(); ++i) {
+ if (Optional<int64_t> constInt =
+ getConstantIntValue(insertSliceOp.getMixedSizes()[i]))
+ newSrcShape[i] = *constInt;
+ }
+ RankedTensorType newSrcType =
+ RankedTensorType::get(newSrcShape, srcType.getElementType());
+ if (srcType == newSrcType)
+ return failure();
+
+ // srcType and newSrcType are
diff erent. Insert a cast.
+ Value cast = rewriter.create<tensor::CastOp>(
+ insertSliceOp.getLoc(), newSrcType, insertSliceOp.source());
+ rewriter.replaceOpWithNewOp<InsertSliceOp>(
+ insertSliceOp, cast, insertSliceOp.dest(),
+ insertSliceOp.getMixedOffsets(), insertSliceOp.getMixedSizes(),
+ insertSliceOp.getMixedStrides());
+ return success();
+ }
+};
} // namespace
void InsertSliceOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) {
- results.add<InsertSliceOpConstantArgumentFolder, InsertSliceOpCastFolder>(
- context);
+ results.add<InsertSliceOpConstantArgumentFolder, InsertSliceOpCastFolder,
+ InsertSliceOpSourceCastInserter>(context);
}
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/SCF/canonicalize.mlir b/mlir/test/Dialect/SCF/canonicalize.mlir
index 7a6c368b5052e..6ea213743a298 100644
--- a/mlir/test/Dialect/SCF/canonicalize.mlir
+++ b/mlir/test/Dialect/SCF/canonicalize.mlir
@@ -666,7 +666,7 @@ func @matmul_on_tensors(%t0: tensor<32x1024xf32>, %t1: tensor<1024x1024xf32>) ->
return %res : tensor<1024x1024xf32>
}
-
+// -----
// CHECK-LABEL: @cond_prop
func @cond_prop(%arg0 : i1) -> index {
@@ -707,6 +707,8 @@ func @cond_prop(%arg0 : i1) -> index {
// CHECK-NEXT: return %[[if]] : index
// CHECK-NEXT:}
+// -----
+
// CHECK-LABEL: @replace_if_with_cond1
func @replace_if_with_cond1(%arg0 : i1) -> (i32, i1) {
%true = constant true
@@ -729,6 +731,8 @@ func @replace_if_with_cond1(%arg0 : i1) -> (i32, i1) {
// CHECK-NEXT: }
// CHECK-NEXT: return %[[if]], %arg0 : i32, i1
+// -----
+
// CHECK-LABEL: @replace_if_with_cond2
func @replace_if_with_cond2(%arg0 : i1) -> (i32, i1) {
%true = constant true
@@ -753,6 +757,7 @@ func @replace_if_with_cond2(%arg0 : i1) -> (i32, i1) {
// CHECK-NEXT: }
// CHECK-NEXT: return %[[if]], %[[toret]] : i32, i1
+// -----
// CHECK-LABEL: @replace_if_with_cond3
func @replace_if_with_cond3(%arg0 : i1, %arg2: i64) -> (i32, i64) {
@@ -774,6 +779,7 @@ func @replace_if_with_cond3(%arg0 : i1, %arg2: i64) -> (i32, i64) {
// CHECK-NEXT: }
// CHECK-NEXT: return %[[if]], %arg1 : i32, i64
+// -----
// CHECK-LABEL: @while_cond_true
func @while_cond_true() {
diff --git a/mlir/test/Dialect/Tensor/canonicalize.mlir b/mlir/test/Dialect/Tensor/canonicalize.mlir
index f0259952da380..7ef93fbe1b10f 100644
--- a/mlir/test/Dialect/Tensor/canonicalize.mlir
+++ b/mlir/test/Dialect/Tensor/canonicalize.mlir
@@ -366,10 +366,11 @@ func @insert_slice_canonicalize(%arg0 : tensor<?x?x?xf32>, %arg1 : index,
}
// CHECK-LABEL: func @insert_slice_canonicalize
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: tensor<?x?x?xf32>
-// CHECK: %[[RESULT:.+]] = tensor.insert_slice %[[ARG0]]
+// CHECK: %[[CAST:.+]] = tensor.cast %[[ARG0]] : tensor<?x?x?xf32> to tensor<4x1x?xf32>
+// CHECK: %[[RESULT:.+]] = tensor.insert_slice %[[CAST]]
// CHECK-SAME: [0, %{{.+}}, 1] [4, 1, %{{.+}}] [1, 1, 1]
-// CHECK-SAME: : tensor<?x?x?xf32> into tensor<?x?x?xf32>
-// CHEKC: return %[[RESULT]]
+// CHECK-SAME: : tensor<4x1x?xf32> into tensor<?x?x?xf32>
+// CHECK: return %[[RESULT]]
// -----
@@ -517,3 +518,17 @@ func @fold_dim_of_tensor.cast(%arg0 : tensor<4x?xf32>) -> (index, index) {
%2 = tensor.dim %0, %c1 : tensor<?x?xf32>
return %1, %2: index, index
}
+
+// -----
+
+// CHECK-LABEL: func @insert_tensor_cast_on_insert_slice_src(
+// CHECK-SAME: %[[arg0:.*]]: tensor<?x5x?xf32>, %[[arg1:.*]]: tensor<?x?x?xf32>
+// CHECK: %[[cast:.*]] = tensor.cast %[[arg0]] : tensor<?x5x?xf32> to tensor<64x5x64xf32>
+// CHECK: %[[r:.*]] = tensor.insert_slice %[[cast]] into %[[arg1]][0, 1, 2] [64, 5, 64] [1, 1, 1] : tensor<64x5x64xf32> into tensor<?x?x?xf32>
+// CHECK: return %[[r]]
+func @insert_tensor_cast_on_insert_slice_src(
+ %arg0 : tensor<?x5x?xf32>, %arg1 : tensor<?x?x?xf32>) -> tensor<?x?x?xf32> {
+ %r = tensor.insert_slice %arg0 into %arg1[0, 1, 2] [64, 5, 64] [1, 1, 1]
+ : tensor<?x5x?xf32> into tensor<?x?x?xf32>
+ return %r : tensor<?x?x?xf32>
+}
More information about the Mlir-commits
mailing list