[Mlir-commits] [mlir] 3e746c6 - [mlir] Add support for ExpM1 to GLSL/OpenCL SPIRV Backends
Rob Suderman
llvmlistbot at llvm.org
Mon Jan 24 15:42:17 PST 2022
Author: Rob Suderman
Date: 2022-01-24T15:38:34-08:00
New Revision: 3e746c6d9ef0758c1e06901a99a75b638d6a5655
URL: https://github.com/llvm/llvm-project/commit/3e746c6d9ef0758c1e06901a99a75b638d6a5655
DIFF: https://github.com/llvm/llvm-project/commit/3e746c6d9ef0758c1e06901a99a75b638d6a5655.diff
LOG: [mlir] Add support for ExpM1 to GLSL/OpenCL SPIRV Backends
Adding a similar decomposition for exponential minus one to the SPIRV
backends along with the necessary tests.
Reviewed By: antiagainst
Differential Revision: https://reviews.llvm.org/D118081
Added:
Modified:
mlir/lib/Conversion/MathToSPIRV/MathToSPIRV.cpp
mlir/test/Conversion/MathToSPIRV/math-to-glsl-spirv.mlir
mlir/test/Conversion/MathToSPIRV/math-to-opencl-spirv.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Conversion/MathToSPIRV/MathToSPIRV.cpp b/mlir/lib/Conversion/MathToSPIRV/MathToSPIRV.cpp
index ec8402af03009..90588ed9bd5f0 100644
--- a/mlir/lib/Conversion/MathToSPIRV/MathToSPIRV.cpp
+++ b/mlir/lib/Conversion/MathToSPIRV/MathToSPIRV.cpp
@@ -30,6 +30,28 @@ using namespace mlir;
// normal RewritePattern.
namespace {
+/// Converts math.expm1 to SPIR-V ops.
+///
+/// SPIR-V does not have a direct operations for exp(x)-1. Explicitly lower to
+/// these operations.
+template <typename ExpOp>
+class ExpM1OpPattern final : public OpConversionPattern<math::ExpM1Op> {
+public:
+ using OpConversionPattern<math::ExpM1Op>::OpConversionPattern;
+
+ LogicalResult
+ matchAndRewrite(math::ExpM1Op operation, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ assert(adaptor.getOperands().size() == 1);
+ Location loc = operation.getLoc();
+ auto type = this->getTypeConverter()->convertType(operation.getType());
+ auto exp = rewriter.create<ExpOp>(loc, type, adaptor.getOperand());
+ auto one = spirv::ConstantOp::getOne(type, loc, rewriter);
+ rewriter.replaceOpWithNewOp<spirv::FSubOp>(operation, exp, one);
+ return success();
+ }
+};
+
/// Converts math.log1p to SPIR-V ops.
///
/// SPIR-V does not have a direct operations for log(1+x). Explicitly lower to
@@ -44,11 +66,10 @@ class Log1pOpPattern final : public OpConversionPattern<math::Log1pOp> {
ConversionPatternRewriter &rewriter) const override {
assert(adaptor.getOperands().size() == 1);
Location loc = operation.getLoc();
- auto type =
- this->getTypeConverter()->convertType(operation.getOperand().getType());
+ auto type = this->getTypeConverter()->convertType(operation.getType());
auto one = spirv::ConstantOp::getOne(type, operation.getLoc(), rewriter);
auto onePlus =
- rewriter.create<spirv::FAddOp>(loc, one, adaptor.getOperands()[0]);
+ rewriter.create<spirv::FAddOp>(loc, one, adaptor.getOperand());
rewriter.replaceOpWithNewOp<LogOp>(operation, type, onePlus);
return success();
}
@@ -65,7 +86,7 @@ void populateMathToSPIRVPatterns(SPIRVTypeConverter &typeConverter,
// GLSL patterns
patterns
- .add<Log1pOpPattern<spirv::GLSLLogOp>,
+ .add<Log1pOpPattern<spirv::GLSLLogOp>, ExpM1OpPattern<spirv::GLSLExpOp>,
spirv::ElementwiseOpPattern<math::AbsOp, spirv::GLSLFAbsOp>,
spirv::ElementwiseOpPattern<math::CeilOp, spirv::GLSLCeilOp>,
spirv::ElementwiseOpPattern<math::CosOp, spirv::GLSLCosOp>,
@@ -81,7 +102,7 @@ void populateMathToSPIRVPatterns(SPIRVTypeConverter &typeConverter,
typeConverter, patterns.getContext());
// OpenCL patterns
- patterns.add<Log1pOpPattern<spirv::OCLLogOp>,
+ patterns.add<Log1pOpPattern<spirv::OCLLogOp>, ExpM1OpPattern<spirv::OCLExpOp>,
spirv::ElementwiseOpPattern<math::AbsOp, spirv::OCLFAbsOp>,
spirv::ElementwiseOpPattern<math::CeilOp, spirv::OCLCeilOp>,
spirv::ElementwiseOpPattern<math::CosOp, spirv::OCLCosOp>,
diff --git a/mlir/test/Conversion/MathToSPIRV/math-to-glsl-spirv.mlir b/mlir/test/Conversion/MathToSPIRV/math-to-glsl-spirv.mlir
index f0e0b7e63fdce..c996e7056783a 100644
--- a/mlir/test/Conversion/MathToSPIRV/math-to-glsl-spirv.mlir
+++ b/mlir/test/Conversion/MathToSPIRV/math-to-glsl-spirv.mlir
@@ -8,26 +8,30 @@ func @float32_unary_scalar(%arg0: f32) {
%0 = math.cos %arg0 : f32
// CHECK: spv.GLSL.Exp %{{.*}}: f32
%1 = math.exp %arg0 : f32
+ // CHECK: %[[EXP:.+]] = spv.GLSL.Exp %arg0
+ // CHECK: %[[ONE:.+]] = spv.Constant 1.000000e+00 : f32
+ // CHECK: spv.FSub %[[EXP]], %[[ONE]]
+ %2 = math.expm1 %arg0 : f32
// CHECK: spv.GLSL.Log %{{.*}}: f32
- %2 = math.log %arg0 : f32
+ %3 = math.log %arg0 : f32
// CHECK: %[[ONE:.+]] = spv.Constant 1.000000e+00 : f32
// CHECK: %[[ADDONE:.+]] = spv.FAdd %[[ONE]], %{{.+}}
// CHECK: spv.GLSL.Log %[[ADDONE]]
- %3 = math.log1p %arg0 : f32
+ %4 = math.log1p %arg0 : f32
// CHECK: spv.GLSL.InverseSqrt %{{.*}}: f32
- %4 = math.rsqrt %arg0 : f32
+ %5 = math.rsqrt %arg0 : f32
// CHECK: spv.GLSL.Sqrt %{{.*}}: f32
- %5 = math.sqrt %arg0 : f32
+ %6 = math.sqrt %arg0 : f32
// CHECK: spv.GLSL.Tanh %{{.*}}: f32
- %6 = math.tanh %arg0 : f32
+ %7 = math.tanh %arg0 : f32
// CHECK: spv.GLSL.Sin %{{.*}}: f32
- %7 = math.sin %arg0 : f32
+ %8 = math.sin %arg0 : f32
// CHECK: spv.GLSL.FAbs %{{.*}}: f32
- %8 = math.abs %arg0 : f32
+ %9 = math.abs %arg0 : f32
// CHECK: spv.GLSL.Ceil %{{.*}}: f32
- %9 = math.ceil %arg0 : f32
+ %10 = math.ceil %arg0 : f32
// CHECK: spv.GLSL.Floor %{{.*}}: f32
- %10 = math.floor %arg0 : f32
+ %11 = math.floor %arg0 : f32
return
}
@@ -37,20 +41,24 @@ func @float32_unary_vector(%arg0: vector<3xf32>) {
%0 = math.cos %arg0 : vector<3xf32>
// CHECK: spv.GLSL.Exp %{{.*}}: vector<3xf32>
%1 = math.exp %arg0 : vector<3xf32>
+ // CHECK: %[[EXP:.+]] = spv.GLSL.Exp %arg0
+ // CHECK: %[[ONE:.+]] = spv.Constant dense<1.000000e+00> : vector<3xf32>
+ // CHECK: spv.FSub %[[EXP]], %[[ONE]]
+ %2 = math.expm1 %arg0 : vector<3xf32>
// CHECK: spv.GLSL.Log %{{.*}}: vector<3xf32>
- %2 = math.log %arg0 : vector<3xf32>
+ %3 = math.log %arg0 : vector<3xf32>
// CHECK: %[[ONE:.+]] = spv.Constant dense<1.000000e+00> : vector<3xf32>
// CHECK: %[[ADDONE:.+]] = spv.FAdd %[[ONE]], %{{.+}}
// CHECK: spv.GLSL.Log %[[ADDONE]]
- %3 = math.log1p %arg0 : vector<3xf32>
+ %4 = math.log1p %arg0 : vector<3xf32>
// CHECK: spv.GLSL.InverseSqrt %{{.*}}: vector<3xf32>
- %4 = math.rsqrt %arg0 : vector<3xf32>
+ %5 = math.rsqrt %arg0 : vector<3xf32>
// CHECK: spv.GLSL.Sqrt %{{.*}}: vector<3xf32>
- %5 = math.sqrt %arg0 : vector<3xf32>
+ %6 = math.sqrt %arg0 : vector<3xf32>
// CHECK: spv.GLSL.Tanh %{{.*}}: vector<3xf32>
- %6 = math.tanh %arg0 : vector<3xf32>
+ %7 = math.tanh %arg0 : vector<3xf32>
// CHECK: spv.GLSL.Sin %{{.*}}: vector<3xf32>
- %7 = math.sin %arg0 : vector<3xf32>
+ %8 = math.sin %arg0 : vector<3xf32>
return
}
diff --git a/mlir/test/Conversion/MathToSPIRV/math-to-opencl-spirv.mlir b/mlir/test/Conversion/MathToSPIRV/math-to-opencl-spirv.mlir
index 7580f1f733c49..d0959efc98ab2 100644
--- a/mlir/test/Conversion/MathToSPIRV/math-to-opencl-spirv.mlir
+++ b/mlir/test/Conversion/MathToSPIRV/math-to-opencl-spirv.mlir
@@ -8,28 +8,32 @@ func @float32_unary_scalar(%arg0: f32) {
%0 = math.cos %arg0 : f32
// CHECK: spv.OCL.exp %{{.*}}: f32
%1 = math.exp %arg0 : f32
+ // CHECK: %[[EXP:.+]] = spv.OCL.exp %arg0
+ // CHECK: %[[ONE:.+]] = spv.Constant 1.000000e+00 : f32
+ // CHECK: spv.FSub %[[EXP]], %[[ONE]]
+ %2 = math.expm1 %arg0 : f32
// CHECK: spv.OCL.log %{{.*}}: f32
- %2 = math.log %arg0 : f32
+ %3 = math.log %arg0 : f32
// CHECK: %[[ONE:.+]] = spv.Constant 1.000000e+00 : f32
// CHECK: %[[ADDONE:.+]] = spv.FAdd %[[ONE]], %{{.+}}
// CHECK: spv.OCL.log %[[ADDONE]]
- %3 = math.log1p %arg0 : f32
+ %4 = math.log1p %arg0 : f32
// CHECK: spv.OCL.rsqrt %{{.*}}: f32
- %4 = math.rsqrt %arg0 : f32
+ %5 = math.rsqrt %arg0 : f32
// CHECK: spv.OCL.sqrt %{{.*}}: f32
- %5 = math.sqrt %arg0 : f32
+ %6 = math.sqrt %arg0 : f32
// CHECK: spv.OCL.tanh %{{.*}}: f32
- %6 = math.tanh %arg0 : f32
+ %7 = math.tanh %arg0 : f32
// CHECK: spv.OCL.sin %{{.*}}: f32
- %7 = math.sin %arg0 : f32
+ %8 = math.sin %arg0 : f32
// CHECK: spv.OCL.fabs %{{.*}}: f32
- %8 = math.abs %arg0 : f32
+ %9 = math.abs %arg0 : f32
// CHECK: spv.OCL.ceil %{{.*}}: f32
- %9 = math.ceil %arg0 : f32
+ %10 = math.ceil %arg0 : f32
// CHECK: spv.OCL.floor %{{.*}}: f32
- %10 = math.floor %arg0 : f32
+ %11 = math.floor %arg0 : f32
// CHECK: spv.OCL.erf %{{.*}}: f32
- %11 = math.erf %arg0 : f32
+ %12 = math.erf %arg0 : f32
return
}
@@ -39,20 +43,24 @@ func @float32_unary_vector(%arg0: vector<3xf32>) {
%0 = math.cos %arg0 : vector<3xf32>
// CHECK: spv.OCL.exp %{{.*}}: vector<3xf32>
%1 = math.exp %arg0 : vector<3xf32>
+ // CHECK: %[[EXP:.+]] = spv.OCL.exp %arg0
+ // CHECK: %[[ONE:.+]] = spv.Constant dense<1.000000e+00> : vector<3xf32>
+ // CHECK: spv.FSub %[[EXP]], %[[ONE]]
+ %2 = math.expm1 %arg0 : vector<3xf32>
// CHECK: spv.OCL.log %{{.*}}: vector<3xf32>
- %2 = math.log %arg0 : vector<3xf32>
+ %3 = math.log %arg0 : vector<3xf32>
// CHECK: %[[ONE:.+]] = spv.Constant dense<1.000000e+00> : vector<3xf32>
// CHECK: %[[ADDONE:.+]] = spv.FAdd %[[ONE]], %{{.+}}
// CHECK: spv.OCL.log %[[ADDONE]]
- %3 = math.log1p %arg0 : vector<3xf32>
+ %4 = math.log1p %arg0 : vector<3xf32>
// CHECK: spv.OCL.rsqrt %{{.*}}: vector<3xf32>
- %4 = math.rsqrt %arg0 : vector<3xf32>
+ %5 = math.rsqrt %arg0 : vector<3xf32>
// CHECK: spv.OCL.sqrt %{{.*}}: vector<3xf32>
- %5 = math.sqrt %arg0 : vector<3xf32>
+ %6 = math.sqrt %arg0 : vector<3xf32>
// CHECK: spv.OCL.tanh %{{.*}}: vector<3xf32>
- %6 = math.tanh %arg0 : vector<3xf32>
+ %7 = math.tanh %arg0 : vector<3xf32>
// CHECK: spv.OCL.sin %{{.*}}: vector<3xf32>
- %7 = math.sin %arg0 : vector<3xf32>
+ %8 = math.sin %arg0 : vector<3xf32>
return
}
More information about the Mlir-commits
mailing list