[Mlir-commits] [mlir] f46744b - [mlir][linalg] Fix FoldTensorCastConsumerOp invalid folding

Ivan Butygin llvmlistbot at llvm.org
Fri Jul 22 02:39:34 PDT 2022


Author: Ivan Butygin
Date: 2022-07-22T11:39:12+02:00
New Revision: f46744bd2a193402a7ea268e4d4a3c9bcbd0f25d

URL: https://github.com/llvm/llvm-project/commit/f46744bd2a193402a7ea268e4d4a3c9bcbd0f25d
DIFF: https://github.com/llvm/llvm-project/commit/f46744bd2a193402a7ea268e4d4a3c9bcbd0f25d.diff

LOG: [mlir][linalg] Fix FoldTensorCastConsumerOp invalid folding

CastOp can be in conditionally reachable region, in which case this folding will be invalid.
Only conservatively fold ops in same block for now.

Fixes https://github.com/llvm/llvm-project/issues/56557

Differential Revision: https://reviews.llvm.org/D130314

Added: 
    

Modified: 
    mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
    mlir/test/Dialect/Linalg/canonicalize.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index c96a1fd01021b..56ce00a0ef803 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -1712,10 +1712,17 @@ struct FoldTensorCastConsumerOp : public OpRewritePattern<tensor::CastOp> {
                                 PatternRewriter &rewriter) const override {
     if (!tensor::canFoldIntoProducerOp(castOp))
       return failure();
+
     auto linalgOp = castOp.getSource().getDefiningOp<LinalgOp>();
     if (!linalgOp)
       return failure();
 
+    // Cast can be in conditionally reachable region, if which case folding will
+    // generate invalid code. Only conservatively fold ops in same block for
+    // now.
+    if (castOp->getBlock() != linalgOp->getBlock())
+      return failure();
+
     OpBuilder::InsertionGuard guard(rewriter);
     rewriter.setInsertionPoint(linalgOp);
 

diff  --git a/mlir/test/Dialect/Linalg/canonicalize.mlir b/mlir/test/Dialect/Linalg/canonicalize.mlir
index b08af21b4bfce..51a7bf6b0ccee 100644
--- a/mlir/test/Dialect/Linalg/canonicalize.mlir
+++ b/mlir/test/Dialect/Linalg/canonicalize.mlir
@@ -846,6 +846,33 @@ func.func @fold_linalgop_with_cast_consumer(%arg0 : tensor<?x?xf32>, %arg1 : ten
 //       CHECK:  %[[RESULT_CAST:.+]] = tensor.cast %[[MATMUL]]
 //       CHECK:  return %[[MATMUL]], %[[RESULT_CAST]]
 
+// -----
+
+func.func private @some_use(%0 : tensor<4x8xf32>)
+
+func.func @linalgop_with_cond_cast_consumer(%arg0 : tensor<?x?xf32>, %arg1 : tensor<?x?xf32>,
+    %arg2 : tensor<?x?xf32>, %arg3 : i1) -> tensor<?x?xf32> {
+  %0 = linalg.matmul ins(%arg0, %arg1 : tensor<?x?xf32>, tensor<?x?xf32>)
+      outs(%arg2 : tensor<?x?xf32>) -> tensor<?x?xf32>
+  scf.if %arg3 {
+    %1 = tensor.cast %0 : tensor<?x?xf32> to tensor<4x8xf32>
+    func.call @some_use(%1) : (tensor<4x8xf32>) -> ()
+  }
+  return %0 : tensor<?x?xf32>
+}
+
+// Check conditionally reachable cast is not folded into producer.
+// CHECK-LABEL: func @linalgop_with_cond_cast_consumer
+//  CHECK-SAME:     (%[[ARG0:.*]]: tensor<?x?xf32>, %[[ARG1:.*]]: tensor<?x?xf32>, %[[ARG2:.*]]: tensor<?x?xf32>, %[[ARG3:.*]]: i1)
+//       CHECK: %[[RES:.*]] = linalg.matmul ins(%[[ARG0]], %[[ARG1]] : tensor<?x?xf32>, tensor<?x?xf32>)
+//  CHECK-SAME:      outs(%[[ARG2]] : tensor<?x?xf32>) -> tensor<?x?xf32>
+//       CHECK: scf.if %[[ARG3]] {
+//       CHECK:   %[[CAST:.*]] = tensor.cast %[[RES]] : tensor<?x?xf32> to tensor<4x8xf32>
+//       CHECK:   func.call @some_use(%[[CAST]]) : (tensor<4x8xf32>) -> ()
+//       CHECK: }
+//       CHECK: return %[[RES]] : tensor<?x?xf32>
+
+
 // -----
 
 func.func @fold_conv_op_with_cast_consumer(%arg0 : tensor<?x?x?x?xf32>,


        


More information about the Mlir-commits mailing list