[Mlir-commits] [mlir] f9d988f - [mlir][math] Added basic support for FPowI operation.
Slava Zakharin
llvmlistbot at llvm.org
Tue Aug 16 09:24:43 PDT 2022
Author: Slava Zakharin
Date: 2022-08-16T09:24:01-07:00
New Revision: f9d988f1acde945c5cbdabe468c147492108dc81
URL: https://github.com/llvm/llvm-project/commit/f9d988f1acde945c5cbdabe468c147492108dc81
DIFF: https://github.com/llvm/llvm-project/commit/f9d988f1acde945c5cbdabe468c147492108dc81.diff
LOG: [mlir][math] Added basic support for FPowI operation.
The operation computes pow(b, p), where 'b' is floating point
and 'p' is a signed integer. The result's type matches 'b' type.
The operands must have the same shape.
Differential Revision: https://reviews.llvm.org/D129811
Added:
Modified:
mlir/include/mlir/Dialect/Math/IR/MathOps.td
mlir/lib/Dialect/Math/Transforms/AlgebraicSimplification.cpp
mlir/test/Dialect/Math/algebraic-simplification.mlir
mlir/test/Dialect/Math/ops.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Math/IR/MathOps.td b/mlir/include/mlir/Dialect/Math/IR/MathOps.td
index 98ed2eed39c6..706e2adb5852 100644
--- a/mlir/include/mlir/Dialect/Math/IR/MathOps.td
+++ b/mlir/include/mlir/Dialect/Math/IR/MathOps.td
@@ -770,4 +770,51 @@ def Math_RoundOp : Math_FloatUnaryOp<"round"> {
}];
}
+//===----------------------------------------------------------------------===//
+// FPowIOp
+//===----------------------------------------------------------------------===//
+
+def Math_FPowIOp : Math_Op<"fpowi",
+ [SameOperandsAndResultShape, AllTypesMatch<["lhs", "result"]>]> {
+ let summary = "floating point raised to the signed integer power";
+ let description = [{
+ Syntax:
+
+ ```
+ operation ::= ssa-id `=` `math.fpowi` ssa-use `,` ssa-use `:` type
+ ```
+
+ The `fpowi` operation takes a `base` operand of floating point type
+ (i.e. scalar, tensor or vector) and a `power` operand of integer type
+ (also scalar, tensor or vector) and returns one result of the same type
+ as `base`. The result is `base` raised to the power of `power`.
+ The operation is elementwise for non-scalars, e.g.:
+
+ ```mlir
+ %v = math.fpowi %base, %power : vector<2xf32>, vector<2xi32
+ ```
+
+ The result is a vector of:
+
+ ```
+ [<math.fpowi %base[0], %power[0]>, <math.fpowi %base[1], %power[1]>]
+ ```
+
+ Example:
+
+ ```mlir
+ // Scalar exponentiation.
+ %a = math.fpowi %base, %power : f64, i32
+ ```
+ }];
+
+ let arguments = (ins FloatLike:$lhs, SignlessIntegerLike:$rhs);
+ let results = (outs FloatLike:$result);
+ let assemblyFormat = "$lhs `,` $rhs attr-dict `:` type($lhs) `,` type($rhs)";
+
+ // TODO: add a constant folder using pow[f] for cases, when
+ // the power argument is exactly representable in floating
+ // point type of the base.
+}
+
#endif // MATH_OPS
diff --git a/mlir/lib/Dialect/Math/Transforms/AlgebraicSimplification.cpp b/mlir/lib/Dialect/Math/Transforms/AlgebraicSimplification.cpp
index b967b8699b45..2c2821531e46 100644
--- a/mlir/lib/Dialect/Math/Transforms/AlgebraicSimplification.cpp
+++ b/mlir/lib/Dialect/Math/Transforms/AlgebraicSimplification.cpp
@@ -113,27 +113,31 @@ PowFStrengthReduction::matchAndRewrite(math::PowFOp op,
}
//----------------------------------------------------------------------------//
-// IPowIOp strength reduction.
+// FPowIOp/IPowIOp strength reduction.
//----------------------------------------------------------------------------//
namespace {
-struct IPowIStrengthReduction : public OpRewritePattern<math::IPowIOp> {
+template <typename PowIOpTy, typename DivOpTy, typename MulOpTy>
+struct PowIStrengthReduction : public OpRewritePattern<PowIOpTy> {
+
unsigned exponentThreshold;
public:
- IPowIStrengthReduction(MLIRContext *context, unsigned exponentThreshold = 3,
- PatternBenefit benefit = 1,
- ArrayRef<StringRef> generatedNames = {})
- : OpRewritePattern<math::IPowIOp>(context, benefit, generatedNames),
+ PowIStrengthReduction(MLIRContext *context, unsigned exponentThreshold = 3,
+ PatternBenefit benefit = 1,
+ ArrayRef<StringRef> generatedNames = {})
+ : OpRewritePattern<PowIOpTy>(context, benefit, generatedNames),
exponentThreshold(exponentThreshold) {}
- LogicalResult matchAndRewrite(math::IPowIOp op,
+
+ LogicalResult matchAndRewrite(PowIOpTy op,
PatternRewriter &rewriter) const final;
};
} // namespace
+template <typename PowIOpTy, typename DivOpTy, typename MulOpTy>
LogicalResult
-IPowIStrengthReduction::matchAndRewrite(math::IPowIOp op,
- PatternRewriter &rewriter) const {
+PowIStrengthReduction<PowIOpTy, DivOpTy, MulOpTy>::matchAndRewrite(
+ PowIOpTy op, PatternRewriter &rewriter) const {
Location loc = op.getLoc();
Value base = op.getLhs();
@@ -153,16 +157,23 @@ IPowIStrengthReduction::matchAndRewrite(math::IPowIOp op,
return failure();
// Maybe broadcasts scalar value into vector type compatible with `op`.
- auto bcast = [&](Value value) -> Value {
- if (auto vec = op.getType().dyn_cast<VectorType>())
+ auto bcast = [&loc, &op, &rewriter](Value value) -> Value {
+ if (auto vec = op.getType().template dyn_cast<VectorType>())
return rewriter.create<vector::BroadcastOp>(loc, vec, value);
return value;
};
+ Value one;
+ Type opType = getElementTypeOrSelf(op.getType());
+ if constexpr (std::is_same_v<PowIOpTy, math::FPowIOp>)
+ one = rewriter.create<arith::ConstantOp>(
+ loc, rewriter.getFloatAttr(opType, 1.0));
+ else
+ one = rewriter.create<arith::ConstantOp>(
+ loc, rewriter.getIntegerAttr(opType, 1));
+
+ // Replace `[fi]powi(x, 0)` with `1`.
if (exponentValue == 0) {
- // Replace `ipowi(x, 0)` with `1`.
- Value one = rewriter.create<arith::ConstantOp>(
- loc, rewriter.getIntegerAttr(getElementTypeOrSelf(op.getType()), 1));
rewriter.replaceOp(op, bcast(one));
return success();
}
@@ -178,25 +189,22 @@ IPowIStrengthReduction::matchAndRewrite(math::IPowIOp op,
return failure();
// Inverse the base for negative exponent, i.e. for
- // `ipowi(x, negative_exponent)` set `x` to `1 / x`.
- if (exponentIsNegative) {
- Value one = rewriter.create<arith::ConstantOp>(
- loc, rewriter.getIntegerAttr(getElementTypeOrSelf(op.getType()), 1));
- base = rewriter.create<arith::DivSIOp>(loc, bcast(one), base);
- }
+ // `[fi]powi(x, negative_exponent)` set `x` to `1 / x`.
+ if (exponentIsNegative)
+ base = rewriter.create<DivOpTy>(loc, bcast(one), base);
Value result = base;
// Transform to naive sequence of multiplications:
// * For positive exponent case replace:
- // `ipowi(x, positive_exponent)`
+ // `[fi]powi(x, positive_exponent)`
// with:
// x * x * x * ...
// * For negative exponent case replace:
- // `ipowi(x, negative_exponent)`
+ // `[fi]powi(x, negative_exponent)`
// with:
// (1 / x) * (1 / x) * (1 / x) * ...
for (unsigned i = 1; i < exponentValue; ++i)
- result = rewriter.create<arith::MulIOp>(loc, result, base);
+ result = rewriter.create<MulOpTy>(loc, result, base);
rewriter.replaceOp(op, result);
return success();
@@ -206,6 +214,9 @@ IPowIStrengthReduction::matchAndRewrite(math::IPowIOp op,
void mlir::populateMathAlgebraicSimplificationPatterns(
RewritePatternSet &patterns) {
- patterns.add<PowFStrengthReduction, IPowIStrengthReduction>(
- patterns.getContext());
+ patterns
+ .add<PowFStrengthReduction,
+ PowIStrengthReduction<math::IPowIOp, arith::DivSIOp, arith::MulIOp>,
+ PowIStrengthReduction<math::FPowIOp, arith::DivFOp, arith::MulFOp>>(
+ patterns.getContext());
}
diff --git a/mlir/test/Dialect/Math/algebraic-simplification.mlir b/mlir/test/Dialect/Math/algebraic-simplification.mlir
index 106f7fd0c4d0..806779ad9198 100644
--- a/mlir/test/Dialect/Math/algebraic-simplification.mlir
+++ b/mlir/test/Dialect/Math/algebraic-simplification.mlir
@@ -163,3 +163,93 @@ func.func @ipowi_exp_three(%arg0: i32, %arg1: vector<4xi32>) -> (i32, vector<4xi
%3 = math.ipowi %arg1, %vm1 : vector<4xi32>
return %0, %1, %2, %3 : i32, vector<4xi32>, i32, vector<4xi32>
}
+
+// CHECK-LABEL: @fpowi_zero_exp(
+// CHECK-SAME: %[[ARG0:.+]]: f32
+// CHECK-SAME: %[[ARG1:.+]]: vector<4xf32>
+// CHECK-SAME: -> (f32, vector<4xf32>) {
+func.func @fpowi_zero_exp(%arg0: f32, %arg1: vector<4xf32>) -> (f32, vector<4xf32>) {
+ // CHECK: %[[CST_S:.*]] = arith.constant 1.000000e+00 : f32
+ // CHECK: %[[CST_V:.*]] = arith.constant dense<1.000000e+00> : vector<4xf32>
+ // CHECK: return %[[CST_S]], %[[CST_V]]
+ %c = arith.constant 0 : i32
+ %v = arith.constant dense <0> : vector<4xi32>
+ %0 = math.fpowi %arg0, %c : f32, i32
+ %1 = math.fpowi %arg1, %v : vector<4xf32>, vector<4xi32>
+ return %0, %1 : f32, vector<4xf32>
+}
+
+// CHECK-LABEL: @fpowi_exp_one(
+// CHECK-SAME: %[[ARG0:.+]]: f32
+// CHECK-SAME: %[[ARG1:.+]]: vector<4xf32>
+// CHECK-SAME: -> (f32, vector<4xf32>, f32, vector<4xf32>) {
+func.func @fpowi_exp_one(%arg0: f32, %arg1: vector<4xf32>) -> (f32, vector<4xf32>, f32, vector<4xf32>) {
+ // CHECK: %[[CST_S:.*]] = arith.constant 1.000000e+00 : f32
+ // CHECK: %[[CST_V:.*]] = arith.constant dense<1.000000e+00> : vector<4xf32>
+ // CHECK: %[[SCALAR:.*]] = arith.divf %[[CST_S]], %[[ARG0]]
+ // CHECK: %[[VECTOR:.*]] = arith.divf %[[CST_V]], %[[ARG1]]
+ // CHECK: return %[[ARG0]], %[[ARG1]], %[[SCALAR]], %[[VECTOR]]
+ %c1 = arith.constant 1 : i32
+ %v1 = arith.constant dense <1> : vector<4xi32>
+ %0 = math.fpowi %arg0, %c1 : f32, i32
+ %1 = math.fpowi %arg1, %v1 : vector<4xf32>, vector<4xi32>
+ %cm1 = arith.constant -1 : i32
+ %vm1 = arith.constant dense <-1> : vector<4xi32>
+ %2 = math.fpowi %arg0, %cm1 : f32, i32
+ %3 = math.fpowi %arg1, %vm1 : vector<4xf32>, vector<4xi32>
+ return %0, %1, %2, %3 : f32, vector<4xf32>, f32, vector<4xf32>
+}
+
+// CHECK-LABEL: @fpowi_exp_two(
+// CHECK-SAME: %[[ARG0:.+]]: f32
+// CHECK-SAME: %[[ARG1:.+]]: vector<4xf32>
+// CHECK-SAME: -> (f32, vector<4xf32>, f32, vector<4xf32>) {
+func.func @fpowi_exp_two(%arg0: f32, %arg1: vector<4xf32>) -> (f32, vector<4xf32>, f32, vector<4xf32>) {
+ // CHECK: %[[CST_S:.*]] = arith.constant 1.000000e+00 : f32
+ // CHECK: %[[CST_V:.*]] = arith.constant dense<1.000000e+00> : vector<4xf32>
+ // CHECK: %[[SCALAR0:.*]] = arith.mulf %[[ARG0]], %[[ARG0]]
+ // CHECK: %[[VECTOR0:.*]] = arith.mulf %[[ARG1]], %[[ARG1]]
+ // CHECK: %[[SCALAR1:.*]] = arith.divf %[[CST_S]], %[[ARG0]]
+ // CHECK: %[[SMUL:.*]] = arith.mulf %[[SCALAR1]], %[[SCALAR1]]
+ // CHECK: %[[VECTOR1:.*]] = arith.divf %[[CST_V]], %[[ARG1]]
+ // CHECK: %[[VMUL:.*]] = arith.mulf %[[VECTOR1]], %[[VECTOR1]]
+ // CHECK: return %[[SCALAR0]], %[[VECTOR0]], %[[SMUL]], %[[VMUL]]
+ %c1 = arith.constant 2 : i32
+ %v1 = arith.constant dense <2> : vector<4xi32>
+ %0 = math.fpowi %arg0, %c1 : f32, i32
+ %1 = math.fpowi %arg1, %v1 : vector<4xf32>, vector<4xi32>
+ %cm1 = arith.constant -2 : i32
+ %vm1 = arith.constant dense <-2> : vector<4xi32>
+ %2 = math.fpowi %arg0, %cm1 : f32, i32
+ %3 = math.fpowi %arg1, %vm1 : vector<4xf32>, vector<4xi32>
+ return %0, %1, %2, %3 : f32, vector<4xf32>, f32, vector<4xf32>
+}
+
+// CHECK-LABEL: @fpowi_exp_three(
+// CHECK-SAME: %[[ARG0:.+]]: f32
+// CHECK-SAME: %[[ARG1:.+]]: vector<4xf32>
+// CHECK-SAME: -> (f32, vector<4xf32>, f32, vector<4xf32>) {
+func.func @fpowi_exp_three(%arg0: f32, %arg1: vector<4xf32>) -> (f32, vector<4xf32>, f32, vector<4xf32>) {
+ // CHECK: %[[CST_S:.*]] = arith.constant 1.000000e+00 : f32
+ // CHECK: %[[CST_V:.*]] = arith.constant dense<1.000000e+00> : vector<4xf32>
+ // CHECK: %[[SMUL0:.*]] = arith.mulf %[[ARG0]], %[[ARG0]]
+ // CHECK: %[[SCALAR0:.*]] = arith.mulf %[[SMUL0]], %[[ARG0]]
+ // CHECK: %[[VMUL0:.*]] = arith.mulf %[[ARG1]], %[[ARG1]]
+ // CHECK: %[[VECTOR0:.*]] = arith.mulf %[[VMUL0]], %[[ARG1]]
+ // CHECK: %[[SCALAR1:.*]] = arith.divf %[[CST_S]], %[[ARG0]]
+ // CHECK: %[[SMUL1:.*]] = arith.mulf %[[SCALAR1]], %[[SCALAR1]]
+ // CHECK: %[[SMUL2:.*]] = arith.mulf %[[SMUL1]], %[[SCALAR1]]
+ // CHECK: %[[VECTOR1:.*]] = arith.divf %[[CST_V]], %[[ARG1]]
+ // CHECK: %[[VMUL1:.*]] = arith.mulf %[[VECTOR1]], %[[VECTOR1]]
+ // CHECK: %[[VMUL2:.*]] = arith.mulf %[[VMUL1]], %[[VECTOR1]]
+ // CHECK: return %[[SCALAR0]], %[[VECTOR0]], %[[SMUL2]], %[[VMUL2]]
+ %c1 = arith.constant 3 : i32
+ %v1 = arith.constant dense <3> : vector<4xi32>
+ %0 = math.fpowi %arg0, %c1 : f32, i32
+ %1 = math.fpowi %arg1, %v1 : vector<4xf32>, vector<4xi32>
+ %cm1 = arith.constant -3 : i32
+ %vm1 = arith.constant dense <-3> : vector<4xi32>
+ %2 = math.fpowi %arg0, %cm1 : f32, i32
+ %3 = math.fpowi %arg1, %vm1 : vector<4xf32>, vector<4xi32>
+ return %0, %1, %2, %3 : f32, vector<4xf32>, f32, vector<4xf32>
+}
diff --git a/mlir/test/Dialect/Math/ops.mlir b/mlir/test/Dialect/Math/ops.mlir
index a2b959a783f2..c25d09734e22 100644
--- a/mlir/test/Dialect/Math/ops.mlir
+++ b/mlir/test/Dialect/Math/ops.mlir
@@ -158,6 +158,20 @@ func.func @powf(%f: f32, %v: vector<4xf32>, %t: tensor<4x4x?xf32>) {
return
}
+// CHECK-LABEL: func @fpowi(
+// CHECK-SAME: %[[SB:.*]]: f32, %[[SP:.*]]: i32,
+// CHECK-SAME: %[[VB:.*]]: vector<4xf64>, %[[VP:.*]]: vector<4xi16>,
+// CHECK-SAME: %[[TB:.*]]: tensor<4x3x?xf16>, %[[TP:.*]]: tensor<4x3x?xi64>) {
+func.func @fpowi(%b: f32, %p: i32, %vb: vector<4xf64>, %vp: vector<4xi16>, %tb: tensor<4x3x?xf16>, %tp: tensor<4x3x?xi64>) {
+// CHECK: {{.*}} = math.fpowi %[[SB]], %[[SP]] : f32, i32
+ %0 = math.fpowi %b, %p : f32, i32
+// CHECK: {{.*}} = math.fpowi %[[VB]], %[[VP]] : vector<4xf64>, vector<4xi16>
+ %1 = math.fpowi %vb, %vp : vector<4xf64>, vector<4xi16>
+// CHECK: {{.*}} = math.fpowi %[[TB]], %[[TP]] : tensor<4x3x?xf16>, tensor<4x3x?xi64>
+ %2 = math.fpowi %tb, %tp : tensor<4x3x?xf16>, tensor<4x3x?xi64>
+ return
+}
+
// CHECK-LABEL: func @rsqrt(
// CHECK-SAME: %[[F:.*]]: f32, %[[V:.*]]: vector<4xf32>, %[[T:.*]]: tensor<4x4x?xf32>)
func.func @rsqrt(%f: f32, %v: vector<4xf32>, %t: tensor<4x4x?xf32>) {
More information about the Mlir-commits
mailing list