[Mlir-commits] [mlir] f3bc0fc - [mlir][spirv] Define spv.ISubBorrowOp
Lei Zhang
llvmlistbot at llvm.org
Wed Jun 15 17:40:20 PDT 2022
Author: Lei Zhang
Date: 2022-06-15T20:38:53-04:00
New Revision: f3bc0fccd68a1208d568360b0e3f6483759bae4a
URL: https://github.com/llvm/llvm-project/commit/f3bc0fccd68a1208d568360b0e3f6483759bae4a
DIFF: https://github.com/llvm/llvm-project/commit/f3bc0fccd68a1208d568360b0e3f6483759bae4a.diff
LOG: [mlir][spirv] Define spv.ISubBorrowOp
Reviewed By: ThomasRaoux
Differential Revision: https://reviews.llvm.org/D127909
Added:
Modified:
mlir/include/mlir/Dialect/SPIRV/IR/SPIRVArithmeticOps.td
mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
mlir/test/Dialect/SPIRV/IR/arithmetic-ops.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVArithmeticOps.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVArithmeticOps.td
index 468934190d606..30fd95f6a1fe5 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVArithmeticOps.td
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVArithmeticOps.td
@@ -390,6 +390,57 @@ def SPV_ISubOp : SPV_ArithmeticBinaryOp<"ISub",
// -----
+def SPV_ISubBorrowOp : SPV_BinaryOp<"ISubBorrow", SPV_AnyStruct, SPV_Integer,
+ [NoSideEffect]> {
+ let summary = [{
+ Result is the unsigned integer subtraction of Operand 2 from Operand 1,
+ and what it needed to borrow.
+ }];
+
+ let description = [{
+ Result Type must be from OpTypeStruct. The struct must have two
+ members, and the two members must be the same type. The member type
+ must be a scalar or vector of integer type, whose Signedness operand is
+ 0.
+
+ Operand 1 and Operand 2 must have the same type as the members of Result
+ Type. These are consumed as unsigned integers.
+
+ Results are computed per component.
+
+ Member 0 of the result gets the low-order bits (full component width) of
+ the subtraction. That is, if Operand 1 is larger than Operand 2, member
+ 0 gets the full value of the subtraction; if Operand 2 is larger than
+ Operand 1, member 0 gets 2w + Operand 1 - Operand 2, where w is the
+ component width.
+
+ Member 1 of the result gets 0 if Operand 1 ≥ Operand 2, and gets 1
+ otherwise.
+
+ <!-- End of AutoGen section -->
+
+ #### Example:
+
+ ```mlir
+ %2 = spv.ISubBorrow %0, %1 : !spv.struct<(i32, i32)>
+ %2 = spv.ISubBorrow %0, %1 : !spv.struct<(vector<2xi32>, vector<2xi32>)>
+ ```
+ }];
+
+ let arguments = (ins
+ SPV_ScalarOrVectorOf<SPV_Integer>:$operand1,
+ SPV_ScalarOrVectorOf<SPV_Integer>:$operand2
+ );
+
+ let results = (outs
+ SPV_AnyStruct:$result
+ );
+
+ let hasVerifier = 1;
+}
+
+// -----
+
def SPV_SDivOp : SPV_ArithmeticBinaryOp<"SDiv",
SPV_Integer,
[UsableInSpecConstantOp]> {
diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
index 0b29e47d65d20..79d1222a04c71 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
@@ -4090,6 +4090,7 @@ def SPV_OC_OpFMod : I32EnumAttrCase<"OpFMod", 141>;
def SPV_OC_OpVectorTimesScalar : I32EnumAttrCase<"OpVectorTimesScalar", 142>;
def SPV_OC_OpMatrixTimesScalar : I32EnumAttrCase<"OpMatrixTimesScalar", 143>;
def SPV_OC_OpMatrixTimesMatrix : I32EnumAttrCase<"OpMatrixTimesMatrix", 146>;
+def SPV_OC_OpISubBorrow : I32EnumAttrCase<"OpISubBorrow", 150>;
def SPV_OC_OpIsNan : I32EnumAttrCase<"OpIsNan", 156>;
def SPV_OC_OpIsInf : I32EnumAttrCase<"OpIsInf", 157>;
def SPV_OC_OpOrdered : I32EnumAttrCase<"OpOrdered", 162>;
@@ -4214,32 +4215,32 @@ def SPV_OpcodeAttr :
SPV_OC_OpISub, SPV_OC_OpFSub, SPV_OC_OpIMul, SPV_OC_OpFMul, SPV_OC_OpUDiv,
SPV_OC_OpSDiv, SPV_OC_OpFDiv, SPV_OC_OpUMod, SPV_OC_OpSRem, SPV_OC_OpSMod,
SPV_OC_OpFRem, SPV_OC_OpFMod, SPV_OC_OpVectorTimesScalar,
- SPV_OC_OpMatrixTimesScalar, SPV_OC_OpMatrixTimesMatrix, SPV_OC_OpIsNan,
- SPV_OC_OpIsInf, SPV_OC_OpOrdered, SPV_OC_OpUnordered, SPV_OC_OpLogicalEqual,
- SPV_OC_OpLogicalNotEqual, SPV_OC_OpLogicalOr, SPV_OC_OpLogicalAnd,
- SPV_OC_OpLogicalNot, SPV_OC_OpSelect, SPV_OC_OpIEqual, SPV_OC_OpINotEqual,
- SPV_OC_OpUGreaterThan, SPV_OC_OpSGreaterThan, SPV_OC_OpUGreaterThanEqual,
- SPV_OC_OpSGreaterThanEqual, SPV_OC_OpULessThan, SPV_OC_OpSLessThan,
- SPV_OC_OpULessThanEqual, SPV_OC_OpSLessThanEqual, SPV_OC_OpFOrdEqual,
- SPV_OC_OpFUnordEqual, SPV_OC_OpFOrdNotEqual, SPV_OC_OpFUnordNotEqual,
- SPV_OC_OpFOrdLessThan, SPV_OC_OpFUnordLessThan, SPV_OC_OpFOrdGreaterThan,
- SPV_OC_OpFUnordGreaterThan, SPV_OC_OpFOrdLessThanEqual,
- SPV_OC_OpFUnordLessThanEqual, SPV_OC_OpFOrdGreaterThanEqual,
- SPV_OC_OpFUnordGreaterThanEqual, SPV_OC_OpShiftRightLogical,
- SPV_OC_OpShiftRightArithmetic, SPV_OC_OpShiftLeftLogical, SPV_OC_OpBitwiseOr,
- SPV_OC_OpBitwiseXor, SPV_OC_OpBitwiseAnd, SPV_OC_OpNot,
- SPV_OC_OpBitFieldInsert, SPV_OC_OpBitFieldSExtract, SPV_OC_OpBitFieldUExtract,
- SPV_OC_OpBitReverse, SPV_OC_OpBitCount, SPV_OC_OpControlBarrier,
- SPV_OC_OpMemoryBarrier, SPV_OC_OpAtomicExchange,
- SPV_OC_OpAtomicCompareExchange, SPV_OC_OpAtomicCompareExchangeWeak,
- SPV_OC_OpAtomicIIncrement, SPV_OC_OpAtomicIDecrement, SPV_OC_OpAtomicIAdd,
- SPV_OC_OpAtomicISub, SPV_OC_OpAtomicSMin, SPV_OC_OpAtomicUMin,
- SPV_OC_OpAtomicSMax, SPV_OC_OpAtomicUMax, SPV_OC_OpAtomicAnd,
- SPV_OC_OpAtomicOr, SPV_OC_OpAtomicXor, SPV_OC_OpPhi, SPV_OC_OpLoopMerge,
- SPV_OC_OpSelectionMerge, SPV_OC_OpLabel, SPV_OC_OpBranch,
- SPV_OC_OpBranchConditional, SPV_OC_OpReturn, SPV_OC_OpReturnValue,
- SPV_OC_OpUnreachable, SPV_OC_OpGroupBroadcast, SPV_OC_OpNoLine,
- SPV_OC_OpModuleProcessed, SPV_OC_OpGroupNonUniformElect,
+ SPV_OC_OpMatrixTimesScalar, SPV_OC_OpMatrixTimesMatrix, SPV_OC_OpISubBorrow,
+ SPV_OC_OpIsNan, SPV_OC_OpIsInf, SPV_OC_OpOrdered, SPV_OC_OpUnordered,
+ SPV_OC_OpLogicalEqual, SPV_OC_OpLogicalNotEqual, SPV_OC_OpLogicalOr,
+ SPV_OC_OpLogicalAnd, SPV_OC_OpLogicalNot, SPV_OC_OpSelect, SPV_OC_OpIEqual,
+ SPV_OC_OpINotEqual, SPV_OC_OpUGreaterThan, SPV_OC_OpSGreaterThan,
+ SPV_OC_OpUGreaterThanEqual, SPV_OC_OpSGreaterThanEqual, SPV_OC_OpULessThan,
+ SPV_OC_OpSLessThan, SPV_OC_OpULessThanEqual, SPV_OC_OpSLessThanEqual,
+ SPV_OC_OpFOrdEqual, SPV_OC_OpFUnordEqual, SPV_OC_OpFOrdNotEqual,
+ SPV_OC_OpFUnordNotEqual, SPV_OC_OpFOrdLessThan, SPV_OC_OpFUnordLessThan,
+ SPV_OC_OpFOrdGreaterThan, SPV_OC_OpFUnordGreaterThan,
+ SPV_OC_OpFOrdLessThanEqual, SPV_OC_OpFUnordLessThanEqual,
+ SPV_OC_OpFOrdGreaterThanEqual, SPV_OC_OpFUnordGreaterThanEqual,
+ SPV_OC_OpShiftRightLogical, SPV_OC_OpShiftRightArithmetic,
+ SPV_OC_OpShiftLeftLogical, SPV_OC_OpBitwiseOr, SPV_OC_OpBitwiseXor,
+ SPV_OC_OpBitwiseAnd, SPV_OC_OpNot, SPV_OC_OpBitFieldInsert,
+ SPV_OC_OpBitFieldSExtract, SPV_OC_OpBitFieldUExtract, SPV_OC_OpBitReverse,
+ SPV_OC_OpBitCount, SPV_OC_OpControlBarrier, SPV_OC_OpMemoryBarrier,
+ SPV_OC_OpAtomicExchange, SPV_OC_OpAtomicCompareExchange,
+ SPV_OC_OpAtomicCompareExchangeWeak, SPV_OC_OpAtomicIIncrement,
+ SPV_OC_OpAtomicIDecrement, SPV_OC_OpAtomicIAdd, SPV_OC_OpAtomicISub,
+ SPV_OC_OpAtomicSMin, SPV_OC_OpAtomicUMin, SPV_OC_OpAtomicSMax,
+ SPV_OC_OpAtomicUMax, SPV_OC_OpAtomicAnd, SPV_OC_OpAtomicOr, SPV_OC_OpAtomicXor,
+ SPV_OC_OpPhi, SPV_OC_OpLoopMerge, SPV_OC_OpSelectionMerge, SPV_OC_OpLabel,
+ SPV_OC_OpBranch, SPV_OC_OpBranchConditional, SPV_OC_OpReturn,
+ SPV_OC_OpReturnValue, SPV_OC_OpUnreachable, SPV_OC_OpGroupBroadcast,
+ SPV_OC_OpNoLine, SPV_OC_OpModuleProcessed, SPV_OC_OpGroupNonUniformElect,
SPV_OC_OpGroupNonUniformBroadcast, SPV_OC_OpGroupNonUniformBallot,
SPV_OC_OpGroupNonUniformIAdd, SPV_OC_OpGroupNonUniformFAdd,
SPV_OC_OpGroupNonUniformIMul, SPV_OC_OpGroupNonUniformFMul,
diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
index 5d7d943eed93f..123f7076835b6 100644
--- a/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
@@ -28,6 +28,7 @@
#include "mlir/Interfaces/CallInterfaces.h"
#include "llvm/ADT/APFloat.h"
#include "llvm/ADT/APInt.h"
+#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/StringExtras.h"
#include "llvm/ADT/bit.h"
@@ -2839,6 +2840,58 @@ void spirv::GroupNonUniformUMinOp::print(OpAsmPrinter &p) {
printGroupNonUniformArithmeticOp(*this, p);
}
+//===----------------------------------------------------------------------===//
+// spv.ISubBorrowOp
+//===----------------------------------------------------------------------===//
+
+LogicalResult spirv::ISubBorrowOp::verify() {
+ auto resultType = getType().cast<spirv::StructType>();
+ if (resultType.getNumElements() != 2)
+ return emitOpError("expected result struct type containing two members");
+
+ SmallVector<Type, 4> types;
+ types.push_back(operand1().getType());
+ types.push_back(operand2().getType());
+ types.push_back(resultType.getElementType(0));
+ types.push_back(resultType.getElementType(1));
+ if (!llvm::is_splat(types))
+ return emitOpError(
+ "expected all operand types and struct member types are the same");
+
+ return success();
+}
+
+ParseResult spirv::ISubBorrowOp::parse(OpAsmParser &parser,
+ OperationState &state) {
+ SmallVector<OpAsmParser::UnresolvedOperand, 2> operands;
+ if (parser.parseOptionalAttrDict(state.attributes) ||
+ parser.parseOperandList(operands) || parser.parseColon())
+ return failure();
+
+ Type resultType;
+ auto loc = parser.getCurrentLocation();
+ if (parser.parseType(resultType))
+ return failure();
+
+ auto structType = resultType.dyn_cast<spirv::StructType>();
+ if (!structType || structType.getNumElements() != 2)
+ return parser.emitError(loc, "expected spv.struct type with two members");
+
+ SmallVector<Type, 2> operandTypes(2, structType.getElementType(0));
+ if (parser.resolveOperands(operands, operandTypes, loc, state.operands))
+ return failure();
+
+ state.addTypes(resultType);
+ return success();
+}
+
+void spirv::ISubBorrowOp::print(OpAsmPrinter &printer) {
+ printer << ' ';
+ printer.printOptionalAttrDict((*this)->getAttrs());
+ printer.printOperands((*this)->getOperands());
+ printer << " : " << getType();
+}
+
//===----------------------------------------------------------------------===//
// spv.LoadOp
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/SPIRV/IR/arithmetic-ops.mlir b/mlir/test/Dialect/SPIRV/IR/arithmetic-ops.mlir
index 214f755266a4c..22159b0dde15c 100644
--- a/mlir/test/Dialect/SPIRV/IR/arithmetic-ops.mlir
+++ b/mlir/test/Dialect/SPIRV/IR/arithmetic-ops.mlir
@@ -150,6 +150,58 @@ func.func @isub_scalar(%arg: i32) -> i32 {
// -----
+//===----------------------------------------------------------------------===//
+// spv.ISubBorrow
+//===----------------------------------------------------------------------===//
+
+// CHECK-LABEL: @isub_borrow_scalar
+func.func @isub_borrow_scalar(%arg: i32) -> !spv.struct<(i32, i32)> {
+ // CHECK: spv.ISubBorrow %{{.+}}, %{{.+}} : !spv.struct<(i32, i32)>
+ %0 = spv.ISubBorrow %arg, %arg : !spv.struct<(i32, i32)>
+ return %0 : !spv.struct<(i32, i32)>
+}
+
+// CHECK-LABEL: @isub_borrow_vector
+func.func @isub_borrow_vector(%arg: vector<3xi32>) -> !spv.struct<(vector<3xi32>, vector<3xi32>)> {
+ // CHECK: spv.ISubBorrow %{{.+}}, %{{.+}} : !spv.struct<(vector<3xi32>, vector<3xi32>)>
+ %0 = spv.ISubBorrow %arg, %arg : !spv.struct<(vector<3xi32>, vector<3xi32>)>
+ return %0 : !spv.struct<(vector<3xi32>, vector<3xi32>)>
+}
+
+// -----
+
+func.func @isub_borrow(%arg: i32) -> !spv.struct<(i32, i32, i32)> {
+ // expected-error @+1 {{expected spv.struct type with two members}}
+ %0 = spv.ISubBorrow %arg, %arg : !spv.struct<(i32, i32, i32)>
+ return %0 : !spv.struct<(i32, i32, i32)>
+}
+
+// -----
+
+func.func @isub_borrow(%arg: i32) -> !spv.struct<(i32)> {
+ // expected-error @+1 {{expected result struct type containing two members}}
+ %0 = "spv.ISubBorrow"(%arg, %arg): (i32, i32) -> !spv.struct<(i32)>
+ return %0 : !spv.struct<(i32)>
+}
+
+// -----
+
+func.func @isub_borrow(%arg: i32) -> !spv.struct<(i32, i64)> {
+ // expected-error @+1 {{expected all operand types and struct member types are the same}}
+ %0 = "spv.ISubBorrow"(%arg, %arg): (i32, i32) -> !spv.struct<(i32, i64)>
+ return %0 : !spv.struct<(i32, i64)>
+}
+
+// -----
+
+func.func @isub_borrow(%arg: i64) -> !spv.struct<(i32, i32)> {
+ // expected-error @+1 {{expected all operand types and struct member types are the same}}
+ %0 = "spv.ISubBorrow"(%arg, %arg): (i64, i64) -> !spv.struct<(i32, i32)>
+ return %0 : !spv.struct<(i32, i32)>
+}
+
+// -----
+
//===----------------------------------------------------------------------===//
// spv.SDiv
//===----------------------------------------------------------------------===//
More information about the Mlir-commits
mailing list