[Mlir-commits] [mlir] [mlir][[spirv] Add support for math.log2 and math.log10 to GLSL/OpenCL SPIRV Backends (PR #104608)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Fri Aug 16 18:57:51 PDT 2024


https://github.com/meehatpa updated https://github.com/llvm/llvm-project/pull/104608

>From 232077e04c7e3c18e66718845b1acfbc89c3441c Mon Sep 17 00:00:00 2001
From: meehatpa <gune30 at gmail.com>
Date: Fri, 16 Aug 2024 15:30:21 +0000
Subject: [PATCH] [mlir][[spirv] Add support for math.log2 and math.log10 to
 GLSL/OpenCL SPIRV Backends

As log2 and log10 are not available in spirv, realize them as a decomposition using
spirv.CL.log/spirv.GL.Log.
---
 .../Conversion/MathToSPIRV/MathToSPIRV.cpp    | 63 +++++++++++++++++++
 .../MathToSPIRV/math-to-gl-spirv.mlir         | 42 +++++++++----
 .../MathToSPIRV/math-to-opencl-spirv.mlir     | 46 +++++++++-----
 3 files changed, 123 insertions(+), 28 deletions(-)

diff --git a/mlir/lib/Conversion/MathToSPIRV/MathToSPIRV.cpp b/mlir/lib/Conversion/MathToSPIRV/MathToSPIRV.cpp
index 5b3c2fb15e7026..b123a22a805e33 100644
--- a/mlir/lib/Conversion/MathToSPIRV/MathToSPIRV.cpp
+++ b/mlir/lib/Conversion/MathToSPIRV/MathToSPIRV.cpp
@@ -291,6 +291,65 @@ struct Log1pOpPattern final : public OpConversionPattern<math::Log1pOp> {
   }
 };
 
+/// Converts math.log2 and math.log10 to SPIR-V ops.
+///
+/// SPIR-V does not have direct operations for log2 and log10. Explicitly
+/// lower to these operations using:
+///   log2(x) = log(x) * 1/log(2)
+///   log10(x) = log(x) * 1/log(10)
+
+template <typename MathLogOp, typename SpirvLogOp>
+struct Log2Log10OpPattern final : public OpConversionPattern<MathLogOp> {
+  using OpConversionPattern<MathLogOp>::OpConversionPattern;
+  using typename OpConversionPattern<MathLogOp>::OpAdaptor;
+
+  static constexpr double log2Reciprocal =
+      1.442695040888963407359924681001892137426645954152985934135449407;
+  static constexpr double log10Reciprocal =
+      0.4342944819032518276511289189166050822943970058036665661144537832;
+
+  LogicalResult
+  matchAndRewrite(MathLogOp operation, OpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    assert(adaptor.getOperands().size() == 1);
+    if (LogicalResult res = checkSourceOpTypes(rewriter, operation);
+        failed(res))
+      return res;
+
+    Location loc = operation.getLoc();
+    Type type = this->getTypeConverter()->convertType(operation.getType());
+    if (!type)
+      return rewriter.notifyMatchFailure(op, "type conversion failed");
+
+    auto getConstantValue = [&](double value) {
+      if (auto floatType = dyn_cast<FloatType>(type)) {
+        return rewriter.create<spirv::ConstantOp>(
+            loc, type, rewriter.getFloatAttr(floatType, value));
+      }
+      if (auto vectorType = dyn_cast<VectorType>(type)) {
+        Type elemType = vectorType.getElementType();
+
+        if (isa<FloatType>(elemType)) {
+          return rewriter.create<spirv::ConstantOp>(
+              loc, type,
+              DenseFPElementsAttr::get(
+                  vectorType, FloatAttr::get(elemType, value).getValue()));
+        }
+      }
+
+      llvm_unreachable("unimplemented types for log2/log10");
+    };
+
+    Value constantValue = getConstantValue(
+        std::is_same<MathLogOp, math::Log2Op>() ? log2Reciprocal
+                                                : log10Reciprocal);
+    Value log = rewriter.create<SpirvLogOp>(loc, adaptor.getOperand());
+    rewriter.replaceOpWithNewOp<spirv::FMulOp>(operation, type, log,
+                                               constantValue);
+    return success();
+  }
+};
+
 /// Converts math.powf to SPIRV-Ops.
 struct PowFOpPattern final : public OpConversionPattern<math::PowFOp> {
   using OpConversionPattern::OpConversionPattern;
@@ -411,6 +470,8 @@ void populateMathToSPIRVPatterns(SPIRVTypeConverter &typeConverter,
   // GLSL patterns
   patterns
       .add<CountLeadingZerosPattern, Log1pOpPattern<spirv::GLLogOp>,
+           Log2Log10OpPattern<math::Log2Op, spirv::GLLogOp>,
+           Log2Log10OpPattern<math::Log10Op, spirv::GLLogOp>,
            ExpM1OpPattern<spirv::GLExpOp>, PowFOpPattern, RoundOpPattern,
            CheckedElementwiseOpPattern<math::AbsFOp, spirv::GLFAbsOp>,
            CheckedElementwiseOpPattern<math::AbsIOp, spirv::GLSAbsOp>,
@@ -430,6 +491,8 @@ void populateMathToSPIRVPatterns(SPIRVTypeConverter &typeConverter,
 
   // OpenCL patterns
   patterns.add<Log1pOpPattern<spirv::CLLogOp>, ExpM1OpPattern<spirv::CLExpOp>,
+               Log2Log10OpPattern<math::Log2Op, spirv::CLLogOp>,
+               Log2Log10OpPattern<math::Log10Op, spirv::CLLogOp>,
                CheckedElementwiseOpPattern<math::AbsFOp, spirv::CLFAbsOp>,
                CheckedElementwiseOpPattern<math::AbsIOp, spirv::CLSAbsOp>,
                CheckedElementwiseOpPattern<math::AtanOp, spirv::CLAtanOp>,
diff --git a/mlir/test/Conversion/MathToSPIRV/math-to-gl-spirv.mlir b/mlir/test/Conversion/MathToSPIRV/math-to-gl-spirv.mlir
index a9397667393429..ecbd59e54971ef 100644
--- a/mlir/test/Conversion/MathToSPIRV/math-to-gl-spirv.mlir
+++ b/mlir/test/Conversion/MathToSPIRV/math-to-gl-spirv.mlir
@@ -22,22 +22,30 @@ func.func @float32_unary_scalar(%arg0: f32) {
   // CHECK: %[[ADDONE:.+]] = spirv.FAdd %[[ONE]], %{{.+}}
   // CHECK: spirv.GL.Log %[[ADDONE]]
   %5 = math.log1p %arg0 : f32
+  // CHECK: %[[LOG2_RECIPROCAL:.+]] = spirv.Constant 1.44269502 : f32
+  // CHECK: %[[LOG0:.+]] = spirv.GL.Log {{.+}}
+  // CHECK: spirv.FMul %[[LOG0]], %[[LOG2_RECIPROCAL]]
+  %6 = math.log2 %arg0 : f32
+  // CHECK: %[[LOG10_RECIPROCAL:.+]] = spirv.Constant 0.434294492 : f32
+  // CHECK: %[[LOG1:.+]] = spirv.GL.Log {{.+}}
+  // CHECK: spirv.FMul %[[LOG1]], %[[LOG10_RECIPROCAL]]
+  %7 = math.log10 %arg0 : f32
   // CHECK: spirv.GL.RoundEven %{{.*}}: f32
-  %6 = math.roundeven %arg0 : f32
+  %8 = math.roundeven %arg0 : f32
   // CHECK: spirv.GL.InverseSqrt %{{.*}}: f32
-  %7 = math.rsqrt %arg0 : f32
+  %9 = math.rsqrt %arg0 : f32
   // CHECK: spirv.GL.Sqrt %{{.*}}: f32
-  %8 = math.sqrt %arg0 : f32
+  %10 = math.sqrt %arg0 : f32
   // CHECK: spirv.GL.Tanh %{{.*}}: f32
-  %9 = math.tanh %arg0 : f32
+  %11 = math.tanh %arg0 : f32
   // CHECK: spirv.GL.Sin %{{.*}}: f32
-  %10 = math.sin %arg0 : f32
+  %12 = math.sin %arg0 : f32
   // CHECK: spirv.GL.FAbs %{{.*}}: f32
-  %11 = math.absf %arg0 : f32
+  %13 = math.absf %arg0 : f32
   // CHECK: spirv.GL.Ceil %{{.*}}: f32
-  %12 = math.ceil %arg0 : f32
+  %14 = math.ceil %arg0 : f32
   // CHECK: spirv.GL.Floor %{{.*}}: f32
-  %13 = math.floor %arg0 : f32
+  %15 = math.floor %arg0 : f32
   return
 }
 
@@ -59,16 +67,24 @@ func.func @float32_unary_vector(%arg0: vector<3xf32>) {
   // CHECK: %[[ADDONE:.+]] = spirv.FAdd %[[ONE]], %{{.+}}
   // CHECK: spirv.GL.Log %[[ADDONE]]
   %5 = math.log1p %arg0 : vector<3xf32>
+  // CHECK: %[[LOG2_RECIPROCAL:.+]] = spirv.Constant dense<1.44269502> : vector<3xf32>
+  // CHECK: %[[LOG0:.+]] = spirv.GL.Log {{.+}}
+  // CHECK: spirv.FMul %[[LOG0]], %[[LOG2_RECIPROCAL]]
+  %6 = math.log2 %arg0 : vector<3xf32>
+  // CHECK: %[[LOG10_RECIPROCAL:.+]] = spirv.Constant dense<0.434294492> : vector<3xf32>
+  // CHECK: %[[LOG1:.+]] = spirv.GL.Log {{.+}}
+  // CHECK: spirv.FMul %[[LOG1]], %[[LOG10_RECIPROCAL]]
+  %7 = math.log10 %arg0 : vector<3xf32>
   // CHECK: spirv.GL.RoundEven %{{.*}}: vector<3xf32>
-  %6 = math.roundeven %arg0 : vector<3xf32>
+  %8 = math.roundeven %arg0 : vector<3xf32>
   // CHECK: spirv.GL.InverseSqrt %{{.*}}: vector<3xf32>
-  %7 = math.rsqrt %arg0 : vector<3xf32>
+  %9 = math.rsqrt %arg0 : vector<3xf32>
   // CHECK: spirv.GL.Sqrt %{{.*}}: vector<3xf32>
-  %8 = math.sqrt %arg0 : vector<3xf32>
+  %10 = math.sqrt %arg0 : vector<3xf32>
   // CHECK: spirv.GL.Tanh %{{.*}}: vector<3xf32>
-  %9 = math.tanh %arg0 : vector<3xf32>
+  %11 = math.tanh %arg0 : vector<3xf32>
   // CHECK: spirv.GL.Sin %{{.*}}: vector<3xf32>
-  %10 = math.sin %arg0 : vector<3xf32>
+  %12 = math.sin %arg0 : vector<3xf32>
   return
 }
 
diff --git a/mlir/test/Conversion/MathToSPIRV/math-to-opencl-spirv.mlir b/mlir/test/Conversion/MathToSPIRV/math-to-opencl-spirv.mlir
index e9ca838354c0de..393a910c1fb1d7 100644
--- a/mlir/test/Conversion/MathToSPIRV/math-to-opencl-spirv.mlir
+++ b/mlir/test/Conversion/MathToSPIRV/math-to-opencl-spirv.mlir
@@ -20,26 +20,34 @@ func.func @float32_unary_scalar(%arg0: f32) {
   // CHECK: %[[ADDONE:.+]] = spirv.FAdd %[[ONE]], %{{.+}}
   // CHECK: spirv.CL.log %[[ADDONE]]
   %5 = math.log1p %arg0 : f32
+  // CHECK: %[[LOG2_RECIPROCAL:.+]] = spirv.Constant 1.44269502 : f32
+  // CHECK: %[[LOG0:.+]] = spirv.CL.log {{.+}}
+  // CHECK: spirv.FMul %[[LOG0]], %[[LOG2_RECIPROCAL]]
+  %6 = math.log2 %arg0 : f32
+  // CHECK: %[[LOG10_RECIPROCAL:.+]] = spirv.Constant 0.434294492 : f32
+  // CHECK: %[[LOG1:.+]] = spirv.CL.log {{.+}}
+  // CHECK: spirv.FMul %[[LOG1]], %[[LOG10_RECIPROCAL]]
+  %7 = math.log10 %arg0 : f32
   // CHECK: spirv.CL.rint %{{.*}}: f32
-  %6 = math.roundeven %arg0 : f32
+  %8 = math.roundeven %arg0 : f32
   // CHECK: spirv.CL.rsqrt %{{.*}}: f32
-  %7 = math.rsqrt %arg0 : f32
+  %9 = math.rsqrt %arg0 : f32
   // CHECK: spirv.CL.sqrt %{{.*}}: f32
-  %8 = math.sqrt %arg0 : f32
+  %10 = math.sqrt %arg0 : f32
   // CHECK: spirv.CL.tanh %{{.*}}: f32
-  %9 = math.tanh %arg0 : f32
+  %11 = math.tanh %arg0 : f32
   // CHECK: spirv.CL.sin %{{.*}}: f32
-  %10 = math.sin %arg0 : f32
+  %12 = math.sin %arg0 : f32
   // CHECK: spirv.CL.fabs %{{.*}}: f32
-  %11 = math.absf %arg0 : f32
+  %13 = math.absf %arg0 : f32
   // CHECK: spirv.CL.ceil %{{.*}}: f32
-  %12 = math.ceil %arg0 : f32
+  %14 = math.ceil %arg0 : f32
   // CHECK: spirv.CL.floor %{{.*}}: f32
-  %13 = math.floor %arg0 : f32
+  %15 = math.floor %arg0 : f32
   // CHECK: spirv.CL.erf %{{.*}}: f32
-  %14 = math.erf %arg0 : f32
+  %16 = math.erf %arg0 : f32
   // CHECK: spirv.CL.round %{{.*}}: f32
-  %15 = math.round %arg0 : f32
+  %17 = math.round %arg0 : f32
   return
 }
 
@@ -61,16 +69,24 @@ func.func @float32_unary_vector(%arg0: vector<3xf32>) {
   // CHECK: %[[ADDONE:.+]] = spirv.FAdd %[[ONE]], %{{.+}}
   // CHECK: spirv.CL.log %[[ADDONE]]
   %5 = math.log1p %arg0 : vector<3xf32>
+  // CHECK: %[[LOG2_RECIPROCAL:.+]] = spirv.Constant dense<1.44269502> : vector<3xf32>
+  // CHECK: %[[LOG0:.+]] = spirv.CL.log {{.+}}
+  // CHECK: spirv.FMul %[[LOG0]], %[[LOG2_RECIPROCAL]]
+  %6 = math.log2 %arg0 : vector<3xf32>
+  // CHECK: %[[LOG10_RECIPROCAL:.+]] = spirv.Constant dense<0.434294492> : vector<3xf32>
+  // CHECK: %[[LOG1:.+]] = spirv.CL.log {{.+}}
+  // CHECK: spirv.FMul %[[LOG1]], %[[LOG10_RECIPROCAL]]
+  %7 = math.log10 %arg0 : vector<3xf32>
   // CHECK: spirv.CL.rint %{{.*}}: vector<3xf32>
-  %6 = math.roundeven %arg0 : vector<3xf32>
+  %8 = math.roundeven %arg0 : vector<3xf32>
   // CHECK: spirv.CL.rsqrt %{{.*}}: vector<3xf32>
-  %7 = math.rsqrt %arg0 : vector<3xf32>
+  %9 = math.rsqrt %arg0 : vector<3xf32>
   // CHECK: spirv.CL.sqrt %{{.*}}: vector<3xf32>
-  %8 = math.sqrt %arg0 : vector<3xf32>
+  %10 = math.sqrt %arg0 : vector<3xf32>
   // CHECK: spirv.CL.tanh %{{.*}}: vector<3xf32>
-  %9 = math.tanh %arg0 : vector<3xf32>
+  %11 = math.tanh %arg0 : vector<3xf32>
   // CHECK: spirv.CL.sin %{{.*}}: vector<3xf32>
-  %10 = math.sin %arg0 : vector<3xf32>
+  %12 = math.sin %arg0 : vector<3xf32>
   return
 }
 



More information about the Mlir-commits mailing list