[Mlir-commits] [mlir] 3aab320 - [MLIR][SPIRVToLLVM] Conversion for inverse sqrt and tanh
George Mitenkov
llvmlistbot at llvm.org
Thu Jul 30 00:51:38 PDT 2020
Author: George Mitenkov
Date: 2020-07-30T10:50:48+03:00
New Revision: 3aab320557e7441bc2ce0b51fd6d82838fd0d484
URL: https://github.com/llvm/llvm-project/commit/3aab320557e7441bc2ce0b51fd6d82838fd0d484
DIFF: https://github.com/llvm/llvm-project/commit/3aab320557e7441bc2ce0b51fd6d82838fd0d484.diff
LOG: [MLIR][SPIRVToLLVM] Conversion for inverse sqrt and tanh
This is a second patch on conversion of GLSL ops to LLVM dialect.
It introduces patterns to convert `spv.InverseSqrt` and `spv.Tanh`.
Reviewed By: antiagainst
Differential Revision: https://reviews.llvm.org/D84633
Added:
Modified:
mlir/lib/Conversion/SPIRVToLLVM/ConvertSPIRVToLLVM.cpp
mlir/test/Conversion/SPIRVToLLVM/glsl-ops-to-llvm.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Conversion/SPIRVToLLVM/ConvertSPIRVToLLVM.cpp b/mlir/lib/Conversion/SPIRVToLLVM/ConvertSPIRVToLLVM.cpp
index 803b05a032da..58d160d30a49 100644
--- a/mlir/lib/Conversion/SPIRVToLLVM/ConvertSPIRVToLLVM.cpp
+++ b/mlir/lib/Conversion/SPIRVToLLVM/ConvertSPIRVToLLVM.cpp
@@ -91,6 +91,21 @@ static Value createConstantAllBitsSet(Location loc, Type srcType, Type dstType,
loc, dstType, minusOneIntegerAttribute(srcType, rewriter));
}
+/// Creates `llvm.mlir.constant` with a floating-point scalar or vector value.
+static Value createFPConstant(Location loc, Type srcType, Type dstType,
+ PatternRewriter &rewriter, double value) {
+ if (auto vecType = srcType.dyn_cast<VectorType>()) {
+ auto floatType = vecType.getElementType().cast<FloatType>();
+ return rewriter.create<LLVM::ConstantOp>(
+ loc, dstType,
+ SplatElementsAttr::get(vecType,
+ rewriter.getFloatAttr(floatType, value)));
+ }
+ auto floatType = srcType.cast<FloatType>();
+ return rewriter.create<LLVM::ConstantOp>(
+ loc, dstType, rewriter.getFloatAttr(floatType, value));
+}
+
/// Utility function for bitfiled ops:
/// - `BitFieldInsert`
/// - `BitFieldSExtract`
@@ -590,6 +605,27 @@ class IComparePattern : public SPIRVToLLVMConversion<SPIRVOp> {
}
};
+class InverseSqrtPattern
+ : public SPIRVToLLVMConversion<spirv::GLSLInverseSqrtOp> {
+public:
+ using SPIRVToLLVMConversion<spirv::GLSLInverseSqrtOp>::SPIRVToLLVMConversion;
+
+ LogicalResult
+ matchAndRewrite(spirv::GLSLInverseSqrtOp op, ArrayRef<Value> operands,
+ ConversionPatternRewriter &rewriter) const override {
+ auto srcType = op.getType();
+ auto dstType = typeConverter.convertType(srcType);
+ if (!dstType)
+ return failure();
+
+ Location loc = op.getLoc();
+ Value one = createFPConstant(loc, srcType, dstType, rewriter, 1.0);
+ Value sqrt = rewriter.create<LLVM::SqrtOp>(loc, dstType, op.operand());
+ rewriter.replaceOpWithNewOp<LLVM::FDivOp>(op, dstType, one, sqrt);
+ return success();
+ }
+};
+
/// Converts `spv.Load` and `spv.Store` to LLVM dialect.
template <typename SPIRVop>
class LoadStorePattern : public SPIRVToLLVMConversion<SPIRVop> {
@@ -821,6 +857,40 @@ class TanPattern : public SPIRVToLLVMConversion<spirv::GLSLTanOp> {
}
};
+/// Convert `spv.Tanh` to
+///
+/// exp(2x) - 1
+/// -----------
+/// exp(2x) + 1
+///
+class TanhPattern : public SPIRVToLLVMConversion<spirv::GLSLTanhOp> {
+public:
+ using SPIRVToLLVMConversion<spirv::GLSLTanhOp>::SPIRVToLLVMConversion;
+
+ LogicalResult
+ matchAndRewrite(spirv::GLSLTanhOp tanhOp, ArrayRef<Value> operands,
+ ConversionPatternRewriter &rewriter) const override {
+ auto srcType = tanhOp.getType();
+ auto dstType = typeConverter.convertType(srcType);
+ if (!dstType)
+ return failure();
+
+ Location loc = tanhOp.getLoc();
+ Value two = createFPConstant(loc, srcType, dstType, rewriter, 2.0);
+ Value multiplied =
+ rewriter.create<LLVM::FMulOp>(loc, dstType, two, tanhOp.operand());
+ Value exponential = rewriter.create<LLVM::ExpOp>(loc, dstType, multiplied);
+ Value one = createFPConstant(loc, srcType, dstType, rewriter, 1.0);
+ Value numerator =
+ rewriter.create<LLVM::FSubOp>(loc, dstType, exponential, one);
+ Value denominator =
+ rewriter.create<LLVM::FAddOp>(loc, dstType, exponential, one);
+ rewriter.replaceOpWithNewOp<LLVM::FDivOp>(tanhOp, dstType, numerator,
+ denominator);
+ return success();
+ }
+};
+
class VariablePattern : public SPIRVToLLVMConversion<spirv::VariableOp> {
public:
using SPIRVToLLVMConversion<spirv::VariableOp>::SPIRVToLLVMConversion;
@@ -1052,7 +1122,8 @@ void mlir::populateSPIRVToLLVMConversionPatterns(
DirectConversionPattern<spirv::GLSLFAbsOp, LLVM::FAbsOp>,
DirectConversionPattern<spirv::GLSLLogOp, LLVM::LogOp>,
DirectConversionPattern<spirv::GLSLSinOp, LLVM::SinOp>,
- DirectConversionPattern<spirv::GLSLSqrtOp, LLVM::SqrtOp>, TanPattern,
+ DirectConversionPattern<spirv::GLSLSqrtOp, LLVM::SqrtOp>,
+ InverseSqrtPattern, TanPattern, TanhPattern,
// Logical ops
DirectConversionPattern<spirv::LogicalAndOp, LLVM::AndOp>,
diff --git a/mlir/test/Conversion/SPIRVToLLVM/glsl-ops-to-llvm.mlir b/mlir/test/Conversion/SPIRVToLLVM/glsl-ops-to-llvm.mlir
index 1907619445c9..ab501b8aabfd 100644
--- a/mlir/test/Conversion/SPIRVToLLVM/glsl-ops-to-llvm.mlir
+++ b/mlir/test/Conversion/SPIRVToLLVM/glsl-ops-to-llvm.mlir
@@ -103,3 +103,33 @@ func @tan(%arg0: f32) {
%0 = spv.GLSL.Tan %arg0 : f32
return
}
+
+//===----------------------------------------------------------------------===//
+// spv.GLSL.Tanh
+//===----------------------------------------------------------------------===//
+
+// CHECK-LABEL: @tanh
+func @tanh(%arg0: f32) {
+ // CHECK: %[[TWO:.*]] = llvm.mlir.constant(2.000000e+00 : f32) : !llvm.float
+ // CHECK: %[[X2:.*]] = llvm.fmul %[[TWO]], %{{.*}} : !llvm.float
+ // CHECK: %[[EXP:.*]] = "llvm.intr.exp"(%[[X2]]) : (!llvm.float) -> !llvm.float
+ // CHECK: %[[ONE:.*]] = llvm.mlir.constant(1.000000e+00 : f32) : !llvm.float
+ // CHECK: %[[T0:.*]] = llvm.fsub %[[EXP]], %[[ONE]] : !llvm.float
+ // CHECK: %[[T1:.*]] = llvm.fadd %[[EXP]], %[[ONE]] : !llvm.float
+ // CHECK: llvm.fdiv %[[T0]], %[[T1]] : !llvm.float
+ %0 = spv.GLSL.Tanh %arg0 : f32
+ return
+}
+
+//===----------------------------------------------------------------------===//
+// spv.GLSL.InverseSqrt
+//===----------------------------------------------------------------------===//
+
+// CHECK-LABEL: @inverse_sqrt
+func @inverse_sqrt(%arg0: f32) {
+ // CHECK: %[[ONE:.*]] = llvm.mlir.constant(1.000000e+00 : f32) : !llvm.float
+ // CHECK: %[[SQRT:.*]] = "llvm.intr.sqrt"(%{{.*}}) : (!llvm.float) -> !llvm.float
+ // CHECK: llvm.fdiv %[[ONE]], %[[SQRT]] : !llvm.float
+ %0 = spv.GLSL.InverseSqrt %arg0 : f32
+ return
+}
More information about the Mlir-commits
mailing list