[Mlir-commits] [mlir] ebf3537 - [mlir][tensor] Insert explicit tensor.cast ops for insert_slice src

Matthias Springer llvmlistbot at llvm.org
Tue Aug 24 03:51:51 PDT 2021


Author: Matthias Springer
Date: 2021-08-24T19:45:04+09:00
New Revision: ebf35370ff596dcbd8a4a74b865cd066440510a2

URL: https://github.com/llvm/llvm-project/commit/ebf35370ff596dcbd8a4a74b865cd066440510a2
DIFF: https://github.com/llvm/llvm-project/commit/ebf35370ff596dcbd8a4a74b865cd066440510a2.diff

LOG: [mlir][tensor] Insert explicit tensor.cast ops for insert_slice src

If additional static type information can be deduced from a insert_slice's size operands, insert an explicit cast of the op's source operand.

This enables other canonicalization patterns that are matching for tensor_cast ops such as `ForOpTensorCastFolder` in SCF.

Differential Revision: https://reviews.llvm.org/D108617

Added: 
    

Modified: 
    mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
    mlir/test/Dialect/SCF/canonicalize.mlir
    mlir/test/Dialect/Tensor/canonicalize.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
index 5dd3127a508bf..14ce6c104d44f 100644
--- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
+++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
@@ -1085,7 +1085,24 @@ class InsertSliceOpConstantArgumentFolder final
   }
 };
 
-/// Fold tensor_casts with insert_slice operations.
+/// Fold tensor_casts with insert_slice operations. If the source or destination
+/// tensor is a tensor_cast that removes static type information, the cast is
+/// folded into the insert_slice operation. E.g.:
+///
+/// ```mlir
+///   %1 = tensor.cast %0 : tensor<8x16xf32> to tensor<?x?xf32>
+///   %2 = tensor.insert_slice %1 into ... : tensor<?x?xf32> into ...
+/// ```
+///
+/// folds into:
+///
+/// ```mlir
+///   %2 = tensor.insert_slice %0 into ... : tensor<8x16xf32> into ...
+/// ```
+///
+/// Note: When folding a cast on the destination tensor, the result of the
+/// insert_slice operation is casted to ensure that the type of the result did
+/// not change.
 struct InsertSliceOpCastFolder final : public OpRewritePattern<InsertSliceOp> {
   using OpRewritePattern<InsertSliceOp>::OpRewritePattern;
 
@@ -1123,12 +1140,63 @@ struct InsertSliceOpCastFolder final : public OpRewritePattern<InsertSliceOp> {
     return success();
   }
 };
+
+/// If additional static type information can be deduced from a insert_slice's
+/// size operands, insert an explicit cast of the op's source operand. This
+/// enables other canonicalization patterns that are matching for tensor_cast
+/// ops such as `ForOpTensorCastFolder` in SCF.
+///
+/// Example:
+///
+/// ```mlir
+///   %r = tensor.insert_slice %0 into %1[...] [64, 64] [1, 1]
+///       : tensor<?x?xf32> into ...
+/// ```
+///
+/// folds into:
+///
+/// ```mlir
+///   %tmp = tensor.cast %0 : tensor<?x?xf32> to tensor<64x64xf32>
+///   %r = tensor.insert_slice %tmp into %1[...] [64, 64] [1, 1]
+///       : tensor<64x64xf32> into ...
+/// ```
+struct InsertSliceOpSourceCastInserter final
+    : public OpRewritePattern<InsertSliceOp> {
+  using OpRewritePattern<InsertSliceOp>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(InsertSliceOp insertSliceOp,
+                                PatternRewriter &rewriter) const override {
+    RankedTensorType srcType = insertSliceOp.getSourceType();
+    if (srcType.getRank() != insertSliceOp.getType().getRank())
+      return failure();
+    SmallVector<int64_t> newSrcShape(srcType.getShape().begin(),
+                                     srcType.getShape().end());
+    for (int64_t i = 0; i < srcType.getRank(); ++i) {
+      if (Optional<int64_t> constInt =
+              getConstantIntValue(insertSliceOp.getMixedSizes()[i]))
+        newSrcShape[i] = *constInt;
+    }
+    RankedTensorType newSrcType =
+        RankedTensorType::get(newSrcShape, srcType.getElementType());
+    if (srcType == newSrcType)
+      return failure();
+
+    // srcType and newSrcType are 
diff erent. Insert a cast.
+    Value cast = rewriter.create<tensor::CastOp>(
+        insertSliceOp.getLoc(), newSrcType, insertSliceOp.source());
+    rewriter.replaceOpWithNewOp<InsertSliceOp>(
+        insertSliceOp, cast, insertSliceOp.dest(),
+        insertSliceOp.getMixedOffsets(), insertSliceOp.getMixedSizes(),
+        insertSliceOp.getMixedStrides());
+    return success();
+  }
+};
 } // namespace
 
 void InsertSliceOp::getCanonicalizationPatterns(RewritePatternSet &results,
                                                 MLIRContext *context) {
-  results.add<InsertSliceOpConstantArgumentFolder, InsertSliceOpCastFolder>(
-      context);
+  results.add<InsertSliceOpConstantArgumentFolder, InsertSliceOpCastFolder,
+              InsertSliceOpSourceCastInserter>(context);
 }
 
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/test/Dialect/SCF/canonicalize.mlir b/mlir/test/Dialect/SCF/canonicalize.mlir
index 7a6c368b5052e..6ea213743a298 100644
--- a/mlir/test/Dialect/SCF/canonicalize.mlir
+++ b/mlir/test/Dialect/SCF/canonicalize.mlir
@@ -666,7 +666,7 @@ func @matmul_on_tensors(%t0: tensor<32x1024xf32>, %t1: tensor<1024x1024xf32>) ->
   return %res : tensor<1024x1024xf32>
 }
 
-
+// -----
 
 // CHECK-LABEL: @cond_prop
 func @cond_prop(%arg0 : i1) -> index {
@@ -707,6 +707,8 @@ func @cond_prop(%arg0 : i1) -> index {
 // CHECK-NEXT:  return %[[if]] : index
 // CHECK-NEXT:}
 
+// -----
+
 // CHECK-LABEL: @replace_if_with_cond1
 func @replace_if_with_cond1(%arg0 : i1) -> (i32, i1) {
   %true = constant true
@@ -729,6 +731,8 @@ func @replace_if_with_cond1(%arg0 : i1) -> (i32, i1) {
 // CHECK-NEXT:    }
 // CHECK-NEXT:    return %[[if]], %arg0 : i32, i1
 
+// -----
+
 // CHECK-LABEL: @replace_if_with_cond2
 func @replace_if_with_cond2(%arg0 : i1) -> (i32, i1) {
   %true = constant true
@@ -753,6 +757,7 @@ func @replace_if_with_cond2(%arg0 : i1) -> (i32, i1) {
 // CHECK-NEXT:     }
 // CHECK-NEXT:     return %[[if]], %[[toret]] : i32, i1
 
+// -----
 
 // CHECK-LABEL: @replace_if_with_cond3
 func @replace_if_with_cond3(%arg0 : i1, %arg2: i64) -> (i32, i64) {
@@ -774,6 +779,7 @@ func @replace_if_with_cond3(%arg0 : i1, %arg2: i64) -> (i32, i64) {
 // CHECK-NEXT:     }
 // CHECK-NEXT:     return %[[if]], %arg1 : i32, i64
 
+// -----
 
 // CHECK-LABEL: @while_cond_true
 func @while_cond_true() {

diff  --git a/mlir/test/Dialect/Tensor/canonicalize.mlir b/mlir/test/Dialect/Tensor/canonicalize.mlir
index f0259952da380..7ef93fbe1b10f 100644
--- a/mlir/test/Dialect/Tensor/canonicalize.mlir
+++ b/mlir/test/Dialect/Tensor/canonicalize.mlir
@@ -366,10 +366,11 @@ func @insert_slice_canonicalize(%arg0 : tensor<?x?x?xf32>, %arg1 : index,
 }
 // CHECK-LABEL: func @insert_slice_canonicalize
 //  CHECK-SAME:   %[[ARG0:[a-zA-Z0-9_]+]]: tensor<?x?x?xf32>
-//       CHECK:   %[[RESULT:.+]] = tensor.insert_slice %[[ARG0]]
+//       CHECK:   %[[CAST:.+]] = tensor.cast %[[ARG0]] : tensor<?x?x?xf32> to tensor<4x1x?xf32>
+//       CHECK:   %[[RESULT:.+]] = tensor.insert_slice %[[CAST]]
 //  CHECK-SAME:      [0, %{{.+}}, 1] [4, 1, %{{.+}}] [1, 1, 1]
-//  CHECK-SAME:      : tensor<?x?x?xf32> into tensor<?x?x?xf32>
-//       CHEKC:   return %[[RESULT]]
+//  CHECK-SAME:      : tensor<4x1x?xf32> into tensor<?x?x?xf32>
+//       CHECK:   return %[[RESULT]]
 
 // -----
 
@@ -517,3 +518,17 @@ func @fold_dim_of_tensor.cast(%arg0 : tensor<4x?xf32>) -> (index, index) {
   %2 = tensor.dim %0, %c1 : tensor<?x?xf32>
   return %1, %2: index, index
 }
+
+// -----
+
+// CHECK-LABEL: func @insert_tensor_cast_on_insert_slice_src(
+// CHECK-SAME:      %[[arg0:.*]]: tensor<?x5x?xf32>, %[[arg1:.*]]: tensor<?x?x?xf32>
+//      CHECK:    %[[cast:.*]] = tensor.cast %[[arg0]] : tensor<?x5x?xf32> to tensor<64x5x64xf32>
+//      CHECK:    %[[r:.*]] =  tensor.insert_slice %[[cast]] into %[[arg1]][0, 1, 2] [64, 5, 64] [1, 1, 1] : tensor<64x5x64xf32> into tensor<?x?x?xf32>
+//      CHECK:    return %[[r]]
+func @insert_tensor_cast_on_insert_slice_src(
+  %arg0 : tensor<?x5x?xf32>,  %arg1 : tensor<?x?x?xf32>) -> tensor<?x?x?xf32> {
+  %r = tensor.insert_slice %arg0 into %arg1[0, 1, 2] [64, 5, 64] [1, 1, 1]
+    : tensor<?x5x?xf32> into tensor<?x?x?xf32>
+  return %r : tensor<?x?x?xf32>
+}


        


More information about the Mlir-commits mailing list