[Mlir-commits] [mlir] 7e65fc9 - [mlir][Vector] Support 0-D vectors in `BroadcastOp`
Nicolas Vasilache
llvmlistbot at llvm.org
Fri Nov 26 09:19:54 PST 2021
Author: Michal Terepeta
Date: 2021-11-26T17:17:18Z
New Revision: 7e65fc9a6009ba3297cbca7dc2bffdb0346d158e
URL: https://github.com/llvm/llvm-project/commit/7e65fc9a6009ba3297cbca7dc2bffdb0346d158e
DIFF: https://github.com/llvm/llvm-project/commit/7e65fc9a6009ba3297cbca7dc2bffdb0346d158e.diff
LOG: [mlir][Vector] Support 0-D vectors in `BroadcastOp`
This changes the op to produce `AnyVectorOfAnyRank` following mostly the code for 1-D vectors.
Depends On D114598
Reviewed By: nicolasvasilache
Differential Revision: https://reviews.llvm.org/D114550
Added:
Modified:
mlir/include/mlir/Dialect/Vector/VectorOps.td
mlir/lib/Dialect/Vector/VectorTransforms.cpp
mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
mlir/test/Integration/Dialect/Vector/CPU/test-0-d-vectors.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Vector/VectorOps.td b/mlir/include/mlir/Dialect/Vector/VectorOps.td
index f274ff656f253..1bab07e77325c 100644
--- a/mlir/include/mlir/Dialect/Vector/VectorOps.td
+++ b/mlir/include/mlir/Dialect/Vector/VectorOps.td
@@ -302,7 +302,7 @@ def Vector_MultiDimReductionOp :
Results<(outs AnyType:$dest)> {
let summary = "Multi-dimensional reduction operation";
let description = [{
- Reduces an n-D vector into an (n-k)-D vector (or a scalar when k == n)
+ Reduces an n-D vector into an (n-k)-D vector (or a scalar when k == n)
using the given operation (add/mul/min/max for int/fp and and/or/xor for
int only).
@@ -380,7 +380,7 @@ def Vector_BroadcastOp :
PredOpTrait<"source operand and result have same element type",
TCresVTEtIsSameAsOpBase<0, 0>>]>,
Arguments<(ins AnyType:$source)>,
- Results<(outs AnyVector:$vector)> {
+ Results<(outs AnyVectorOfAnyRank:$vector)> {
let summary = "broadcast operation";
let description = [{
Broadcasts the scalar or k-D vector value in the source operand
diff --git a/mlir/lib/Dialect/Vector/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/VectorTransforms.cpp
index 5760e80bfcaff..6bdbeb1a550b5 100644
--- a/mlir/lib/Dialect/Vector/VectorTransforms.cpp
+++ b/mlir/lib/Dialect/Vector/VectorTransforms.cpp
@@ -546,10 +546,27 @@ class BroadcastOpLowering : public OpRewritePattern<vector::BroadcastOp> {
VectorType srcType = op.getSourceType().dyn_cast<VectorType>();
Type eltType = dstType.getElementType();
+ // Scalar to any vector can use splat.
+ if (!srcType) {
+ rewriter.replaceOpWithNewOp<SplatOp>(op, dstType, op.source());
+ return success();
+ }
+
// Determine rank of source and destination.
- int64_t srcRank = srcType ? srcType.getRank() : 0;
+ int64_t srcRank = srcType.getRank();
int64_t dstRank = dstType.getRank();
+ // Stretching scalar inside vector (e.g. vector<1xf32>) can use splat.
+ if (srcRank <= 1 && dstRank == 1) {
+ Value ext;
+ if (srcRank == 0)
+ ext = rewriter.create<vector::ExtractElementOp>(loc, op.source());
+ else
+ ext = rewriter.create<vector::ExtractOp>(loc, op.source(), 0);
+ rewriter.replaceOpWithNewOp<SplatOp>(op, dstType, ext);
+ return success();
+ }
+
// Duplicate this rank.
// For example:
// %x = broadcast %y : k-D to n-D, k < n
@@ -560,11 +577,6 @@ class BroadcastOpLowering : public OpRewritePattern<vector::BroadcastOp> {
// %b = [%y,%y] : (n-1)-D
// %x = [%b,%b,%b,%b] : n-D
if (srcRank < dstRank) {
- // Scalar to any vector can use splat.
- if (srcRank == 0) {
- rewriter.replaceOpWithNewOp<SplatOp>(op, dstType, op.source());
- return success();
- }
// Duplication.
VectorType resType =
VectorType::get(dstType.getShape().drop_front(), eltType);
@@ -593,14 +605,6 @@ class BroadcastOpLowering : public OpRewritePattern<vector::BroadcastOp> {
return success();
}
- // Stretching scalar inside vector (e.g. vector<1xf32>) can use splat.
- if (srcRank == 1) {
- assert(m == 0);
- Value ext = rewriter.create<vector::ExtractOp>(loc, op.source(), 0);
- rewriter.replaceOpWithNewOp<SplatOp>(op, dstType, ext);
- return success();
- }
-
// Any non-matching dimension forces a stretch along this rank.
// For example:
// %x = broadcast %y : vector<4x1x2xf32> to vector<4x2x2xf32>
diff --git a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
index c700b6bcb5d49..42a264a8aa972 100644
--- a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
+++ b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
@@ -35,6 +35,27 @@ func @bitcast_index_to_i8_vector(%input: vector<16xindex>) -> vector<128xi8> {
// -----
+func @broadcast_vec0d_from_f32(%arg0: f32) -> vector<f32> {
+ %0 = vector.broadcast %arg0 : f32 to vector<f32>
+ return %0 : vector<f32>
+}
+// CHECK-LABEL: @broadcast_vec0d_from_f32
+// CHECK-SAME: %[[A:.*]]: f32)
+// CHECK: %[[T0:.*]] = splat %[[A]] : vector<f32>
+// CHECK: return %[[T0]] : vector<f32>
+
+// -----
+
+func @broadcast_vec0d_from_vec0d(%arg0: vector<f32>) -> vector<f32> {
+ %0 = vector.broadcast %arg0 : vector<f32> to vector<f32>
+ return %0 : vector<f32>
+}
+// CHECK-LABEL: @broadcast_vec0d_from_vec0d(
+// CHECK-SAME: %[[A:.*]]: vector<f32>)
+// CHECK: return %[[A]] : vector<f32>
+
+// -----
+
func @broadcast_vec1d_from_f32(%arg0: f32) -> vector<2xf32> {
%0 = vector.broadcast %arg0 : f32 to vector<2xf32>
return %0 : vector<2xf32>
@@ -89,6 +110,26 @@ func @broadcast_vec1d_from_vec1d(%arg0: vector<2xf32>) -> vector<2xf32> {
// -----
+func @broadcast_vec2d_from_vec0d(%arg0: vector<f32>) -> vector<3x2xf32> {
+ %0 = vector.broadcast %arg0 : vector<f32> to vector<3x2xf32>
+ return %0 : vector<3x2xf32>
+}
+// CHECK-LABEL: @broadcast_vec2d_from_vec0d(
+// CHECK-SAME: %[[A:.*]]: vector<f32>)
+// CHECK: %[[T0:.*]] = builtin.unrealized_conversion_cast %[[A]] : vector<f32> to vector<1xf32>
+// CHECK: %[[T1:.*]] = arith.constant dense<0.000000e+00> : vector<3x2xf32>
+// CHECK: %[[T2:.*]] = builtin.unrealized_conversion_cast %[[T1]] : vector<3x2xf32> to !llvm.array<3 x vector<2xf32>>
+// CHECK: %[[T4:.*]] = llvm.mlir.constant(0 : index) : i64
+// CHECK: %[[T5:.*]] = llvm.extractelement %[[T0]][%[[T4]] : i64] : vector<1xf32>
+// CHECK: %[[T6:.*]] = splat %[[T5]] : vector<2xf32>
+// CHECK: %[[T7:.*]] = llvm.insertvalue %[[T6]], %[[T2]][0] : !llvm.array<3 x vector<2xf32>>
+// CHECK: %[[T8:.*]] = llvm.insertvalue %[[T6]], %[[T7]][1] : !llvm.array<3 x vector<2xf32>>
+// CHECK: %[[T9:.*]] = llvm.insertvalue %[[T6]], %[[T8]][2] : !llvm.array<3 x vector<2xf32>>
+// CHECK: %[[T10:.*]] = builtin.unrealized_conversion_cast %[[T9]] : !llvm.array<3 x vector<2xf32>> to vector<3x2xf32>
+// CHECK: return %[[T10]] : vector<3x2xf32>
+
+// -----
+
func @broadcast_vec2d_from_vec1d(%arg0: vector<2xf32>) -> vector<3x2xf32> {
%0 = vector.broadcast %arg0 : vector<2xf32> to vector<3x2xf32>
return %0 : vector<3x2xf32>
diff --git a/mlir/test/Integration/Dialect/Vector/CPU/test-0-d-vectors.mlir b/mlir/test/Integration/Dialect/Vector/CPU/test-0-d-vectors.mlir
index 4d74366f80da0..8e69d658612b8 100644
--- a/mlir/test/Integration/Dialect/Vector/CPU/test-0-d-vectors.mlir
+++ b/mlir/test/Integration/Dialect/Vector/CPU/test-0-d-vectors.mlir
@@ -28,6 +28,33 @@ func @splat_0d(%a: f32) {
return
}
+func @broadcast_0d(%a: f32) {
+ %1 = vector.broadcast %a : f32 to vector<f32>
+ // CHECK: ( 42 )
+ vector.print %1: vector<f32>
+
+ %2 = vector.broadcast %1 : vector<f32> to vector<f32>
+ // CHECK: ( 42 )
+ vector.print %2: vector<f32>
+
+ %3 = vector.broadcast %1 : vector<f32> to vector<1xf32>
+ // CHECK: ( 42 )
+ vector.print %3: vector<1xf32>
+
+ %4 = vector.broadcast %1 : vector<f32> to vector<2xf32>
+ // CHECK: ( 42, 42 )
+ vector.print %4: vector<2xf32>
+
+ %5 = vector.broadcast %1 : vector<f32> to vector<2x1xf32>
+ // CHECK: ( ( 42 ), ( 42 ) )
+ vector.print %5: vector<2x1xf32>
+
+ %6 = vector.broadcast %1 : vector<f32> to vector<2x3xf32>
+ // CHECK: ( ( 42, 42, 42 ), ( 42, 42, 42 ) )
+ vector.print %6: vector<2x3xf32>
+ return
+}
+
func @entry() {
%0 = arith.constant 42.0 : f32
%1 = arith.constant dense<0.0> : vector<f32>
@@ -39,6 +66,7 @@ func @entry() {
%4 = arith.constant 42.0 : f32
call @splat_0d(%4) : (f32) -> ()
+ call @broadcast_0d(%4) : (f32) -> ()
return
}
More information about the Mlir-commits
mailing list