[Mlir-commits] [mlir] 4ba45a7 - [mlir][StandardToSPIRV] Fix conversion for signed remainder
Lei Zhang
llvmlistbot at llvm.org
Mon Jul 13 13:18:53 PDT 2020
Author: Lei Zhang
Date: 2020-07-13T16:15:31-04:00
New Revision: 4ba45a778a13eab1495a75a14682f874016f3d21
URL: https://github.com/llvm/llvm-project/commit/4ba45a778a13eab1495a75a14682f874016f3d21
DIFF: https://github.com/llvm/llvm-project/commit/4ba45a778a13eab1495a75a14682f874016f3d21.diff
LOG: [mlir][StandardToSPIRV] Fix conversion for signed remainder
Per the Vulkan's SPIR-V environment spec, "for the OpSRem and OpSMod
instructions, if either operand is negative the result is undefined."
So we cannot directly use spv.SRem/spv.SMod if either operand can be
negative. Emulate it via spv.UMod.
Because the emulation uses spv.SNegate, this commit also defines
spv.SNegate.
Differential Revision: https://reviews.llvm.org/D83679
Added:
Modified:
mlir/include/mlir/Dialect/SPIRV/SPIRVArithmeticOps.td
mlir/include/mlir/Dialect/SPIRV/SPIRVBase.td
mlir/include/mlir/Dialect/SPIRV/SPIRVLogicalOps.td
mlir/lib/Conversion/StandardToSPIRV/ConvertStandardToSPIRV.cpp
mlir/lib/Dialect/SPIRV/SPIRVOps.cpp
mlir/test/Conversion/StandardToSPIRV/std-ops-to-spirv.mlir
mlir/test/Dialect/SPIRV/Serialization/arithmetic-ops.mlir
mlir/test/Dialect/SPIRV/arithmetic-ops.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/SPIRV/SPIRVArithmeticOps.td b/mlir/include/mlir/Dialect/SPIRV/SPIRVArithmeticOps.td
index 350e3659a28d..5a12e6f36ec4 100644
--- a/mlir/include/mlir/Dialect/SPIRV/SPIRVArithmeticOps.td
+++ b/mlir/include/mlir/Dialect/SPIRV/SPIRVArithmeticOps.td
@@ -452,6 +452,31 @@ def SPV_SModOp : SPV_ArithmeticBinaryOp<"SMod", SPV_Integer, []> {
// -----
+def SPV_SNegateOp : SPV_ArithmeticUnaryOp<"SNegate", SPV_Integer, []> {
+ let summary = "Signed-integer subtract of Operand from zero.";
+
+ let description = [{
+ Result Type must be a scalar or vector of integer type.
+
+ Operand’s type must be a scalar or vector of integer type. It must
+ have the same number of components as Result Type. The component width
+ must equal the component width in Result Type.
+
+ Results are computed per component.
+
+ <!-- End of AutoGen section -->
+
+ #### Example:
+
+ ```mlir
+ %1 = spv.SNegate %0 : i32
+ %3 = spv.SNegate %2 : vector<4xi32>
+ ```
+ }];
+}
+
+// -----
+
def SPV_SRemOp : SPV_ArithmeticBinaryOp<"SRem", SPV_Integer, []> {
let summary = [{
Signed remainder operation for the remainder whose sign matches the sign
diff --git a/mlir/include/mlir/Dialect/SPIRV/SPIRVBase.td b/mlir/include/mlir/Dialect/SPIRV/SPIRVBase.td
index f114c878569d..cbff82efdfd3 100644
--- a/mlir/include/mlir/Dialect/SPIRV/SPIRVBase.td
+++ b/mlir/include/mlir/Dialect/SPIRV/SPIRVBase.td
@@ -3150,6 +3150,7 @@ def SPV_OC_OpUConvert : I32EnumAttrCase<"OpUConvert", 113>;
def SPV_OC_OpSConvert : I32EnumAttrCase<"OpSConvert", 114>;
def SPV_OC_OpFConvert : I32EnumAttrCase<"OpFConvert", 115>;
def SPV_OC_OpBitcast : I32EnumAttrCase<"OpBitcast", 124>;
+def SPV_OC_OpSNegate : I32EnumAttrCase<"OpSNegate", 126>;
def SPV_OC_OpFNegate : I32EnumAttrCase<"OpFNegate", 127>;
def SPV_OC_OpIAdd : I32EnumAttrCase<"OpIAdd", 128>;
def SPV_OC_OpFAdd : I32EnumAttrCase<"OpFAdd", 129>;
@@ -3271,41 +3272,42 @@ def SPV_OpcodeAttr :
SPV_OC_OpCompositeInsert, SPV_OC_OpTranspose, SPV_OC_OpConvertFToU,
SPV_OC_OpConvertFToS, SPV_OC_OpConvertSToF, SPV_OC_OpConvertUToF,
SPV_OC_OpUConvert, SPV_OC_OpSConvert, SPV_OC_OpFConvert, SPV_OC_OpBitcast,
- SPV_OC_OpFNegate, SPV_OC_OpIAdd, SPV_OC_OpFAdd, SPV_OC_OpISub, SPV_OC_OpFSub,
- SPV_OC_OpIMul, SPV_OC_OpFMul, SPV_OC_OpUDiv, SPV_OC_OpSDiv, SPV_OC_OpFDiv,
- SPV_OC_OpUMod, SPV_OC_OpSRem, SPV_OC_OpSMod, SPV_OC_OpFRem, SPV_OC_OpFMod,
- SPV_OC_OpMatrixTimesScalar, SPV_OC_OpMatrixTimesMatrix, SPV_OC_OpLogicalEqual,
- SPV_OC_OpLogicalNotEqual, SPV_OC_OpLogicalOr, SPV_OC_OpLogicalAnd,
- SPV_OC_OpLogicalNot, SPV_OC_OpSelect, SPV_OC_OpIEqual, SPV_OC_OpINotEqual,
- SPV_OC_OpUGreaterThan, SPV_OC_OpSGreaterThan, SPV_OC_OpUGreaterThanEqual,
- SPV_OC_OpSGreaterThanEqual, SPV_OC_OpULessThan, SPV_OC_OpSLessThan,
- SPV_OC_OpULessThanEqual, SPV_OC_OpSLessThanEqual, SPV_OC_OpFOrdEqual,
- SPV_OC_OpFUnordEqual, SPV_OC_OpFOrdNotEqual, SPV_OC_OpFUnordNotEqual,
- SPV_OC_OpFOrdLessThan, SPV_OC_OpFUnordLessThan, SPV_OC_OpFOrdGreaterThan,
- SPV_OC_OpFUnordGreaterThan, SPV_OC_OpFOrdLessThanEqual,
- SPV_OC_OpFUnordLessThanEqual, SPV_OC_OpFOrdGreaterThanEqual,
- SPV_OC_OpFUnordGreaterThanEqual, SPV_OC_OpShiftRightLogical,
- SPV_OC_OpShiftRightArithmetic, SPV_OC_OpShiftLeftLogical, SPV_OC_OpBitwiseOr,
- SPV_OC_OpBitwiseXor, SPV_OC_OpBitwiseAnd, SPV_OC_OpNot,
- SPV_OC_OpBitFieldInsert, SPV_OC_OpBitFieldSExtract, SPV_OC_OpBitFieldUExtract,
- SPV_OC_OpBitReverse, SPV_OC_OpBitCount, SPV_OC_OpControlBarrier,
- SPV_OC_OpMemoryBarrier, SPV_OC_OpAtomicCompareExchangeWeak,
- SPV_OC_OpAtomicIIncrement, SPV_OC_OpAtomicIDecrement, SPV_OC_OpAtomicIAdd,
- SPV_OC_OpAtomicISub, SPV_OC_OpAtomicSMin, SPV_OC_OpAtomicUMin,
- SPV_OC_OpAtomicSMax, SPV_OC_OpAtomicUMax, SPV_OC_OpAtomicAnd,
- SPV_OC_OpAtomicOr, SPV_OC_OpAtomicXor, SPV_OC_OpPhi, SPV_OC_OpLoopMerge,
- SPV_OC_OpSelectionMerge, SPV_OC_OpLabel, SPV_OC_OpBranch,
- SPV_OC_OpBranchConditional, SPV_OC_OpReturn, SPV_OC_OpReturnValue,
- SPV_OC_OpUnreachable, SPV_OC_OpNoLine, SPV_OC_OpModuleProcessed,
- SPV_OC_OpGroupNonUniformElect, SPV_OC_OpGroupNonUniformBallot,
- SPV_OC_OpGroupNonUniformIAdd, SPV_OC_OpGroupNonUniformFAdd,
- SPV_OC_OpGroupNonUniformIMul, SPV_OC_OpGroupNonUniformFMul,
- SPV_OC_OpGroupNonUniformSMin, SPV_OC_OpGroupNonUniformUMin,
- SPV_OC_OpGroupNonUniformFMin, SPV_OC_OpGroupNonUniformSMax,
- SPV_OC_OpGroupNonUniformUMax, SPV_OC_OpGroupNonUniformFMax,
- SPV_OC_OpSubgroupBallotKHR, SPV_OC_OpTypeCooperativeMatrixNV,
- SPV_OC_OpCooperativeMatrixLoadNV, SPV_OC_OpCooperativeMatrixStoreNV,
- SPV_OC_OpCooperativeMatrixMulAddNV, SPV_OC_OpCooperativeMatrixLengthNV
+ SPV_OC_OpSNegate, SPV_OC_OpFNegate, SPV_OC_OpIAdd, SPV_OC_OpFAdd,
+ SPV_OC_OpISub, SPV_OC_OpFSub, SPV_OC_OpIMul, SPV_OC_OpFMul, SPV_OC_OpUDiv,
+ SPV_OC_OpSDiv, SPV_OC_OpFDiv, SPV_OC_OpUMod, SPV_OC_OpSRem, SPV_OC_OpSMod,
+ SPV_OC_OpFRem, SPV_OC_OpFMod, SPV_OC_OpMatrixTimesScalar,
+ SPV_OC_OpMatrixTimesMatrix, SPV_OC_OpLogicalEqual, SPV_OC_OpLogicalNotEqual,
+ SPV_OC_OpLogicalOr, SPV_OC_OpLogicalAnd, SPV_OC_OpLogicalNot, SPV_OC_OpSelect,
+ SPV_OC_OpIEqual, SPV_OC_OpINotEqual, SPV_OC_OpUGreaterThan,
+ SPV_OC_OpSGreaterThan, SPV_OC_OpUGreaterThanEqual, SPV_OC_OpSGreaterThanEqual,
+ SPV_OC_OpULessThan, SPV_OC_OpSLessThan, SPV_OC_OpULessThanEqual,
+ SPV_OC_OpSLessThanEqual, SPV_OC_OpFOrdEqual, SPV_OC_OpFUnordEqual,
+ SPV_OC_OpFOrdNotEqual, SPV_OC_OpFUnordNotEqual, SPV_OC_OpFOrdLessThan,
+ SPV_OC_OpFUnordLessThan, SPV_OC_OpFOrdGreaterThan, SPV_OC_OpFUnordGreaterThan,
+ SPV_OC_OpFOrdLessThanEqual, SPV_OC_OpFUnordLessThanEqual,
+ SPV_OC_OpFOrdGreaterThanEqual, SPV_OC_OpFUnordGreaterThanEqual,
+ SPV_OC_OpShiftRightLogical, SPV_OC_OpShiftRightArithmetic,
+ SPV_OC_OpShiftLeftLogical, SPV_OC_OpBitwiseOr, SPV_OC_OpBitwiseXor,
+ SPV_OC_OpBitwiseAnd, SPV_OC_OpNot, SPV_OC_OpBitFieldInsert,
+ SPV_OC_OpBitFieldSExtract, SPV_OC_OpBitFieldUExtract, SPV_OC_OpBitReverse,
+ SPV_OC_OpBitCount, SPV_OC_OpControlBarrier, SPV_OC_OpMemoryBarrier,
+ SPV_OC_OpAtomicCompareExchangeWeak, SPV_OC_OpAtomicIIncrement,
+ SPV_OC_OpAtomicIDecrement, SPV_OC_OpAtomicIAdd, SPV_OC_OpAtomicISub,
+ SPV_OC_OpAtomicSMin, SPV_OC_OpAtomicUMin, SPV_OC_OpAtomicSMax,
+ SPV_OC_OpAtomicUMax, SPV_OC_OpAtomicAnd, SPV_OC_OpAtomicOr, SPV_OC_OpAtomicXor,
+ SPV_OC_OpPhi, SPV_OC_OpLoopMerge, SPV_OC_OpSelectionMerge, SPV_OC_OpLabel,
+ SPV_OC_OpBranch, SPV_OC_OpBranchConditional, SPV_OC_OpReturn,
+ SPV_OC_OpReturnValue, SPV_OC_OpUnreachable, SPV_OC_OpNoLine,
+ SPV_OC_OpModuleProcessed, SPV_OC_OpGroupNonUniformElect,
+ SPV_OC_OpGroupNonUniformBallot, SPV_OC_OpGroupNonUniformIAdd,
+ SPV_OC_OpGroupNonUniformFAdd, SPV_OC_OpGroupNonUniformIMul,
+ SPV_OC_OpGroupNonUniformFMul, SPV_OC_OpGroupNonUniformSMin,
+ SPV_OC_OpGroupNonUniformUMin, SPV_OC_OpGroupNonUniformFMin,
+ SPV_OC_OpGroupNonUniformSMax, SPV_OC_OpGroupNonUniformUMax,
+ SPV_OC_OpGroupNonUniformFMax, SPV_OC_OpSubgroupBallotKHR,
+ SPV_OC_OpTypeCooperativeMatrixNV, SPV_OC_OpCooperativeMatrixLoadNV,
+ SPV_OC_OpCooperativeMatrixStoreNV, SPV_OC_OpCooperativeMatrixMulAddNV,
+ SPV_OC_OpCooperativeMatrixLengthNV
]>;
// End opcode section. Generated from SPIR-V spec; DO NOT MODIFY!
diff --git a/mlir/include/mlir/Dialect/SPIRV/SPIRVLogicalOps.td b/mlir/include/mlir/Dialect/SPIRV/SPIRVLogicalOps.td
index e1b477126a02..9789122809ec 100644
--- a/mlir/include/mlir/Dialect/SPIRV/SPIRVLogicalOps.td
+++ b/mlir/include/mlir/Dialect/SPIRV/SPIRVLogicalOps.td
@@ -26,6 +26,12 @@ class SPV_LogicalBinaryOp<string mnemonic, Type operandsType,
SameOperandsAndResultShape])> {
let parser = [{ return ::parseLogicalBinaryOp(parser, result); }];
let printer = [{ return ::printLogicalOp(getOperation(), p); }];
+
+ let builders = [
+ OpBuilder<
+ "OpBuilder &builder, OperationState &state, Value lhs, Value rhs",
+ "::buildLogicalBinaryOp(builder, state, lhs, rhs);">
+ ];
}
class SPV_LogicalUnaryOp<string mnemonic, Type operandType,
diff --git a/mlir/lib/Conversion/StandardToSPIRV/ConvertStandardToSPIRV.cpp b/mlir/lib/Conversion/StandardToSPIRV/ConvertStandardToSPIRV.cpp
index 6bb7a17ae46f..dad8bfc0173f 100644
--- a/mlir/lib/Conversion/StandardToSPIRV/ConvertStandardToSPIRV.cpp
+++ b/mlir/lib/Conversion/StandardToSPIRV/ConvertStandardToSPIRV.cpp
@@ -97,6 +97,35 @@ static FloatAttr convertFloatAttr(FloatAttr srcAttr, FloatType dstType,
return builder.getF32FloatAttr(dstVal.convertToFloat());
}
+/// Returns signed remainder for `lhs` and `rhs` and lets the result follow
+/// the sign of `signOperand`.
+///
+/// Note that this is needed for Vulkan. Per the Vulkan's SPIR-V environment
+/// spec, "for the OpSRem and OpSMod instructions, if either operand is negative
+/// the result is undefined." So we cannot directly use spv.SRem/spv.SMod
+/// if either operand can be negative. Emulate it via spv.UMod.
+static Value emulateSignedRemainder(Location loc, Value lhs, Value rhs,
+ Value signOperand, OpBuilder &builder) {
+ assert(lhs.getType() == rhs.getType());
+ assert(lhs == signOperand || rhs == signOperand);
+
+ Type type = lhs.getType();
+
+ // Calculate the remainder with spv.UMod.
+ Value lhsAbs = builder.create<spirv::GLSLSAbsOp>(loc, type, lhs);
+ Value rhsAbs = builder.create<spirv::GLSLSAbsOp>(loc, type, rhs);
+ Value abs = builder.create<spirv::UModOp>(loc, lhsAbs, rhsAbs);
+
+ // Fix the sign.
+ Value isPositive;
+ if (lhs == signOperand)
+ isPositive = builder.create<spirv::IEqualOp>(loc, lhs, lhsAbs);
+ else
+ isPositive = builder.create<spirv::IEqualOp>(loc, rhs, rhsAbs);
+ Value absNegate = builder.create<spirv::SNegateOp>(loc, type, abs);
+ return builder.create<spirv::SelectOp>(loc, type, isPositive, abs, absNegate);
+}
+
/// Returns the offset of the value in `targetBits` representation. `srcIdx` is
/// an index into a 1-D array with each element having `sourceBits`. When
/// accessing an element in the array treating as having elements of
@@ -308,6 +337,19 @@ class UnaryAndBinaryOpPattern final : public SPIRVOpLowering<StdOp> {
}
};
+/// Converts std.remi_signed to SPIR-V ops.
+///
+/// This cannot be merged into the template unary/binary pattern due to
+/// Vulkan restrictions over spv.SRem and spv.SMod.
+class SignedRemIOpPattern final : public SPIRVOpLowering<SignedRemIOp> {
+public:
+ using SPIRVOpLowering<SignedRemIOp>::SPIRVOpLowering;
+
+ LogicalResult
+ matchAndRewrite(SignedRemIOp remOp, ArrayRef<Value> operands,
+ ConversionPatternRewriter &rewriter) const override;
+};
+
/// Converts bitwise standard operations to SPIR-V operations. This is a special
/// pattern other than the BinaryOpPatternPattern because if the operands are
/// boolean values, SPIR-V uses
diff erent operations (`SPIRVLogicalOp`). For
@@ -506,6 +548,20 @@ class XOrOpPattern final : public SPIRVOpLowering<XOrOp> {
} // namespace
+//===----------------------------------------------------------------------===//
+// SignedRemIOpPattern
+//===----------------------------------------------------------------------===//
+
+LogicalResult SignedRemIOpPattern::matchAndRewrite(
+ SignedRemIOp remOp, ArrayRef<Value> operands,
+ ConversionPatternRewriter &rewriter) const {
+ Value result = emulateSignedRemainder(remOp.getLoc(), operands[0],
+ operands[1], operands[0], rewriter);
+ rewriter.replaceOp(remOp, result);
+
+ return success();
+}
+
//===----------------------------------------------------------------------===//
// ConstantOp with composite type.
//===----------------------------------------------------------------------===//
@@ -1005,6 +1061,9 @@ void populateStandardToSPIRVPatterns(MLIRContext *context,
SPIRVTypeConverter &typeConverter,
OwningRewritePatternList &patterns) {
patterns.insert<
+ // Unary and binary patterns
+ BitwiseOpPattern<AndOp, spirv::LogicalAndOp, spirv::BitwiseAndOp>,
+ BitwiseOpPattern<OrOp, spirv::LogicalOrOp, spirv::BitwiseOrOp>,
UnaryAndBinaryOpPattern<AbsFOp, spirv::GLSLFAbsOp>,
UnaryAndBinaryOpPattern<AddFOp, spirv::FAddOp>,
UnaryAndBinaryOpPattern<AddIOp, spirv::IAddOp>,
@@ -1020,7 +1079,6 @@ void populateStandardToSPIRVPatterns(MLIRContext *context,
UnaryAndBinaryOpPattern<RsqrtOp, spirv::GLSLInverseSqrtOp>,
UnaryAndBinaryOpPattern<ShiftLeftOp, spirv::ShiftLeftLogicalOp>,
UnaryAndBinaryOpPattern<SignedDivIOp, spirv::SDivOp>,
- UnaryAndBinaryOpPattern<SignedRemIOp, spirv::SRemOp>,
UnaryAndBinaryOpPattern<SignedShiftRightOp,
spirv::ShiftRightArithmeticOp>,
UnaryAndBinaryOpPattern<SinOp, spirv::GLSLSinOp>,
@@ -1031,19 +1089,28 @@ void populateStandardToSPIRVPatterns(MLIRContext *context,
UnaryAndBinaryOpPattern<UnsignedDivIOp, spirv::UDivOp>,
UnaryAndBinaryOpPattern<UnsignedRemIOp, spirv::UModOp>,
UnaryAndBinaryOpPattern<UnsignedShiftRightOp, spirv::ShiftRightLogicalOp>,
- AllocOpPattern, DeallocOpPattern,
- BitwiseOpPattern<AndOp, spirv::LogicalAndOp, spirv::BitwiseAndOp>,
- BitwiseOpPattern<OrOp, spirv::LogicalOrOp, spirv::BitwiseOrOp>,
- BoolCmpIOpPattern, ConstantCompositeOpPattern, ConstantScalarOpPattern,
- CmpFOpPattern, CmpIOpPattern, IntLoadOpPattern, LoadOpPattern,
- ReturnOpPattern, SelectOpPattern, IntStoreOpPattern, StoreOpPattern,
+ SignedRemIOpPattern, XOrOpPattern,
+
+ // Comparison patterns
+ BoolCmpIOpPattern, CmpFOpPattern, CmpIOpPattern,
+
+ // Constant patterns
+ ConstantCompositeOpPattern, ConstantScalarOpPattern,
+
+ // Memory patterns
+ AllocOpPattern, DeallocOpPattern, IntLoadOpPattern, IntStoreOpPattern,
+ LoadOpPattern, StoreOpPattern,
+
+ ReturnOpPattern, SelectOpPattern,
+
+ // Type cast patterns
ZeroExtendI1Pattern, TypeCastingOpPattern<IndexCastOp, spirv::SConvertOp>,
TypeCastingOpPattern<SIToFPOp, spirv::ConvertSToFOp>,
TypeCastingOpPattern<ZeroExtendIOp, spirv::UConvertOp>,
TypeCastingOpPattern<TruncateIOp, spirv::SConvertOp>,
TypeCastingOpPattern<FPToSIOp, spirv::ConvertFToSOp>,
TypeCastingOpPattern<FPExtOp, spirv::FConvertOp>,
- TypeCastingOpPattern<FPTruncOp, spirv::FConvertOp>, XOrOpPattern>(
- context, typeConverter);
+ TypeCastingOpPattern<FPTruncOp, spirv::FConvertOp>>(context,
+ typeConverter);
}
} // namespace mlir
diff --git a/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp b/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp
index 6c8319224974..9d0570257d42 100644
--- a/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp
+++ b/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp
@@ -844,6 +844,18 @@ static LogicalResult verifyShiftOp(Operation *op) {
return success();
}
+static void buildLogicalBinaryOp(OpBuilder &builder, OperationState &state,
+ Value lhs, Value rhs) {
+ assert(lhs.getType() == rhs.getType());
+
+ Type boolType = builder.getI1Type();
+ if (auto vecType = lhs.getType().dyn_cast<VectorType>())
+ boolType = VectorType::get(vecType.getShape(), boolType);
+ state.addTypes(boolType);
+
+ state.addOperands({lhs, rhs});
+}
+
//===----------------------------------------------------------------------===//
// spv.AccessChainOp
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/Conversion/StandardToSPIRV/std-ops-to-spirv.mlir b/mlir/test/Conversion/StandardToSPIRV/std-ops-to-spirv.mlir
index c232395d80db..a93bf792b34f 100644
--- a/mlir/test/Conversion/StandardToSPIRV/std-ops-to-spirv.mlir
+++ b/mlir/test/Conversion/StandardToSPIRV/std-ops-to-spirv.mlir
@@ -22,12 +22,23 @@ func @int32_scalar(%lhs: i32, %rhs: i32) {
%2 = muli %lhs, %rhs: i32
// CHECK: spv.SDiv %{{.*}}, %{{.*}}: i32
%3 = divi_signed %lhs, %rhs: i32
- // CHECK: spv.SRem %{{.*}}, %{{.*}}: i32
- %4 = remi_signed %lhs, %rhs: i32
// CHECK: spv.UDiv %{{.*}}, %{{.*}}: i32
- %5 = divi_unsigned %lhs, %rhs: i32
+ %4 = divi_unsigned %lhs, %rhs: i32
// CHECK: spv.UMod %{{.*}}, %{{.*}}: i32
- %6 = remi_unsigned %lhs, %rhs: i32
+ %5 = remi_unsigned %lhs, %rhs: i32
+ return
+}
+
+// CHECK-LABEL: @scalar_srem
+// CHECK-SAME: (%[[LHS:.+]]: i32, %[[RHS:.+]]: i32)
+func @scalar_srem(%lhs: i32, %rhs: i32) {
+ // CHECK: %[[LABS:.+]] = spv.GLSL.SAbs %[[LHS]] : i32
+ // CHECK: %[[RABS:.+]] = spv.GLSL.SAbs %[[RHS]] : i32
+ // CHECK: %[[ABS:.+]] = spv.UMod %[[LABS]], %[[RABS]] : i32
+ // CHECK: %[[POS:.+]] = spv.IEqual %[[LHS]], %[[LABS]] : i32
+ // CHECK: %[[NEG:.+]] = spv.SNegate %[[ABS]] : i32
+ // CHECK: %{{.+}} = spv.Select %[[POS]], %[[ABS]], %[[NEG]] : i1, i32
+ %0 = remi_signed %lhs, %rhs: i32
return
}
@@ -75,13 +86,24 @@ func @float32_binary_scalar(%lhs: f32, %rhs: f32) {
// Check int vector types.
// CHECK-LABEL: @int_vector234
-func @int_vector234(%arg0: vector<2xi8>, %arg1: vector<3xi16>, %arg2: vector<4xi64>) {
+func @int_vector234(%arg0: vector<2xi8>, %arg1: vector<4xi64>) {
// CHECK: spv.SDiv %{{.*}}, %{{.*}}: vector<2xi8>
%0 = divi_signed %arg0, %arg0: vector<2xi8>
- // CHECK: spv.SRem %{{.*}}, %{{.*}}: vector<3xi16>
- %1 = remi_signed %arg1, %arg1: vector<3xi16>
// CHECK: spv.UDiv %{{.*}}, %{{.*}}: vector<4xi64>
- %2 = divi_unsigned %arg2, %arg2: vector<4xi64>
+ %1 = divi_unsigned %arg1, %arg1: vector<4xi64>
+ return
+}
+
+// CHECK-LABEL: @vector_srem
+// CHECK-SAME: (%[[LHS:.+]]: vector<3xi16>, %[[RHS:.+]]: vector<3xi16>)
+func @vector_srem(%arg0: vector<3xi16>, %arg1: vector<3xi16>) {
+ // CHECK: %[[LABS:.+]] = spv.GLSL.SAbs %[[LHS]] : vector<3xi16>
+ // CHECK: %[[RABS:.+]] = spv.GLSL.SAbs %[[RHS]] : vector<3xi16>
+ // CHECK: %[[ABS:.+]] = spv.UMod %[[LABS]], %[[RABS]] : vector<3xi16>
+ // CHECK: %[[POS:.+]] = spv.IEqual %[[LHS]], %[[LABS]] : vector<3xi16>
+ // CHECK: %[[NEG:.+]] = spv.SNegate %[[ABS]] : vector<3xi16>
+ // CHECK: %{{.+}} = spv.Select %[[POS]], %[[ABS]], %[[NEG]] : vector<3xi1>, vector<3xi16>
+ %0 = remi_signed %arg0, %arg1: vector<3xi16>
return
}
@@ -132,8 +154,8 @@ module attributes {
func @int_vector23(%arg0: vector<2xi8>, %arg1: vector<3xi16>) {
// CHECK: spv.SDiv %{{.*}}, %{{.*}}: vector<2xi32>
%0 = divi_signed %arg0, %arg0: vector<2xi8>
- // CHECK: spv.SRem %{{.*}}, %{{.*}}: vector<3xi32>
- %1 = remi_signed %arg1, %arg1: vector<3xi16>
+ // CHECK: spv.SDiv %{{.*}}, %{{.*}}: vector<3xi32>
+ %1 = divi_signed %arg1, %arg1: vector<3xi16>
return
}
diff --git a/mlir/test/Dialect/SPIRV/Serialization/arithmetic-ops.mlir b/mlir/test/Dialect/SPIRV/Serialization/arithmetic-ops.mlir
index 55c67dafe6bb..9752c0d0e579 100644
--- a/mlir/test/Dialect/SPIRV/Serialization/arithmetic-ops.mlir
+++ b/mlir/test/Dialect/SPIRV/Serialization/arithmetic-ops.mlir
@@ -71,6 +71,11 @@ spv.module Logical GLSL450 requires #spv.vce<v1.0, [Shader], []> {
%0 = spv.SMod %arg0, %arg1 : vector<4xi32>
spv.Return
}
+ spv.func @snegate(%arg0 : vector<4xi32>) "None" {
+ // CHECK: {{%.*}} = spv.SNegate {{%.*}} : vector<4xi32>
+ %0 = spv.SNegate %arg0 : vector<4xi32>
+ spv.Return
+ }
spv.func @srem(%arg0 : vector<4xi32>, %arg1 : vector<4xi32>) "None" {
// CHECK: {{%.*}} = spv.SRem {{%.*}}, {{%.*}} : vector<4xi32>
%0 = spv.SRem %arg0, %arg1 : vector<4xi32>
diff --git a/mlir/test/Dialect/SPIRV/arithmetic-ops.mlir b/mlir/test/Dialect/SPIRV/arithmetic-ops.mlir
index 85998fb03efd..de574b1510c9 100644
--- a/mlir/test/Dialect/SPIRV/arithmetic-ops.mlir
+++ b/mlir/test/Dialect/SPIRV/arithmetic-ops.mlir
@@ -174,6 +174,17 @@ func @smod_scalar(%arg: i32) -> i32 {
// -----
+//===----------------------------------------------------------------------===//
+// spv.SNegate
+//===----------------------------------------------------------------------===//
+
+func @snegate_scalar(%arg: i32) -> i32 {
+ // CHECK: spv.SNegate
+ %0 = spv.SNegate %arg : i32
+ return %0 : i32
+}
+
+// -----
//===----------------------------------------------------------------------===//
// spv.SRem
//===----------------------------------------------------------------------===//
More information about the Mlir-commits
mailing list