[Mlir-commits] [mlir] f9d988f - [mlir][math] Added basic support for FPowI operation.

Slava Zakharin llvmlistbot at llvm.org
Tue Aug 16 09:24:43 PDT 2022


Author: Slava Zakharin
Date: 2022-08-16T09:24:01-07:00
New Revision: f9d988f1acde945c5cbdabe468c147492108dc81

URL: https://github.com/llvm/llvm-project/commit/f9d988f1acde945c5cbdabe468c147492108dc81
DIFF: https://github.com/llvm/llvm-project/commit/f9d988f1acde945c5cbdabe468c147492108dc81.diff

LOG: [mlir][math] Added basic support for FPowI operation.

The operation computes pow(b, p), where 'b' is floating point
and 'p' is a signed integer. The result's type matches 'b' type.
The operands must have the same shape.

Differential Revision: https://reviews.llvm.org/D129811

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Math/IR/MathOps.td
    mlir/lib/Dialect/Math/Transforms/AlgebraicSimplification.cpp
    mlir/test/Dialect/Math/algebraic-simplification.mlir
    mlir/test/Dialect/Math/ops.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Math/IR/MathOps.td b/mlir/include/mlir/Dialect/Math/IR/MathOps.td
index 98ed2eed39c6..706e2adb5852 100644
--- a/mlir/include/mlir/Dialect/Math/IR/MathOps.td
+++ b/mlir/include/mlir/Dialect/Math/IR/MathOps.td
@@ -770,4 +770,51 @@ def Math_RoundOp : Math_FloatUnaryOp<"round"> {
   }];
 }
 
+//===----------------------------------------------------------------------===//
+// FPowIOp
+//===----------------------------------------------------------------------===//
+
+def Math_FPowIOp : Math_Op<"fpowi",
+    [SameOperandsAndResultShape, AllTypesMatch<["lhs", "result"]>]> {
+  let summary = "floating point raised to the signed integer power";
+  let description = [{
+    Syntax:
+
+    ```
+    operation ::= ssa-id `=` `math.fpowi` ssa-use `,` ssa-use `:` type
+    ```
+
+    The `fpowi` operation takes a `base` operand of floating point type
+    (i.e. scalar, tensor or vector) and a `power` operand of integer type
+    (also scalar, tensor or vector) and returns one result of the same type
+    as `base`. The result is `base` raised to the power of `power`.
+    The operation is elementwise for non-scalars, e.g.:
+
+    ```mlir
+    %v = math.fpowi %base, %power : vector<2xf32>, vector<2xi32
+    ```
+
+    The result is a vector of:
+
+    ```
+    [<math.fpowi %base[0], %power[0]>, <math.fpowi %base[1], %power[1]>]
+    ```
+
+    Example:
+
+    ```mlir
+    // Scalar exponentiation.
+    %a = math.fpowi %base, %power : f64, i32
+    ```
+  }];
+
+  let arguments = (ins FloatLike:$lhs, SignlessIntegerLike:$rhs);
+  let results = (outs FloatLike:$result);
+  let assemblyFormat = "$lhs `,` $rhs attr-dict `:` type($lhs) `,` type($rhs)";
+
+  // TODO: add a constant folder using pow[f] for cases, when
+  //       the power argument is exactly representable in floating
+  //       point type of the base.
+}
+
 #endif // MATH_OPS

diff  --git a/mlir/lib/Dialect/Math/Transforms/AlgebraicSimplification.cpp b/mlir/lib/Dialect/Math/Transforms/AlgebraicSimplification.cpp
index b967b8699b45..2c2821531e46 100644
--- a/mlir/lib/Dialect/Math/Transforms/AlgebraicSimplification.cpp
+++ b/mlir/lib/Dialect/Math/Transforms/AlgebraicSimplification.cpp
@@ -113,27 +113,31 @@ PowFStrengthReduction::matchAndRewrite(math::PowFOp op,
 }
 
 //----------------------------------------------------------------------------//
-// IPowIOp strength reduction.
+// FPowIOp/IPowIOp strength reduction.
 //----------------------------------------------------------------------------//
 
 namespace {
-struct IPowIStrengthReduction : public OpRewritePattern<math::IPowIOp> {
+template <typename PowIOpTy, typename DivOpTy, typename MulOpTy>
+struct PowIStrengthReduction : public OpRewritePattern<PowIOpTy> {
+
   unsigned exponentThreshold;
 
 public:
-  IPowIStrengthReduction(MLIRContext *context, unsigned exponentThreshold = 3,
-                         PatternBenefit benefit = 1,
-                         ArrayRef<StringRef> generatedNames = {})
-      : OpRewritePattern<math::IPowIOp>(context, benefit, generatedNames),
+  PowIStrengthReduction(MLIRContext *context, unsigned exponentThreshold = 3,
+                        PatternBenefit benefit = 1,
+                        ArrayRef<StringRef> generatedNames = {})
+      : OpRewritePattern<PowIOpTy>(context, benefit, generatedNames),
         exponentThreshold(exponentThreshold) {}
-  LogicalResult matchAndRewrite(math::IPowIOp op,
+
+  LogicalResult matchAndRewrite(PowIOpTy op,
                                 PatternRewriter &rewriter) const final;
 };
 } // namespace
 
+template <typename PowIOpTy, typename DivOpTy, typename MulOpTy>
 LogicalResult
-IPowIStrengthReduction::matchAndRewrite(math::IPowIOp op,
-                                        PatternRewriter &rewriter) const {
+PowIStrengthReduction<PowIOpTy, DivOpTy, MulOpTy>::matchAndRewrite(
+    PowIOpTy op, PatternRewriter &rewriter) const {
   Location loc = op.getLoc();
   Value base = op.getLhs();
 
@@ -153,16 +157,23 @@ IPowIStrengthReduction::matchAndRewrite(math::IPowIOp op,
     return failure();
 
   // Maybe broadcasts scalar value into vector type compatible with `op`.
-  auto bcast = [&](Value value) -> Value {
-    if (auto vec = op.getType().dyn_cast<VectorType>())
+  auto bcast = [&loc, &op, &rewriter](Value value) -> Value {
+    if (auto vec = op.getType().template dyn_cast<VectorType>())
       return rewriter.create<vector::BroadcastOp>(loc, vec, value);
     return value;
   };
 
+  Value one;
+  Type opType = getElementTypeOrSelf(op.getType());
+  if constexpr (std::is_same_v<PowIOpTy, math::FPowIOp>)
+    one = rewriter.create<arith::ConstantOp>(
+        loc, rewriter.getFloatAttr(opType, 1.0));
+  else
+    one = rewriter.create<arith::ConstantOp>(
+        loc, rewriter.getIntegerAttr(opType, 1));
+
+  // Replace `[fi]powi(x, 0)` with `1`.
   if (exponentValue == 0) {
-    // Replace `ipowi(x, 0)` with `1`.
-    Value one = rewriter.create<arith::ConstantOp>(
-        loc, rewriter.getIntegerAttr(getElementTypeOrSelf(op.getType()), 1));
     rewriter.replaceOp(op, bcast(one));
     return success();
   }
@@ -178,25 +189,22 @@ IPowIStrengthReduction::matchAndRewrite(math::IPowIOp op,
     return failure();
 
   // Inverse the base for negative exponent, i.e. for
-  // `ipowi(x, negative_exponent)` set `x` to `1 / x`.
-  if (exponentIsNegative) {
-    Value one = rewriter.create<arith::ConstantOp>(
-        loc, rewriter.getIntegerAttr(getElementTypeOrSelf(op.getType()), 1));
-    base = rewriter.create<arith::DivSIOp>(loc, bcast(one), base);
-  }
+  // `[fi]powi(x, negative_exponent)` set `x` to `1 / x`.
+  if (exponentIsNegative)
+    base = rewriter.create<DivOpTy>(loc, bcast(one), base);
 
   Value result = base;
   // Transform to naive sequence of multiplications:
   //   * For positive exponent case replace:
-  //       `ipowi(x, positive_exponent)`
+  //       `[fi]powi(x, positive_exponent)`
   //     with:
   //       x * x * x * ...
   //   * For negative exponent case replace:
-  //       `ipowi(x, negative_exponent)`
+  //       `[fi]powi(x, negative_exponent)`
   //     with:
   //       (1 / x) * (1 / x) * (1 / x) * ...
   for (unsigned i = 1; i < exponentValue; ++i)
-    result = rewriter.create<arith::MulIOp>(loc, result, base);
+    result = rewriter.create<MulOpTy>(loc, result, base);
 
   rewriter.replaceOp(op, result);
   return success();
@@ -206,6 +214,9 @@ IPowIStrengthReduction::matchAndRewrite(math::IPowIOp op,
 
 void mlir::populateMathAlgebraicSimplificationPatterns(
     RewritePatternSet &patterns) {
-  patterns.add<PowFStrengthReduction, IPowIStrengthReduction>(
-      patterns.getContext());
+  patterns
+      .add<PowFStrengthReduction,
+           PowIStrengthReduction<math::IPowIOp, arith::DivSIOp, arith::MulIOp>,
+           PowIStrengthReduction<math::FPowIOp, arith::DivFOp, arith::MulFOp>>(
+          patterns.getContext());
 }

diff  --git a/mlir/test/Dialect/Math/algebraic-simplification.mlir b/mlir/test/Dialect/Math/algebraic-simplification.mlir
index 106f7fd0c4d0..806779ad9198 100644
--- a/mlir/test/Dialect/Math/algebraic-simplification.mlir
+++ b/mlir/test/Dialect/Math/algebraic-simplification.mlir
@@ -163,3 +163,93 @@ func.func @ipowi_exp_three(%arg0: i32, %arg1: vector<4xi32>) -> (i32, vector<4xi
   %3 = math.ipowi %arg1, %vm1 : vector<4xi32>
   return %0, %1, %2, %3 : i32, vector<4xi32>, i32, vector<4xi32>
 }
+
+// CHECK-LABEL: @fpowi_zero_exp(
+// CHECK-SAME: %[[ARG0:.+]]: f32
+// CHECK-SAME: %[[ARG1:.+]]: vector<4xf32>
+// CHECK-SAME: -> (f32, vector<4xf32>) {
+func.func @fpowi_zero_exp(%arg0: f32, %arg1: vector<4xf32>) -> (f32, vector<4xf32>) {
+  // CHECK: %[[CST_S:.*]] = arith.constant 1.000000e+00 : f32
+  // CHECK: %[[CST_V:.*]] = arith.constant dense<1.000000e+00> : vector<4xf32>
+  // CHECK: return %[[CST_S]], %[[CST_V]]
+  %c = arith.constant 0 : i32
+  %v = arith.constant dense <0> : vector<4xi32>
+  %0 = math.fpowi %arg0, %c : f32, i32
+  %1 = math.fpowi %arg1, %v : vector<4xf32>, vector<4xi32>
+  return %0, %1 : f32, vector<4xf32>
+}
+
+// CHECK-LABEL: @fpowi_exp_one(
+// CHECK-SAME: %[[ARG0:.+]]: f32
+// CHECK-SAME: %[[ARG1:.+]]: vector<4xf32>
+// CHECK-SAME: -> (f32, vector<4xf32>, f32, vector<4xf32>) {
+func.func @fpowi_exp_one(%arg0: f32, %arg1: vector<4xf32>) -> (f32, vector<4xf32>, f32, vector<4xf32>) {
+  // CHECK: %[[CST_S:.*]] = arith.constant 1.000000e+00 : f32
+  // CHECK: %[[CST_V:.*]] = arith.constant dense<1.000000e+00> : vector<4xf32>
+  // CHECK: %[[SCALAR:.*]] = arith.divf %[[CST_S]], %[[ARG0]]
+  // CHECK: %[[VECTOR:.*]] = arith.divf %[[CST_V]], %[[ARG1]]
+  // CHECK: return %[[ARG0]], %[[ARG1]], %[[SCALAR]], %[[VECTOR]]
+  %c1 = arith.constant 1 : i32
+  %v1 = arith.constant dense <1> : vector<4xi32>
+  %0 = math.fpowi %arg0, %c1 : f32, i32
+  %1 = math.fpowi %arg1, %v1 : vector<4xf32>, vector<4xi32>
+  %cm1 = arith.constant -1 : i32
+  %vm1 = arith.constant dense <-1> : vector<4xi32>
+  %2 = math.fpowi %arg0, %cm1 : f32, i32
+  %3 = math.fpowi %arg1, %vm1 : vector<4xf32>, vector<4xi32>
+  return %0, %1, %2, %3 : f32, vector<4xf32>, f32, vector<4xf32>
+}
+
+// CHECK-LABEL: @fpowi_exp_two(
+// CHECK-SAME: %[[ARG0:.+]]: f32
+// CHECK-SAME: %[[ARG1:.+]]: vector<4xf32>
+// CHECK-SAME: -> (f32, vector<4xf32>, f32, vector<4xf32>) {
+func.func @fpowi_exp_two(%arg0: f32, %arg1: vector<4xf32>) -> (f32, vector<4xf32>, f32, vector<4xf32>) {
+  // CHECK: %[[CST_S:.*]] = arith.constant 1.000000e+00 : f32
+  // CHECK: %[[CST_V:.*]] = arith.constant dense<1.000000e+00> : vector<4xf32>
+  // CHECK: %[[SCALAR0:.*]] = arith.mulf %[[ARG0]], %[[ARG0]]
+  // CHECK: %[[VECTOR0:.*]] = arith.mulf %[[ARG1]], %[[ARG1]]
+  // CHECK: %[[SCALAR1:.*]] = arith.divf %[[CST_S]], %[[ARG0]]
+  // CHECK: %[[SMUL:.*]] = arith.mulf %[[SCALAR1]], %[[SCALAR1]]
+  // CHECK: %[[VECTOR1:.*]] = arith.divf %[[CST_V]], %[[ARG1]]
+  // CHECK: %[[VMUL:.*]] = arith.mulf %[[VECTOR1]], %[[VECTOR1]]
+  // CHECK: return %[[SCALAR0]], %[[VECTOR0]], %[[SMUL]], %[[VMUL]]
+  %c1 = arith.constant 2 : i32
+  %v1 = arith.constant dense <2> : vector<4xi32>
+  %0 = math.fpowi %arg0, %c1 : f32, i32
+  %1 = math.fpowi %arg1, %v1 : vector<4xf32>, vector<4xi32>
+  %cm1 = arith.constant -2 : i32
+  %vm1 = arith.constant dense <-2> : vector<4xi32>
+  %2 = math.fpowi %arg0, %cm1 : f32, i32
+  %3 = math.fpowi %arg1, %vm1 : vector<4xf32>, vector<4xi32>
+  return %0, %1, %2, %3 : f32, vector<4xf32>, f32, vector<4xf32>
+}
+
+// CHECK-LABEL: @fpowi_exp_three(
+// CHECK-SAME: %[[ARG0:.+]]: f32
+// CHECK-SAME: %[[ARG1:.+]]: vector<4xf32>
+// CHECK-SAME: -> (f32, vector<4xf32>, f32, vector<4xf32>) {
+func.func @fpowi_exp_three(%arg0: f32, %arg1: vector<4xf32>) -> (f32, vector<4xf32>, f32, vector<4xf32>) {
+  // CHECK: %[[CST_S:.*]] = arith.constant 1.000000e+00 : f32
+  // CHECK: %[[CST_V:.*]] = arith.constant dense<1.000000e+00> : vector<4xf32>
+  // CHECK: %[[SMUL0:.*]] = arith.mulf %[[ARG0]], %[[ARG0]]
+  // CHECK: %[[SCALAR0:.*]] = arith.mulf %[[SMUL0]], %[[ARG0]]
+  // CHECK: %[[VMUL0:.*]] = arith.mulf %[[ARG1]], %[[ARG1]]
+  // CHECK: %[[VECTOR0:.*]] = arith.mulf %[[VMUL0]], %[[ARG1]]
+  // CHECK: %[[SCALAR1:.*]] = arith.divf %[[CST_S]], %[[ARG0]]
+  // CHECK: %[[SMUL1:.*]] = arith.mulf %[[SCALAR1]], %[[SCALAR1]]
+  // CHECK: %[[SMUL2:.*]] = arith.mulf %[[SMUL1]], %[[SCALAR1]]
+  // CHECK: %[[VECTOR1:.*]] = arith.divf %[[CST_V]], %[[ARG1]]
+  // CHECK: %[[VMUL1:.*]] = arith.mulf %[[VECTOR1]], %[[VECTOR1]]
+  // CHECK: %[[VMUL2:.*]] = arith.mulf %[[VMUL1]], %[[VECTOR1]]
+  // CHECK: return %[[SCALAR0]], %[[VECTOR0]], %[[SMUL2]], %[[VMUL2]]
+  %c1 = arith.constant 3 : i32
+  %v1 = arith.constant dense <3> : vector<4xi32>
+  %0 = math.fpowi %arg0, %c1 : f32, i32
+  %1 = math.fpowi %arg1, %v1 : vector<4xf32>, vector<4xi32>
+  %cm1 = arith.constant -3 : i32
+  %vm1 = arith.constant dense <-3> : vector<4xi32>
+  %2 = math.fpowi %arg0, %cm1 : f32, i32
+  %3 = math.fpowi %arg1, %vm1 : vector<4xf32>, vector<4xi32>
+  return %0, %1, %2, %3 : f32, vector<4xf32>, f32, vector<4xf32>
+}

diff  --git a/mlir/test/Dialect/Math/ops.mlir b/mlir/test/Dialect/Math/ops.mlir
index a2b959a783f2..c25d09734e22 100644
--- a/mlir/test/Dialect/Math/ops.mlir
+++ b/mlir/test/Dialect/Math/ops.mlir
@@ -158,6 +158,20 @@ func.func @powf(%f: f32, %v: vector<4xf32>, %t: tensor<4x4x?xf32>) {
   return
 }
 
+// CHECK-LABEL: func @fpowi(
+// CHECK-SAME: %[[SB:.*]]: f32, %[[SP:.*]]: i32,
+// CHECK-SAME: %[[VB:.*]]: vector<4xf64>, %[[VP:.*]]: vector<4xi16>,
+// CHECK-SAME: %[[TB:.*]]: tensor<4x3x?xf16>, %[[TP:.*]]: tensor<4x3x?xi64>) {
+func.func @fpowi(%b: f32, %p: i32, %vb: vector<4xf64>, %vp: vector<4xi16>, %tb: tensor<4x3x?xf16>, %tp: tensor<4x3x?xi64>) {
+// CHECK: {{.*}} = math.fpowi %[[SB]], %[[SP]] : f32, i32
+  %0 = math.fpowi %b, %p : f32, i32
+// CHECK: {{.*}} = math.fpowi %[[VB]], %[[VP]] : vector<4xf64>, vector<4xi16>
+  %1 = math.fpowi %vb, %vp : vector<4xf64>, vector<4xi16>
+// CHECK: {{.*}} = math.fpowi %[[TB]], %[[TP]] : tensor<4x3x?xf16>, tensor<4x3x?xi64>
+  %2 = math.fpowi %tb, %tp : tensor<4x3x?xf16>, tensor<4x3x?xi64>
+  return
+}
+
 // CHECK-LABEL: func @rsqrt(
 // CHECK-SAME:              %[[F:.*]]: f32, %[[V:.*]]: vector<4xf32>, %[[T:.*]]: tensor<4x4x?xf32>)
 func.func @rsqrt(%f: f32, %v: vector<4xf32>, %t: tensor<4x4x?xf32>) {


        


More information about the Mlir-commits mailing list