[Mlir-commits] [mlir] 533ec92 - [mlir][spirv] Add pattern to lower math.copysign
Lei Zhang
llvmlistbot at llvm.org
Fri Apr 1 09:06:53 PDT 2022
Author: Lei Zhang
Date: 2022-04-01T12:06:47-04:00
New Revision: 533ec929f67d0169f247cb73835a127053cc5933
URL: https://github.com/llvm/llvm-project/commit/533ec929f67d0169f247cb73835a127053cc5933
DIFF: https://github.com/llvm/llvm-project/commit/533ec929f67d0169f247cb73835a127053cc5933.diff
LOG: [mlir][spirv] Add pattern to lower math.copysign
This follows the logic:
https://git.musl-libc.org/cgit/musl/tree/src/math/copysignf.c
Reviewed By: ThomasRaoux
Differential Revision: https://reviews.llvm.org/D122910
Added:
mlir/test/Conversion/MathToSPIRV/math-to-core-spirv.mlir
Modified:
mlir/lib/Conversion/MathToSPIRV/MathToSPIRV.cpp
Removed:
################################################################################
diff --git a/mlir/lib/Conversion/MathToSPIRV/MathToSPIRV.cpp b/mlir/lib/Conversion/MathToSPIRV/MathToSPIRV.cpp
index 6d309d6760830..956118a5d9f77 100644
--- a/mlir/lib/Conversion/MathToSPIRV/MathToSPIRV.cpp
+++ b/mlir/lib/Conversion/MathToSPIRV/MathToSPIRV.cpp
@@ -15,6 +15,7 @@
#include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h"
#include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"
#include "mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h"
+#include "mlir/IR/BuiltinTypes.h"
#include "llvm/Support/Debug.h"
#define DEBUG_TYPE "math-to-spirv-pattern"
@@ -30,14 +31,74 @@ using namespace mlir;
// normal RewritePattern.
namespace {
+/// Converts math.copysign to SPIR-V ops.
+class CopySignPattern final : public OpConversionPattern<math::CopySignOp> {
+ using OpConversionPattern::OpConversionPattern;
+
+ LogicalResult
+ matchAndRewrite(math::CopySignOp copySignOp, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ auto type = getTypeConverter()->convertType(copySignOp.getType());
+ if (!type)
+ return failure();
+
+ FloatType floatType;
+ if (auto scalarType = copySignOp.getType().dyn_cast<FloatType>()) {
+ floatType = scalarType;
+ } else if (auto vectorType = copySignOp.getType().dyn_cast<VectorType>()) {
+ floatType = vectorType.getElementType().cast<FloatType>();
+ } else {
+ return failure();
+ }
+
+ Location loc = copySignOp.getLoc();
+ int bitwidth = floatType.getWidth();
+ Type intType = rewriter.getIntegerType(bitwidth);
+
+ Value signMask = rewriter.create<spirv::ConstantOp>(
+ loc, intType, rewriter.getIntegerAttr(intType, (1u << (bitwidth - 1))));
+ Value valueMask = rewriter.create<spirv::ConstantOp>(
+ loc, intType,
+ rewriter.getIntegerAttr(intType, (1u << (bitwidth - 1)) - 1u));
+
+ if (auto vectorType = copySignOp.getType().dyn_cast<VectorType>()) {
+ assert(vectorType.getRank() == 1);
+ int count = vectorType.getNumElements();
+ intType = VectorType::get(count, intType);
+
+ SmallVector<Value> signSplat(count, signMask);
+ signMask =
+ rewriter.create<spirv::CompositeConstructOp>(loc, intType, signSplat);
+
+ SmallVector<Value> valueSplat(count, valueMask);
+ valueMask = rewriter.create<spirv::CompositeConstructOp>(loc, intType,
+ valueSplat);
+ }
+
+ Value lhsCast =
+ rewriter.create<spirv::BitcastOp>(loc, intType, adaptor.getLhs());
+ Value rhsCast =
+ rewriter.create<spirv::BitcastOp>(loc, intType, adaptor.getRhs());
+
+ Value value = rewriter.create<spirv::BitwiseAndOp>(
+ loc, intType, ValueRange{lhsCast, valueMask});
+ Value sign = rewriter.create<spirv::BitwiseAndOp>(
+ loc, intType, ValueRange{rhsCast, signMask});
+
+ Value result = rewriter.create<spirv::BitwiseOrOp>(loc, intType,
+ ValueRange{value, sign});
+ rewriter.replaceOpWithNewOp<spirv::BitcastOp>(copySignOp, type, result);
+ return success();
+ }
+};
+
/// Converts math.expm1 to SPIR-V ops.
///
/// SPIR-V does not have a direct operations for exp(x)-1. Explicitly lower to
/// these operations.
template <typename ExpOp>
-class ExpM1OpPattern final : public OpConversionPattern<math::ExpM1Op> {
-public:
- using OpConversionPattern<math::ExpM1Op>::OpConversionPattern;
+struct ExpM1OpPattern final : public OpConversionPattern<math::ExpM1Op> {
+ using OpConversionPattern::OpConversionPattern;
LogicalResult
matchAndRewrite(math::ExpM1Op operation, OpAdaptor adaptor,
@@ -57,9 +118,8 @@ class ExpM1OpPattern final : public OpConversionPattern<math::ExpM1Op> {
/// SPIR-V does not have a direct operations for log(1+x). Explicitly lower to
/// these operations.
template <typename LogOp>
-class Log1pOpPattern final : public OpConversionPattern<math::Log1pOp> {
-public:
- using OpConversionPattern<math::Log1pOp>::OpConversionPattern;
+struct Log1pOpPattern final : public OpConversionPattern<math::Log1pOp> {
+ using OpConversionPattern::OpConversionPattern;
LogicalResult
matchAndRewrite(math::Log1pOp operation, OpAdaptor adaptor,
@@ -83,6 +143,8 @@ class Log1pOpPattern final : public OpConversionPattern<math::Log1pOp> {
namespace mlir {
void populateMathToSPIRVPatterns(SPIRVTypeConverter &typeConverter,
RewritePatternSet &patterns) {
+ // Core patterns
+ patterns.add<CopySignPattern>(typeConverter, patterns.getContext());
// GLSL patterns
patterns
diff --git a/mlir/test/Conversion/MathToSPIRV/math-to-core-spirv.mlir b/mlir/test/Conversion/MathToSPIRV/math-to-core-spirv.mlir
new file mode 100644
index 0000000000000..f2379f2cab5cd
--- /dev/null
+++ b/mlir/test/Conversion/MathToSPIRV/math-to-core-spirv.mlir
@@ -0,0 +1,43 @@
+// RUN: mlir-opt -split-input-file -convert-math-to-spirv -verify-diagnostics %s -o - | FileCheck %s
+
+func @copy_sign_scalar(%value: f32, %sign: f32) -> f32 {
+ %0 = math.copysign %value, %sign : f32
+ return %0: f32
+}
+
+// CHECK-LABEL: func @copy_sign_scalar
+// CHECK-SAME: (%[[VALUE:.+]]: f32, %[[SIGN:.+]]: f32)
+// CHECK: %[[SMASK:.+]] = spv.Constant -2147483648 : i32
+// CHECK: %[[VMASK:.+]] = spv.Constant 2147483647 : i32
+// CHECK: %[[VCAST:.+]] = spv.Bitcast %[[VALUE]] : f32 to i32
+// CHECK: %[[SCAST:.+]] = spv.Bitcast %[[SIGN]] : f32 to i32
+// CHECK: %[[VAND:.+]] = spv.BitwiseAnd %[[VCAST]], %[[VMASK]] : i32
+// CHECK: %[[SAND:.+]] = spv.BitwiseAnd %[[SCAST]], %[[SMASK]] : i32
+// CHECK: %[[OR:.+]] = spv.BitwiseOr %[[VAND]], %[[SAND]] : i32
+// CHECK: %[[RESULT:.+]] = spv.Bitcast %[[OR]] : i32 to f32
+// CHECK: return %[[RESULT]]
+
+// -----
+
+module attributes { spv.target_env = #spv.target_env<#spv.vce<v1.0, [Float16, Int16], []>, {}> } {
+
+func @copy_sign_vector(%value: vector<3xf16>, %sign: vector<3xf16>) -> vector<3xf16> {
+ %0 = math.copysign %value, %sign : vector<3xf16>
+ return %0: vector<3xf16>
+}
+
+}
+
+// CHECK-LABEL: func @copy_sign_vector
+// CHECK-SAME: (%[[VALUE:.+]]: vector<3xf16>, %[[SIGN:.+]]: vector<3xf16>)
+// CHECK: %[[SMASK:.+]] = spv.Constant -32768 : i16
+// CHECK: %[[VMASK:.+]] = spv.Constant 32767 : i16
+// CHECK: %[[SVMASK:.+]] = spv.CompositeConstruct %[[SMASK]], %[[SMASK]], %[[SMASK]] : vector<3xi16>
+// CHECK: %[[VVMASK:.+]] = spv.CompositeConstruct %[[VMASK]], %[[VMASK]], %[[VMASK]] : vector<3xi16>
+// CHECK: %[[VCAST:.+]] = spv.Bitcast %[[VALUE]] : vector<3xf16> to vector<3xi16>
+// CHECK: %[[SCAST:.+]] = spv.Bitcast %[[SIGN]] : vector<3xf16> to vector<3xi16>
+// CHECK: %[[VAND:.+]] = spv.BitwiseAnd %[[VCAST]], %[[VVMASK]] : vector<3xi16>
+// CHECK: %[[SAND:.+]] = spv.BitwiseAnd %[[SCAST]], %[[SVMASK]] : vector<3xi16>
+// CHECK: %[[OR:.+]] = spv.BitwiseOr %[[VAND]], %[[SAND]] : vector<3xi16>
+// CHECK: %[[RESULT:.+]] = spv.Bitcast %[[OR]] : vector<3xi16> to vector<3xf16>
+// CHECK: return %[[RESULT]]
More information about the Mlir-commits
mailing list