[Mlir-commits] [mlir] 8dae0b6 - [mlir][spirv] arith::RemSIOp OpenCL lowering
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Thu Nov 25 01:44:33 PST 2021
Author: Butygin
Date: 2021-11-25T12:44:06+03:00
New Revision: 8dae0b6b6c9a0256908820b9ef4bc04f0119666b
URL: https://github.com/llvm/llvm-project/commit/8dae0b6b6c9a0256908820b9ef4bc04f0119666b
DIFF: https://github.com/llvm/llvm-project/commit/8dae0b6b6c9a0256908820b9ef4bc04f0119666b.diff
LOG: [mlir][spirv] arith::RemSIOp OpenCL lowering
Differential Revision: https://reviews.llvm.org/D114524
Added:
Modified:
mlir/lib/Conversion/ArithmeticToSPIRV/ArithmeticToSPIRV.cpp
mlir/test/Conversion/ArithmeticToSPIRV/arithmetic-to-spirv.mlir
mlir/test/Conversion/MathToSPIRV/math-to-glsl-spirv.mlir
mlir/test/Conversion/MathToSPIRV/math-to-opencl-spirv.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Conversion/ArithmeticToSPIRV/ArithmeticToSPIRV.cpp b/mlir/lib/Conversion/ArithmeticToSPIRV/ArithmeticToSPIRV.cpp
index 6fd69637df1d8..99e15a2f25a11 100644
--- a/mlir/lib/Conversion/ArithmeticToSPIRV/ArithmeticToSPIRV.cpp
+++ b/mlir/lib/Conversion/ArithmeticToSPIRV/ArithmeticToSPIRV.cpp
@@ -45,11 +45,20 @@ struct ConstantScalarOpPattern final
ConversionPatternRewriter &rewriter) const override;
};
-/// Converts arith.remsi to SPIR-V ops.
+/// Converts arith.remsi to GLSL SPIR-V ops.
///
/// This cannot be merged into the template unary/binary pattern due to Vulkan
/// restrictions over spv.SRem and spv.SMod.
-struct RemSIOpPattern final : public OpConversionPattern<arith::RemSIOp> {
+struct RemSIOpGLSLPattern final : public OpConversionPattern<arith::RemSIOp> {
+ using OpConversionPattern<arith::RemSIOp>::OpConversionPattern;
+
+ LogicalResult
+ matchAndRewrite(arith::RemSIOp op, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override;
+};
+
+/// Converts arith.remsi to OpenCL SPIR-V ops.
+struct RemSIOpOCLPattern final : public OpConversionPattern<arith::RemSIOp> {
using OpConversionPattern<arith::RemSIOp>::OpConversionPattern;
LogicalResult
@@ -396,7 +405,7 @@ LogicalResult ConstantScalarOpPattern::matchAndRewrite(
}
//===----------------------------------------------------------------------===//
-// RemSIOpPattern
+// RemSIOpGLSLPattern
//===----------------------------------------------------------------------===//
/// Returns signed remainder for `lhs` and `rhs` and lets the result follow
@@ -406,6 +415,7 @@ LogicalResult ConstantScalarOpPattern::matchAndRewrite(
/// spec, "for the OpSRem and OpSMod instructions, if either operand is negative
/// the result is undefined." So we cannot directly use spv.SRem/spv.SMod
/// if either operand can be negative. Emulate it via spv.UMod.
+template <typename SignedAbsOp>
static Value emulateSignedRemainder(Location loc, Value lhs, Value rhs,
Value signOperand, OpBuilder &builder) {
assert(lhs.getType() == rhs.getType());
@@ -414,8 +424,8 @@ static Value emulateSignedRemainder(Location loc, Value lhs, Value rhs,
Type type = lhs.getType();
// Calculate the remainder with spv.UMod.
- Value lhsAbs = builder.create<spirv::GLSLSAbsOp>(loc, type, lhs);
- Value rhsAbs = builder.create<spirv::GLSLSAbsOp>(loc, type, rhs);
+ Value lhsAbs = builder.create<SignedAbsOp>(loc, type, lhs);
+ Value rhsAbs = builder.create<SignedAbsOp>(loc, type, rhs);
Value abs = builder.create<spirv::UModOp>(loc, lhsAbs, rhsAbs);
// Fix the sign.
@@ -429,11 +439,26 @@ static Value emulateSignedRemainder(Location loc, Value lhs, Value rhs,
}
LogicalResult
-RemSIOpPattern::matchAndRewrite(arith::RemSIOp op, OpAdaptor adaptor,
- ConversionPatternRewriter &rewriter) const {
- Value result = emulateSignedRemainder(op.getLoc(), adaptor.getOperands()[0],
- adaptor.getOperands()[1],
- adaptor.getOperands()[0], rewriter);
+RemSIOpGLSLPattern::matchAndRewrite(arith::RemSIOp op, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const {
+ Value result = emulateSignedRemainder<spirv::GLSLSAbsOp>(
+ op.getLoc(), adaptor.getOperands()[0], adaptor.getOperands()[1],
+ adaptor.getOperands()[0], rewriter);
+ rewriter.replaceOp(op, result);
+
+ return success();
+}
+
+//===----------------------------------------------------------------------===//
+// RemSIOpOCLPattern
+//===----------------------------------------------------------------------===//
+
+LogicalResult
+RemSIOpOCLPattern::matchAndRewrite(arith::RemSIOp op, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const {
+ Value result = emulateSignedRemainder<spirv::OCLSAbsOp>(
+ op.getLoc(), adaptor.getOperands()[0], adaptor.getOperands()[1],
+ adaptor.getOperands()[0], rewriter);
rewriter.replaceOp(op, result);
return success();
@@ -762,7 +787,7 @@ void mlir::arith::populateArithmeticToSPIRVPatterns(
spirv::UnaryAndBinaryOpPattern<arith::DivUIOp, spirv::UDivOp>,
spirv::UnaryAndBinaryOpPattern<arith::DivSIOp, spirv::SDivOp>,
spirv::UnaryAndBinaryOpPattern<arith::RemUIOp, spirv::UModOp>,
- RemSIOpPattern,
+ RemSIOpGLSLPattern, RemSIOpOCLPattern,
BitwiseOpPattern<arith::AndIOp, spirv::LogicalAndOp, spirv::BitwiseAndOp>,
BitwiseOpPattern<arith::OrIOp, spirv::LogicalOrOp, spirv::BitwiseOrOp>,
XOrIOpLogicalPattern, XOrIOpBooleanPattern,
diff --git a/mlir/test/Conversion/ArithmeticToSPIRV/arithmetic-to-spirv.mlir b/mlir/test/Conversion/ArithmeticToSPIRV/arithmetic-to-spirv.mlir
index 5ab13e780d0d4..291c7fdd77b1c 100644
--- a/mlir/test/Conversion/ArithmeticToSPIRV/arithmetic-to-spirv.mlir
+++ b/mlir/test/Conversion/ArithmeticToSPIRV/arithmetic-to-spirv.mlir
@@ -43,14 +43,8 @@ func @scalar_srem(%lhs: i32, %rhs: i32) {
// Check float unary operation conversions.
// CHECK-LABEL: @float32_unary_scalar
func @float32_unary_scalar(%arg0: f32) {
- // CHECK: spv.GLSL.FAbs %{{.*}}: f32
- %0 = math.abs %arg0 : f32
- // CHECK: spv.GLSL.Ceil %{{.*}}: f32
- %1 = math.ceil %arg0 : f32
// CHECK: spv.FNegate %{{.*}}: f32
- %5 = arith.negf %arg0 : f32
- // CHECK: spv.GLSL.Floor %{{.*}}: f32
- %10 = math.floor %arg0 : f32
+ %0 = arith.negf %arg0 : f32
return
}
@@ -842,3 +836,39 @@ func @sitofp(%arg0 : i64) -> f64 {
}
} // end module
+
+// -----
+
+// Check OpenCL lowering of arith.remsi
+module attributes {
+ spv.target_env = #spv.target_env<
+ #spv.vce<v1.0, [Int16, Kernel], []>, {}>
+} {
+
+// CHECK-LABEL: @scalar_srem
+// CHECK-SAME: (%[[LHS:.+]]: i32, %[[RHS:.+]]: i32)
+func @scalar_srem(%lhs: i32, %rhs: i32) {
+ // CHECK: %[[LABS:.+]] = spv.OCL.s_abs %[[LHS]] : i32
+ // CHECK: %[[RABS:.+]] = spv.OCL.s_abs %[[RHS]] : i32
+ // CHECK: %[[ABS:.+]] = spv.UMod %[[LABS]], %[[RABS]] : i32
+ // CHECK: %[[POS:.+]] = spv.IEqual %[[LHS]], %[[LABS]] : i32
+ // CHECK: %[[NEG:.+]] = spv.SNegate %[[ABS]] : i32
+ // CHECK: %{{.+}} = spv.Select %[[POS]], %[[ABS]], %[[NEG]] : i1, i32
+ %0 = arith.remsi %lhs, %rhs: i32
+ return
+}
+
+// CHECK-LABEL: @vector_srem
+// CHECK-SAME: (%[[LHS:.+]]: vector<3xi16>, %[[RHS:.+]]: vector<3xi16>)
+func @vector_srem(%arg0: vector<3xi16>, %arg1: vector<3xi16>) {
+ // CHECK: %[[LABS:.+]] = spv.OCL.s_abs %[[LHS]] : vector<3xi16>
+ // CHECK: %[[RABS:.+]] = spv.OCL.s_abs %[[RHS]] : vector<3xi16>
+ // CHECK: %[[ABS:.+]] = spv.UMod %[[LABS]], %[[RABS]] : vector<3xi16>
+ // CHECK: %[[POS:.+]] = spv.IEqual %[[LHS]], %[[LABS]] : vector<3xi16>
+ // CHECK: %[[NEG:.+]] = spv.SNegate %[[ABS]] : vector<3xi16>
+ // CHECK: %{{.+}} = spv.Select %[[POS]], %[[ABS]], %[[NEG]] : vector<3xi1>, vector<3xi16>
+ %0 = arith.remsi %arg0, %arg1: vector<3xi16>
+ return
+}
+
+} // end module
diff --git a/mlir/test/Conversion/MathToSPIRV/math-to-glsl-spirv.mlir b/mlir/test/Conversion/MathToSPIRV/math-to-glsl-spirv.mlir
index ad32a88a876ea..8cae1ca7d94ef 100644
--- a/mlir/test/Conversion/MathToSPIRV/math-to-glsl-spirv.mlir
+++ b/mlir/test/Conversion/MathToSPIRV/math-to-glsl-spirv.mlir
@@ -22,6 +22,12 @@ func @float32_unary_scalar(%arg0: f32) {
%6 = math.tanh %arg0 : f32
// CHECK: spv.GLSL.Sin %{{.*}}: f32
%7 = math.sin %arg0 : f32
+ // CHECK: spv.GLSL.FAbs %{{.*}}: f32
+ %8 = math.abs %arg0 : f32
+ // CHECK: spv.GLSL.Ceil %{{.*}}: f32
+ %9 = math.ceil %arg0 : f32
+ // CHECK: spv.GLSL.Floor %{{.*}}: f32
+ %10 = math.floor %arg0 : f32
return
}
diff --git a/mlir/test/Conversion/MathToSPIRV/math-to-opencl-spirv.mlir b/mlir/test/Conversion/MathToSPIRV/math-to-opencl-spirv.mlir
index 8a1a3acc5f0cd..5bfd4e477c21c 100644
--- a/mlir/test/Conversion/MathToSPIRV/math-to-opencl-spirv.mlir
+++ b/mlir/test/Conversion/MathToSPIRV/math-to-opencl-spirv.mlir
@@ -22,6 +22,12 @@ func @float32_unary_scalar(%arg0: f32) {
%6 = math.tanh %arg0 : f32
// CHECK: spv.OCL.sin %{{.*}}: f32
%7 = math.sin %arg0 : f32
+ // CHECK: spv.OCL.fabs %{{.*}}: f32
+ %8 = math.abs %arg0 : f32
+ // CHECK: spv.OCL.ceil %{{.*}}: f32
+ %9 = math.ceil %arg0 : f32
+ // CHECK: spv.OCL.floor %{{.*}}: f32
+ %10 = math.floor %arg0 : f32
return
}
More information about the Mlir-commits
mailing list