[Mlir-commits] [mlir] [mlir][math] powi with negative exponent should invert at the end (PR #135735)

Asher Mancinelli llvmlistbot at llvm.org
Mon Apr 14 19:46:17 PDT 2025


https://github.com/ashermancinelli created https://github.com/llvm/llvm-project/pull/135735

Previously, an FPowI operation would invert the base *before* performing a sequence of multiplications, but this led to discrepancies between LLVM pow intrinsic folding and that coming from the math dialect.

See compiler-rt's version, which does the inverse at the end of the calculation: compiler-rt/lib/builtins/powidf2.c

>From 941668627964ff132dfdd915dff8ea8c5b6888d1 Mon Sep 17 00:00:00 2001
From: Asher Mancinelli <ashermancinelli at gmail.com>
Date: Mon, 14 Apr 2025 16:40:29 -0700
Subject: [PATCH] [mlir][math] powi with negative exponent should invert at the
 end

Previously, an FPowI operation would invert the base *before*
performing a sequence of multiplications, but this led to
discrepancies between LLVM pow intrinsic folding and that
coming from the math dialect.

See compiler-rt's version, which does the inverse at the end of the
calculation: compiler-rt/lib/builtins/powidf2.c
---
 .../Transforms/AlgebraicSimplification.cpp    | 10 ++--
 .../Math/algebraic-simplification.mlir        | 48 +++++++++----------
 2 files changed, 29 insertions(+), 29 deletions(-)

diff --git a/mlir/lib/Dialect/Math/Transforms/AlgebraicSimplification.cpp b/mlir/lib/Dialect/Math/Transforms/AlgebraicSimplification.cpp
index dcace489673f0..13e2a4b5541b2 100644
--- a/mlir/lib/Dialect/Math/Transforms/AlgebraicSimplification.cpp
+++ b/mlir/lib/Dialect/Math/Transforms/AlgebraicSimplification.cpp
@@ -197,11 +197,6 @@ PowIStrengthReduction<PowIOpTy, DivOpTy, MulOpTy>::matchAndRewrite(
   if (exponentValue > exponentThreshold)
     return failure();
 
-  // Inverse the base for negative exponent, i.e. for
-  // `[fi]powi(x, negative_exponent)` set `x` to `1 / x`.
-  if (exponentIsNegative)
-    base = rewriter.create<DivOpTy>(loc, bcast(one), base);
-
   Value result = base;
   // Transform to naive sequence of multiplications:
   //   * For positive exponent case replace:
@@ -215,6 +210,11 @@ PowIStrengthReduction<PowIOpTy, DivOpTy, MulOpTy>::matchAndRewrite(
   for (unsigned i = 1; i < exponentValue; ++i)
     result = rewriter.create<MulOpTy>(loc, result, base);
 
+  // Inverse the base for negative exponent, i.e. for
+  // `[fi]powi(x, negative_exponent)` set `x` to `1 / x`.
+  if (exponentIsNegative)
+    result = rewriter.create<DivOpTy>(loc, bcast(one), result);
+
   rewriter.replaceOp(op, result);
   return success();
 }
diff --git a/mlir/test/Dialect/Math/algebraic-simplification.mlir b/mlir/test/Dialect/Math/algebraic-simplification.mlir
index a97ecc52a17e9..e0e2b9853a2a1 100644
--- a/mlir/test/Dialect/Math/algebraic-simplification.mlir
+++ b/mlir/test/Dialect/Math/algebraic-simplification.mlir
@@ -135,11 +135,11 @@ func.func @ipowi_exp_two(%arg0: i32, %arg1: vector<4xi32>) -> (i32, vector<4xi32
   // CHECK-DAG: %[[CST_V:.*]] = arith.constant dense<1> : vector<4xi32>
   // CHECK: %[[SCALAR0:.*]] = arith.muli %[[ARG0]], %[[ARG0]]
   // CHECK: %[[VECTOR0:.*]] = arith.muli %[[ARG1]], %[[ARG1]]
-  // CHECK: %[[SCALAR1:.*]] = arith.divsi %[[CST_S]], %[[ARG0]]
-  // CHECK: %[[SMUL:.*]] = arith.muli %[[SCALAR1]], %[[SCALAR1]]
-  // CHECK: %[[VECTOR1:.*]] = arith.divsi %[[CST_V]], %[[ARG1]]
-  // CHECK: %[[VMUL:.*]] = arith.muli %[[VECTOR1]], %[[VECTOR1]]
-  // CHECK: return %[[SCALAR0]], %[[VECTOR0]], %[[SMUL]], %[[VMUL]]
+  // CHECK: %[[SMUL:.*]] = arith.muli %[[ARG0]], %[[ARG0]]
+  // CHECK: %[[SCALAR1:.*]] = arith.divsi %[[CST_S]], %[[SMUL]]
+  // CHECK: %[[VMUL:.*]] = arith.muli %[[ARG1]], %[[ARG1]]
+  // CHECK: %[[VECTOR1:.*]] = arith.divsi %[[CST_V]], %[[VMUL]]
+  // CHECK: return %[[SCALAR0]], %[[VECTOR0]], %[[SCALAR1]], %[[VECTOR1]]
   %c1 = arith.constant 2 : i32
   %v1 = arith.constant dense <2> : vector<4xi32>
   %0 = math.ipowi %arg0, %c1 : i32
@@ -162,13 +162,13 @@ func.func @ipowi_exp_three(%arg0: i32, %arg1: vector<4xi32>) -> (i32, vector<4xi
   // CHECK: %[[SCALAR0:.*]] = arith.muli %[[SMUL0]], %[[ARG0]]
   // CHECK: %[[VMUL0:.*]] = arith.muli %[[ARG1]], %[[ARG1]]
   // CHECK: %[[VECTOR0:.*]] = arith.muli %[[VMUL0]], %[[ARG1]]
-  // CHECK: %[[SCALAR1:.*]] = arith.divsi %[[CST_S]], %[[ARG0]]
-  // CHECK: %[[SMUL1:.*]] = arith.muli %[[SCALAR1]], %[[SCALAR1]]
-  // CHECK: %[[SMUL2:.*]] = arith.muli %[[SMUL1]], %[[SCALAR1]]
-  // CHECK: %[[VECTOR1:.*]] = arith.divsi %[[CST_V]], %[[ARG1]]
-  // CHECK: %[[VMUL1:.*]] = arith.muli %[[VECTOR1]], %[[VECTOR1]]
-  // CHECK: %[[VMUL2:.*]] = arith.muli %[[VMUL1]], %[[VECTOR1]]
-  // CHECK: return %[[SCALAR0]], %[[VECTOR0]], %[[SMUL2]], %[[VMUL2]]
+  // CHECK: %[[SMUL1:.*]] = arith.muli %[[ARG0]], %[[ARG0]]
+  // CHECK: %[[SMUL2:.*]] = arith.muli %[[SMUL1]], %[[ARG0]]
+  // CHECK: %[[SCALAR1:.*]] = arith.divsi %[[CST_S]], %[[SMUL2]]
+  // CHECK: %[[VMUL1:.*]] = arith.muli %[[ARG1]], %[[ARG1]]
+  // CHECK: %[[VMUL2:.*]] = arith.muli %[[VMUL1]], %[[ARG1]]
+  // CHECK: %[[VECTOR1:.*]] = arith.divsi %[[CST_V]], %[[VMUL2]]
+  // CHECK: return %[[SCALAR0]], %[[VECTOR0]], %[[SCALAR1]], %[[VECTOR1]]
   %c1 = arith.constant 3 : i32
   %v1 = arith.constant dense <3> : vector<4xi32>
   %0 = math.ipowi %arg0, %c1 : i32
@@ -225,11 +225,11 @@ func.func @fpowi_exp_two(%arg0: f32, %arg1: vector<4xf32>) -> (f32, vector<4xf32
   // CHECK-DAG: %[[CST_V:.*]] = arith.constant dense<1.000000e+00> : vector<4xf32>
   // CHECK: %[[SCALAR0:.*]] = arith.mulf %[[ARG0]], %[[ARG0]]
   // CHECK: %[[VECTOR0:.*]] = arith.mulf %[[ARG1]], %[[ARG1]]
-  // CHECK: %[[SCALAR1:.*]] = arith.divf %[[CST_S]], %[[ARG0]]
-  // CHECK: %[[SMUL:.*]] = arith.mulf %[[SCALAR1]], %[[SCALAR1]]
-  // CHECK: %[[VECTOR1:.*]] = arith.divf %[[CST_V]], %[[ARG1]]
-  // CHECK: %[[VMUL:.*]] = arith.mulf %[[VECTOR1]], %[[VECTOR1]]
-  // CHECK: return %[[SCALAR0]], %[[VECTOR0]], %[[SMUL]], %[[VMUL]]
+  // CHECK: %[[SMUL:.*]] = arith.mulf %[[ARG0]], %[[ARG0]]
+  // CHECK: %[[SCALAR1:.*]] = arith.divf %[[CST_S]], %[[SMUL]]
+  // CHECK: %[[VMUL:.*]] = arith.mulf %[[ARG1]], %[[ARG1]]
+  // CHECK: %[[VECTOR1:.*]] = arith.divf %[[CST_V]], %[[VMUL]]
+  // CHECK: return %[[SCALAR0]], %[[VECTOR0]], %[[SCALAR1]], %[[VECTOR1]]
   %c1 = arith.constant 2 : i32
   %v1 = arith.constant dense <2> : vector<4xi32>
   %0 = math.fpowi %arg0, %c1 : f32, i32
@@ -252,13 +252,13 @@ func.func @fpowi_exp_three(%arg0: f32, %arg1: vector<4xf32>) -> (f32, vector<4xf
   // CHECK: %[[SCALAR0:.*]] = arith.mulf %[[SMUL0]], %[[ARG0]]
   // CHECK: %[[VMUL0:.*]] = arith.mulf %[[ARG1]], %[[ARG1]]
   // CHECK: %[[VECTOR0:.*]] = arith.mulf %[[VMUL0]], %[[ARG1]]
-  // CHECK: %[[SCALAR1:.*]] = arith.divf %[[CST_S]], %[[ARG0]]
-  // CHECK: %[[SMUL1:.*]] = arith.mulf %[[SCALAR1]], %[[SCALAR1]]
-  // CHECK: %[[SMUL2:.*]] = arith.mulf %[[SMUL1]], %[[SCALAR1]]
-  // CHECK: %[[VECTOR1:.*]] = arith.divf %[[CST_V]], %[[ARG1]]
-  // CHECK: %[[VMUL1:.*]] = arith.mulf %[[VECTOR1]], %[[VECTOR1]]
-  // CHECK: %[[VMUL2:.*]] = arith.mulf %[[VMUL1]], %[[VECTOR1]]
-  // CHECK: return %[[SCALAR0]], %[[VECTOR0]], %[[SMUL2]], %[[VMUL2]]
+  // CHECK: %[[SMUL1:.*]] = arith.mulf %[[ARG0]], %[[ARG0]]
+  // CHECK: %[[SMUL2:.*]] = arith.mulf %[[SMUL1]], %[[ARG0]]
+  // CHECK: %[[SCALAR1:.*]] = arith.divf %[[CST_S]], %[[SMUL2]]
+  // CHECK: %[[VMUL1:.*]] = arith.mulf %[[ARG1]], %[[ARG1]]
+  // CHECK: %[[VMUL2:.*]] = arith.mulf %[[VMUL1]], %[[ARG1]]
+  // CHECK: %[[VECTOR1:.*]] = arith.divf %[[CST_V]], %[[VMUL2]]
+  // CHECK: return %[[SCALAR0]], %[[VECTOR0]], %[[SCALAR1]], %[[VECTOR1]]
   %c1 = arith.constant 3 : i32
   %v1 = arith.constant dense <3> : vector<4xi32>
   %0 = math.fpowi %arg0, %c1 : f32, i32



More information about the Mlir-commits mailing list