[Mlir-commits] [mlir] [mlir][vector] Fix vector.broadcast lowering for scalable vectors (PR #66344)
Andrzej WarzyĆski
llvmlistbot at llvm.org
Thu Sep 14 06:10:05 PDT 2023
https://github.com/banach-space updated https://github.com/llvm/llvm-project/pull/66344:
>From 903b95985875f168aa524e94ec27da62eba407a6 Mon Sep 17 00:00:00 2001
From: Andrzej Warzynski <andrzej.warzynski at arm.com>
Date: Thu, 14 Sep 2023 08:52:34 +0000
Subject: [PATCH 1/2] [mlir][vector] Fix vector.broadcast lowering for scalable
vectors
This patch makes sure that the following case is lowered correctly
("duplication"):
```
func.func @broadcast_scalable_duplication(%arg0: vector<[32]xf32>) -> vector<1x[32]xf32> {
%res = vector.broadcast %arg0 : vector<[32]xf32> to vector<1x[32]xf32>
return %res : vector<1x[32]xf32>
}
```
---
.../Vector/Transforms/LowerVectorBroadcast.cpp | 3 ++-
.../Vector/vector-broadcast-lowering-transforms.mlir | 11 +++++++++++
2 files changed, 13 insertions(+), 1 deletion(-)
diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorBroadcast.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorBroadcast.cpp
index 7c606e0c35f0899..2937b2d08b06979 100644
--- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorBroadcast.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorBroadcast.cpp
@@ -85,7 +85,8 @@ class BroadcastOpLowering : public OpRewritePattern<vector::BroadcastOp> {
if (srcRank < dstRank) {
// Duplication.
VectorType resType =
- VectorType::get(dstType.getShape().drop_front(), eltType);
+ VectorType::get(dstType.getShape().drop_front(), eltType,
+ dstType.getScalableDims().drop_front());
Value bcst =
rewriter.create<vector::BroadcastOp>(loc, resType, op.getSource());
Value result = rewriter.create<arith::ConstantOp>(
diff --git a/mlir/test/Dialect/Vector/vector-broadcast-lowering-transforms.mlir b/mlir/test/Dialect/Vector/vector-broadcast-lowering-transforms.mlir
index 2d3c88d751192aa..386102cf5b4d225 100644
--- a/mlir/test/Dialect/Vector/vector-broadcast-lowering-transforms.mlir
+++ b/mlir/test/Dialect/Vector/vector-broadcast-lowering-transforms.mlir
@@ -162,6 +162,17 @@ func.func @broadcast_stretch_in_middle(%arg0: vector<4x1x2xf32>) -> vector<4x3x2
return %0 : vector<4x3x2xf32>
}
+// CHECK-LABEL: func.func @broadcast_scalable_duplication
+// CHECK-SAME: %[[ARG0:.*]]: vector<[32]xf32>)
+// CHECK: %[[CST:.*]] = arith.constant dense<0.000000e+00> : vector<1x[32]xf32>
+// CHECK: %[[RES:.*]] = vector.insert %[[ARG0]], %[[CST]] [0] : vector<[32]xf32> into vector<1x[32]xf32>
+// CHECK: return %[[RES]] : vector<1x[32]xf32>
+
+func.func @broadcast_scalable_duplication(%arg0: vector<[32]xf32>) -> vector<1x[32]xf32> {
+ %res = vector.broadcast %arg0 : vector<[32]xf32> to vector<1x[32]xf32>
+ return %res : vector<1x[32]xf32>
+}
+
transform.sequence failures(propagate) {
^bb1(%module_op: !transform.any_op):
%f = transform.structured.match ops{["func.func"]} in %module_op
>From c9d7d4ac1a4a6c1847ddc5ece6b427ec9b729720 Mon Sep 17 00:00:00 2001
From: Andrzej Warzynski <andrzej.warzynski at arm.com>
Date: Thu, 14 Sep 2023 13:09:32 +0000
Subject: [PATCH 2/2] fixup! [mlir][vector] Fix vector.broadcast lowering for
scalable vectors
---
mlir/lib/Dialect/Vector/Transforms/LowerVectorBroadcast.cpp | 4 +---
1 file changed, 1 insertion(+), 3 deletions(-)
diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorBroadcast.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorBroadcast.cpp
index 2937b2d08b06979..44e3f76112a7ca8 100644
--- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorBroadcast.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorBroadcast.cpp
@@ -84,9 +84,7 @@ class BroadcastOpLowering : public OpRewritePattern<vector::BroadcastOp> {
// %x = [%b,%b,%b,%b] : n-D
if (srcRank < dstRank) {
// Duplication.
- VectorType resType =
- VectorType::get(dstType.getShape().drop_front(), eltType,
- dstType.getScalableDims().drop_front());
+ VectorType resType = VectorType::Builder(dstType).dropDim(0);
Value bcst =
rewriter.create<vector::BroadcastOp>(loc, resType, op.getSource());
Value result = rewriter.create<arith::ConstantOp>(
More information about the Mlir-commits
mailing list