[Mlir-commits] [mlir] f46744b - [mlir][linalg] Fix FoldTensorCastConsumerOp invalid folding
Ivan Butygin
llvmlistbot at llvm.org
Fri Jul 22 02:39:34 PDT 2022
Author: Ivan Butygin
Date: 2022-07-22T11:39:12+02:00
New Revision: f46744bd2a193402a7ea268e4d4a3c9bcbd0f25d
URL: https://github.com/llvm/llvm-project/commit/f46744bd2a193402a7ea268e4d4a3c9bcbd0f25d
DIFF: https://github.com/llvm/llvm-project/commit/f46744bd2a193402a7ea268e4d4a3c9bcbd0f25d.diff
LOG: [mlir][linalg] Fix FoldTensorCastConsumerOp invalid folding
CastOp can be in conditionally reachable region, in which case this folding will be invalid.
Only conservatively fold ops in same block for now.
Fixes https://github.com/llvm/llvm-project/issues/56557
Differential Revision: https://reviews.llvm.org/D130314
Added:
Modified:
mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
mlir/test/Dialect/Linalg/canonicalize.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index c96a1fd01021b..56ce00a0ef803 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -1712,10 +1712,17 @@ struct FoldTensorCastConsumerOp : public OpRewritePattern<tensor::CastOp> {
PatternRewriter &rewriter) const override {
if (!tensor::canFoldIntoProducerOp(castOp))
return failure();
+
auto linalgOp = castOp.getSource().getDefiningOp<LinalgOp>();
if (!linalgOp)
return failure();
+ // Cast can be in conditionally reachable region, if which case folding will
+ // generate invalid code. Only conservatively fold ops in same block for
+ // now.
+ if (castOp->getBlock() != linalgOp->getBlock())
+ return failure();
+
OpBuilder::InsertionGuard guard(rewriter);
rewriter.setInsertionPoint(linalgOp);
diff --git a/mlir/test/Dialect/Linalg/canonicalize.mlir b/mlir/test/Dialect/Linalg/canonicalize.mlir
index b08af21b4bfce..51a7bf6b0ccee 100644
--- a/mlir/test/Dialect/Linalg/canonicalize.mlir
+++ b/mlir/test/Dialect/Linalg/canonicalize.mlir
@@ -846,6 +846,33 @@ func.func @fold_linalgop_with_cast_consumer(%arg0 : tensor<?x?xf32>, %arg1 : ten
// CHECK: %[[RESULT_CAST:.+]] = tensor.cast %[[MATMUL]]
// CHECK: return %[[MATMUL]], %[[RESULT_CAST]]
+// -----
+
+func.func private @some_use(%0 : tensor<4x8xf32>)
+
+func.func @linalgop_with_cond_cast_consumer(%arg0 : tensor<?x?xf32>, %arg1 : tensor<?x?xf32>,
+ %arg2 : tensor<?x?xf32>, %arg3 : i1) -> tensor<?x?xf32> {
+ %0 = linalg.matmul ins(%arg0, %arg1 : tensor<?x?xf32>, tensor<?x?xf32>)
+ outs(%arg2 : tensor<?x?xf32>) -> tensor<?x?xf32>
+ scf.if %arg3 {
+ %1 = tensor.cast %0 : tensor<?x?xf32> to tensor<4x8xf32>
+ func.call @some_use(%1) : (tensor<4x8xf32>) -> ()
+ }
+ return %0 : tensor<?x?xf32>
+}
+
+// Check conditionally reachable cast is not folded into producer.
+// CHECK-LABEL: func @linalgop_with_cond_cast_consumer
+// CHECK-SAME: (%[[ARG0:.*]]: tensor<?x?xf32>, %[[ARG1:.*]]: tensor<?x?xf32>, %[[ARG2:.*]]: tensor<?x?xf32>, %[[ARG3:.*]]: i1)
+// CHECK: %[[RES:.*]] = linalg.matmul ins(%[[ARG0]], %[[ARG1]] : tensor<?x?xf32>, tensor<?x?xf32>)
+// CHECK-SAME: outs(%[[ARG2]] : tensor<?x?xf32>) -> tensor<?x?xf32>
+// CHECK: scf.if %[[ARG3]] {
+// CHECK: %[[CAST:.*]] = tensor.cast %[[RES]] : tensor<?x?xf32> to tensor<4x8xf32>
+// CHECK: func.call @some_use(%[[CAST]]) : (tensor<4x8xf32>) -> ()
+// CHECK: }
+// CHECK: return %[[RES]] : tensor<?x?xf32>
+
+
// -----
func.func @fold_conv_op_with_cast_consumer(%arg0 : tensor<?x?x?x?xf32>,
More information about the Mlir-commits
mailing list