[Mlir-commits] [mlir] [mlir][[spirv] Add support for math.log2 and math.log10 to GLSL/OpenCL SPIRV Backends (PR #104608)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Fri Aug 16 18:59:53 PDT 2024
https://github.com/meehatpa updated https://github.com/llvm/llvm-project/pull/104608
>From 417622bd1429a7ffd2aea6c2887080a8007f6d9e Mon Sep 17 00:00:00 2001
From: meehatpa <gune30 at gmail.com>
Date: Fri, 16 Aug 2024 15:30:21 +0000
Subject: [PATCH] [mlir][[spirv] Add support for math.log2 and math.log10 to
GLSL/OpenCL SPIRV Backends
As log2 and log10 are not available in spirv, realize them as a decomposition using
spirv.CL.log/spirv.GL.Log.
---
.../Conversion/MathToSPIRV/MathToSPIRV.cpp | 63 +++++++++++++++++++
.../MathToSPIRV/math-to-gl-spirv.mlir | 42 +++++++++----
.../MathToSPIRV/math-to-opencl-spirv.mlir | 46 +++++++++-----
3 files changed, 123 insertions(+), 28 deletions(-)
diff --git a/mlir/lib/Conversion/MathToSPIRV/MathToSPIRV.cpp b/mlir/lib/Conversion/MathToSPIRV/MathToSPIRV.cpp
index 5b3c2fb15e7026..52ff138bedf65b 100644
--- a/mlir/lib/Conversion/MathToSPIRV/MathToSPIRV.cpp
+++ b/mlir/lib/Conversion/MathToSPIRV/MathToSPIRV.cpp
@@ -291,6 +291,65 @@ struct Log1pOpPattern final : public OpConversionPattern<math::Log1pOp> {
}
};
+/// Converts math.log2 and math.log10 to SPIR-V ops.
+///
+/// SPIR-V does not have direct operations for log2 and log10. Explicitly
+/// lower to these operations using:
+/// log2(x) = log(x) * 1/log(2)
+/// log10(x) = log(x) * 1/log(10)
+
+template <typename MathLogOp, typename SpirvLogOp>
+struct Log2Log10OpPattern final : public OpConversionPattern<MathLogOp> {
+ using OpConversionPattern<MathLogOp>::OpConversionPattern;
+ using typename OpConversionPattern<MathLogOp>::OpAdaptor;
+
+ static constexpr double log2Reciprocal =
+ 1.442695040888963407359924681001892137426645954152985934135449407;
+ static constexpr double log10Reciprocal =
+ 0.4342944819032518276511289189166050822943970058036665661144537832;
+
+ LogicalResult
+ matchAndRewrite(MathLogOp operation, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ assert(adaptor.getOperands().size() == 1);
+ if (LogicalResult res = checkSourceOpTypes(rewriter, operation);
+ failed(res))
+ return res;
+
+ Location loc = operation.getLoc();
+ Type type = this->getTypeConverter()->convertType(operation.getType());
+ if (!type)
+ return rewriter.notifyMatchFailure(operation, "type conversion failed");
+
+ auto getConstantValue = [&](double value) {
+ if (auto floatType = dyn_cast<FloatType>(type)) {
+ return rewriter.create<spirv::ConstantOp>(
+ loc, type, rewriter.getFloatAttr(floatType, value));
+ }
+ if (auto vectorType = dyn_cast<VectorType>(type)) {
+ Type elemType = vectorType.getElementType();
+
+ if (isa<FloatType>(elemType)) {
+ return rewriter.create<spirv::ConstantOp>(
+ loc, type,
+ DenseFPElementsAttr::get(
+ vectorType, FloatAttr::get(elemType, value).getValue()));
+ }
+ }
+
+ llvm_unreachable("unimplemented types for log2/log10");
+ };
+
+ Value constantValue = getConstantValue(
+ std::is_same<MathLogOp, math::Log2Op>() ? log2Reciprocal
+ : log10Reciprocal);
+ Value log = rewriter.create<SpirvLogOp>(loc, adaptor.getOperand());
+ rewriter.replaceOpWithNewOp<spirv::FMulOp>(operation, type, log,
+ constantValue);
+ return success();
+ }
+};
+
/// Converts math.powf to SPIRV-Ops.
struct PowFOpPattern final : public OpConversionPattern<math::PowFOp> {
using OpConversionPattern::OpConversionPattern;
@@ -411,6 +470,8 @@ void populateMathToSPIRVPatterns(SPIRVTypeConverter &typeConverter,
// GLSL patterns
patterns
.add<CountLeadingZerosPattern, Log1pOpPattern<spirv::GLLogOp>,
+ Log2Log10OpPattern<math::Log2Op, spirv::GLLogOp>,
+ Log2Log10OpPattern<math::Log10Op, spirv::GLLogOp>,
ExpM1OpPattern<spirv::GLExpOp>, PowFOpPattern, RoundOpPattern,
CheckedElementwiseOpPattern<math::AbsFOp, spirv::GLFAbsOp>,
CheckedElementwiseOpPattern<math::AbsIOp, spirv::GLSAbsOp>,
@@ -430,6 +491,8 @@ void populateMathToSPIRVPatterns(SPIRVTypeConverter &typeConverter,
// OpenCL patterns
patterns.add<Log1pOpPattern<spirv::CLLogOp>, ExpM1OpPattern<spirv::CLExpOp>,
+ Log2Log10OpPattern<math::Log2Op, spirv::CLLogOp>,
+ Log2Log10OpPattern<math::Log10Op, spirv::CLLogOp>,
CheckedElementwiseOpPattern<math::AbsFOp, spirv::CLFAbsOp>,
CheckedElementwiseOpPattern<math::AbsIOp, spirv::CLSAbsOp>,
CheckedElementwiseOpPattern<math::AtanOp, spirv::CLAtanOp>,
diff --git a/mlir/test/Conversion/MathToSPIRV/math-to-gl-spirv.mlir b/mlir/test/Conversion/MathToSPIRV/math-to-gl-spirv.mlir
index a9397667393429..ecbd59e54971ef 100644
--- a/mlir/test/Conversion/MathToSPIRV/math-to-gl-spirv.mlir
+++ b/mlir/test/Conversion/MathToSPIRV/math-to-gl-spirv.mlir
@@ -22,22 +22,30 @@ func.func @float32_unary_scalar(%arg0: f32) {
// CHECK: %[[ADDONE:.+]] = spirv.FAdd %[[ONE]], %{{.+}}
// CHECK: spirv.GL.Log %[[ADDONE]]
%5 = math.log1p %arg0 : f32
+ // CHECK: %[[LOG2_RECIPROCAL:.+]] = spirv.Constant 1.44269502 : f32
+ // CHECK: %[[LOG0:.+]] = spirv.GL.Log {{.+}}
+ // CHECK: spirv.FMul %[[LOG0]], %[[LOG2_RECIPROCAL]]
+ %6 = math.log2 %arg0 : f32
+ // CHECK: %[[LOG10_RECIPROCAL:.+]] = spirv.Constant 0.434294492 : f32
+ // CHECK: %[[LOG1:.+]] = spirv.GL.Log {{.+}}
+ // CHECK: spirv.FMul %[[LOG1]], %[[LOG10_RECIPROCAL]]
+ %7 = math.log10 %arg0 : f32
// CHECK: spirv.GL.RoundEven %{{.*}}: f32
- %6 = math.roundeven %arg0 : f32
+ %8 = math.roundeven %arg0 : f32
// CHECK: spirv.GL.InverseSqrt %{{.*}}: f32
- %7 = math.rsqrt %arg0 : f32
+ %9 = math.rsqrt %arg0 : f32
// CHECK: spirv.GL.Sqrt %{{.*}}: f32
- %8 = math.sqrt %arg0 : f32
+ %10 = math.sqrt %arg0 : f32
// CHECK: spirv.GL.Tanh %{{.*}}: f32
- %9 = math.tanh %arg0 : f32
+ %11 = math.tanh %arg0 : f32
// CHECK: spirv.GL.Sin %{{.*}}: f32
- %10 = math.sin %arg0 : f32
+ %12 = math.sin %arg0 : f32
// CHECK: spirv.GL.FAbs %{{.*}}: f32
- %11 = math.absf %arg0 : f32
+ %13 = math.absf %arg0 : f32
// CHECK: spirv.GL.Ceil %{{.*}}: f32
- %12 = math.ceil %arg0 : f32
+ %14 = math.ceil %arg0 : f32
// CHECK: spirv.GL.Floor %{{.*}}: f32
- %13 = math.floor %arg0 : f32
+ %15 = math.floor %arg0 : f32
return
}
@@ -59,16 +67,24 @@ func.func @float32_unary_vector(%arg0: vector<3xf32>) {
// CHECK: %[[ADDONE:.+]] = spirv.FAdd %[[ONE]], %{{.+}}
// CHECK: spirv.GL.Log %[[ADDONE]]
%5 = math.log1p %arg0 : vector<3xf32>
+ // CHECK: %[[LOG2_RECIPROCAL:.+]] = spirv.Constant dense<1.44269502> : vector<3xf32>
+ // CHECK: %[[LOG0:.+]] = spirv.GL.Log {{.+}}
+ // CHECK: spirv.FMul %[[LOG0]], %[[LOG2_RECIPROCAL]]
+ %6 = math.log2 %arg0 : vector<3xf32>
+ // CHECK: %[[LOG10_RECIPROCAL:.+]] = spirv.Constant dense<0.434294492> : vector<3xf32>
+ // CHECK: %[[LOG1:.+]] = spirv.GL.Log {{.+}}
+ // CHECK: spirv.FMul %[[LOG1]], %[[LOG10_RECIPROCAL]]
+ %7 = math.log10 %arg0 : vector<3xf32>
// CHECK: spirv.GL.RoundEven %{{.*}}: vector<3xf32>
- %6 = math.roundeven %arg0 : vector<3xf32>
+ %8 = math.roundeven %arg0 : vector<3xf32>
// CHECK: spirv.GL.InverseSqrt %{{.*}}: vector<3xf32>
- %7 = math.rsqrt %arg0 : vector<3xf32>
+ %9 = math.rsqrt %arg0 : vector<3xf32>
// CHECK: spirv.GL.Sqrt %{{.*}}: vector<3xf32>
- %8 = math.sqrt %arg0 : vector<3xf32>
+ %10 = math.sqrt %arg0 : vector<3xf32>
// CHECK: spirv.GL.Tanh %{{.*}}: vector<3xf32>
- %9 = math.tanh %arg0 : vector<3xf32>
+ %11 = math.tanh %arg0 : vector<3xf32>
// CHECK: spirv.GL.Sin %{{.*}}: vector<3xf32>
- %10 = math.sin %arg0 : vector<3xf32>
+ %12 = math.sin %arg0 : vector<3xf32>
return
}
diff --git a/mlir/test/Conversion/MathToSPIRV/math-to-opencl-spirv.mlir b/mlir/test/Conversion/MathToSPIRV/math-to-opencl-spirv.mlir
index e9ca838354c0de..393a910c1fb1d7 100644
--- a/mlir/test/Conversion/MathToSPIRV/math-to-opencl-spirv.mlir
+++ b/mlir/test/Conversion/MathToSPIRV/math-to-opencl-spirv.mlir
@@ -20,26 +20,34 @@ func.func @float32_unary_scalar(%arg0: f32) {
// CHECK: %[[ADDONE:.+]] = spirv.FAdd %[[ONE]], %{{.+}}
// CHECK: spirv.CL.log %[[ADDONE]]
%5 = math.log1p %arg0 : f32
+ // CHECK: %[[LOG2_RECIPROCAL:.+]] = spirv.Constant 1.44269502 : f32
+ // CHECK: %[[LOG0:.+]] = spirv.CL.log {{.+}}
+ // CHECK: spirv.FMul %[[LOG0]], %[[LOG2_RECIPROCAL]]
+ %6 = math.log2 %arg0 : f32
+ // CHECK: %[[LOG10_RECIPROCAL:.+]] = spirv.Constant 0.434294492 : f32
+ // CHECK: %[[LOG1:.+]] = spirv.CL.log {{.+}}
+ // CHECK: spirv.FMul %[[LOG1]], %[[LOG10_RECIPROCAL]]
+ %7 = math.log10 %arg0 : f32
// CHECK: spirv.CL.rint %{{.*}}: f32
- %6 = math.roundeven %arg0 : f32
+ %8 = math.roundeven %arg0 : f32
// CHECK: spirv.CL.rsqrt %{{.*}}: f32
- %7 = math.rsqrt %arg0 : f32
+ %9 = math.rsqrt %arg0 : f32
// CHECK: spirv.CL.sqrt %{{.*}}: f32
- %8 = math.sqrt %arg0 : f32
+ %10 = math.sqrt %arg0 : f32
// CHECK: spirv.CL.tanh %{{.*}}: f32
- %9 = math.tanh %arg0 : f32
+ %11 = math.tanh %arg0 : f32
// CHECK: spirv.CL.sin %{{.*}}: f32
- %10 = math.sin %arg0 : f32
+ %12 = math.sin %arg0 : f32
// CHECK: spirv.CL.fabs %{{.*}}: f32
- %11 = math.absf %arg0 : f32
+ %13 = math.absf %arg0 : f32
// CHECK: spirv.CL.ceil %{{.*}}: f32
- %12 = math.ceil %arg0 : f32
+ %14 = math.ceil %arg0 : f32
// CHECK: spirv.CL.floor %{{.*}}: f32
- %13 = math.floor %arg0 : f32
+ %15 = math.floor %arg0 : f32
// CHECK: spirv.CL.erf %{{.*}}: f32
- %14 = math.erf %arg0 : f32
+ %16 = math.erf %arg0 : f32
// CHECK: spirv.CL.round %{{.*}}: f32
- %15 = math.round %arg0 : f32
+ %17 = math.round %arg0 : f32
return
}
@@ -61,16 +69,24 @@ func.func @float32_unary_vector(%arg0: vector<3xf32>) {
// CHECK: %[[ADDONE:.+]] = spirv.FAdd %[[ONE]], %{{.+}}
// CHECK: spirv.CL.log %[[ADDONE]]
%5 = math.log1p %arg0 : vector<3xf32>
+ // CHECK: %[[LOG2_RECIPROCAL:.+]] = spirv.Constant dense<1.44269502> : vector<3xf32>
+ // CHECK: %[[LOG0:.+]] = spirv.CL.log {{.+}}
+ // CHECK: spirv.FMul %[[LOG0]], %[[LOG2_RECIPROCAL]]
+ %6 = math.log2 %arg0 : vector<3xf32>
+ // CHECK: %[[LOG10_RECIPROCAL:.+]] = spirv.Constant dense<0.434294492> : vector<3xf32>
+ // CHECK: %[[LOG1:.+]] = spirv.CL.log {{.+}}
+ // CHECK: spirv.FMul %[[LOG1]], %[[LOG10_RECIPROCAL]]
+ %7 = math.log10 %arg0 : vector<3xf32>
// CHECK: spirv.CL.rint %{{.*}}: vector<3xf32>
- %6 = math.roundeven %arg0 : vector<3xf32>
+ %8 = math.roundeven %arg0 : vector<3xf32>
// CHECK: spirv.CL.rsqrt %{{.*}}: vector<3xf32>
- %7 = math.rsqrt %arg0 : vector<3xf32>
+ %9 = math.rsqrt %arg0 : vector<3xf32>
// CHECK: spirv.CL.sqrt %{{.*}}: vector<3xf32>
- %8 = math.sqrt %arg0 : vector<3xf32>
+ %10 = math.sqrt %arg0 : vector<3xf32>
// CHECK: spirv.CL.tanh %{{.*}}: vector<3xf32>
- %9 = math.tanh %arg0 : vector<3xf32>
+ %11 = math.tanh %arg0 : vector<3xf32>
// CHECK: spirv.CL.sin %{{.*}}: vector<3xf32>
- %10 = math.sin %arg0 : vector<3xf32>
+ %12 = math.sin %arg0 : vector<3xf32>
return
}
More information about the Mlir-commits
mailing list