[Mlir-commits] [mlir] [mlir][math] Expand powfI operation for constant power operand. (PR #87081)
Prashant Kumar
llvmlistbot at llvm.org
Sat Mar 30 06:26:56 PDT 2024
https://github.com/pashu123 updated https://github.com/llvm/llvm-project/pull/87081
>From 58c30b405108d688cfac55d937b0a3c7c62b0b8d Mon Sep 17 00:00:00 2001
From: Prashant Kumar <pk5561 at gmail.com>
Date: Thu, 28 Mar 2024 23:43:00 -0400
Subject: [PATCH] [mlir][math] Expand powfI operation for constant power
operand.
---
.../mlir/Dialect/Math/Transforms/Passes.h | 1 +
.../Math/Transforms/ExpandPatterns.cpp | 68 +++++++++++++++
mlir/test/Dialect/Math/expand-math.mlir | 85 +++++++++++++++++++
mlir/test/lib/Dialect/Math/TestExpandMath.cpp | 1 +
4 files changed, 155 insertions(+)
diff --git a/mlir/include/mlir/Dialect/Math/Transforms/Passes.h b/mlir/include/mlir/Dialect/Math/Transforms/Passes.h
index 11b2c7a7afa2f7..e2c513047c77a5 100644
--- a/mlir/include/mlir/Dialect/Math/Transforms/Passes.h
+++ b/mlir/include/mlir/Dialect/Math/Transforms/Passes.h
@@ -36,6 +36,7 @@ void populateExpandFloorFPattern(RewritePatternSet &patterns);
void populateExpandCeilFPattern(RewritePatternSet &patterns);
void populateExpandExp2FPattern(RewritePatternSet &patterns);
void populateExpandPowFPattern(RewritePatternSet &patterns);
+void populateExpandFPowIPattern(RewritePatternSet &patterns);
void populateExpandRoundFPattern(RewritePatternSet &patterns);
void populateExpandRoundEvenPattern(RewritePatternSet &patterns);
void populateMathAlgebraicSimplificationPatterns(RewritePatternSet &patterns);
diff --git a/mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp b/mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp
index e1ab9c905447b7..1d79ac1422bae5 100644
--- a/mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp
+++ b/mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp
@@ -202,6 +202,70 @@ static LogicalResult convertCeilOp(math::CeilOp op, PatternRewriter &rewriter) {
rewriter.replaceOp(op, ret);
return success();
}
+
+// Convert `math.fpowi` to a series of `arith.mulf` operations.
+// If the power is negative, we divide one by the result.
+static LogicalResult convertFPowICstOp(math::FPowIOp op,
+ PatternRewriter &rewriter) {
+ ImplicitLocOpBuilder b(op->getLoc(), rewriter);
+ Value base = op.getOperand(0);
+ Value power = op.getOperand(1);
+ Type baseType = base.getType();
+ Value tempBase = op.getOperand(0);
+
+ Attribute cstAttr;
+ if (!matchPattern(power, m_Constant(&cstAttr)))
+ return failure();
+
+ int64_t powerInt;
+
+ // Check for Splat or Integer Attrs.
+ if (auto splatAttr = dyn_cast<SplatElementsAttr>(cstAttr)) {
+ powerInt = splatAttr.getSplatValue<int64_t>();
+ } else if (auto iAttr = dyn_cast<IntegerAttr>(cstAttr)) {
+ powerInt = iAttr.getInt();
+ } else {
+ return failure();
+ }
+
+ bool isNegative = powerInt < 0;
+ int64_t absPower = std::abs(powerInt);
+ Value one = createFloatConst(op->getLoc(), baseType, 1.00, rewriter);
+ Value res = createFloatConst(op->getLoc(), baseType, 1.00, rewriter);
+
+ Value zero = createFloatConst(op->getLoc(), baseType, 0.00, rewriter);
+ Value negZero = createFloatConst(op->getLoc(), baseType, -0.00, rewriter);
+ Value posInfinity =
+ createFloatConst(op->getLoc(), baseType,
+ std::numeric_limits<double_t>::infinity(), rewriter);
+ Value negInfinity =
+ createFloatConst(op->getLoc(), baseType,
+ -std::numeric_limits<double_t>::infinity(), rewriter);
+
+ while (absPower > 0) {
+ if (absPower & 1)
+ res = b.create<arith::MulFOp>(baseType, tempBase, res);
+ absPower >>= 1;
+ tempBase = b.create<arith::MulFOp>(baseType, tempBase, tempBase);
+ }
+
+ // Take care of UB in case of negative power.
+ if (isNegative) {
+ res = b.create<arith::DivFOp>(baseType, one, res);
+ Value zeroEqCheck =
+ b.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, base, zero);
+ Value negZeroEqCheck =
+ b.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, base, negZero);
+ res =
+ b.create<arith::SelectOp>(op->getLoc(), zeroEqCheck, posInfinity, res);
+ res = b.create<arith::SelectOp>(op->getLoc(), negZeroEqCheck, negInfinity,
+ res);
+ }
+
+ rewriter.replaceOp(op, res);
+ return success();
+}
+
// Converts Powf(float a, float b) (meaning a^b) to exp^(b * ln(a))
static LogicalResult convertPowfOp(math::PowFOp op, PatternRewriter &rewriter) {
ImplicitLocOpBuilder b(op->getLoc(), rewriter);
@@ -517,6 +581,10 @@ void mlir::populateExpandPowFPattern(RewritePatternSet &patterns) {
patterns.add(convertPowfOp);
}
+void mlir::populateExpandFPowIPattern(RewritePatternSet &patterns) {
+ patterns.add(convertFPowICstOp);
+}
+
void mlir::populateExpandRoundFPattern(RewritePatternSet &patterns) {
patterns.add(convertRoundOp);
}
diff --git a/mlir/test/Dialect/Math/expand-math.mlir b/mlir/test/Dialect/Math/expand-math.mlir
index 6326d3a71874b4..1e6ccc51c05a5d 100644
--- a/mlir/test/Dialect/Math/expand-math.mlir
+++ b/mlir/test/Dialect/Math/expand-math.mlir
@@ -511,3 +511,88 @@ func.func @roundeven16(%arg: f16) -> f16 {
// CHECK: %[[COPYSIGN:.*]] = math.copysign %[[RESULT]], %[[VAL_0]] : f16
// CHECK: return %[[COPYSIGN]] : f16
+
+// -----
+
+// CHECK-LABEL: func.func @math_fpowi_neg_odd_power
+func.func @math_fpowi_neg_odd_power(%0 : tensor<8xf32>) -> tensor<8xf32> {
+ %1 = arith.constant dense<-3> : tensor<8xi64>
+ %2 = math.fpowi %0, %1 : tensor<8xf32>, tensor<8xi64>
+ return %2 : tensor<8xf32>
+}
+// CHECK-SAME: (%[[ARG0:.*]]: tensor<8xf32>) -> tensor<8xf32> {
+// CHECK-DAG: %[[CST1:.*]] = arith.constant dense<1.000000e+00> : tensor<8xf32>
+// CHECK-DAG: %[[CST0:.*]] = arith.constant dense<0.000000e+00> : tensor<8xf32>
+// CHECK-DAG: %[[CSTNEG0:.*]] = arith.constant dense<-0.000000e+00> : tensor<8xf32>
+// CHECK-DAG: %[[CSTINF:.*]] = arith.constant dense<0x7F800000> : tensor<8xf32>
+// CHECK-DAG: %[[CSTNEGINF:.*]] = arith.constant dense<0xFF800000> : tensor<8xf32>
+// CHECK: %[[SQ:.*]] = arith.mulf %[[ARG0]], %[[ARG0]] : tensor<8xf32>
+// CHECK: %[[CUBE:.*]] = arith.mulf %[[SQ]], %[[ARG0]] : tensor<8xf32>
+// CHECK: %[[INV:.*]] = arith.divf %[[CST1]], %[[CUBE]] : tensor<8xf32>
+// CHECK: %[[CMP0:.*]] = arith.cmpf oeq, %[[ARG0]], %[[CST0]] : tensor<8xf32>
+// CHECK: %[[CMPNEG0:.*]] = arith.cmpf oeq, %[[ARG0]], %[[CSTNEG0]] : tensor<8xf32>
+// CHECK: %[[UB1:.*]] = arith.select %[[CMP0]], %[[CSTINF]], %[[INV]] : tensor<8xi1>, tensor<8xf32>
+// CHECK: %[[UB2:.*]] = arith.select %[[CMPNEG0]], %[[CSTNEGINF]], %[[UB1]] : tensor<8xi1>, tensor<8xf32>
+// CHECK: return %[[UB2]] : tensor<8xf32>
+
+// -----
+
+// CHECK-LABEL: func.func @math_fpowi_neg_even_power
+func.func @math_fpowi_neg_even_power(%0 : tensor<8xf32>) -> tensor<8xf32> {
+ %1 = arith.constant dense<-4> : tensor<8xi64>
+ %2 = math.fpowi %0, %1 : tensor<8xf32>, tensor<8xi64>
+ return %2 : tensor<8xf32>
+}
+// CHECK-SAME: (%[[ARG0:.*]]: tensor<8xf32>) -> tensor<8xf32> {
+// CHECK-DAG: %[[CST1:.*]] = arith.constant dense<1.000000e+00> : tensor<8xf32>
+// CHECK-DAG: %[[CST0:.*]] = arith.constant dense<0.000000e+00> : tensor<8xf32>
+// CHECK-DAG: %[[CSTNEG0:.*]] = arith.constant dense<-0.000000e+00> : tensor<8xf32>
+// CHECK-DAG: %[[CSTINF:.*]] = arith.constant dense<0x7F800000> : tensor<8xf32>
+// CHECK-DAG: %[[CSTNEGINF:.*]] = arith.constant dense<0xFF800000> : tensor<8xf32>
+// CHECK: %[[SQ:.*]] = arith.mulf %[[ARG0]], %[[ARG0]] : tensor<8xf32>
+// CHECK: %[[PW4:.*]] = arith.mulf %[[SQ]], %[[SQ]] : tensor<8xf32>
+// CHECK: %[[INV:.*]] = arith.divf %[[CST1]], %[[PW4]] : tensor<8xf32>
+// CHECK: %[[CMP0:.*]] = arith.cmpf oeq, %[[ARG0]], %[[CST0]] : tensor<8xf32>
+// CHECK: %[[CMPNEG0:.*]] = arith.cmpf oeq, %[[ARG0]], %[[CSTNEG0]] : tensor<8xf32>
+// CHECK: %[[UB1:.*]] = arith.select %[[CMP0]], %[[CSTINF]], %[[INV]] : tensor<8xi1>, tensor<8xf32>
+// CHECK: %[[UB2:.*]] = arith.select %[[CMPNEG0]], %[[CSTNEGINF]], %[[UB1]] : tensor<8xi1>, tensor<8xf32>
+// CHECK: return %[[UB2]] : tensor<8xf32>
+
+// -----
+
+// CHECK-LABEL: func.func @math_fpowi_pos_odd_power
+func.func @math_fpowi_pos_odd_power(%0 : tensor<8xf32>) -> tensor<8xf32> {
+ %1 = arith.constant dense<5> : tensor<8xi64>
+ %2 = math.fpowi %0, %1 : tensor<8xf32>, tensor<8xi64>
+ return %2 : tensor<8xf32>
+}
+// CHECK-SAME: (%[[ARG0:.*]]: tensor<8xf32>) -> tensor<8xf32> {
+// CHECK: %[[SQ:.*]] = arith.mulf %[[ARG0]], %[[ARG0]] : tensor<8xf32>
+// CHECK: %[[PW4:.*]] = arith.mulf %[[SQ]], %[[SQ]] : tensor<8xf32>
+// CHECK: %[[PW5:.*]] = arith.mulf %[[PW4]], %[[ARG0]] : tensor<8xf32>
+// CHECK: return %[[PW5]] : tensor<8xf32>
+
+// -----
+
+// CHECK-LABEL: func.func @math_fpowi_pos_even_power
+func.func @math_fpowi_pos_even_power(%0 : tensor<8xf32>) -> tensor<8xf32> {
+ %1 = arith.constant dense<4> : tensor<8xi64>
+ %2 = math.fpowi %0, %1 : tensor<8xf32>, tensor<8xi64>
+ return %2 : tensor<8xf32>
+}
+// CHECK-SAME: (%[[ARG0:.*]]: tensor<8xf32>) -> tensor<8xf32> {
+// CHECK: %[[SQ:.*]] = arith.mulf %[[ARG0]], %[[ARG0]] : tensor<8xf32>
+// CHECK: %[[PW4:.*]] = arith.mulf %[[SQ]], %[[SQ]] : tensor<8xf32>
+// CHECK: return %[[PW4]] : tensor<8xf32>
+
+// -----
+
+// CHECK-LABEL: func.func @math_fpowi_even_scalar
+func.func @math_fpowi_even_scalar(%0 : f32) -> f32 {
+ %pow = arith.constant 2 : i64
+ %2 = math.fpowi %0, %pow : f32, i64
+ return %2 : f32
+}
+// CHECK-SAME: (%[[ARG0:.*]]: f32) -> f32 {
+// CHECK: %[[SQ:.*]] = arith.mulf %[[ARG0]], %[[ARG0]] : f32
+// CHECK: return %[[SQ]] : f32
diff --git a/mlir/test/lib/Dialect/Math/TestExpandMath.cpp b/mlir/test/lib/Dialect/Math/TestExpandMath.cpp
index 7ce8b5a7cfe9b3..97600ad1ebe7a3 100644
--- a/mlir/test/lib/Dialect/Math/TestExpandMath.cpp
+++ b/mlir/test/lib/Dialect/Math/TestExpandMath.cpp
@@ -46,6 +46,7 @@ void TestExpandMathPass::runOnOperation() {
populateExpandFloorFPattern(patterns);
populateExpandCeilFPattern(patterns);
populateExpandPowFPattern(patterns);
+ populateExpandFPowIPattern(patterns);
populateExpandRoundFPattern(patterns);
populateExpandRoundEvenPattern(patterns);
(void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
More information about the Mlir-commits
mailing list