[Mlir-commits] [mlir] da37c76 - [mlir][vector] Add a check to ensure input vector rank equals target shape rank (#127706)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Tue Feb 25 17:39:28 PST 2025


Author: Prakhar Dixit
Date: 2025-02-26T10:39:24+09:00
New Revision: da37c76ac621c64216e56ead3efe1bd569250ee2

URL: https://github.com/llvm/llvm-project/commit/da37c76ac621c64216e56ead3efe1bd569250ee2
DIFF: https://github.com/llvm/llvm-project/commit/da37c76ac621c64216e56ead3efe1bd569250ee2.diff

LOG: [mlir][vector] Add a check to ensure input vector rank equals target shape rank (#127706)

Fixes issue #126197

The crash is caused because, during IR transformation, the
vector-unrolling pass (using ExtractStridedSliceOp) attempts to slice an
input vector of higher rank using a target vector of lower rank, which
is not supported.

Specific example :
```
module {
  func.func @func1() {
    %cst_25 = arith.constant dense<3.718400e+04> : vector<4x2x2xf16>
    %cst_26 = arith.constant dense<1.000000e+00> : vector<24x2x2xf32>
    %47 = vector.fma %cst_26, %cst_26, %cst_26 : vector<24x2x2xf32>
    %818 = scf.execute_region -> vector<24x2x2xf32> {
        scf.yield %47 : vector<24x2x2xf32>
      }
    %823 = vector.extract_strided_slice %cst_25 {offsets = [2], sizes = [1], strides = [1]} : vector<4x2x2xf16> to vector<1x2x2xf16>
    return
  }
}
```

---------

Co-authored-by: Kai Sasaki <lewuathe at gmail.com>

Added: 
    

Modified: 
    mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp
    mlir/test/Dialect/Vector/vector-unroll-options.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp
index c1e3850f05c5e..08ba972b12ce6 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp
@@ -437,6 +437,12 @@ struct UnrollElementwisePattern : public RewritePattern {
     auto dstVecType = cast<VectorType>(op->getResult(0).getType());
     SmallVector<int64_t> originalSize =
         *cast<VectorUnrollOpInterface>(op).getShapeForUnroll();
+    // Bail-out if rank(source) != rank(target). The main limitation here is the
+    // fact that `ExtractStridedSlice` requires the rank for the input and
+    // output to match. If needed, we can relax this later.
+    if (originalSize.size() != targetShape->size())
+      return rewriter.notifyMatchFailure(
+          op, "expected input vector rank to match target shape rank");
     Location loc = op->getLoc();
     // Prepare the result vector.
     Value result = rewriter.create<arith::ConstantOp>(

diff  --git a/mlir/test/Dialect/Vector/vector-unroll-options.mlir b/mlir/test/Dialect/Vector/vector-unroll-options.mlir
index 7e3fe56f6b124..16d30aec7c041 100644
--- a/mlir/test/Dialect/Vector/vector-unroll-options.mlir
+++ b/mlir/test/Dialect/Vector/vector-unroll-options.mlir
@@ -188,6 +188,16 @@ func.func @vector_fma(%a: vector<4x4xf32>, %b: vector<4x4xf32>, %c: vector<4x4xf
 //   CHECK-LABEL: func @vector_fma
 // CHECK-COUNT-4: vector.fma %{{.+}}, %{{.+}}, %{{.+}} : vector<2x2xf32>
 
+// TODO: We should be able to unroll this like the example above - this will require extending UnrollElementwisePattern.
+func.func @negative_vector_fma_3d(%a: vector<3x2x2xf32>) -> vector<3x2x2xf32>{
+  %0 = vector.fma %a, %a, %a : vector<3x2x2xf32>
+  return %0 : vector<3x2x2xf32>
+}
+// CHECK-LABEL: func @negative_vector_fma_3d
+//   CHECK-NOT: vector.extract_strided_slice
+//       CHECK: %[[R0:.*]] = vector.fma %{{.+}} : vector<3x2x2xf32>
+//       CHECK: return 
+
 func.func @vector_multi_reduction(%v : vector<4x6xf32>, %acc: vector<4xf32>) -> vector<4xf32> {
   %0 = vector.multi_reduction #vector.kind<add>, %v, %acc [1] : vector<4x6xf32> to vector<4xf32>
   return %0 : vector<4xf32>


        


More information about the Mlir-commits mailing list