[Mlir-commits] [mlir] [math] lower rsqrt to sqrt + fdiv (PR #91344)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Tue May 7 08:25:03 PDT 2024


================
@@ -615,6 +615,24 @@ static LogicalResult convertRoundEvenOp(math::RoundEvenOp op,
   return success();
 }
 
+// Convert `math.rsqrt` into `arith.divf` + `math.sqrt`
+static LogicalResult convertRsqrtOp(math::RsqrtOp op,
+                                    PatternRewriter &rewriter) {
+
+  auto operand = op.getOperand();
+  auto operandTy = operand.getType();
+  auto eTy = getElementTypeOrSelf(operandTy);
+  if (!isa<FloatType>(eTy))
+    return failure();
+
+  Location loc = op->getLoc();
+  auto constOneFloat = createFloatConst(loc, operandTy, 1.0, rewriter);
+  auto sqrtOp = rewriter.create<math::SqrtOp>(loc, op->getOperand(0));
----------------
srcarroll wrote:

nit: it is preferred to use tablegen designated getter functions for operands and results (otherwise you need to check that you can even index `getOperand(0)`.). so can do `op.getOperand()` instead (see https://github.com/llvm/llvm-project/blob/main/mlir/include/mlir/Dialect/Math/IR/MathOps.td#L44)

https://github.com/llvm/llvm-project/pull/91344


More information about the Mlir-commits mailing list