[Mlir-commits] [mlir] da37c76 - [mlir][vector] Add a check to ensure input vector rank equals target shape rank (#127706)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Tue Feb 25 17:39:28 PST 2025
Author: Prakhar Dixit
Date: 2025-02-26T10:39:24+09:00
New Revision: da37c76ac621c64216e56ead3efe1bd569250ee2
URL: https://github.com/llvm/llvm-project/commit/da37c76ac621c64216e56ead3efe1bd569250ee2
DIFF: https://github.com/llvm/llvm-project/commit/da37c76ac621c64216e56ead3efe1bd569250ee2.diff
LOG: [mlir][vector] Add a check to ensure input vector rank equals target shape rank (#127706)
Fixes issue #126197
The crash is caused because, during IR transformation, the
vector-unrolling pass (using ExtractStridedSliceOp) attempts to slice an
input vector of higher rank using a target vector of lower rank, which
is not supported.
Specific example :
```
module {
func.func @func1() {
%cst_25 = arith.constant dense<3.718400e+04> : vector<4x2x2xf16>
%cst_26 = arith.constant dense<1.000000e+00> : vector<24x2x2xf32>
%47 = vector.fma %cst_26, %cst_26, %cst_26 : vector<24x2x2xf32>
%818 = scf.execute_region -> vector<24x2x2xf32> {
scf.yield %47 : vector<24x2x2xf32>
}
%823 = vector.extract_strided_slice %cst_25 {offsets = [2], sizes = [1], strides = [1]} : vector<4x2x2xf16> to vector<1x2x2xf16>
return
}
}
```
---------
Co-authored-by: Kai Sasaki <lewuathe at gmail.com>
Added:
Modified:
mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp
mlir/test/Dialect/Vector/vector-unroll-options.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp
index c1e3850f05c5e..08ba972b12ce6 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp
@@ -437,6 +437,12 @@ struct UnrollElementwisePattern : public RewritePattern {
auto dstVecType = cast<VectorType>(op->getResult(0).getType());
SmallVector<int64_t> originalSize =
*cast<VectorUnrollOpInterface>(op).getShapeForUnroll();
+ // Bail-out if rank(source) != rank(target). The main limitation here is the
+ // fact that `ExtractStridedSlice` requires the rank for the input and
+ // output to match. If needed, we can relax this later.
+ if (originalSize.size() != targetShape->size())
+ return rewriter.notifyMatchFailure(
+ op, "expected input vector rank to match target shape rank");
Location loc = op->getLoc();
// Prepare the result vector.
Value result = rewriter.create<arith::ConstantOp>(
diff --git a/mlir/test/Dialect/Vector/vector-unroll-options.mlir b/mlir/test/Dialect/Vector/vector-unroll-options.mlir
index 7e3fe56f6b124..16d30aec7c041 100644
--- a/mlir/test/Dialect/Vector/vector-unroll-options.mlir
+++ b/mlir/test/Dialect/Vector/vector-unroll-options.mlir
@@ -188,6 +188,16 @@ func.func @vector_fma(%a: vector<4x4xf32>, %b: vector<4x4xf32>, %c: vector<4x4xf
// CHECK-LABEL: func @vector_fma
// CHECK-COUNT-4: vector.fma %{{.+}}, %{{.+}}, %{{.+}} : vector<2x2xf32>
+// TODO: We should be able to unroll this like the example above - this will require extending UnrollElementwisePattern.
+func.func @negative_vector_fma_3d(%a: vector<3x2x2xf32>) -> vector<3x2x2xf32>{
+ %0 = vector.fma %a, %a, %a : vector<3x2x2xf32>
+ return %0 : vector<3x2x2xf32>
+}
+// CHECK-LABEL: func @negative_vector_fma_3d
+// CHECK-NOT: vector.extract_strided_slice
+// CHECK: %[[R0:.*]] = vector.fma %{{.+}} : vector<3x2x2xf32>
+// CHECK: return
+
func.func @vector_multi_reduction(%v : vector<4x6xf32>, %acc: vector<4xf32>) -> vector<4xf32> {
%0 = vector.multi_reduction #vector.kind<add>, %v, %acc [1] : vector<4x6xf32> to vector<4xf32>
return %0 : vector<4xf32>
More information about the Mlir-commits
mailing list