[Mlir-commits] [mlir] a583616 - [mlir][vector] Fix error handling in VectorizationState::initState

Matthias Springer llvmlistbot at llvm.org
Mon Dec 19 01:24:12 PST 2022


Author: Matthias Springer
Date: 2022-12-19T10:20:05+01:00
New Revision: a583616918ec2984092c99363f200dd31642859a

URL: https://github.com/llvm/llvm-project/commit/a583616918ec2984092c99363f200dd31642859a
DIFF: https://github.com/llvm/llvm-project/commit/a583616918ec2984092c99363f200dd31642859a.diff

LOG: [mlir][vector] Fix error handling in VectorizationState::initState

This function used to create new ops even if the vectorization failed. Those ops were then folded away. This caused a failure of the GreedyPatternRewriter, which no longer terminated (each time the IR is modified => one more iteration).

Differential Revision: https://reviews.llvm.org/D140286

Added: 
    

Modified: 
    mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
    mlir/test/Dialect/Linalg/vectorization.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
index 5a407b9054fb4..2d6628f04f05f 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
@@ -179,6 +179,9 @@ VectorizationState::initState(RewriterBase &rewriter, LinalgOp linalgOp,
   LLVM_DEBUG(llvm::interleaveComma(canonicalVecShape, llvm::dbgs()));
   LLVM_DEBUG(llvm::dbgs() << "\n");
 
+  if (ShapedType::isDynamicShape(canonicalVecShape))
+    return failure();
+
   // Initialize iteration space static sizes.
   initIterSpaceStaticSizes(linalgOp);
 
@@ -187,8 +190,6 @@ VectorizationState::initState(RewriterBase &rewriter, LinalgOp linalgOp,
   if (failed(precomputeIterSpaceDynamicSizes(rewriter, linalgOp)))
     return failure();
 
-  if (ShapedType::isDynamicShape(canonicalVecShape))
-    return failure();
   return success();
 }
 

diff  --git a/mlir/test/Dialect/Linalg/vectorization.mlir b/mlir/test/Dialect/Linalg/vectorization.mlir
index 96c81d1593a34..74ef59e3ece05 100644
--- a/mlir/test/Dialect/Linalg/vectorization.mlir
+++ b/mlir/test/Dialect/Linalg/vectorization.mlir
@@ -1750,3 +1750,34 @@ transform.sequence failures(propagate) {
   transform.structured.masked_vectorize %0 vector_sizes [4, 8]
 }
 
+// -----
+
+// This is a regression test. This IR cannot be vectorized, but
+// structured.vectorize should nevertheless succeed.
+
+#map = affine_map<(d0) -> (d0)>
+// CHECK-LABEL:   @not_vectorizable
+func.func @not_vectorizable(%arg0: tensor<1x?xf32>, %arg1: index, %arg2: index, %arg3: index) -> tensor<1x128xf32> {
+  %0 = tensor.empty() : tensor<1x128xf32>
+  %1 = scf.for %arg5 = %arg2 to %arg1 step %arg3 iter_args(%arg6 = %0) -> (tensor<1x128xf32>) {
+    %extracted_slice = tensor.extract_slice %arg6[0, 0] [1, %arg1] [1, 1] : tensor<1x128xf32> to tensor<?xf32>
+    %expanded = tensor.expand_shape %extracted_slice [[0, 1]] : tensor<?xf32> into tensor<1x?xf32>
+    %extracted_slice_0 = tensor.extract_slice %arg0[0, %arg3] [1, %arg2] [1, 1] : tensor<1x?xf32> to tensor<?xf32>
+    %extracted_slice_1 = tensor.extract_slice %expanded[0, %arg3] [1, %arg2] [1, 1] : tensor<1x?xf32> to tensor<?xf32>
+    %2 = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel"]} ins(%extracted_slice_0 : tensor<?xf32>) outs(%extracted_slice_1 : tensor<?xf32>) {
+    ^bb0(%in: f32, %out: f32):
+      %3 = arith.addf %in, %out : f32
+      linalg.yield %3 : f32
+    } -> tensor<?xf32>
+    %inserted_slice = tensor.insert_slice %2 into %expanded[0, %arg3] [1, %arg2] [1, 1] : tensor<?xf32> into tensor<1x?xf32>
+    %collapsed = tensor.collapse_shape %inserted_slice [[0, 1]] : tensor<1x?xf32> into tensor<?xf32>
+    %inserted_slice_2 = tensor.insert_slice %collapsed into %arg6[0, 0] [1, %arg1] [1, 1] : tensor<?xf32> into tensor<1x128xf32>
+    scf.yield %inserted_slice_2 : tensor<1x128xf32>
+  }
+  return %1 : tensor<1x128xf32>
+}
+transform.sequence failures(propagate) {
+^bb0(%arg0: !pdl.operation):
+  %0 = transform.structured.match ops{["func.func"]} in %arg0
+  %1 = transform.structured.vectorize %0
+}


        


More information about the Mlir-commits mailing list