[Mlir-commits] [mlir] [mlir][vector] Improve `makeArithReduction` expansion (PR #75846)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Mon Dec 18 11:43:53 PST 2023


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir

Author: Jakub Kuderski (kuhar)

<details>
<summary>Changes</summary>

Propagate fast math flags.
Distinguish `minf`/`maxf` and `minimumf`/`maximumf`.

Required for future patterns in https://github.com/llvm/llvm-project/pull/75727.

---
Full diff: https://github.com/llvm/llvm-project/pull/75846.diff


6 Files Affected:

- (modified) mlir/include/mlir/Dialect/Vector/IR/VectorOps.h (+4-2) 
- (modified) mlir/lib/Dialect/Vector/IR/VectorOps.cpp (+18-7) 
- (modified) mlir/lib/Dialect/Vector/Transforms/LowerVectorContract.cpp (+2-1) 
- (modified) mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir (+2-2) 
- (modified) mlir/test/Dialect/Vector/canonicalize.mlir (+12) 
- (modified) mlir/test/Dialect/Vector/vector-multi-reduction-outer-lowering.mlir (+8-8) 


``````````diff
diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.h b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.h
index 59d585a77b1e29..a28b27e4e15816 100644
--- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.h
+++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.h
@@ -123,10 +123,12 @@ bool isDisjointTransferSet(VectorTransferOpInterface transferA,
                            VectorTransferOpInterface transferB,
                            bool testDynamicValueUsingBounds = false);
 
-/// Return the result value of reducing two scalar/vector values with the
+/// Returns the result value of reducing two scalar/vector values with the
 /// corresponding arith operation.
 Value makeArithReduction(OpBuilder &b, Location loc, CombiningKind kind,
-                         Value v1, Value acc, Value mask = Value());
+                         Value v1, Value acc,
+                         arith::FastMathFlagsAttr fastmath = nullptr,
+                         Value mask = nullptr);
 
 /// Returns true if `attr` has "parallel" iterator type semantics.
 inline bool isParallelIterator(Attribute attr) {
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index 540959b486db9c..9f3e13c90a624d 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -507,8 +507,9 @@ struct ElideUnitDimsInMultiDimReduction
                                                 zeroIdx);
     }
 
-    Value result = vector::makeArithReduction(
-        rewriter, loc, reductionOp.getKind(), acc, cast, mask);
+    Value result =
+        vector::makeArithReduction(rewriter, loc, reductionOp.getKind(), acc,
+                                   cast, /*fastmath=*/nullptr, mask);
     rewriter.replaceOp(rootOp, result);
     return success();
   }
@@ -650,7 +651,8 @@ struct ElideSingleElementReduction : public OpRewritePattern<ReductionOp> {
 
     if (Value acc = reductionOp.getAcc())
       result = vector::makeArithReduction(rewriter, loc, reductionOp.getKind(),
-                                          result, acc, mask);
+                                          result, acc,
+                                          reductionOp.getFastmathAttr(), mask);
 
     rewriter.replaceOp(rootOp, result);
     return success();
@@ -6212,6 +6214,7 @@ bool WarpExecuteOnLane0Op::areTypesCompatible(Type lhs, Type rhs) {
 
 Value mlir::vector::makeArithReduction(OpBuilder &b, Location loc,
                                        CombiningKind kind, Value v1, Value acc,
+                                       arith::FastMathFlagsAttr fastmath,
                                        Value mask) {
   Type t1 = getElementTypeOrSelf(v1.getType());
   Type tAcc = getElementTypeOrSelf(acc.getType());
@@ -6222,7 +6225,7 @@ Value mlir::vector::makeArithReduction(OpBuilder &b, Location loc,
     if (t1.isIntOrIndex() && tAcc.isIntOrIndex())
       result = b.createOrFold<arith::AddIOp>(loc, v1, acc);
     else if (llvm::isa<FloatType>(t1) && llvm::isa<FloatType>(tAcc))
-      result = b.createOrFold<arith::AddFOp>(loc, v1, acc);
+      result = b.createOrFold<arith::AddFOp>(loc, v1, acc, fastmath);
     else
       llvm_unreachable("invalid value types for ADD reduction");
     break;
@@ -6231,16 +6234,24 @@ Value mlir::vector::makeArithReduction(OpBuilder &b, Location loc,
     result = b.createOrFold<arith::AndIOp>(loc, v1, acc);
     break;
   case CombiningKind::MAXF:
+    assert(llvm::isa<FloatType>(t1) && llvm::isa<FloatType>(tAcc) &&
+           "expected float values");
+    result = b.createOrFold<arith::MaxNumFOp>(loc, v1, acc, fastmath);
+    break;
   case CombiningKind::MAXIMUMF:
     assert(llvm::isa<FloatType>(t1) && llvm::isa<FloatType>(tAcc) &&
            "expected float values");
-    result = b.createOrFold<arith::MaximumFOp>(loc, v1, acc);
+    result = b.createOrFold<arith::MaximumFOp>(loc, v1, acc, fastmath);
     break;
   case CombiningKind::MINF:
+    assert(llvm::isa<FloatType>(t1) && llvm::isa<FloatType>(tAcc) &&
+           "expected float values");
+    result = b.createOrFold<arith::MinNumFOp>(loc, v1, acc, fastmath);
+    break;
   case CombiningKind::MINIMUMF:
     assert(llvm::isa<FloatType>(t1) && llvm::isa<FloatType>(tAcc) &&
            "expected float values");
-    result = b.createOrFold<arith::MinimumFOp>(loc, v1, acc);
+    result = b.createOrFold<arith::MinimumFOp>(loc, v1, acc, fastmath);
     break;
   case CombiningKind::MAXSI:
     assert(t1.isIntOrIndex() && tAcc.isIntOrIndex() && "expected int values");
@@ -6262,7 +6273,7 @@ Value mlir::vector::makeArithReduction(OpBuilder &b, Location loc,
     if (t1.isIntOrIndex() && tAcc.isIntOrIndex())
       result = b.createOrFold<arith::MulIOp>(loc, v1, acc);
     else if (llvm::isa<FloatType>(t1) && llvm::isa<FloatType>(tAcc))
-      result = b.createOrFold<arith::MulFOp>(loc, v1, acc);
+      result = b.createOrFold<arith::MulFOp>(loc, v1, acc, fastmath);
     else
       llvm_unreachable("invalid value types for MUL reduction");
     break;
diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorContract.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorContract.cpp
index 6dbe36e605e9a7..41ff0c18fe6258 100644
--- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorContract.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorContract.cpp
@@ -167,7 +167,8 @@ createContractArithOp(Location loc, Value x, Value y, Value acc,
   if (!acc)
     return std::optional<Value>(mul);
 
-  return makeArithReduction(rewriter, loc, kind, mul, acc, mask);
+  return makeArithReduction(rewriter, loc, kind, mul, acc,
+                            /*fastmath=*/nullptr, mask);
 }
 
 /// Return the positions of the reductions in the given map.
diff --git a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
index 012d30d96799f2..7353d16d79cea0 100644
--- a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
+++ b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
@@ -450,7 +450,7 @@ func.func @masked_float_max_outerprod(%arg0: vector<2xf32>, %arg1: f32, %arg2: v
 // CHECK-LABEL:   func.func @masked_float_max_outerprod(
 // CHECK-SAME:                                          %[[VAL_0:.*]]: vector<2xf32>, %[[VAL_1:.*]]: f32, %[[VAL_2:.*]]: vector<2xf32>, %[[VAL_3:.*]]: vector<2xi1>) -> vector<2xf32> {
 // CHECK:           %[[VAL_8:.*]] = arith.mulf %[[VAL_0]], %{{.*}} : vector<2xf32>
-// CHECK:           %[[VAL_9:.*]] = arith.maximumf %[[VAL_8]], %[[VAL_2]] : vector<2xf32>
+// CHECK:           %[[VAL_9:.*]] = arith.maxnumf %[[VAL_8]], %[[VAL_2]] : vector<2xf32>
 // CHECK:           %[[VAL_10:.*]] = arith.select %[[VAL_3]], %[[VAL_9]], %[[VAL_2]] : vector<2xi1>, vector<2xf32>
 
 // -----
@@ -463,7 +463,7 @@ func.func @masked_float_min_outerprod(%arg0: vector<2xf32>, %arg1: f32, %arg2: v
 // CHECK-LABEL:   func.func @masked_float_min_outerprod(
 // CHECK-SAME:                                          %[[VAL_0:.*]]: vector<2xf32>, %[[VAL_1:.*]]: f32, %[[VAL_2:.*]]: vector<2xf32>, %[[VAL_3:.*]]: vector<2xi1>) -> vector<2xf32> {
 // CHECK:           %[[VAL_8:.*]] = arith.mulf %[[VAL_0]], %{{.*}} : vector<2xf32>
-// CHECK:           %[[VAL_9:.*]] = arith.minimumf %[[VAL_8]], %[[VAL_2]] : vector<2xf32>
+// CHECK:           %[[VAL_9:.*]] = arith.minnumf %[[VAL_8]], %[[VAL_2]] : vector<2xf32>
 // CHECK:           %[[VAL_10:.*]] = arith.select %[[VAL_3]], %[[VAL_9]], %[[VAL_2]] : vector<2xi1>, vector<2xf32>
 
 // -----
diff --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir
index 1021c73cc57d34..b5164b66817352 100644
--- a/mlir/test/Dialect/Vector/canonicalize.mlir
+++ b/mlir/test/Dialect/Vector/canonicalize.mlir
@@ -2172,6 +2172,18 @@ func.func @reduce_one_element_vector_addf(%a : vector<1xf32>, %b: f32) -> f32 {
 
 // -----
 
+// CHECK-LABEL: func @reduce_one_element_vector_addf_fastmath
+//  CHECK-SAME: (%[[V:.+]]: vector<1xf32>, %[[B:.+]]: f32)
+//       CHECK:   %[[A:.+]] = vector.extract %[[V]][0] : f32 from vector<1xf32>
+//       CHECK:   %[[S:.+]] = arith.addf %[[A]], %arg1 fastmath<nnan,ninf> : f32
+//       CHECK:   return %[[S]]
+func.func @reduce_one_element_vector_addf_fastmath(%a : vector<1xf32>, %b: f32) -> f32 {
+  %s = vector.reduction <add>, %a, %b fastmath<nnan,ninf> : vector<1xf32> into f32
+  return %s : f32
+}
+
+// -----
+
 // CHECK-LABEL: func @masked_reduce_one_element_vector_addf
 //  CHECK-SAME: %[[VAL_0:.*]]: vector<1xf32>, %[[VAL_1:.*]]: f32,
 //  CHECK-SAME: %[[VAL_2:.*]]: vector<1xi1>)
diff --git a/mlir/test/Dialect/Vector/vector-multi-reduction-outer-lowering.mlir b/mlir/test/Dialect/Vector/vector-multi-reduction-outer-lowering.mlir
index 12ea87ffb1413f..614a97fe4d6777 100644
--- a/mlir/test/Dialect/Vector/vector-multi-reduction-outer-lowering.mlir
+++ b/mlir/test/Dialect/Vector/vector-multi-reduction-outer-lowering.mlir
@@ -27,13 +27,13 @@ func.func @vector_multi_reduction_min(%arg0: vector<2x4xf32>, %acc: vector<2xf32
 //  CHECK-SAME:   %[[INPUT:.+]]: vector<2x4xf32>, %[[ACC:.*]]: vector<2xf32>
 //       CHECK:   %[[TRANSPOSED:.+]] = vector.transpose %[[INPUT]], [1, 0] : vector<2x4xf32> to vector<4x2xf32>
 //       CHECK:   %[[V0:.+]] = vector.extract %[[TRANSPOSED]][0] : vector<2xf32> from vector<4x2xf32>
-//       CHECK:   %[[RV0:.+]] = arith.minimumf %[[V0]], %[[ACC]] : vector<2xf32>
+//       CHECK:   %[[RV0:.+]] = arith.minnumf %[[V0]], %[[ACC]] : vector<2xf32>
 //       CHECK:   %[[V1:.+]] = vector.extract %[[TRANSPOSED]][1] : vector<2xf32> from vector<4x2xf32>
-//       CHECK:   %[[RV01:.+]] = arith.minimumf %[[V1]], %[[RV0]] : vector<2xf32>
+//       CHECK:   %[[RV01:.+]] = arith.minnumf %[[V1]], %[[RV0]] : vector<2xf32>
 //       CHECK:   %[[V2:.+]] = vector.extract %[[TRANSPOSED]][2] : vector<2xf32> from vector<4x2xf32>
-//       CHECK:   %[[RV012:.+]] = arith.minimumf %[[V2]], %[[RV01]] : vector<2xf32>
+//       CHECK:   %[[RV012:.+]] = arith.minnumf %[[V2]], %[[RV01]] : vector<2xf32>
 //       CHECK:   %[[V3:.+]] = vector.extract %[[TRANSPOSED]][3] : vector<2xf32> from vector<4x2xf32>
-//       CHECK:   %[[RESULT_VEC:.+]] = arith.minimumf %[[V3]], %[[RV012]] : vector<2xf32>
+//       CHECK:   %[[RESULT_VEC:.+]] = arith.minnumf %[[V3]], %[[RV012]] : vector<2xf32>
 //       CHECK:   return %[[RESULT_VEC]] : vector<2xf32>
 
 func.func @vector_multi_reduction_max(%arg0: vector<2x4xf32>, %acc: vector<2xf32>) -> vector<2xf32> {
@@ -45,13 +45,13 @@ func.func @vector_multi_reduction_max(%arg0: vector<2x4xf32>, %acc: vector<2xf32
 //  CHECK-SAME:   %[[INPUT:.+]]: vector<2x4xf32>, %[[ACC:.*]]: vector<2xf32>
 //       CHECK:   %[[TRANSPOSED:.+]] = vector.transpose %[[INPUT]], [1, 0] : vector<2x4xf32> to vector<4x2xf32>
 //       CHECK:   %[[V0:.+]] = vector.extract %[[TRANSPOSED]][0] : vector<2xf32> from vector<4x2xf32>
-//       CHECK:   %[[RV0:.+]] = arith.maximumf %[[V0]], %[[ACC]] : vector<2xf32>
+//       CHECK:   %[[RV0:.+]] = arith.maxnumf %[[V0]], %[[ACC]] : vector<2xf32>
 //       CHECK:   %[[V1:.+]] = vector.extract %[[TRANSPOSED]][1] : vector<2xf32> from vector<4x2xf32>
-//       CHECK:   %[[RV01:.+]] = arith.maximumf %[[V1]], %[[RV0]] : vector<2xf32>
+//       CHECK:   %[[RV01:.+]] = arith.maxnumf %[[V1]], %[[RV0]] : vector<2xf32>
 //       CHECK:   %[[V2:.+]] = vector.extract %[[TRANSPOSED]][2] : vector<2xf32> from vector<4x2xf32>
-//       CHECK:   %[[RV012:.+]] = arith.maximumf %[[V2]], %[[RV01]] : vector<2xf32>
+//       CHECK:   %[[RV012:.+]] = arith.maxnumf %[[V2]], %[[RV01]] : vector<2xf32>
 //       CHECK:   %[[V3:.+]] = vector.extract %[[TRANSPOSED]][3] : vector<2xf32> from vector<4x2xf32>
-//       CHECK:   %[[RESULT_VEC:.+]] = arith.maximumf %[[V3]], %[[RV012]] : vector<2xf32>
+//       CHECK:   %[[RESULT_VEC:.+]] = arith.maxnumf %[[V3]], %[[RV012]] : vector<2xf32>
 //       CHECK:   return %[[RESULT_VEC]] : vector<2xf32>
 
 func.func @vector_multi_reduction_and(%arg0: vector<2x4xi32>, %acc: vector<2xi32>) -> vector<2xi32> {

``````````

</details>


https://github.com/llvm/llvm-project/pull/75846


More information about the Mlir-commits mailing list