[Mlir-commits] [mlir] 6b53881 - [mlir][math] Add math.cbrt polynomial approximation
Rob Suderman
llvmlistbot at llvm.org
Mon Mar 6 13:31:01 PST 2023
Author: Robert Suderman
Date: 2023-03-06T13:29:49-08:00
New Revision: 6b5388104803262fedc783ad09d4b4fdfcc3646f
URL: https://github.com/llvm/llvm-project/commit/6b5388104803262fedc783ad09d4b4fdfcc3646f
DIFF: https://github.com/llvm/llvm-project/commit/6b5388104803262fedc783ad09d4b4fdfcc3646f.diff
LOG: [mlir][math] Add math.cbrt polynomial approximation
Cbrt can be approximated with some relatively simple polynomial
operators. This includes a lit test validating the implementation
and some run tests that validate numerical correct.
Reviewed By: jpienaar
Differential Revision: https://reviews.llvm.org/D145019
Added:
Modified:
mlir/lib/Dialect/Math/Transforms/PolynomialApproximation.cpp
mlir/test/Dialect/Math/polynomial-approximation.mlir
mlir/test/mlir-cpu-runner/math-polynomial-approx.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/Math/Transforms/PolynomialApproximation.cpp b/mlir/lib/Dialect/Math/Transforms/PolynomialApproximation.cpp
index c0f30283ca8e6..0d170f985fc30 100644
--- a/mlir/lib/Dialect/Math/Transforms/PolynomialApproximation.cpp
+++ b/mlir/lib/Dialect/Math/Transforms/PolynomialApproximation.cpp
@@ -1212,6 +1212,99 @@ LogicalResult SinAndCosApproximation<isSine, OpTy>::matchAndRewrite(
return success();
}
+//----------------------------------------------------------------------------//
+// Cbrt approximation.
+//----------------------------------------------------------------------------//
+
+namespace {
+struct CbrtApproximation : public OpRewritePattern<math::CbrtOp> {
+ using OpRewritePattern::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(math::CbrtOp op,
+ PatternRewriter &rewriter) const final;
+};
+} // namespace
+
+// Estimation of cube-root using an algorithm defined in
+// Hacker's Delight 2nd Edition.
+LogicalResult
+CbrtApproximation::matchAndRewrite(math::CbrtOp op,
+ PatternRewriter &rewriter) const {
+ auto operand = op.getOperand();
+ if (!getElementTypeOrSelf(operand).isF32())
+ return rewriter.notifyMatchFailure(op, "unsupported operand type");
+
+ ImplicitLocOpBuilder b(op->getLoc(), rewriter);
+ ArrayRef<int64_t> shape = vectorShape(operand);
+
+ Type floatTy = getElementTypeOrSelf(operand.getType());
+ Type intTy = b.getIntegerType(floatTy.getIntOrFloatBitWidth());
+
+ // Convert to vector types if necessary.
+ floatTy = broadcast(floatTy, shape);
+ intTy = broadcast(intTy, shape);
+
+ auto bconst = [&](Attribute attr) -> Value {
+ Value value = b.create<arith::ConstantOp>(attr);
+ return broadcast(b, value, shape);
+ };
+
+ // Declare the initial values:
+ Value intTwo = bconst(b.getI32IntegerAttr(2));
+ Value intFour = bconst(b.getI32IntegerAttr(4));
+ Value intEight = bconst(b.getI32IntegerAttr(8));
+ Value intMagic = bconst(b.getI32IntegerAttr(0x2a5137a0));
+ Value fpThird = bconst(b.getF32FloatAttr(0.33333333f));
+ Value fpTwo = bconst(b.getF32FloatAttr(2.0f));
+ Value fpZero = bconst(b.getF32FloatAttr(0.0f));
+
+ // Compute an approximation of one third:
+ // union {int ix; float x;};
+ // x = x0;
+ // ix = ix/4 + ix/16;
+ Value absValue = b.create<math::AbsFOp>(operand);
+ Value intValue = b.create<arith::BitcastOp>(intTy, absValue);
+ Value divideBy4 = b.create<arith::ShRSIOp>(intValue, intTwo);
+ Value divideBy16 = b.create<arith::ShRSIOp>(intValue, intFour);
+ intValue = b.create<arith::AddIOp>(divideBy4, divideBy16);
+
+ // ix = ix + ix/16;
+ divideBy16 = b.create<arith::ShRSIOp>(intValue, intFour);
+ intValue = b.create<arith::AddIOp>(intValue, divideBy16);
+
+ // ix = ix + ix/256;
+ Value divideBy256 = b.create<arith::ShRSIOp>(intValue, intEight);
+ intValue = b.create<arith::AddIOp>(intValue, divideBy256);
+
+ // ix = 0x2a5137a0 + ix;
+ intValue = b.create<arith::AddIOp>(intValue, intMagic);
+
+ // Perform one newtons step:
+ // x = 0.33333333f*(2.0f*x + x0/(x*x));
+ Value floatValue = b.create<arith::BitcastOp>(floatTy, intValue);
+ Value squared = b.create<arith::MulFOp>(floatValue, floatValue);
+ Value mulTwo = b.create<arith::MulFOp>(floatValue, fpTwo);
+ Value divSquared = b.create<arith::DivFOp>(absValue, squared);
+ floatValue = b.create<arith::AddFOp>(mulTwo, divSquared);
+ floatValue = b.create<arith::MulFOp>(floatValue, fpThird);
+
+ // x = 0.33333333f*(2.0f*x + x0/(x*x));
+ squared = b.create<arith::MulFOp>(floatValue, floatValue);
+ mulTwo = b.create<arith::MulFOp>(floatValue, fpTwo);
+ divSquared = b.create<arith::DivFOp>(absValue, squared);
+ floatValue = b.create<arith::AddFOp>(mulTwo, divSquared);
+ floatValue = b.create<arith::MulFOp>(floatValue, fpThird);
+
+ // Check for zero and restore sign.
+ Value isZero =
+ b.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, absValue, fpZero);
+ floatValue = b.create<arith::SelectOp>(isZero, fpZero, floatValue);
+ floatValue = b.create<math::CopySignOp>(floatValue, operand);
+
+ rewriter.replaceOp(op, floatValue);
+ return success();
+}
+
//----------------------------------------------------------------------------//
// Rsqrt approximation.
//----------------------------------------------------------------------------//
@@ -1291,7 +1384,7 @@ void mlir::populateMathPolynomialApproximationPatterns(
patterns.add<AtanApproximation, Atan2Approximation, TanhApproximation,
LogApproximation, Log2Approximation, Log1pApproximation,
ErfPolynomialApproximation, ExpApproximation, ExpM1Approximation,
- ReuseF32Expansion<math::Atan2Op>,
+ CbrtApproximation, ReuseF32Expansion<math::Atan2Op>,
SinAndCosApproximation<true, math::SinOp>,
SinAndCosApproximation<false, math::CosOp>>(
patterns.getContext());
diff --git a/mlir/test/Dialect/Math/polynomial-approximation.mlir b/mlir/test/Dialect/Math/polynomial-approximation.mlir
index 33ac11b29b9c6..4b490e4ea990c 100644
--- a/mlir/test/Dialect/Math/polynomial-approximation.mlir
+++ b/mlir/test/Dialect/Math/polynomial-approximation.mlir
@@ -593,3 +593,53 @@ func.func @atan2_scalar(%arg0: f16, %arg1: f16) -> f16 {
%0 = math.atan2 %arg0, %arg1 : f16
return %0 : f16
}
+
+// CHECK-LABEL: @cbrt_vector
+// CHECK-SAME: %[[ARG0:.+]]: vector<4xf32>
+
+// CHECK: %[[TWO_INT:.+]] = arith.constant dense<2>
+// CHECK: %[[FOUR_INT:.+]] = arith.constant dense<4>
+// CHECK: %[[EIGHT_INT:.+]] = arith.constant dense<8>
+// CHECK: %[[MAGIC:.+]] = arith.constant dense<709965728>
+// CHECK: %[[THIRD_FP:.+]] = arith.constant dense<0.333333343> : vector<4xf32>
+// CHECK: %[[TWO_FP:.+]] = arith.constant dense<2.000000e+00> : vector<4xf32>
+// CHECK: %[[ZERO_FP:.+]] = arith.constant dense<0.000000e+00> : vector<4xf32>
+
+// CHECK: %[[ABS:.+]] = math.absf %[[ARG0]] : vector<4xf32>
+
+// Perform the initial approximation:
+// CHECK: %[[CAST:.+]] = arith.bitcast %[[ABS]] : vector<4xf32> to vector<4xi32>
+// CHECK: %[[SH_TWO:.+]] = arith.shrsi %[[CAST]], %[[TWO_INT]]
+// CHECK: %[[SH_FOUR:.+]] = arith.shrsi %[[CAST]], %[[FOUR_INT]]
+// CHECK: %[[APPROX0:.+]] = arith.addi %[[SH_TWO]], %[[SH_FOUR]]
+// CHECK: %[[SH_FOUR:.+]] = arith.shrsi %[[APPROX0]], %[[FOUR_INT]]
+// CHECK: %[[APPROX1:.+]] = arith.addi %[[APPROX0]], %[[SH_FOUR]]
+// CHECK: %[[SH_EIGHT:.+]] = arith.shrsi %[[APPROX1]], %[[EIGHT_INT]]
+// CHECK: %[[APPROX2:.+]] = arith.addi %[[APPROX1]], %[[SH_EIGHT]]
+// CHECK: %[[FIX:.+]] = arith.addi %[[APPROX2]], %[[MAGIC]]
+// CHECK: %[[BCAST:.+]] = arith.bitcast %[[FIX]]
+
+// First Newton Step:
+// CHECK: %[[SQR:.+]] = arith.mulf %[[BCAST]], %[[BCAST]]
+// CHECK: %[[DOUBLE:.+]] = arith.mulf %[[BCAST]], %[[TWO_FP]]
+// CHECK: %[[DIV:.+]] = arith.divf %[[ABS]], %[[SQR]]
+// CHECK: %[[ADD:.+]] = arith.addf %[[DOUBLE]], %[[DIV]]
+// CHECK: %[[APPROX3:.+]] = arith.mulf %[[ADD]], %[[THIRD_FP]]
+
+// Second Newton Step:
+// CHECK: %[[SQR:.+]] = arith.mulf %[[APPROX3]], %[[APPROX3]]
+// CHECK: %[[DOUBLE:.+]] = arith.mulf %[[APPROX3]], %[[TWO_FP]]
+// CHECK: %[[DIV:.+]] = arith.divf %[[ABS]], %[[SQR]]
+// CHECK: %[[ADD:.+]] = arith.addf %[[DOUBLE]], %[[DIV]]
+// CHECK: %[[APPROX4:.+]] = arith.mulf %[[ADD]], %[[THIRD_FP]]
+
+// Check for zero special case and copy the sign:
+// CHECK: %[[CMP:.+]] = arith.cmpf oeq, %[[ABS]], %[[ZERO_FP]]
+// CHECK: %[[SEL:.+]] = arith.select %[[CMP]], %[[ZERO_FP]], %[[APPROX4]]
+// CHECK: %[[SIGN:.+]] = math.copysign %[[SEL]], %[[ARG0]]
+// CHECK: return %[[SIGN]]
+
+func.func @cbrt_vector(%arg0: vector<4xf32>) -> vector<4xf32> {
+ %0 = "math.cbrt"(%arg0) : (vector<4xf32>) -> vector<4xf32>
+ func.return %0 : vector<4xf32>
+}
\ No newline at end of file
diff --git a/mlir/test/mlir-cpu-runner/math-polynomial-approx.mlir b/mlir/test/mlir-cpu-runner/math-polynomial-approx.mlir
index dbd816639ede3..665d3280c0c46 100644
--- a/mlir/test/mlir-cpu-runner/math-polynomial-approx.mlir
+++ b/mlir/test/mlir-cpu-runner/math-polynomial-approx.mlir
@@ -568,6 +568,48 @@ func.func @atan2() {
}
+// -------------------------------------------------------------------------- //
+// Cbrt.
+// -------------------------------------------------------------------------- //
+
+func.func @cbrt_f32(%a : f32) {
+ %r = math.cbrt %a : f32
+ vector.print %r : f32
+ return
+}
+
+func.func @cbrt() {
+ // CHECK: 1
+ %a = arith.constant 1.0 : f32
+ call @cbrt_f32(%a) : (f32) -> ()
+
+ // CHECK: -1
+ %b = arith.constant -1.0 : f32
+ call @cbrt_f32(%b) : (f32) -> ()
+
+ // CHECK: 0
+ %c = arith.constant 0.0 : f32
+ call @cbrt_f32(%c) : (f32) -> ()
+
+ // CHECK: -0
+ %d = arith.constant -0.0 : f32
+ call @cbrt_f32(%d) : (f32) -> ()
+
+ // CHECK: 10
+ %e = arith.constant 1000.0 : f32
+ call @cbrt_f32(%e) : (f32) -> ()
+
+ // CHECK: -10
+ %f = arith.constant -1000.0 : f32
+ call @cbrt_f32(%f) : (f32) -> ()
+
+ // CHECK: 2.57128
+ %g = arith.constant 17.0 : f32
+ call @cbrt_f32(%g) : (f32) -> ()
+
+ return
+}
+
func.func @main() {
call @tanh(): () -> ()
call @log(): () -> ()
@@ -580,5 +622,8 @@ func.func @main() {
call @cos(): () -> ()
call @atan() : () -> ()
call @atan2() : () -> ()
+ call @cbrt() : () -> ()
return
}
+
+
More information about the Mlir-commits
mailing list