[Mlir-commits] [mlir] [mlir][vector] Fix vector.broadcast lowering for scalable vectors (PR #66344)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Thu Sep 14 02:05:49 PDT 2023


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir-core
            
<details>
<summary>Changes</summary>
This patch makes sure that the following case is lowered correctly
("duplication"):
```
  func.func @broadcast_scalable_duplication(%arg0: vector<[32]xf32>) -> vector<1x[32]xf32> {
    %res = vector.broadcast %arg0 : vector<[32]xf32> to vector<1x[32]xf32>
    return %res : vector<1x[32]xf32>
  }
```

--
Full diff: https://github.com/llvm/llvm-project/pull/66344.diff

2 Files Affected:

- (modified) mlir/lib/Dialect/Vector/Transforms/LowerVectorBroadcast.cpp (+2-1) 
- (modified) mlir/test/Dialect/Vector/vector-broadcast-lowering-transforms.mlir (+11) 


<pre>
diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorBroadcast.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorBroadcast.cpp
index 7c606e0c35f0899..2937b2d08b06979 100644
--- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorBroadcast.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorBroadcast.cpp
@@ -85,7 +85,8 @@ class BroadcastOpLowering : public OpRewritePattern&lt;vector::BroadcastOp&gt; {
     if (srcRank &lt; dstRank) {
       // Duplication.
       VectorType resType =
-          VectorType::get(dstType.getShape().drop_front(), eltType);
+          VectorType::get(dstType.getShape().drop_front(), eltType,
+                          dstType.getScalableDims().drop_front());
       Value bcst =
           rewriter.create&lt;vector::BroadcastOp&gt;(loc, resType, op.getSource());
       Value result = rewriter.create&lt;arith::ConstantOp&gt;(
diff --git a/mlir/test/Dialect/Vector/vector-broadcast-lowering-transforms.mlir b/mlir/test/Dialect/Vector/vector-broadcast-lowering-transforms.mlir
index 2d3c88d751192aa..386102cf5b4d225 100644
--- a/mlir/test/Dialect/Vector/vector-broadcast-lowering-transforms.mlir
+++ b/mlir/test/Dialect/Vector/vector-broadcast-lowering-transforms.mlir
@@ -162,6 +162,17 @@ func.func @broadcast_stretch_in_middle(%arg0: vector&lt;4x1x2xf32&gt;) -&gt; vector&lt;4x3x2
   return %0 : vector&lt;4x3x2xf32&gt;
 }
 
+// CHECK-LABEL:   func.func @broadcast_scalable_duplication
+// CHECK-SAME:      %[[ARG0:.*]]: vector&lt;[32]xf32&gt;)
+// CHECK:           %[[CST:.*]] = arith.constant dense&lt;0.000000e+00&gt; : vector&lt;1x[32]xf32&gt;
+// CHECK:           %[[RES:.*]] = vector.insert %[[ARG0]], %[[CST]] [0] : vector&lt;[32]xf32&gt; into vector&lt;1x[32]xf32&gt;
+// CHECK:           return %[[RES]] : vector&lt;1x[32]xf32&gt;
+
+func.func @broadcast_scalable_duplication(%arg0: vector&lt;[32]xf32&gt;) -&gt; vector&lt;1x[32]xf32&gt; {
+  %res = vector.broadcast %arg0 : vector&lt;[32]xf32&gt; to vector&lt;1x[32]xf32&gt;
+  return %res : vector&lt;1x[32]xf32&gt;
+}
+
 transform.sequence failures(propagate) {
 ^bb1(%module_op: !transform.any_op):
   %f = transform.structured.match ops{[&quot;func.func&quot;]} in %module_op 
</pre>
</details>


https://github.com/llvm/llvm-project/pull/66344


More information about the Mlir-commits mailing list