[Mlir-commits] [mlir] 3027582 - [mlir][spirv] Define `OpSMulExtended` and `OpUMulExtended` ops
Jakub Kuderski
llvmlistbot at llvm.org
Wed Nov 23 16:06:12 PST 2022
Author: Jakub Kuderski
Date: 2022-11-23T19:04:56-05:00
New Revision: 30275821f1cacd855c7f9246351242223bed9d29
URL: https://github.com/llvm/llvm-project/commit/30275821f1cacd855c7f9246351242223bed9d29
DIFF: https://github.com/llvm/llvm-project/commit/30275821f1cacd855c7f9246351242223bed9d29.diff
LOG: [mlir][spirv] Define `OpSMulExtended` and `OpUMulExtended` ops
These perform exact multiplication and return the high half as a second
result.
Also factor out common code shared between 'extended binary ops'.
Reviewed By: antiagainst
Differential Revision: https://reviews.llvm.org/D138624
Added:
Modified:
mlir/include/mlir/Dialect/SPIRV/IR/SPIRVArithmeticOps.td
mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
mlir/test/Dialect/SPIRV/IR/arithmetic-ops.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVArithmeticOps.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVArithmeticOps.td
index b124625f70543..1d6c98d017350 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVArithmeticOps.td
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVArithmeticOps.td
@@ -53,7 +53,32 @@ class SPIRV_ArithmeticUnaryOp<string mnemonic, Type type,
SPIRV_ScalarOrVectorOrCoopMatrixOf<type>:$result
);
let assemblyFormat = "operands attr-dict `:` type($result)";
- }
+}
+
+class SPIRV_ArithmeticExtendedBinaryOp<string mnemonic,
+ list<Trait> traits = []> :
+ // Result type is a struct with two operand-typed elements.
+ SPIRV_BinaryOp<mnemonic, SPIRV_AnyStruct, SPIRV_Integer, traits> {
+ let arguments = (ins
+ SPIRV_ScalarOrVectorOf<SPIRV_Integer>:$operand1,
+ SPIRV_ScalarOrVectorOf<SPIRV_Integer>:$operand2
+ );
+
+ let results = (outs
+ SPIRV_AnyStruct:$result
+ );
+
+ let builders = [
+ OpBuilder<(ins "Value":$operand1, "Value":$operand2), [{
+ build($_builder, $_state,
+ ::mlir::spirv::StructType::get({operand1.getType(), operand1.getType()}),
+ operand1, operand2);
+ }]>
+ ];
+
+ // These op require a custom verifier.
+ let hasVerifier = 1;
+}
// -----
@@ -321,9 +346,8 @@ def SPIRV_IAddOp : SPIRV_ArithmeticBinaryOp<"IAdd",
// -----
-def SPIRV_IAddCarryOp : SPIRV_BinaryOp<"IAddCarry",
- SPIRV_AnyStruct, SPIRV_Integer,
- [Commutative, Pure]> {
+def SPIRV_IAddCarryOp : SPIRV_ArithmeticExtendedBinaryOp<"IAddCarry",
+ [Commutative, Pure]> {
let summary = [{
Integer addition of Operand 1 and Operand 2, including the carry.
}];
@@ -355,25 +379,6 @@ def SPIRV_IAddCarryOp : SPIRV_BinaryOp<"IAddCarry",
%2 = spirv.IAddCarry %0, %1 : !spirv.struct<(vector<2xi32>, vector<2xi32>)>
```
}];
-
- let arguments = (ins
- SPIRV_ScalarOrVectorOf<SPIRV_Integer>:$operand1,
- SPIRV_ScalarOrVectorOf<SPIRV_Integer>:$operand2
- );
-
- let results = (outs
- SPIRV_AnyStruct:$result
- );
-
- let builders = [
- OpBuilder<(ins "Value":$operand1, "Value":$operand2), [{
- build($_builder, $_state,
- ::mlir::spirv::StructType::get({operand1.getType(), operand1.getType()}),
- operand1, operand2);
- }]>
- ];
-
- let hasVerifier = 1;
}
// -----
@@ -418,6 +423,75 @@ def SPIRV_IMulOp : SPIRV_ArithmeticBinaryOp<"IMul",
// -----
+def SPIRV_SMulExtendedOp : SPIRV_ArithmeticExtendedBinaryOp<"SMulExtended",
+ [Pure, Commutative]> {
+ let summary = [{
+ Result is the full value of the signed integer multiplication of Operand
+ 1 and Operand 2.
+ }];
+
+ let description = [{
+ Result Type must be from OpTypeStruct. The struct must have two
+ members, and the two members must be the same type. The member type
+ must be a scalar or vector of integer type.
+
+ Operand 1 and Operand 2 must have the same type as the members of Result
+ Type. These are consumed as signed integers.
+
+ Results are computed per component.
+
+ Member 0 of the result gets the low-order bits of the multiplication.
+
+ Member 1 of the result gets the high-order bits of the multiplication.
+
+ <!-- End of AutoGen section -->
+
+ #### Example:
+
+ ```mlir
+ %2 = spirv.SMulExtended %0, %1 : !spirv.struct<(i32, i32)>
+ %2 = spirv.SMulExtended %0, %1 : !spirv.struct<(vector<2xi32>, vector<2xi32>)>
+ ```
+ }];
+}
+
+// -----
+
+def SPIRV_UMulExtendedOp : SPIRV_ArithmeticExtendedBinaryOp<"UMulExtended",
+ [Pure, Commutative]> {
+ let summary = [{
+ Result is the full value of the unsigned integer multiplication of
+ Operand 1 and Operand 2.
+ }];
+
+ let description = [{
+ Result Type must be from OpTypeStruct. The struct must have two
+ members, and the two members must be the same type. The member type
+ must be a scalar or vector of integer type, whose Signedness operand is
+ 0.
+
+ Operand 1 and Operand 2 must have the same type as the members of Result
+ Type. These are consumed as unsigned integers.
+
+ Results are computed per component.
+
+ Member 0 of the result gets the low-order bits of the multiplication.
+
+ Member 1 of the result gets the high-order bits of the multiplication.
+
+ <!-- End of AutoGen section -->
+
+ #### Example:
+
+ ```mlir
+ %2 = spirv.UMulExtended %0, %1 : !spirv.struct<(i32, i32)>
+ %2 = spirv.UMulExtended %0, %1 : !spirv.struct<(vector<2xi32>, vector<2xi32>)>
+ ```
+ }];
+}
+
+// -----
+
def SPIRV_ISubOp : SPIRV_ArithmeticBinaryOp<"ISub",
SPIRV_Integer,
[UsableInSpecConstantOp]> {
@@ -458,8 +532,8 @@ def SPIRV_ISubOp : SPIRV_ArithmeticBinaryOp<"ISub",
// -----
-def SPIRV_ISubBorrowOp : SPIRV_BinaryOp<"ISubBorrow", SPIRV_AnyStruct, SPIRV_Integer,
- [Pure]> {
+def SPIRV_ISubBorrowOp : SPIRV_ArithmeticExtendedBinaryOp<"ISubBorrow",
+ [Pure]> {
let summary = [{
Result is the unsigned integer subtraction of Operand 2 from Operand 1,
and what it needed to borrow.
@@ -494,25 +568,6 @@ def SPIRV_ISubBorrowOp : SPIRV_BinaryOp<"ISubBorrow", SPIRV_AnyStruct, SPIRV_Int
%2 = spirv.ISubBorrow %0, %1 : !spirv.struct<(vector<2xi32>, vector<2xi32>)>
```
}];
-
- let arguments = (ins
- SPIRV_ScalarOrVectorOf<SPIRV_Integer>:$operand1,
- SPIRV_ScalarOrVectorOf<SPIRV_Integer>:$operand2
- );
-
- let results = (outs
- SPIRV_AnyStruct:$result
- );
-
- let builders = [
- OpBuilder<(ins "Value":$operand1, "Value":$operand2), [{
- build($_builder, $_state,
- ::mlir::spirv::StructType::get({operand1.getType(), operand1.getType()}),
- operand1, operand2);
- }]>
- ];
-
- let hasVerifier = 1;
}
// -----
diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
index 77fe5ac9e84e9..400947a4043e3 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
@@ -4243,6 +4243,8 @@ def SPIRV_OC_OpMatrixTimesScalar : I32EnumAttrCase<"OpMatrixTimesScalar
def SPIRV_OC_OpMatrixTimesMatrix : I32EnumAttrCase<"OpMatrixTimesMatrix", 146>;
def SPIRV_OC_OpIAddCarry : I32EnumAttrCase<"OpIAddCarry", 149>;
def SPIRV_OC_OpISubBorrow : I32EnumAttrCase<"OpISubBorrow", 150>;
+def SPIRV_OC_OpUMulExtended : I32EnumAttrCase<"OpUMulExtended", 151>;
+def SPIRV_OC_OpSMulExtended : I32EnumAttrCase<"OpSMulExtended", 152>;
def SPIRV_OC_OpIsNan : I32EnumAttrCase<"OpIsNan", 156>;
def SPIRV_OC_OpIsInf : I32EnumAttrCase<"OpIsInf", 157>;
def SPIRV_OC_OpOrdered : I32EnumAttrCase<"OpOrdered", 162>;
@@ -4372,17 +4374,17 @@ def SPIRV_OpcodeAttr :
SPIRV_OC_OpCompositeInsert, SPIRV_OC_OpTranspose, SPIRV_OC_OpImageDrefGather,
SPIRV_OC_OpImage, SPIRV_OC_OpImageQuerySize, SPIRV_OC_OpConvertFToU,
SPIRV_OC_OpConvertFToS, SPIRV_OC_OpConvertSToF, SPIRV_OC_OpConvertUToF,
- SPIRV_OC_OpUConvert, SPIRV_OC_OpSConvert, SPIRV_OC_OpFConvert, SPIRV_OC_OpPtrCastToGeneric,
+ SPIRV_OC_OpUConvert, SPIRV_OC_OpSConvert, SPIRV_OC_OpFConvert, SPIRV_OC_OpPtrCastToGeneric,
SPIRV_OC_OpGenericCastToPtr, SPIRV_OC_OpGenericCastToPtrExplicit, SPIRV_OC_OpBitcast,
SPIRV_OC_OpSNegate, SPIRV_OC_OpFNegate, SPIRV_OC_OpIAdd, SPIRV_OC_OpFAdd,
SPIRV_OC_OpISub, SPIRV_OC_OpFSub, SPIRV_OC_OpIMul, SPIRV_OC_OpFMul, SPIRV_OC_OpUDiv,
SPIRV_OC_OpSDiv, SPIRV_OC_OpFDiv, SPIRV_OC_OpUMod, SPIRV_OC_OpSRem, SPIRV_OC_OpSMod,
SPIRV_OC_OpFRem, SPIRV_OC_OpFMod, SPIRV_OC_OpVectorTimesScalar,
SPIRV_OC_OpMatrixTimesScalar, SPIRV_OC_OpMatrixTimesMatrix, SPIRV_OC_OpIAddCarry,
- SPIRV_OC_OpISubBorrow, SPIRV_OC_OpIsNan, SPIRV_OC_OpIsInf, SPIRV_OC_OpOrdered,
- SPIRV_OC_OpUnordered, SPIRV_OC_OpLogicalEqual, SPIRV_OC_OpLogicalNotEqual,
- SPIRV_OC_OpLogicalOr, SPIRV_OC_OpLogicalAnd, SPIRV_OC_OpLogicalNot, SPIRV_OC_OpSelect,
- SPIRV_OC_OpIEqual, SPIRV_OC_OpINotEqual, SPIRV_OC_OpUGreaterThan,
+ SPIRV_OC_OpISubBorrow, SPIRV_OC_OpUMulExtended, SPIRV_OC_OpSMulExtended, SPIRV_OC_OpIsNan,
+ SPIRV_OC_OpIsInf, SPIRV_OC_OpOrdered, SPIRV_OC_OpUnordered, SPIRV_OC_OpLogicalEqual,
+ SPIRV_OC_OpLogicalNotEqual, SPIRV_OC_OpLogicalOr, SPIRV_OC_OpLogicalAnd, SPIRV_OC_OpLogicalNot,
+ SPIRV_OC_OpSelect, SPIRV_OC_OpIEqual, SPIRV_OC_OpINotEqual, SPIRV_OC_OpUGreaterThan,
SPIRV_OC_OpSGreaterThan, SPIRV_OC_OpUGreaterThanEqual, SPIRV_OC_OpSGreaterThanEqual,
SPIRV_OC_OpULessThan, SPIRV_OC_OpSLessThan, SPIRV_OC_OpULessThanEqual,
SPIRV_OC_OpSLessThanEqual, SPIRV_OC_OpFOrdEqual, SPIRV_OC_OpFUnordEqual,
diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
index c16d3d0b2d8e4..606562292ebcf 100644
--- a/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
@@ -24,6 +24,7 @@
#include "mlir/IR/FunctionImplementation.h"
#include "mlir/IR/OpDefinition.h"
#include "mlir/IR/OpImplementation.h"
+#include "mlir/IR/Operation.h"
#include "mlir/IR/TypeUtilities.h"
#include "mlir/Interfaces/CallInterfaces.h"
#include "llvm/ADT/APFloat.h"
@@ -31,6 +32,7 @@
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/StringExtras.h"
#include "llvm/ADT/bit.h"
+#include <cassert>
#include <numeric>
using namespace mlir;
@@ -763,6 +765,53 @@ static inline bool isMergeBlock(Block &block) {
isa<spirv::MergeOp>(block.front());
}
+template <typename ExtendedBinaryOp>
+static LogicalResult verifyArithmeticExtendedBinaryOp(ExtendedBinaryOp op) {
+ auto resultType = op.getType().template cast<spirv::StructType>();
+ if (resultType.getNumElements() != 2)
+ return op.emitOpError("expected result struct type containing two members");
+
+ if (!llvm::all_equal({op.getOperand1().getType(), op.getOperand2().getType(),
+ resultType.getElementType(0),
+ resultType.getElementType(1)}))
+ return op.emitOpError(
+ "expected all operand types and struct member types are the same");
+
+ return success();
+}
+
+static ParseResult parseArithmeticExtendedBinaryOp(OpAsmParser &parser,
+ OperationState &result) {
+ SmallVector<OpAsmParser::UnresolvedOperand, 2> operands;
+ if (parser.parseOptionalAttrDict(result.attributes) ||
+ parser.parseOperandList(operands) || parser.parseColon())
+ return failure();
+
+ Type resultType;
+ SMLoc loc = parser.getCurrentLocation();
+ if (parser.parseType(resultType))
+ return failure();
+
+ auto structType = resultType.dyn_cast<spirv::StructType>();
+ if (!structType || structType.getNumElements() != 2)
+ return parser.emitError(loc, "expected spirv.struct type with two members");
+
+ SmallVector<Type, 2> operandTypes(2, structType.getElementType(0));
+ if (parser.resolveOperands(operands, operandTypes, loc, result.operands))
+ return failure();
+
+ result.addTypes(resultType);
+ return success();
+}
+
+static void printArithmeticExtendedBinaryOp(Operation *op,
+ OpAsmPrinter &printer) {
+ printer << ' ';
+ printer.printOptionalAttrDict(op->getAttrs());
+ printer.printOperands(op->getOperands());
+ printer << " : " << op->getResultTypes().front();
+}
+
//===----------------------------------------------------------------------===//
// Common parsers and printers
//===----------------------------------------------------------------------===//
@@ -2990,48 +3039,16 @@ void spirv::GroupNonUniformUMinOp::print(OpAsmPrinter &p) {
//===----------------------------------------------------------------------===//
LogicalResult spirv::IAddCarryOp::verify() {
- auto resultType = getType().cast<spirv::StructType>();
- if (resultType.getNumElements() != 2)
- return emitOpError("expected result struct type containing two members");
-
- if (!llvm::all_equal({getOperand1().getType(), getOperand2().getType(),
- resultType.getElementType(0),
- resultType.getElementType(1)}))
- return emitOpError(
- "expected all operand types and struct member types are the same");
-
- return success();
+ return ::verifyArithmeticExtendedBinaryOp(*this);
}
ParseResult spirv::IAddCarryOp::parse(OpAsmParser &parser,
OperationState &result) {
- SmallVector<OpAsmParser::UnresolvedOperand, 2> operands;
- if (parser.parseOptionalAttrDict(result.attributes) ||
- parser.parseOperandList(operands) || parser.parseColon())
- return failure();
-
- Type resultType;
- SMLoc loc = parser.getCurrentLocation();
- if (parser.parseType(resultType))
- return failure();
-
- auto structType = resultType.dyn_cast<spirv::StructType>();
- if (!structType || structType.getNumElements() != 2)
- return parser.emitError(loc, "expected spirv.struct type with two members");
-
- SmallVector<Type, 2> operandTypes(2, structType.getElementType(0));
- if (parser.resolveOperands(operands, operandTypes, loc, result.operands))
- return failure();
-
- result.addTypes(resultType);
- return success();
+ return ::parseArithmeticExtendedBinaryOp(parser, result);
}
void spirv::IAddCarryOp::print(OpAsmPrinter &printer) {
- printer << ' ';
- printer.printOptionalAttrDict((*this)->getAttrs());
- printer.printOperands((*this)->getOperands());
- printer << " : " << getType();
+ ::printArithmeticExtendedBinaryOp(*this, printer);
}
//===----------------------------------------------------------------------===//
@@ -3039,48 +3056,50 @@ void spirv::IAddCarryOp::print(OpAsmPrinter &printer) {
//===----------------------------------------------------------------------===//
LogicalResult spirv::ISubBorrowOp::verify() {
- auto resultType = getType().cast<spirv::StructType>();
- if (resultType.getNumElements() != 2)
- return emitOpError("expected result struct type containing two members");
-
- if (!llvm::all_equal({getOperand1().getType(), getOperand2().getType(),
- resultType.getElementType(0),
- resultType.getElementType(1)}))
- return emitOpError(
- "expected all operand types and struct member types are the same");
-
- return success();
+ return ::verifyArithmeticExtendedBinaryOp(*this);
}
ParseResult spirv::ISubBorrowOp::parse(OpAsmParser &parser,
OperationState &result) {
- SmallVector<OpAsmParser::UnresolvedOperand, 2> operands;
- if (parser.parseOptionalAttrDict(result.attributes) ||
- parser.parseOperandList(operands) || parser.parseColon())
- return failure();
+ return ::parseArithmeticExtendedBinaryOp(parser, result);
+}
- Type resultType;
- auto loc = parser.getCurrentLocation();
- if (parser.parseType(resultType))
- return failure();
+void spirv::ISubBorrowOp::print(OpAsmPrinter &printer) {
+ ::printArithmeticExtendedBinaryOp(*this, printer);
+}
- auto structType = resultType.dyn_cast<spirv::StructType>();
- if (!structType || structType.getNumElements() != 2)
- return parser.emitError(loc, "expected spirv.struct type with two members");
+//===----------------------------------------------------------------------===//
+// spirv.SMulExtended
+//===----------------------------------------------------------------------===//
- SmallVector<Type, 2> operandTypes(2, structType.getElementType(0));
- if (parser.resolveOperands(operands, operandTypes, loc, result.operands))
- return failure();
+LogicalResult spirv::SMulExtendedOp::verify() {
+ return ::verifyArithmeticExtendedBinaryOp(*this);
+}
- result.addTypes(resultType);
- return success();
+ParseResult spirv::SMulExtendedOp::parse(OpAsmParser &parser,
+ OperationState &result) {
+ return ::parseArithmeticExtendedBinaryOp(parser, result);
}
-void spirv::ISubBorrowOp::print(OpAsmPrinter &printer) {
- printer << ' ';
- printer.printOptionalAttrDict((*this)->getAttrs());
- printer.printOperands((*this)->getOperands());
- printer << " : " << getType();
+void spirv::SMulExtendedOp::print(OpAsmPrinter &printer) {
+ ::printArithmeticExtendedBinaryOp(*this, printer);
+}
+
+//===----------------------------------------------------------------------===//
+// spirv.UMulExtended
+//===----------------------------------------------------------------------===//
+
+LogicalResult spirv::UMulExtendedOp::verify() {
+ return ::verifyArithmeticExtendedBinaryOp(*this);
+}
+
+ParseResult spirv::UMulExtendedOp::parse(OpAsmParser &parser,
+ OperationState &result) {
+ return ::parseArithmeticExtendedBinaryOp(parser, result);
+}
+
+void spirv::UMulExtendedOp::print(OpAsmPrinter &printer) {
+ ::printArithmeticExtendedBinaryOp(*this, printer);
}
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/SPIRV/IR/arithmetic-ops.mlir b/mlir/test/Dialect/SPIRV/IR/arithmetic-ops.mlir
index c59179eda9fd7..9617204d3419c 100644
--- a/mlir/test/Dialect/SPIRV/IR/arithmetic-ops.mlir
+++ b/mlir/test/Dialect/SPIRV/IR/arithmetic-ops.mlir
@@ -254,6 +254,110 @@ func.func @isub_borrow(%arg: i64) -> !spirv.struct<(i32, i32)> {
// -----
+//===----------------------------------------------------------------------===//
+// spirv.SMulExtended
+//===----------------------------------------------------------------------===//
+
+// CHECK-LABEL: @smul_extended_scalar
+func.func @smul_extended_scalar(%arg: i32) -> !spirv.struct<(i32, i32)> {
+ // CHECK: spirv.SMulExtended %{{.+}}, %{{.+}} : !spirv.struct<(i32, i32)>
+ %0 = spirv.SMulExtended %arg, %arg : !spirv.struct<(i32, i32)>
+ return %0 : !spirv.struct<(i32, i32)>
+}
+
+// CHECK-LABEL: @smul_extended_vector
+func.func @smul_extended_vector(%arg: vector<3xi32>) -> !spirv.struct<(vector<3xi32>, vector<3xi32>)> {
+ // CHECK: spirv.SMulExtended %{{.+}}, %{{.+}} : !spirv.struct<(vector<3xi32>, vector<3xi32>)>
+ %0 = spirv.SMulExtended %arg, %arg : !spirv.struct<(vector<3xi32>, vector<3xi32>)>
+ return %0 : !spirv.struct<(vector<3xi32>, vector<3xi32>)>
+}
+
+// -----
+
+func.func @smul_extended(%arg: i32) -> !spirv.struct<(i32, i32, i32)> {
+ // expected-error @+1 {{expected spirv.struct type with two members}}
+ %0 = spirv.SMulExtended %arg, %arg : !spirv.struct<(i32, i32, i32)>
+ return %0 : !spirv.struct<(i32, i32, i32)>
+}
+
+// -----
+
+func.func @smul_extended(%arg: i32) -> !spirv.struct<(i32)> {
+ // expected-error @+1 {{expected result struct type containing two members}}
+ %0 = "spirv.SMulExtended"(%arg, %arg): (i32, i32) -> !spirv.struct<(i32)>
+ return %0 : !spirv.struct<(i32)>
+}
+
+// -----
+
+func.func @smul_extended(%arg: i32) -> !spirv.struct<(i32, i64)> {
+ // expected-error @+1 {{expected all operand types and struct member types are the same}}
+ %0 = "spirv.SMulExtended"(%arg, %arg): (i32, i32) -> !spirv.struct<(i32, i64)>
+ return %0 : !spirv.struct<(i32, i64)>
+}
+
+// -----
+
+func.func @smul_extended(%arg: i64) -> !spirv.struct<(i32, i32)> {
+ // expected-error @+1 {{expected all operand types and struct member types are the same}}
+ %0 = "spirv.SMulExtended"(%arg, %arg): (i64, i64) -> !spirv.struct<(i32, i32)>
+ return %0 : !spirv.struct<(i32, i32)>
+}
+
+// -----
+
+//===----------------------------------------------------------------------===//
+// spirv.UMulExtended
+//===----------------------------------------------------------------------===//
+
+// CHECK-LABEL: @umul_extended_scalar
+func.func @umul_extended_scalar(%arg: i32) -> !spirv.struct<(i32, i32)> {
+ // CHECK: spirv.UMulExtended %{{.+}}, %{{.+}} : !spirv.struct<(i32, i32)>
+ %0 = spirv.UMulExtended %arg, %arg : !spirv.struct<(i32, i32)>
+ return %0 : !spirv.struct<(i32, i32)>
+}
+
+// CHECK-LABEL: @umul_extended_vector
+func.func @umul_extended_vector(%arg: vector<3xi32>) -> !spirv.struct<(vector<3xi32>, vector<3xi32>)> {
+ // CHECK: spirv.UMulExtended %{{.+}}, %{{.+}} : !spirv.struct<(vector<3xi32>, vector<3xi32>)>
+ %0 = spirv.UMulExtended %arg, %arg : !spirv.struct<(vector<3xi32>, vector<3xi32>)>
+ return %0 : !spirv.struct<(vector<3xi32>, vector<3xi32>)>
+}
+
+// -----
+
+func.func @umul_extended(%arg: i32) -> !spirv.struct<(i32, i32, i32)> {
+ // expected-error @+1 {{expected spirv.struct type with two members}}
+ %0 = spirv.UMulExtended %arg, %arg : !spirv.struct<(i32, i32, i32)>
+ return %0 : !spirv.struct<(i32, i32, i32)>
+}
+
+// -----
+
+func.func @umul_extended(%arg: i32) -> !spirv.struct<(i32)> {
+ // expected-error @+1 {{expected result struct type containing two members}}
+ %0 = "spirv.UMulExtended"(%arg, %arg): (i32, i32) -> !spirv.struct<(i32)>
+ return %0 : !spirv.struct<(i32)>
+}
+
+// -----
+
+func.func @umul_extended(%arg: i32) -> !spirv.struct<(i32, i64)> {
+ // expected-error @+1 {{expected all operand types and struct member types are the same}}
+ %0 = "spirv.UMulExtended"(%arg, %arg): (i32, i32) -> !spirv.struct<(i32, i64)>
+ return %0 : !spirv.struct<(i32, i64)>
+}
+
+// -----
+
+func.func @umul_extended(%arg: i64) -> !spirv.struct<(i32, i32)> {
+ // expected-error @+1 {{expected all operand types and struct member types are the same}}
+ %0 = "spirv.UMulExtended"(%arg, %arg): (i64, i64) -> !spirv.struct<(i32, i32)>
+ return %0 : !spirv.struct<(i32, i32)>
+}
+
+// -----
+
//===----------------------------------------------------------------------===//
// spirv.SDiv
//===----------------------------------------------------------------------===//
More information about the Mlir-commits
mailing list