[Mlir-commits] [mlir] 85882e7 - [mlir][Vector] Support 0-D vectors in ReductionOp
Güray Özen
llvmlistbot at llvm.org
Thu Aug 18 02:13:04 PDT 2022
Author: Güray Özen
Date: 2022-08-18T09:12:47Z
New Revision: 85882e7d64bdec697706ddf7ec41fd94add547f7
URL: https://github.com/llvm/llvm-project/commit/85882e7d64bdec697706ddf7ec41fd94add547f7
DIFF: https://github.com/llvm/llvm-project/commit/85882e7d64bdec697706ddf7ec41fd94add547f7.diff
LOG: [mlir][Vector] Support 0-D vectors in ReductionOp
This commit adds support for 0-D vectors in ReductionOp.
Reviewed By: nicolasvasilache, dcaballe
Differential Revision: https://reviews.llvm.org/D131896
Added:
Modified:
mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
mlir/lib/Dialect/Vector/IR/VectorOps.cpp
mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
mlir/test/Dialect/Vector/ops.mlir
mlir/test/Integration/Dialect/Vector/CPU/test-0-d-vectors.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
index 34460f9c55706..3cc2287599a78 100644
--- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
+++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
@@ -268,8 +268,9 @@ def Vector_ReductionOp :
TCresVTEtIsSameAsOpBase<0, 0>>,
DeclareOpInterfaceMethods<VectorUnrollOpInterface,
["getShapeForUnroll"]>]>,
- Arguments<(ins Vector_CombiningKindAttr:$kind, AnyVector:$vector,
- Optional<AnyType>:$acc)>,
+ Arguments<(ins Vector_CombiningKindAttr:$kind,
+ AnyVectorOfAnyRank:$vector,
+ Optional<AnyType>:$acc)>,
Results<(outs AnyType:$dest)> {
let summary = "reduction operation";
let description = [{
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index dc555e8a92248..b411471983115 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -393,9 +393,9 @@ void vector::ReductionOp::build(OpBuilder &builder, OperationState &result,
}
LogicalResult ReductionOp::verify() {
- // Verify for 1-D vector.
+ // Verify for 0-D and 1-D vector.
int64_t rank = getVectorType().getRank();
- if (rank != 1)
+ if (rank > 1)
return emitOpError("unsupported reduction rank: ") << rank;
// Verify supported reduction kind.
diff --git a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
index 1bca6b35685d4..4aa20595c7daf 100644
--- a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
+++ b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
@@ -1117,6 +1117,20 @@ func.func @vector_fma(%a: vector<8xf32>, %b: vector<2x4xf32>, %c: vector<1x1x1xf
// -----
+func.func @reduce_0d_f32(%arg0: vector<f32>) -> f32 {
+ %0 = vector.reduction <add>, %arg0 : vector<f32> into f32
+ return %0 : f32
+}
+// CHECK-LABEL: @reduce_0d_f32(
+// CHECK-SAME: %[[A:.*]]: vector<f32>)
+// CHECK: %[[CA:.*]] = builtin.unrealized_conversion_cast %[[A]] : vector<f32> to vector<1xf32>
+// CHECK: %[[C:.*]] = llvm.mlir.constant(0.000000e+00 : f32) : f32
+// CHECK: %[[V:.*]] = "llvm.intr.vector.reduce.fadd"(%[[C]], %[[CA]])
+// CHECK-SAME: {reassoc = false} : (f32, vector<1xf32>) -> f32
+// CHECK: return %[[V]] : f32
+
+// -----
+
func.func @reduce_f16(%arg0: vector<16xf16>) -> f16 {
%0 = vector.reduction <add>, %arg0 : vector<16xf16> into f16
return %0 : f16
diff --git a/mlir/test/Dialect/Vector/ops.mlir b/mlir/test/Dialect/Vector/ops.mlir
index 6ccf5295eccd6..a017a8c0cfd45 100644
--- a/mlir/test/Dialect/Vector/ops.mlir
+++ b/mlir/test/Dialect/Vector/ops.mlir
@@ -528,6 +528,30 @@ func.func @reduce_int(%arg0: vector<16xi32>) -> i32 {
return %0 : i32
}
+// CHECK-LABEL: @reduce_int
+func.func @reduce_int_0d(%arg0: vector<i32>) -> i32 {
+ // CHECK: vector.reduction <add>, %{{.*}} : vector<i32> into i32
+ vector.reduction <add>, %arg0 : vector<i32> into i32
+ // CHECK: vector.reduction <mul>, %{{.*}} : vector<i32> into i32
+ vector.reduction <mul>, %arg0 : vector<i32> into i32
+ // CHECK: vector.reduction <minui>, %{{.*}} : vector<i32> into i32
+ vector.reduction <minui>, %arg0 : vector<i32> into i32
+ // CHECK: vector.reduction <minsi>, %{{.*}} : vector<i32> into i32
+ vector.reduction <minsi>, %arg0 : vector<i32> into i32
+ // CHECK: vector.reduction <maxui>, %{{.*}} : vector<i32> into i32
+ vector.reduction <maxui>, %arg0 : vector<i32> into i32
+ // CHECK: vector.reduction <maxsi>, %{{.*}} : vector<i32> into i32
+ vector.reduction <maxsi>, %arg0 : vector<i32> into i32
+ // CHECK: vector.reduction <and>, %{{.*}} : vector<i32> into i32
+ vector.reduction <and>, %arg0 : vector<i32> into i32
+ // CHECK: vector.reduction <or>, %{{.*}} : vector<i32> into i32
+ vector.reduction <or>, %arg0 : vector<i32> into i32
+ // CHECK: %[[X:.*]] = vector.reduction <xor>, %{{.*}} : vector<i32> into i32
+ %0 = vector.reduction <xor>, %arg0 : vector<i32> into i32
+ // CHECK: return %[[X]] : i32
+ return %0 : i32
+}
+
// CHECK-LABEL: @transpose_fp
func.func @transpose_fp(%arg0: vector<3x7xf32>) -> vector<7x3xf32> {
// CHECK: %[[X:.*]] = vector.transpose %{{.*}}, [1, 0] : vector<3x7xf32> to vector<7x3xf32>
diff --git a/mlir/test/Integration/Dialect/Vector/CPU/test-0-d-vectors.mlir b/mlir/test/Integration/Dialect/Vector/CPU/test-0-d-vectors.mlir
index 985ecdbfd9149..95e52192fa8c9 100644
--- a/mlir/test/Integration/Dialect/Vector/CPU/test-0-d-vectors.mlir
+++ b/mlir/test/Integration/Dialect/Vector/CPU/test-0-d-vectors.mlir
@@ -105,6 +105,11 @@ func.func @create_mask_0d(%zero : index, %one : index) {
return
}
+func.func @reduce_add(%arg0: vector<f32>) -> f32 {
+ %0 = vector.reduction <add>, %arg0 : vector<f32> into f32
+ return %0 : f32
+}
+
func.func @entry() {
%0 = arith.constant 42.0 : f32
%1 = arith.constant dense<0.0> : vector<f32>
@@ -131,5 +136,10 @@ func.func @entry() {
%one_idx = arith.constant 1 : index
call @create_mask_0d(%zero_idx, %one_idx) : (index, index) -> ()
+ %red_array = arith.constant dense<5.0> : vector<f32>
+ %red_res = call @reduce_add(%red_array) : (vector<f32>) -> (f32)
+ vector.print %red_res : f32
+ // CHECK: 5
+
return
}
More information about the Mlir-commits
mailing list