[Mlir-commits] [mlir] [mlir][vector] Fix crash when folding 0D extract from splat/broadcast (PR #95918)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Tue Jun 18 06:17:19 PDT 2024


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir

Author: Matthias Springer (matthias-springer)

<details>
<summary>Changes</summary>

There was an assertion in the folder that caused a crash when extracting from a vector that is defined by an op with 0D semantics. This commit removes the assertion and adds test cases to ensure that 0D scenarios are handled correctly.

---
Full diff: https://github.com/llvm/llvm-project/pull/95918.diff


2 Files Affected:

- (modified) mlir/lib/Dialect/Vector/IR/VectorOps.cpp (-5) 
- (modified) mlir/test/Dialect/Vector/canonicalize.mlir (+38) 


``````````diff
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index 5e5d3e002086a..2bf4f16f96e6a 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -1631,11 +1631,6 @@ static Value foldExtractFromBroadcast(ExtractOp extractOp) {
   if (!defOp || !isa<vector::BroadcastOp, SplatOp>(defOp))
     return Value();
 
-  // 0-D vectors not supported.
-  assert(!hasZeroDimVectors(extractOp) && "0-D vectors not supported");
-  if (hasZeroDimVectors(defOp))
-    return Value();
-
   Value source = defOp->getOperand(0);
   if (extractOp.getType() == source.getType())
     return source;
diff --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir
index 61269e3687ab3..caccd1f1c9c24 100644
--- a/mlir/test/Dialect/Vector/canonicalize.mlir
+++ b/mlir/test/Dialect/Vector/canonicalize.mlir
@@ -2604,3 +2604,41 @@ func.func @extract_from_0d_regression(%v: vector<f32>) -> f32 {
   %0 = vector.extract %v[] : f32 from vector<f32>
   return %0 : f32
 }
+
+// -----
+
+// CHECK-LABEL: func @extract_from_0d_splat_broadcast_regression(
+//  CHECK-SAME:     %[[a:.*]]: f32, %[[b:.*]]: vector<f32>, %[[c:.*]]: vector<2xf32>)
+func.func @extract_from_0d_splat_broadcast_regression(%a: f32, %b: vector<f32>, %c: vector<2xf32>) -> (f32, f32, f32, f32, f32, vector<6x7xf32>, vector<3xf32>) {
+  // Splat scalar to 0D and extract scalar.
+  %0 = vector.splat %a : vector<f32>
+  %1 = vector.extract %0[] : f32 from vector<f32>
+
+  // Broadcast scalar to 0D and extract scalar.
+  %2 = vector.broadcast %a : f32 to vector<f32>
+  %3 = vector.extract %2[] : f32 from vector<f32>
+
+  // Broadcast 0D to 3D and extract scalar.
+  // CHECK: %[[extract1:.*]] = vector.extractelement %[[b]][] : vector<f32>
+  %4 = vector.broadcast %b : vector<f32> to vector<1x2x4xf32>
+  %5 = vector.extract %4[0, 0, 1] : f32 from vector<1x2x4xf32>
+
+  // Splat scalar to 2D and extract scalar.
+  %6 = vector.splat %a : vector<2x3xf32>
+  %7 = vector.extract %6[0, 1] : f32 from vector<2x3xf32>
+
+  // Broadcast scalar to 3D and extract scalar.
+  %8 = vector.broadcast %a : f32 to vector<5x6x7xf32>
+  %9 = vector.extract %8[2, 1, 5] : f32 from vector<5x6x7xf32>
+
+  // Extract 2D from 3D that was broadcasted from a scalar.
+  // CHECK: %[[extract2:.*]] = vector.broadcast %[[a]] : f32 to vector<6x7xf32>
+  %10 = vector.extract %8[2] : vector<6x7xf32> from vector<5x6x7xf32>
+
+  // Extract 1D from 2D that was splat'ed from a scalar.
+  // CHECK: %[[extract3:.*]] = vector.broadcast %[[a]] : f32 to vector<3xf32>
+  %11 = vector.extract %6[1] : vector<3xf32> from vector<2x3xf32>
+
+  // CHECK:   return %[[a]], %[[a]], %[[extract1]], %[[a]], %[[a]], %[[extract2]], %[[extract3]]
+  return %1, %3, %5, %7, %9, %10, %11 : f32, f32, f32, f32, f32, vector<6x7xf32>, vector<3xf32>
+}

``````````

</details>


https://github.com/llvm/llvm-project/pull/95918


More information about the Mlir-commits mailing list