[Mlir-commits] [mlir] 7e65fc9 - [mlir][Vector] Support 0-D vectors in `BroadcastOp`

Nicolas Vasilache llvmlistbot at llvm.org
Fri Nov 26 09:19:54 PST 2021


Author: Michal Terepeta
Date: 2021-11-26T17:17:18Z
New Revision: 7e65fc9a6009ba3297cbca7dc2bffdb0346d158e

URL: https://github.com/llvm/llvm-project/commit/7e65fc9a6009ba3297cbca7dc2bffdb0346d158e
DIFF: https://github.com/llvm/llvm-project/commit/7e65fc9a6009ba3297cbca7dc2bffdb0346d158e.diff

LOG: [mlir][Vector] Support 0-D vectors in `BroadcastOp`

This changes the op to produce `AnyVectorOfAnyRank` following mostly the code for 1-D vectors.

Depends On D114598

Reviewed By: nicolasvasilache

Differential Revision: https://reviews.llvm.org/D114550

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Vector/VectorOps.td
    mlir/lib/Dialect/Vector/VectorTransforms.cpp
    mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
    mlir/test/Integration/Dialect/Vector/CPU/test-0-d-vectors.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Vector/VectorOps.td b/mlir/include/mlir/Dialect/Vector/VectorOps.td
index f274ff656f253..1bab07e77325c 100644
--- a/mlir/include/mlir/Dialect/Vector/VectorOps.td
+++ b/mlir/include/mlir/Dialect/Vector/VectorOps.td
@@ -302,7 +302,7 @@ def Vector_MultiDimReductionOp :
     Results<(outs AnyType:$dest)> {
   let summary = "Multi-dimensional reduction operation";
   let description = [{
-    Reduces an n-D vector into an (n-k)-D vector (or a scalar when k == n) 
+    Reduces an n-D vector into an (n-k)-D vector (or a scalar when k == n)
     using the given operation (add/mul/min/max for int/fp and and/or/xor for
     int only).
 
@@ -380,7 +380,7 @@ def Vector_BroadcastOp :
      PredOpTrait<"source operand and result have same element type",
                  TCresVTEtIsSameAsOpBase<0, 0>>]>,
     Arguments<(ins AnyType:$source)>,
-    Results<(outs AnyVector:$vector)> {
+    Results<(outs AnyVectorOfAnyRank:$vector)> {
   let summary = "broadcast operation";
   let description = [{
     Broadcasts the scalar or k-D vector value in the source operand

diff  --git a/mlir/lib/Dialect/Vector/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/VectorTransforms.cpp
index 5760e80bfcaff..6bdbeb1a550b5 100644
--- a/mlir/lib/Dialect/Vector/VectorTransforms.cpp
+++ b/mlir/lib/Dialect/Vector/VectorTransforms.cpp
@@ -546,10 +546,27 @@ class BroadcastOpLowering : public OpRewritePattern<vector::BroadcastOp> {
     VectorType srcType = op.getSourceType().dyn_cast<VectorType>();
     Type eltType = dstType.getElementType();
 
+    // Scalar to any vector can use splat.
+    if (!srcType) {
+      rewriter.replaceOpWithNewOp<SplatOp>(op, dstType, op.source());
+      return success();
+    }
+
     // Determine rank of source and destination.
-    int64_t srcRank = srcType ? srcType.getRank() : 0;
+    int64_t srcRank = srcType.getRank();
     int64_t dstRank = dstType.getRank();
 
+    // Stretching scalar inside vector (e.g. vector<1xf32>) can use splat.
+    if (srcRank <= 1 && dstRank == 1) {
+      Value ext;
+      if (srcRank == 0)
+        ext = rewriter.create<vector::ExtractElementOp>(loc, op.source());
+      else
+        ext = rewriter.create<vector::ExtractOp>(loc, op.source(), 0);
+      rewriter.replaceOpWithNewOp<SplatOp>(op, dstType, ext);
+      return success();
+    }
+
     // Duplicate this rank.
     // For example:
     //   %x = broadcast %y  : k-D to n-D, k < n
@@ -560,11 +577,6 @@ class BroadcastOpLowering : public OpRewritePattern<vector::BroadcastOp> {
     //   %b = [%y,%y]       : (n-1)-D
     //   %x = [%b,%b,%b,%b] : n-D
     if (srcRank < dstRank) {
-      // Scalar to any vector can use splat.
-      if (srcRank == 0) {
-        rewriter.replaceOpWithNewOp<SplatOp>(op, dstType, op.source());
-        return success();
-      }
       // Duplication.
       VectorType resType =
           VectorType::get(dstType.getShape().drop_front(), eltType);
@@ -593,14 +605,6 @@ class BroadcastOpLowering : public OpRewritePattern<vector::BroadcastOp> {
       return success();
     }
 
-    // Stretching scalar inside vector (e.g. vector<1xf32>) can use splat.
-    if (srcRank == 1) {
-      assert(m == 0);
-      Value ext = rewriter.create<vector::ExtractOp>(loc, op.source(), 0);
-      rewriter.replaceOpWithNewOp<SplatOp>(op, dstType, ext);
-      return success();
-    }
-
     // Any non-matching dimension forces a stretch along this rank.
     // For example:
     //   %x = broadcast %y : vector<4x1x2xf32> to vector<4x2x2xf32>

diff  --git a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
index c700b6bcb5d49..42a264a8aa972 100644
--- a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
+++ b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
@@ -35,6 +35,27 @@ func @bitcast_index_to_i8_vector(%input: vector<16xindex>) -> vector<128xi8> {
 
 // -----
 
+func @broadcast_vec0d_from_f32(%arg0: f32) -> vector<f32> {
+  %0 = vector.broadcast %arg0 : f32 to vector<f32>
+  return %0 : vector<f32>
+}
+// CHECK-LABEL: @broadcast_vec0d_from_f32
+// CHECK-SAME:  %[[A:.*]]: f32)
+// CHECK:       %[[T0:.*]] = splat %[[A]] : vector<f32>
+// CHECK:       return %[[T0]] : vector<f32>
+
+// -----
+
+func @broadcast_vec0d_from_vec0d(%arg0: vector<f32>) -> vector<f32> {
+  %0 = vector.broadcast %arg0 : vector<f32> to vector<f32>
+  return %0 : vector<f32>
+}
+// CHECK-LABEL: @broadcast_vec0d_from_vec0d(
+// CHECK-SAME:  %[[A:.*]]: vector<f32>)
+// CHECK:       return %[[A]] : vector<f32>
+
+// -----
+
 func @broadcast_vec1d_from_f32(%arg0: f32) -> vector<2xf32> {
   %0 = vector.broadcast %arg0 : f32 to vector<2xf32>
   return %0 : vector<2xf32>
@@ -89,6 +110,26 @@ func @broadcast_vec1d_from_vec1d(%arg0: vector<2xf32>) -> vector<2xf32> {
 
 // -----
 
+func @broadcast_vec2d_from_vec0d(%arg0: vector<f32>) -> vector<3x2xf32> {
+  %0 = vector.broadcast %arg0 : vector<f32> to vector<3x2xf32>
+  return %0 : vector<3x2xf32>
+}
+// CHECK-LABEL: @broadcast_vec2d_from_vec0d(
+// CHECK-SAME:  %[[A:.*]]: vector<f32>)
+//       CHECK: %[[T0:.*]] = builtin.unrealized_conversion_cast %[[A]] : vector<f32> to vector<1xf32>
+//       CHECK: %[[T1:.*]] = arith.constant dense<0.000000e+00> : vector<3x2xf32>
+//       CHECK: %[[T2:.*]] = builtin.unrealized_conversion_cast %[[T1]] : vector<3x2xf32> to !llvm.array<3 x vector<2xf32>>
+//       CHECK: %[[T4:.*]] = llvm.mlir.constant(0 : index) : i64
+//       CHECK: %[[T5:.*]] = llvm.extractelement %[[T0]][%[[T4]] : i64] : vector<1xf32>
+//       CHECK: %[[T6:.*]] = splat %[[T5]] : vector<2xf32>
+//       CHECK: %[[T7:.*]] = llvm.insertvalue %[[T6]], %[[T2]][0] : !llvm.array<3 x vector<2xf32>>
+//       CHECK: %[[T8:.*]] = llvm.insertvalue %[[T6]], %[[T7]][1] : !llvm.array<3 x vector<2xf32>>
+//       CHECK: %[[T9:.*]] = llvm.insertvalue %[[T6]], %[[T8]][2] : !llvm.array<3 x vector<2xf32>>
+//       CHECK: %[[T10:.*]] = builtin.unrealized_conversion_cast %[[T9]] : !llvm.array<3 x vector<2xf32>> to vector<3x2xf32>
+//       CHECK: return %[[T10]] : vector<3x2xf32>
+
+// -----
+
 func @broadcast_vec2d_from_vec1d(%arg0: vector<2xf32>) -> vector<3x2xf32> {
   %0 = vector.broadcast %arg0 : vector<2xf32> to vector<3x2xf32>
   return %0 : vector<3x2xf32>

diff  --git a/mlir/test/Integration/Dialect/Vector/CPU/test-0-d-vectors.mlir b/mlir/test/Integration/Dialect/Vector/CPU/test-0-d-vectors.mlir
index 4d74366f80da0..8e69d658612b8 100644
--- a/mlir/test/Integration/Dialect/Vector/CPU/test-0-d-vectors.mlir
+++ b/mlir/test/Integration/Dialect/Vector/CPU/test-0-d-vectors.mlir
@@ -28,6 +28,33 @@ func @splat_0d(%a: f32) {
   return
 }
 
+func @broadcast_0d(%a: f32) {
+  %1 = vector.broadcast %a : f32 to vector<f32>
+  // CHECK: ( 42 )
+  vector.print %1: vector<f32>
+
+  %2 = vector.broadcast %1 : vector<f32> to vector<f32>
+  // CHECK: ( 42 )
+  vector.print %2: vector<f32>
+
+  %3 = vector.broadcast %1 : vector<f32> to vector<1xf32>
+  // CHECK: ( 42 )
+  vector.print %3: vector<1xf32>
+
+  %4 = vector.broadcast %1 : vector<f32> to vector<2xf32>
+  // CHECK: ( 42, 42 )
+  vector.print %4: vector<2xf32>
+
+  %5 = vector.broadcast %1 : vector<f32> to vector<2x1xf32>
+  // CHECK: ( ( 42 ), ( 42 ) )
+  vector.print %5: vector<2x1xf32>
+
+  %6 = vector.broadcast %1 : vector<f32> to vector<2x3xf32>
+  // CHECK: ( ( 42, 42, 42 ), ( 42, 42, 42 ) )
+  vector.print %6: vector<2x3xf32>
+  return
+}
+
 func @entry() {
   %0 = arith.constant 42.0 : f32
   %1 = arith.constant dense<0.0> : vector<f32>
@@ -39,6 +66,7 @@ func @entry() {
 
   %4 = arith.constant 42.0 : f32
   call  @splat_0d(%4) : (f32) -> ()
+  call  @broadcast_0d(%4) : (f32) -> ()
 
   return
 }


        


More information about the Mlir-commits mailing list