[Mlir-commits] [mlir] a5cdcf4 - [mlir] Fix folding for scf.for(tensor.cast).

Alexander Belyaev llvmlistbot at llvm.org
Thu Feb 23 02:23:33 PST 2023


Author: Alexander Belyaev
Date: 2023-02-23T11:23:11+01:00
New Revision: a5cdcf49b2f7895dbdf3d00fe8cd7d3fb4f9d38b

URL: https://github.com/llvm/llvm-project/commit/a5cdcf49b2f7895dbdf3d00fe8cd7d3fb4f9d38b
DIFF: https://github.com/llvm/llvm-project/commit/a5cdcf49b2f7895dbdf3d00fe8cd7d3fb4f9d38b.diff

LOG: [mlir] Fix folding for scf.for(tensor.cast).

We should only fold tensor.casts that provide some new static information about
shapes, instead of looking for a symmetric pattern cast(for(cast)).

Differential Revision: https://reviews.llvm.org/D144577

Added: 
    

Modified: 
    mlir/lib/Dialect/SCF/IR/SCF.cpp
    mlir/test/Dialect/SCF/canonicalize.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/SCF/IR/SCF.cpp b/mlir/lib/Dialect/SCF/IR/SCF.cpp
index 80038b9c436f2..f8fd2016cd9a9 100644
--- a/mlir/lib/Dialect/SCF/IR/SCF.cpp
+++ b/mlir/lib/Dialect/SCF/IR/SCF.cpp
@@ -894,8 +894,7 @@ static ForOp replaceTensorCastForOpIterArg(PatternRewriter &rewriter,
 ///     %2 = call @do(%iter_t0) : (tensor<?x?xf32>) -> tensor<?x?xf32>
 ///     scf.yield %2 : tensor<?x?xf32>
 ///   }
-///   %2 = tensor.cast %1 : tensor<?x?xf32> to tensor<32x1024xf32>
-///   use_of(%2)
+///   use_of(%1)
 /// ```
 ///
 /// folds into:
@@ -908,7 +907,8 @@ static ForOp replaceTensorCastForOpIterArg(PatternRewriter &rewriter,
 ///     %4 = tensor.cast %3 : tensor<?x?xf32> to tensor<32x1024xf32>
 ///     scf.yield %4 : tensor<32x1024xf32>
 ///   }
-///   use_of(%0)
+///   %1 = tensor.cast %0 : tensor<32x1024xf32> to tensor<?x?xf32>
+///   use_of(%1)
 /// ```
 struct ForOpTensorCastFolder : public OpRewritePattern<ForOp> {
   using OpRewritePattern<ForOp>::OpRewritePattern;
@@ -920,17 +920,13 @@ struct ForOpTensorCastFolder : public OpRewritePattern<ForOp> {
       auto incomingCast = iterOpOperand.get().getDefiningOp<tensor::CastOp>();
       if (!incomingCast)
         continue;
+      // If the dest type of the cast does not preserve static information in
+      // the source type.
+      if (!tensor::preservesStaticInformation(incomingCast.getDest().getType(),
+                                         incomingCast.getSource().getType()))
+          continue;
       if (!std::get<1>(it).hasOneUse())
         continue;
-      auto outgoingCastOp =
-          dyn_cast<tensor::CastOp>(*std::get<1>(it).user_begin());
-      if (!outgoingCastOp)
-        continue;
-
-      // Must be a tensor.cast op pair with matching types.
-      if (outgoingCastOp.getResult().getType() !=
-          incomingCast.getSource().getType())
-        continue;
 
       // Create a new ForOp with that iter operand replaced.
       auto newForOp = replaceTensorCastForOpIterArg(rewriter, iterOpOperand,

diff  --git a/mlir/test/Dialect/SCF/canonicalize.mlir b/mlir/test/Dialect/SCF/canonicalize.mlir
index c211596db7445..d3dfd16ba0442 100644
--- a/mlir/test/Dialect/SCF/canonicalize.mlir
+++ b/mlir/test/Dialect/SCF/canonicalize.mlir
@@ -850,13 +850,20 @@ func.func @fold_away_iter_and_result_with_no_use(%arg0 : i32,
 
 func.func private @do(%arg0: tensor<?x?xf32>) -> tensor<?x?xf32>
 
-// CHECK-LABEL: matmul_on_tensors
-//  CHECK-SAME:   %[[T0:[0-9a-z]*]]: tensor<32x1024xf32>
-//  CHECK-SAME:   %[[T1:[0-9a-z]*]]: tensor<1024x1024xf32>
-func.func @matmul_on_tensors(%t0: tensor<32x1024xf32>, %t1: tensor<1024x1024xf32>) -> tensor<1024x1024xf32> {
+func.func @matmul_on_tensors(%t0: tensor<32x1024xf32>) -> tensor<?x?xf32> {
   %c0 = arith.constant 0 : index
   %c32 = arith.constant 32 : index
   %c1024 = arith.constant 1024 : index
+  %0 = tensor.cast %t0 : tensor<32x1024xf32> to tensor<?x?xf32>
+  %1 = scf.for %i = %c0 to %c1024 step %c32 iter_args(%iter_t0 = %0) -> (tensor<?x?xf32>) {
+    %2 = func.call @do(%iter_t0) : (tensor<?x?xf32>) -> tensor<?x?xf32>
+    scf.yield %2 : tensor<?x?xf32>
+  } {some_attr}
+  return %1 : tensor<?x?xf32>
+}
+// CHECK-LABEL: matmul_on_tensors
+//  CHECK-SAME:   %[[T0:[0-9a-z]*]]: tensor<32x1024xf32>
+
 //   CHECK-NOT: tensor.cast
 //       CHECK: %[[FOR_RES:.*]] = scf.for {{.*}} iter_args(%[[ITER_T0:.*]] = %[[T0]]) -> (tensor<32x1024xf32>) {
 //       CHECK:   %[[CAST:.*]] = tensor.cast %[[ITER_T0]] : tensor<32x1024xf32> to tensor<?x?xf32>
@@ -864,18 +871,8 @@ func.func @matmul_on_tensors(%t0: tensor<32x1024xf32>, %t1: tensor<1024x1024xf32
 //       CHECK:   %[[UNCAST:.*]] = tensor.cast %[[DONE]] : tensor<?x?xf32> to tensor<32x1024xf32>
 //       CHECK:   scf.yield %[[UNCAST]] : tensor<32x1024xf32>
 //       CHECK: } {some_attr}
-  %0 = tensor.cast %t0 : tensor<32x1024xf32> to tensor<?x?xf32>
-  %1 = scf.for %i = %c0 to %c1024 step %c32 iter_args(%iter_t0 = %0) -> (tensor<?x?xf32>) {
-    %2 = func.call @do(%iter_t0) : (tensor<?x?xf32>) -> tensor<?x?xf32>
-    scf.yield %2 : tensor<?x?xf32>
-  } {some_attr}
-//   CHECK-NOT: tensor.cast
-//       CHECK: %[[RES:.*]] = tensor.insert_slice %[[FOR_RES]] into %[[T1]][0, 0] [32, 1024] [1, 1] : tensor<32x1024xf32> into tensor<1024x1024xf32>
-//       CHECK: return %[[RES]] : tensor<1024x1024xf32>
-  %2 = tensor.cast %1 : tensor<?x?xf32> to tensor<32x1024xf32>
-  %res = tensor.insert_slice %2 into %t1[0, 0] [32, 1024] [1, 1] : tensor<32x1024xf32> into tensor<1024x1024xf32>
-  return %res : tensor<1024x1024xf32>
-}
+//       CHECK: %[[RES:.*]] = tensor.cast
+//       CHECK: return %[[RES]] : tensor<?x?xf32>
 
 // -----
 


        


More information about the Mlir-commits mailing list