[Mlir-commits] [mlir] [mlir][vector] Improve `makeArithReduction` expansion (PR #75846)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Mon Dec 18 11:43:53 PST 2023
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir
Author: Jakub Kuderski (kuhar)
<details>
<summary>Changes</summary>
Propagate fast math flags.
Distinguish `minf`/`maxf` and `minimumf`/`maximumf`.
Required for future patterns in https://github.com/llvm/llvm-project/pull/75727.
---
Full diff: https://github.com/llvm/llvm-project/pull/75846.diff
6 Files Affected:
- (modified) mlir/include/mlir/Dialect/Vector/IR/VectorOps.h (+4-2)
- (modified) mlir/lib/Dialect/Vector/IR/VectorOps.cpp (+18-7)
- (modified) mlir/lib/Dialect/Vector/Transforms/LowerVectorContract.cpp (+2-1)
- (modified) mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir (+2-2)
- (modified) mlir/test/Dialect/Vector/canonicalize.mlir (+12)
- (modified) mlir/test/Dialect/Vector/vector-multi-reduction-outer-lowering.mlir (+8-8)
``````````diff
diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.h b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.h
index 59d585a77b1e29..a28b27e4e15816 100644
--- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.h
+++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.h
@@ -123,10 +123,12 @@ bool isDisjointTransferSet(VectorTransferOpInterface transferA,
VectorTransferOpInterface transferB,
bool testDynamicValueUsingBounds = false);
-/// Return the result value of reducing two scalar/vector values with the
+/// Returns the result value of reducing two scalar/vector values with the
/// corresponding arith operation.
Value makeArithReduction(OpBuilder &b, Location loc, CombiningKind kind,
- Value v1, Value acc, Value mask = Value());
+ Value v1, Value acc,
+ arith::FastMathFlagsAttr fastmath = nullptr,
+ Value mask = nullptr);
/// Returns true if `attr` has "parallel" iterator type semantics.
inline bool isParallelIterator(Attribute attr) {
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index 540959b486db9c..9f3e13c90a624d 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -507,8 +507,9 @@ struct ElideUnitDimsInMultiDimReduction
zeroIdx);
}
- Value result = vector::makeArithReduction(
- rewriter, loc, reductionOp.getKind(), acc, cast, mask);
+ Value result =
+ vector::makeArithReduction(rewriter, loc, reductionOp.getKind(), acc,
+ cast, /*fastmath=*/nullptr, mask);
rewriter.replaceOp(rootOp, result);
return success();
}
@@ -650,7 +651,8 @@ struct ElideSingleElementReduction : public OpRewritePattern<ReductionOp> {
if (Value acc = reductionOp.getAcc())
result = vector::makeArithReduction(rewriter, loc, reductionOp.getKind(),
- result, acc, mask);
+ result, acc,
+ reductionOp.getFastmathAttr(), mask);
rewriter.replaceOp(rootOp, result);
return success();
@@ -6212,6 +6214,7 @@ bool WarpExecuteOnLane0Op::areTypesCompatible(Type lhs, Type rhs) {
Value mlir::vector::makeArithReduction(OpBuilder &b, Location loc,
CombiningKind kind, Value v1, Value acc,
+ arith::FastMathFlagsAttr fastmath,
Value mask) {
Type t1 = getElementTypeOrSelf(v1.getType());
Type tAcc = getElementTypeOrSelf(acc.getType());
@@ -6222,7 +6225,7 @@ Value mlir::vector::makeArithReduction(OpBuilder &b, Location loc,
if (t1.isIntOrIndex() && tAcc.isIntOrIndex())
result = b.createOrFold<arith::AddIOp>(loc, v1, acc);
else if (llvm::isa<FloatType>(t1) && llvm::isa<FloatType>(tAcc))
- result = b.createOrFold<arith::AddFOp>(loc, v1, acc);
+ result = b.createOrFold<arith::AddFOp>(loc, v1, acc, fastmath);
else
llvm_unreachable("invalid value types for ADD reduction");
break;
@@ -6231,16 +6234,24 @@ Value mlir::vector::makeArithReduction(OpBuilder &b, Location loc,
result = b.createOrFold<arith::AndIOp>(loc, v1, acc);
break;
case CombiningKind::MAXF:
+ assert(llvm::isa<FloatType>(t1) && llvm::isa<FloatType>(tAcc) &&
+ "expected float values");
+ result = b.createOrFold<arith::MaxNumFOp>(loc, v1, acc, fastmath);
+ break;
case CombiningKind::MAXIMUMF:
assert(llvm::isa<FloatType>(t1) && llvm::isa<FloatType>(tAcc) &&
"expected float values");
- result = b.createOrFold<arith::MaximumFOp>(loc, v1, acc);
+ result = b.createOrFold<arith::MaximumFOp>(loc, v1, acc, fastmath);
break;
case CombiningKind::MINF:
+ assert(llvm::isa<FloatType>(t1) && llvm::isa<FloatType>(tAcc) &&
+ "expected float values");
+ result = b.createOrFold<arith::MinNumFOp>(loc, v1, acc, fastmath);
+ break;
case CombiningKind::MINIMUMF:
assert(llvm::isa<FloatType>(t1) && llvm::isa<FloatType>(tAcc) &&
"expected float values");
- result = b.createOrFold<arith::MinimumFOp>(loc, v1, acc);
+ result = b.createOrFold<arith::MinimumFOp>(loc, v1, acc, fastmath);
break;
case CombiningKind::MAXSI:
assert(t1.isIntOrIndex() && tAcc.isIntOrIndex() && "expected int values");
@@ -6262,7 +6273,7 @@ Value mlir::vector::makeArithReduction(OpBuilder &b, Location loc,
if (t1.isIntOrIndex() && tAcc.isIntOrIndex())
result = b.createOrFold<arith::MulIOp>(loc, v1, acc);
else if (llvm::isa<FloatType>(t1) && llvm::isa<FloatType>(tAcc))
- result = b.createOrFold<arith::MulFOp>(loc, v1, acc);
+ result = b.createOrFold<arith::MulFOp>(loc, v1, acc, fastmath);
else
llvm_unreachable("invalid value types for MUL reduction");
break;
diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorContract.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorContract.cpp
index 6dbe36e605e9a7..41ff0c18fe6258 100644
--- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorContract.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorContract.cpp
@@ -167,7 +167,8 @@ createContractArithOp(Location loc, Value x, Value y, Value acc,
if (!acc)
return std::optional<Value>(mul);
- return makeArithReduction(rewriter, loc, kind, mul, acc, mask);
+ return makeArithReduction(rewriter, loc, kind, mul, acc,
+ /*fastmath=*/nullptr, mask);
}
/// Return the positions of the reductions in the given map.
diff --git a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
index 012d30d96799f2..7353d16d79cea0 100644
--- a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
+++ b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
@@ -450,7 +450,7 @@ func.func @masked_float_max_outerprod(%arg0: vector<2xf32>, %arg1: f32, %arg2: v
// CHECK-LABEL: func.func @masked_float_max_outerprod(
// CHECK-SAME: %[[VAL_0:.*]]: vector<2xf32>, %[[VAL_1:.*]]: f32, %[[VAL_2:.*]]: vector<2xf32>, %[[VAL_3:.*]]: vector<2xi1>) -> vector<2xf32> {
// CHECK: %[[VAL_8:.*]] = arith.mulf %[[VAL_0]], %{{.*}} : vector<2xf32>
-// CHECK: %[[VAL_9:.*]] = arith.maximumf %[[VAL_8]], %[[VAL_2]] : vector<2xf32>
+// CHECK: %[[VAL_9:.*]] = arith.maxnumf %[[VAL_8]], %[[VAL_2]] : vector<2xf32>
// CHECK: %[[VAL_10:.*]] = arith.select %[[VAL_3]], %[[VAL_9]], %[[VAL_2]] : vector<2xi1>, vector<2xf32>
// -----
@@ -463,7 +463,7 @@ func.func @masked_float_min_outerprod(%arg0: vector<2xf32>, %arg1: f32, %arg2: v
// CHECK-LABEL: func.func @masked_float_min_outerprod(
// CHECK-SAME: %[[VAL_0:.*]]: vector<2xf32>, %[[VAL_1:.*]]: f32, %[[VAL_2:.*]]: vector<2xf32>, %[[VAL_3:.*]]: vector<2xi1>) -> vector<2xf32> {
// CHECK: %[[VAL_8:.*]] = arith.mulf %[[VAL_0]], %{{.*}} : vector<2xf32>
-// CHECK: %[[VAL_9:.*]] = arith.minimumf %[[VAL_8]], %[[VAL_2]] : vector<2xf32>
+// CHECK: %[[VAL_9:.*]] = arith.minnumf %[[VAL_8]], %[[VAL_2]] : vector<2xf32>
// CHECK: %[[VAL_10:.*]] = arith.select %[[VAL_3]], %[[VAL_9]], %[[VAL_2]] : vector<2xi1>, vector<2xf32>
// -----
diff --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir
index 1021c73cc57d34..b5164b66817352 100644
--- a/mlir/test/Dialect/Vector/canonicalize.mlir
+++ b/mlir/test/Dialect/Vector/canonicalize.mlir
@@ -2172,6 +2172,18 @@ func.func @reduce_one_element_vector_addf(%a : vector<1xf32>, %b: f32) -> f32 {
// -----
+// CHECK-LABEL: func @reduce_one_element_vector_addf_fastmath
+// CHECK-SAME: (%[[V:.+]]: vector<1xf32>, %[[B:.+]]: f32)
+// CHECK: %[[A:.+]] = vector.extract %[[V]][0] : f32 from vector<1xf32>
+// CHECK: %[[S:.+]] = arith.addf %[[A]], %arg1 fastmath<nnan,ninf> : f32
+// CHECK: return %[[S]]
+func.func @reduce_one_element_vector_addf_fastmath(%a : vector<1xf32>, %b: f32) -> f32 {
+ %s = vector.reduction <add>, %a, %b fastmath<nnan,ninf> : vector<1xf32> into f32
+ return %s : f32
+}
+
+// -----
+
// CHECK-LABEL: func @masked_reduce_one_element_vector_addf
// CHECK-SAME: %[[VAL_0:.*]]: vector<1xf32>, %[[VAL_1:.*]]: f32,
// CHECK-SAME: %[[VAL_2:.*]]: vector<1xi1>)
diff --git a/mlir/test/Dialect/Vector/vector-multi-reduction-outer-lowering.mlir b/mlir/test/Dialect/Vector/vector-multi-reduction-outer-lowering.mlir
index 12ea87ffb1413f..614a97fe4d6777 100644
--- a/mlir/test/Dialect/Vector/vector-multi-reduction-outer-lowering.mlir
+++ b/mlir/test/Dialect/Vector/vector-multi-reduction-outer-lowering.mlir
@@ -27,13 +27,13 @@ func.func @vector_multi_reduction_min(%arg0: vector<2x4xf32>, %acc: vector<2xf32
// CHECK-SAME: %[[INPUT:.+]]: vector<2x4xf32>, %[[ACC:.*]]: vector<2xf32>
// CHECK: %[[TRANSPOSED:.+]] = vector.transpose %[[INPUT]], [1, 0] : vector<2x4xf32> to vector<4x2xf32>
// CHECK: %[[V0:.+]] = vector.extract %[[TRANSPOSED]][0] : vector<2xf32> from vector<4x2xf32>
-// CHECK: %[[RV0:.+]] = arith.minimumf %[[V0]], %[[ACC]] : vector<2xf32>
+// CHECK: %[[RV0:.+]] = arith.minnumf %[[V0]], %[[ACC]] : vector<2xf32>
// CHECK: %[[V1:.+]] = vector.extract %[[TRANSPOSED]][1] : vector<2xf32> from vector<4x2xf32>
-// CHECK: %[[RV01:.+]] = arith.minimumf %[[V1]], %[[RV0]] : vector<2xf32>
+// CHECK: %[[RV01:.+]] = arith.minnumf %[[V1]], %[[RV0]] : vector<2xf32>
// CHECK: %[[V2:.+]] = vector.extract %[[TRANSPOSED]][2] : vector<2xf32> from vector<4x2xf32>
-// CHECK: %[[RV012:.+]] = arith.minimumf %[[V2]], %[[RV01]] : vector<2xf32>
+// CHECK: %[[RV012:.+]] = arith.minnumf %[[V2]], %[[RV01]] : vector<2xf32>
// CHECK: %[[V3:.+]] = vector.extract %[[TRANSPOSED]][3] : vector<2xf32> from vector<4x2xf32>
-// CHECK: %[[RESULT_VEC:.+]] = arith.minimumf %[[V3]], %[[RV012]] : vector<2xf32>
+// CHECK: %[[RESULT_VEC:.+]] = arith.minnumf %[[V3]], %[[RV012]] : vector<2xf32>
// CHECK: return %[[RESULT_VEC]] : vector<2xf32>
func.func @vector_multi_reduction_max(%arg0: vector<2x4xf32>, %acc: vector<2xf32>) -> vector<2xf32> {
@@ -45,13 +45,13 @@ func.func @vector_multi_reduction_max(%arg0: vector<2x4xf32>, %acc: vector<2xf32
// CHECK-SAME: %[[INPUT:.+]]: vector<2x4xf32>, %[[ACC:.*]]: vector<2xf32>
// CHECK: %[[TRANSPOSED:.+]] = vector.transpose %[[INPUT]], [1, 0] : vector<2x4xf32> to vector<4x2xf32>
// CHECK: %[[V0:.+]] = vector.extract %[[TRANSPOSED]][0] : vector<2xf32> from vector<4x2xf32>
-// CHECK: %[[RV0:.+]] = arith.maximumf %[[V0]], %[[ACC]] : vector<2xf32>
+// CHECK: %[[RV0:.+]] = arith.maxnumf %[[V0]], %[[ACC]] : vector<2xf32>
// CHECK: %[[V1:.+]] = vector.extract %[[TRANSPOSED]][1] : vector<2xf32> from vector<4x2xf32>
-// CHECK: %[[RV01:.+]] = arith.maximumf %[[V1]], %[[RV0]] : vector<2xf32>
+// CHECK: %[[RV01:.+]] = arith.maxnumf %[[V1]], %[[RV0]] : vector<2xf32>
// CHECK: %[[V2:.+]] = vector.extract %[[TRANSPOSED]][2] : vector<2xf32> from vector<4x2xf32>
-// CHECK: %[[RV012:.+]] = arith.maximumf %[[V2]], %[[RV01]] : vector<2xf32>
+// CHECK: %[[RV012:.+]] = arith.maxnumf %[[V2]], %[[RV01]] : vector<2xf32>
// CHECK: %[[V3:.+]] = vector.extract %[[TRANSPOSED]][3] : vector<2xf32> from vector<4x2xf32>
-// CHECK: %[[RESULT_VEC:.+]] = arith.maximumf %[[V3]], %[[RV012]] : vector<2xf32>
+// CHECK: %[[RESULT_VEC:.+]] = arith.maxnumf %[[V3]], %[[RV012]] : vector<2xf32>
// CHECK: return %[[RESULT_VEC]] : vector<2xf32>
func.func @vector_multi_reduction_and(%arg0: vector<2x4xi32>, %acc: vector<2xi32>) -> vector<2xi32> {
``````````
</details>
https://github.com/llvm/llvm-project/pull/75846
More information about the Mlir-commits
mailing list