[Mlir-commits] [mlir] a5cdcf4 - [mlir] Fix folding for scf.for(tensor.cast).
Alexander Belyaev
llvmlistbot at llvm.org
Thu Feb 23 02:23:33 PST 2023
Author: Alexander Belyaev
Date: 2023-02-23T11:23:11+01:00
New Revision: a5cdcf49b2f7895dbdf3d00fe8cd7d3fb4f9d38b
URL: https://github.com/llvm/llvm-project/commit/a5cdcf49b2f7895dbdf3d00fe8cd7d3fb4f9d38b
DIFF: https://github.com/llvm/llvm-project/commit/a5cdcf49b2f7895dbdf3d00fe8cd7d3fb4f9d38b.diff
LOG: [mlir] Fix folding for scf.for(tensor.cast).
We should only fold tensor.casts that provide some new static information about
shapes, instead of looking for a symmetric pattern cast(for(cast)).
Differential Revision: https://reviews.llvm.org/D144577
Added:
Modified:
mlir/lib/Dialect/SCF/IR/SCF.cpp
mlir/test/Dialect/SCF/canonicalize.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/SCF/IR/SCF.cpp b/mlir/lib/Dialect/SCF/IR/SCF.cpp
index 80038b9c436f2..f8fd2016cd9a9 100644
--- a/mlir/lib/Dialect/SCF/IR/SCF.cpp
+++ b/mlir/lib/Dialect/SCF/IR/SCF.cpp
@@ -894,8 +894,7 @@ static ForOp replaceTensorCastForOpIterArg(PatternRewriter &rewriter,
/// %2 = call @do(%iter_t0) : (tensor<?x?xf32>) -> tensor<?x?xf32>
/// scf.yield %2 : tensor<?x?xf32>
/// }
-/// %2 = tensor.cast %1 : tensor<?x?xf32> to tensor<32x1024xf32>
-/// use_of(%2)
+/// use_of(%1)
/// ```
///
/// folds into:
@@ -908,7 +907,8 @@ static ForOp replaceTensorCastForOpIterArg(PatternRewriter &rewriter,
/// %4 = tensor.cast %3 : tensor<?x?xf32> to tensor<32x1024xf32>
/// scf.yield %4 : tensor<32x1024xf32>
/// }
-/// use_of(%0)
+/// %1 = tensor.cast %0 : tensor<32x1024xf32> to tensor<?x?xf32>
+/// use_of(%1)
/// ```
struct ForOpTensorCastFolder : public OpRewritePattern<ForOp> {
using OpRewritePattern<ForOp>::OpRewritePattern;
@@ -920,17 +920,13 @@ struct ForOpTensorCastFolder : public OpRewritePattern<ForOp> {
auto incomingCast = iterOpOperand.get().getDefiningOp<tensor::CastOp>();
if (!incomingCast)
continue;
+ // If the dest type of the cast does not preserve static information in
+ // the source type.
+ if (!tensor::preservesStaticInformation(incomingCast.getDest().getType(),
+ incomingCast.getSource().getType()))
+ continue;
if (!std::get<1>(it).hasOneUse())
continue;
- auto outgoingCastOp =
- dyn_cast<tensor::CastOp>(*std::get<1>(it).user_begin());
- if (!outgoingCastOp)
- continue;
-
- // Must be a tensor.cast op pair with matching types.
- if (outgoingCastOp.getResult().getType() !=
- incomingCast.getSource().getType())
- continue;
// Create a new ForOp with that iter operand replaced.
auto newForOp = replaceTensorCastForOpIterArg(rewriter, iterOpOperand,
diff --git a/mlir/test/Dialect/SCF/canonicalize.mlir b/mlir/test/Dialect/SCF/canonicalize.mlir
index c211596db7445..d3dfd16ba0442 100644
--- a/mlir/test/Dialect/SCF/canonicalize.mlir
+++ b/mlir/test/Dialect/SCF/canonicalize.mlir
@@ -850,13 +850,20 @@ func.func @fold_away_iter_and_result_with_no_use(%arg0 : i32,
func.func private @do(%arg0: tensor<?x?xf32>) -> tensor<?x?xf32>
-// CHECK-LABEL: matmul_on_tensors
-// CHECK-SAME: %[[T0:[0-9a-z]*]]: tensor<32x1024xf32>
-// CHECK-SAME: %[[T1:[0-9a-z]*]]: tensor<1024x1024xf32>
-func.func @matmul_on_tensors(%t0: tensor<32x1024xf32>, %t1: tensor<1024x1024xf32>) -> tensor<1024x1024xf32> {
+func.func @matmul_on_tensors(%t0: tensor<32x1024xf32>) -> tensor<?x?xf32> {
%c0 = arith.constant 0 : index
%c32 = arith.constant 32 : index
%c1024 = arith.constant 1024 : index
+ %0 = tensor.cast %t0 : tensor<32x1024xf32> to tensor<?x?xf32>
+ %1 = scf.for %i = %c0 to %c1024 step %c32 iter_args(%iter_t0 = %0) -> (tensor<?x?xf32>) {
+ %2 = func.call @do(%iter_t0) : (tensor<?x?xf32>) -> tensor<?x?xf32>
+ scf.yield %2 : tensor<?x?xf32>
+ } {some_attr}
+ return %1 : tensor<?x?xf32>
+}
+// CHECK-LABEL: matmul_on_tensors
+// CHECK-SAME: %[[T0:[0-9a-z]*]]: tensor<32x1024xf32>
+
// CHECK-NOT: tensor.cast
// CHECK: %[[FOR_RES:.*]] = scf.for {{.*}} iter_args(%[[ITER_T0:.*]] = %[[T0]]) -> (tensor<32x1024xf32>) {
// CHECK: %[[CAST:.*]] = tensor.cast %[[ITER_T0]] : tensor<32x1024xf32> to tensor<?x?xf32>
@@ -864,18 +871,8 @@ func.func @matmul_on_tensors(%t0: tensor<32x1024xf32>, %t1: tensor<1024x1024xf32
// CHECK: %[[UNCAST:.*]] = tensor.cast %[[DONE]] : tensor<?x?xf32> to tensor<32x1024xf32>
// CHECK: scf.yield %[[UNCAST]] : tensor<32x1024xf32>
// CHECK: } {some_attr}
- %0 = tensor.cast %t0 : tensor<32x1024xf32> to tensor<?x?xf32>
- %1 = scf.for %i = %c0 to %c1024 step %c32 iter_args(%iter_t0 = %0) -> (tensor<?x?xf32>) {
- %2 = func.call @do(%iter_t0) : (tensor<?x?xf32>) -> tensor<?x?xf32>
- scf.yield %2 : tensor<?x?xf32>
- } {some_attr}
-// CHECK-NOT: tensor.cast
-// CHECK: %[[RES:.*]] = tensor.insert_slice %[[FOR_RES]] into %[[T1]][0, 0] [32, 1024] [1, 1] : tensor<32x1024xf32> into tensor<1024x1024xf32>
-// CHECK: return %[[RES]] : tensor<1024x1024xf32>
- %2 = tensor.cast %1 : tensor<?x?xf32> to tensor<32x1024xf32>
- %res = tensor.insert_slice %2 into %t1[0, 0] [32, 1024] [1, 1] : tensor<32x1024xf32> into tensor<1024x1024xf32>
- return %res : tensor<1024x1024xf32>
-}
+// CHECK: %[[RES:.*]] = tensor.cast
+// CHECK: return %[[RES]] : tensor<?x?xf32>
// -----
More information about the Mlir-commits
mailing list