[Mlir-commits] [mlir] 85882e7 - [mlir][Vector] Support 0-D vectors in ReductionOp

Güray Özen llvmlistbot at llvm.org
Thu Aug 18 02:13:04 PDT 2022


Author: Güray Özen
Date: 2022-08-18T09:12:47Z
New Revision: 85882e7d64bdec697706ddf7ec41fd94add547f7

URL: https://github.com/llvm/llvm-project/commit/85882e7d64bdec697706ddf7ec41fd94add547f7
DIFF: https://github.com/llvm/llvm-project/commit/85882e7d64bdec697706ddf7ec41fd94add547f7.diff

LOG: [mlir][Vector] Support 0-D vectors in ReductionOp

This commit adds support for 0-D vectors in ReductionOp.

Reviewed By: nicolasvasilache, dcaballe

Differential Revision: https://reviews.llvm.org/D131896

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
    mlir/lib/Dialect/Vector/IR/VectorOps.cpp
    mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
    mlir/test/Dialect/Vector/ops.mlir
    mlir/test/Integration/Dialect/Vector/CPU/test-0-d-vectors.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
index 34460f9c55706..3cc2287599a78 100644
--- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
+++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
@@ -268,8 +268,9 @@ def Vector_ReductionOp :
                  TCresVTEtIsSameAsOpBase<0, 0>>,
      DeclareOpInterfaceMethods<VectorUnrollOpInterface,
                                ["getShapeForUnroll"]>]>,
-    Arguments<(ins Vector_CombiningKindAttr:$kind, AnyVector:$vector,
-                   Optional<AnyType>:$acc)>,
+    Arguments<(ins Vector_CombiningKindAttr:$kind,
+               AnyVectorOfAnyRank:$vector,
+               Optional<AnyType>:$acc)>,
     Results<(outs AnyType:$dest)> {
   let summary = "reduction operation";
   let description = [{

diff  --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index dc555e8a92248..b411471983115 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -393,9 +393,9 @@ void vector::ReductionOp::build(OpBuilder &builder, OperationState &result,
 }
 
 LogicalResult ReductionOp::verify() {
-  // Verify for 1-D vector.
+  // Verify for 0-D and 1-D vector.
   int64_t rank = getVectorType().getRank();
-  if (rank != 1)
+  if (rank > 1)
     return emitOpError("unsupported reduction rank: ") << rank;
 
   // Verify supported reduction kind.

diff  --git a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
index 1bca6b35685d4..4aa20595c7daf 100644
--- a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
+++ b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
@@ -1117,6 +1117,20 @@ func.func @vector_fma(%a: vector<8xf32>, %b: vector<2x4xf32>, %c: vector<1x1x1xf
 
 // -----
 
+func.func @reduce_0d_f32(%arg0: vector<f32>) -> f32 {
+  %0 = vector.reduction <add>, %arg0 : vector<f32> into f32
+  return %0 : f32
+}
+// CHECK-LABEL: @reduce_0d_f32(
+// CHECK-SAME: %[[A:.*]]: vector<f32>)
+//      CHECK: %[[CA:.*]] = builtin.unrealized_conversion_cast %[[A]] : vector<f32> to vector<1xf32>
+//      CHECK: %[[C:.*]] = llvm.mlir.constant(0.000000e+00 : f32) : f32
+//      CHECK: %[[V:.*]] = "llvm.intr.vector.reduce.fadd"(%[[C]], %[[CA]])
+// CHECK-SAME: {reassoc = false} : (f32, vector<1xf32>) -> f32
+//      CHECK: return %[[V]] : f32
+
+// -----
+
 func.func @reduce_f16(%arg0: vector<16xf16>) -> f16 {
   %0 = vector.reduction <add>, %arg0 : vector<16xf16> into f16
   return %0 : f16

diff  --git a/mlir/test/Dialect/Vector/ops.mlir b/mlir/test/Dialect/Vector/ops.mlir
index 6ccf5295eccd6..a017a8c0cfd45 100644
--- a/mlir/test/Dialect/Vector/ops.mlir
+++ b/mlir/test/Dialect/Vector/ops.mlir
@@ -528,6 +528,30 @@ func.func @reduce_int(%arg0: vector<16xi32>) -> i32 {
   return %0 : i32
 }
 
+// CHECK-LABEL: @reduce_int
+func.func @reduce_int_0d(%arg0: vector<i32>) -> i32 {
+  // CHECK:    vector.reduction <add>, %{{.*}} : vector<i32> into i32
+  vector.reduction <add>, %arg0 : vector<i32> into i32
+  // CHECK:    vector.reduction <mul>, %{{.*}} : vector<i32> into i32
+  vector.reduction <mul>, %arg0 : vector<i32> into i32
+  // CHECK:    vector.reduction <minui>, %{{.*}} : vector<i32> into i32
+  vector.reduction <minui>, %arg0 : vector<i32> into i32
+  // CHECK:    vector.reduction <minsi>, %{{.*}} : vector<i32> into i32
+  vector.reduction <minsi>, %arg0 : vector<i32> into i32
+  // CHECK:    vector.reduction <maxui>, %{{.*}} : vector<i32> into i32
+  vector.reduction <maxui>, %arg0 : vector<i32> into i32
+  // CHECK:    vector.reduction <maxsi>, %{{.*}} : vector<i32> into i32
+  vector.reduction <maxsi>, %arg0 : vector<i32> into i32
+  // CHECK:    vector.reduction <and>, %{{.*}} : vector<i32> into i32
+  vector.reduction <and>, %arg0 : vector<i32> into i32
+  // CHECK:    vector.reduction <or>, %{{.*}} : vector<i32> into i32
+  vector.reduction <or>, %arg0 : vector<i32> into i32
+  // CHECK:    %[[X:.*]] = vector.reduction <xor>, %{{.*}} : vector<i32> into i32
+  %0 = vector.reduction <xor>, %arg0 : vector<i32> into i32
+  // CHECK:    return %[[X]] : i32
+  return %0 : i32
+}
+
 // CHECK-LABEL: @transpose_fp
 func.func @transpose_fp(%arg0: vector<3x7xf32>) -> vector<7x3xf32> {
   // CHECK: %[[X:.*]] = vector.transpose %{{.*}}, [1, 0] : vector<3x7xf32> to vector<7x3xf32>

diff  --git a/mlir/test/Integration/Dialect/Vector/CPU/test-0-d-vectors.mlir b/mlir/test/Integration/Dialect/Vector/CPU/test-0-d-vectors.mlir
index 985ecdbfd9149..95e52192fa8c9 100644
--- a/mlir/test/Integration/Dialect/Vector/CPU/test-0-d-vectors.mlir
+++ b/mlir/test/Integration/Dialect/Vector/CPU/test-0-d-vectors.mlir
@@ -105,6 +105,11 @@ func.func @create_mask_0d(%zero : index, %one : index) {
   return
 }
 
+func.func @reduce_add(%arg0: vector<f32>) -> f32 {
+  %0 = vector.reduction <add>, %arg0 : vector<f32> into f32    
+  return %0 : f32
+}
+
 func.func @entry() {
   %0 = arith.constant 42.0 : f32
   %1 = arith.constant dense<0.0> : vector<f32>
@@ -131,5 +136,10 @@ func.func @entry() {
   %one_idx = arith.constant 1 : index
   call  @create_mask_0d(%zero_idx, %one_idx) : (index, index) -> ()
 
+  %red_array = arith.constant dense<5.0> : vector<f32>
+  %red_res = call  @reduce_add(%red_array) : (vector<f32>) -> (f32)  
+  vector.print %red_res : f32
+  // CHECK: 5
+
   return
 }


        


More information about the Mlir-commits mailing list