[Mlir-commits] [mlir] [mlir][[spirv] Add support for math.log2 and math.log10 to GLSL/OpenCL SPIRV Backends (PR #104608)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Fri Aug 16 19:02:06 PDT 2024


================
@@ -291,6 +291,65 @@ struct Log1pOpPattern final : public OpConversionPattern<math::Log1pOp> {
   }
 };
 
+/// Converts math.log2 and math.log10 to SPIR-V ops.
+///
+/// SPIR-V does not have direct operations for log2 and log10. Explicitly
+/// lower to these operations using:
+///   log2(x) = log(x) * 1/log(2)
+///   log10(x) = log(x) * 1/log(10)
+
+#define LOG2_RECIPROCAL                                                        \
+  1.442695040888963407359924681001892137426645954152985934135449407
+#define LOG10_RECIPROCAL                                                       \
+  0.4342944819032518276511289189166050822943970058036665661144537832
+
+template <typename MathLogOp, typename SpirvLogOp>
+struct Log2Log10OpPattern final : public OpConversionPattern<MathLogOp> {
+  using OpConversionPattern<MathLogOp>::OpConversionPattern;
+  using typename OpConversionPattern<MathLogOp>::OpAdaptor;
+
+  LogicalResult
+  matchAndRewrite(MathLogOp operation, OpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    assert(adaptor.getOperands().size() == 1);
+    if (LogicalResult res = checkSourceOpTypes(rewriter, operation);
+        failed(res))
+      return res;
+
+    Location loc = operation.getLoc();
+    Type type = this->getTypeConverter()->convertType(operation.getType());
+    if (!type)
+      return failure();
+
+    auto getConstantValue = [&](double value) {
+      if (auto floatType = llvm::dyn_cast<FloatType>(type)) {
+        return rewriter.create<spirv::ConstantOp>(
+            loc, type, rewriter.getFloatAttr(floatType, value));
+      }
+      if (auto vectorType = llvm::dyn_cast<VectorType>(type)) {
+        Type elemType = vectorType.getElementType();
+
+        if (llvm::isa<FloatType>(elemType)) {
+          return rewriter.create<spirv::ConstantOp>(
+              loc, type,
+              DenseFPElementsAttr::get(
+                  vectorType, FloatAttr::get(elemType, value).getValue()));
+        }
+      }
+
+      llvm_unreachable("unimplemented types for log2/log10");
+    };
+
+    auto constantValue = getConstantValue(
+        std::is_same<MathLogOp, math::Log2Op>() ? LOG2_RECIPROCAL
+                                                : LOG10_RECIPROCAL);
+    auto log = rewriter.create<SpirvLogOp>(loc, adaptor.getOperand());
----------------
meehatpa wrote:

Done

https://github.com/llvm/llvm-project/pull/104608


More information about the Mlir-commits mailing list