[Mlir-commits] [mlir] 57cf689 - [mlir][vector] Fix vector.broadcast lowering for scalable vectors (#66344)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Fri Sep 15 08:35:50 PDT 2023


Author: Andrzej WarzyƄski
Date: 2023-09-15T16:35:47+01:00
New Revision: 57cf6896cd5a48d6978372b9e0fa93fa5381bbba

URL: https://github.com/llvm/llvm-project/commit/57cf6896cd5a48d6978372b9e0fa93fa5381bbba
DIFF: https://github.com/llvm/llvm-project/commit/57cf6896cd5a48d6978372b9e0fa93fa5381bbba.diff

LOG: [mlir][vector] Fix vector.broadcast lowering for scalable vectors (#66344)

This patch makes sure that the following case is lowered correctly
("duplication"):
```
  func.func @broadcast_scalable_duplication(%arg0: vector<[32]xf32>) -> vector<1x[32]xf32> {
    %res = vector.broadcast %arg0 : vector<[32]xf32> to vector<1x[32]xf32>
    return %res : vector<1x[32]xf32>
  }
```

Added: 
    

Modified: 
    mlir/lib/Dialect/Vector/Transforms/LowerVectorBroadcast.cpp
    mlir/test/Dialect/Vector/vector-broadcast-lowering-transforms.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorBroadcast.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorBroadcast.cpp
index 7c606e0c35f0899..44e3f76112a7ca8 100644
--- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorBroadcast.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorBroadcast.cpp
@@ -84,8 +84,7 @@ class BroadcastOpLowering : public OpRewritePattern<vector::BroadcastOp> {
     //   %x = [%b,%b,%b,%b] : n-D
     if (srcRank < dstRank) {
       // Duplication.
-      VectorType resType =
-          VectorType::get(dstType.getShape().drop_front(), eltType);
+      VectorType resType = VectorType::Builder(dstType).dropDim(0);
       Value bcst =
           rewriter.create<vector::BroadcastOp>(loc, resType, op.getSource());
       Value result = rewriter.create<arith::ConstantOp>(

diff  --git a/mlir/test/Dialect/Vector/vector-broadcast-lowering-transforms.mlir b/mlir/test/Dialect/Vector/vector-broadcast-lowering-transforms.mlir
index 2d3c88d751192aa..386102cf5b4d225 100644
--- a/mlir/test/Dialect/Vector/vector-broadcast-lowering-transforms.mlir
+++ b/mlir/test/Dialect/Vector/vector-broadcast-lowering-transforms.mlir
@@ -162,6 +162,17 @@ func.func @broadcast_stretch_in_middle(%arg0: vector<4x1x2xf32>) -> vector<4x3x2
   return %0 : vector<4x3x2xf32>
 }
 
+// CHECK-LABEL:   func.func @broadcast_scalable_duplication
+// CHECK-SAME:      %[[ARG0:.*]]: vector<[32]xf32>)
+// CHECK:           %[[CST:.*]] = arith.constant dense<0.000000e+00> : vector<1x[32]xf32>
+// CHECK:           %[[RES:.*]] = vector.insert %[[ARG0]], %[[CST]] [0] : vector<[32]xf32> into vector<1x[32]xf32>
+// CHECK:           return %[[RES]] : vector<1x[32]xf32>
+
+func.func @broadcast_scalable_duplication(%arg0: vector<[32]xf32>) -> vector<1x[32]xf32> {
+  %res = vector.broadcast %arg0 : vector<[32]xf32> to vector<1x[32]xf32>
+  return %res : vector<1x[32]xf32>
+}
+
 transform.sequence failures(propagate) {
 ^bb1(%module_op: !transform.any_op):
   %f = transform.structured.match ops{["func.func"]} in %module_op 


        


More information about the Mlir-commits mailing list