[Mlir-commits] [mlir] [mlir][vector] Fix crash when folding 0D extract from splat/broadcast (PR #95918)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Tue Jun 18 06:17:19 PDT 2024
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir
Author: Matthias Springer (matthias-springer)
<details>
<summary>Changes</summary>
There was an assertion in the folder that caused a crash when extracting from a vector that is defined by an op with 0D semantics. This commit removes the assertion and adds test cases to ensure that 0D scenarios are handled correctly.
---
Full diff: https://github.com/llvm/llvm-project/pull/95918.diff
2 Files Affected:
- (modified) mlir/lib/Dialect/Vector/IR/VectorOps.cpp (-5)
- (modified) mlir/test/Dialect/Vector/canonicalize.mlir (+38)
``````````diff
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index 5e5d3e002086a..2bf4f16f96e6a 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -1631,11 +1631,6 @@ static Value foldExtractFromBroadcast(ExtractOp extractOp) {
if (!defOp || !isa<vector::BroadcastOp, SplatOp>(defOp))
return Value();
- // 0-D vectors not supported.
- assert(!hasZeroDimVectors(extractOp) && "0-D vectors not supported");
- if (hasZeroDimVectors(defOp))
- return Value();
-
Value source = defOp->getOperand(0);
if (extractOp.getType() == source.getType())
return source;
diff --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir
index 61269e3687ab3..caccd1f1c9c24 100644
--- a/mlir/test/Dialect/Vector/canonicalize.mlir
+++ b/mlir/test/Dialect/Vector/canonicalize.mlir
@@ -2604,3 +2604,41 @@ func.func @extract_from_0d_regression(%v: vector<f32>) -> f32 {
%0 = vector.extract %v[] : f32 from vector<f32>
return %0 : f32
}
+
+// -----
+
+// CHECK-LABEL: func @extract_from_0d_splat_broadcast_regression(
+// CHECK-SAME: %[[a:.*]]: f32, %[[b:.*]]: vector<f32>, %[[c:.*]]: vector<2xf32>)
+func.func @extract_from_0d_splat_broadcast_regression(%a: f32, %b: vector<f32>, %c: vector<2xf32>) -> (f32, f32, f32, f32, f32, vector<6x7xf32>, vector<3xf32>) {
+ // Splat scalar to 0D and extract scalar.
+ %0 = vector.splat %a : vector<f32>
+ %1 = vector.extract %0[] : f32 from vector<f32>
+
+ // Broadcast scalar to 0D and extract scalar.
+ %2 = vector.broadcast %a : f32 to vector<f32>
+ %3 = vector.extract %2[] : f32 from vector<f32>
+
+ // Broadcast 0D to 3D and extract scalar.
+ // CHECK: %[[extract1:.*]] = vector.extractelement %[[b]][] : vector<f32>
+ %4 = vector.broadcast %b : vector<f32> to vector<1x2x4xf32>
+ %5 = vector.extract %4[0, 0, 1] : f32 from vector<1x2x4xf32>
+
+ // Splat scalar to 2D and extract scalar.
+ %6 = vector.splat %a : vector<2x3xf32>
+ %7 = vector.extract %6[0, 1] : f32 from vector<2x3xf32>
+
+ // Broadcast scalar to 3D and extract scalar.
+ %8 = vector.broadcast %a : f32 to vector<5x6x7xf32>
+ %9 = vector.extract %8[2, 1, 5] : f32 from vector<5x6x7xf32>
+
+ // Extract 2D from 3D that was broadcasted from a scalar.
+ // CHECK: %[[extract2:.*]] = vector.broadcast %[[a]] : f32 to vector<6x7xf32>
+ %10 = vector.extract %8[2] : vector<6x7xf32> from vector<5x6x7xf32>
+
+ // Extract 1D from 2D that was splat'ed from a scalar.
+ // CHECK: %[[extract3:.*]] = vector.broadcast %[[a]] : f32 to vector<3xf32>
+ %11 = vector.extract %6[1] : vector<3xf32> from vector<2x3xf32>
+
+ // CHECK: return %[[a]], %[[a]], %[[extract1]], %[[a]], %[[a]], %[[extract2]], %[[extract3]]
+ return %1, %3, %5, %7, %9, %10, %11 : f32, f32, f32, f32, f32, vector<6x7xf32>, vector<3xf32>
+}
``````````
</details>
https://github.com/llvm/llvm-project/pull/95918
More information about the Mlir-commits
mailing list