[llvm-branch-commits] [mlir] c939331 - [mlir][linalg] Fix incorrect reduction detection in Vectorizer

Diego Caballero via llvm-branch-commits llvm-branch-commits at lists.llvm.org
Thu Jan 12 15:27:32 PST 2023


Author: Diego Caballero
Date: 2023-01-12T23:15:29Z
New Revision: c939331e4d97ec667ebf6bd470ff069b42ea6fb4

URL: https://github.com/llvm/llvm-project/commit/c939331e4d97ec667ebf6bd470ff069b42ea6fb4
DIFF: https://github.com/llvm/llvm-project/commit/c939331e4d97ec667ebf6bd470ff069b42ea6fb4.diff

LOG: [mlir][linalg] Fix incorrect reduction detection in Vectorizer

When detecting reductions, make sure the block argument is from the linalg generic op.
This fixes https://github.com/iree-org/iree/issues/11779.

Co-authored-by: Andrzej Warzynski <andrzej.warzynski at arm.com>

Reviewed By: nicolasvasilache

Differential Revision: https://reviews.llvm.org/D141413

Added: 
    

Modified: 
    mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
    mlir/test/Dialect/Linalg/vectorization.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
index 4f3d55ffde2bc..1e8335012504d 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
@@ -752,13 +752,14 @@ vectorizeOneOp(RewriterBase &rewriter, LinalgOp linalgOp, Operation *op,
   // 4 . Check if the operation is a reduction.
   SmallVector<std::pair<Value, Value>> reductionOperands;
   for (Value operand : op->getOperands()) {
-    auto arg = operand.dyn_cast<BlockArgument>();
-    if (!arg || arg.getArgNumber() < linalgOp.getNumDpsInputs())
+    auto blockArg = operand.dyn_cast<BlockArgument>();
+    if (!blockArg || blockArg.getOwner() != linalgOp.getBlock() ||
+        blockArg.getArgNumber() < linalgOp.getNumDpsInputs())
       continue;
     SmallVector<Operation *> reductionOps;
     Value reduceValue = matchReduction(
         linalgOp.getRegionOutputArgs(),
-        arg.getArgNumber() - linalgOp.getNumDpsInputs(), reductionOps);
+        blockArg.getArgNumber() - linalgOp.getNumDpsInputs(), reductionOps);
     if (!reduceValue)
       continue;
     reductionOperands.push_back(std::make_pair(reduceValue, operand));

diff  --git a/mlir/test/Dialect/Linalg/vectorization.mlir b/mlir/test/Dialect/Linalg/vectorization.mlir
index 2c7d34066a4bc..0ccd6c4b96733 100644
--- a/mlir/test/Dialect/Linalg/vectorization.mlir
+++ b/mlir/test/Dialect/Linalg/vectorization.mlir
@@ -1853,3 +1853,43 @@ transform.sequence failures(propagate) {
   %0 = transform.structured.match ops{["func.func"]} in %arg0
   %1 = transform.structured.vectorize %0
 }
+
+// -----
+
+// Regression test: %13 was incorrectly detected as a reduction and
+// vectorization failed.
+
+func.func @wrong_reduction_detection(%input: tensor<120x64xf32>) -> tensor<120x64xf32> {
+  %c0 = arith.constant 0 : index
+  %c4 = arith.constant 4 : index
+  %c64 = arith.constant 64 : index
+  %cst_6 = arith.constant 4.000000e+00 : f32
+  %1 = scf.for %arg0 = %c0 to %c64 step %c4 iter_args(%arg1 = %input) -> (tensor<120x64xf32>) {
+    %extracted_slice = tensor.extract_slice %arg1[%c0, %arg0] [1, 4] [1, 1] : tensor<120x64xf32> to tensor<1x4xf32>
+    %10 = linalg.fill {__internal_linalg_transform__ = "1"} ins(%cst_6 : f32) outs(%extracted_slice : tensor<1x4xf32>) -> tensor<1x4xf32>
+    %11 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"]} outs(%10 : tensor<1x4xf32>) {
+    ^bb0(%out: f32):
+      %12 = linalg.index 0 : index
+      %13 = arith.addi %arg0, %12 : index
+      %18 = arith.index_cast %13 : index to i32
+      %20 = arith.uitofp %18 : i32 to f32
+      %67 = arith.mulf %out, %20 : f32
+      linalg.yield %67 : f32
+    } -> tensor<1x4xf32>
+    %inserted_slice = tensor.insert_slice %11 into %arg1[%c0, %arg0] [1, 4] [1, 1] : tensor<1x4xf32> into tensor<120x64xf32>
+    scf.yield %inserted_slice : tensor<120x64xf32>
+  }
+  return %1 : tensor<120x64xf32>
+}
+
+transform.sequence failures(propagate) {
+^bb1(%arg1: !pdl.operation):
+  %0 = transform.structured.match ops{["linalg.generic"]} in %arg1
+  %1 = get_closest_isolated_parent %0 : (!pdl.operation) -> !pdl.operation
+  %2 = transform.structured.vectorize %1
+}
+
+// CHECK-LABEL: @wrong_reduction_detection
+// CHECK:         vector.broadcast
+// CHECK:         vector.transfer_write
+


        


More information about the llvm-branch-commits mailing list