[Mlir-commits] [mlir] 3027582 - [mlir][spirv] Define `OpSMulExtended` and `OpUMulExtended` ops

Jakub Kuderski llvmlistbot at llvm.org
Wed Nov 23 16:06:12 PST 2022


Author: Jakub Kuderski
Date: 2022-11-23T19:04:56-05:00
New Revision: 30275821f1cacd855c7f9246351242223bed9d29

URL: https://github.com/llvm/llvm-project/commit/30275821f1cacd855c7f9246351242223bed9d29
DIFF: https://github.com/llvm/llvm-project/commit/30275821f1cacd855c7f9246351242223bed9d29.diff

LOG: [mlir][spirv] Define `OpSMulExtended` and `OpUMulExtended` ops

These perform exact multiplication and return the high half as a second
result.

Also factor out common code shared between 'extended binary ops'.

Reviewed By: antiagainst

Differential Revision: https://reviews.llvm.org/D138624

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/SPIRV/IR/SPIRVArithmeticOps.td
    mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
    mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
    mlir/test/Dialect/SPIRV/IR/arithmetic-ops.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVArithmeticOps.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVArithmeticOps.td
index b124625f70543..1d6c98d017350 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVArithmeticOps.td
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVArithmeticOps.td
@@ -53,7 +53,32 @@ class SPIRV_ArithmeticUnaryOp<string mnemonic, Type type,
     SPIRV_ScalarOrVectorOrCoopMatrixOf<type>:$result
   );
   let assemblyFormat = "operands attr-dict `:` type($result)";
-  }
+}
+
+class SPIRV_ArithmeticExtendedBinaryOp<string mnemonic,
+                                       list<Trait> traits = []> :
+      // Result type is a struct with two operand-typed elements.
+      SPIRV_BinaryOp<mnemonic, SPIRV_AnyStruct, SPIRV_Integer, traits> {
+  let arguments = (ins
+    SPIRV_ScalarOrVectorOf<SPIRV_Integer>:$operand1,
+    SPIRV_ScalarOrVectorOf<SPIRV_Integer>:$operand2
+  );
+
+  let results = (outs
+    SPIRV_AnyStruct:$result
+  );
+
+  let builders = [
+    OpBuilder<(ins "Value":$operand1, "Value":$operand2), [{
+      build($_builder, $_state,
+            ::mlir::spirv::StructType::get({operand1.getType(), operand1.getType()}),
+            operand1, operand2);
+    }]>
+  ];
+
+  // These op require a custom verifier.
+  let hasVerifier = 1;
+}
 
 // -----
 
@@ -321,9 +346,8 @@ def SPIRV_IAddOp : SPIRV_ArithmeticBinaryOp<"IAdd",
 
 // -----
 
-def SPIRV_IAddCarryOp : SPIRV_BinaryOp<"IAddCarry",
-                                   SPIRV_AnyStruct, SPIRV_Integer,
-                                   [Commutative, Pure]> {
+def SPIRV_IAddCarryOp : SPIRV_ArithmeticExtendedBinaryOp<"IAddCarry",
+                                                         [Commutative, Pure]> {
   let summary = [{
     Integer addition of Operand 1 and Operand 2, including the carry.
   }];
@@ -355,25 +379,6 @@ def SPIRV_IAddCarryOp : SPIRV_BinaryOp<"IAddCarry",
     %2 = spirv.IAddCarry %0, %1 : !spirv.struct<(vector<2xi32>, vector<2xi32>)>
     ```
   }];
-
-  let arguments = (ins
-    SPIRV_ScalarOrVectorOf<SPIRV_Integer>:$operand1,
-    SPIRV_ScalarOrVectorOf<SPIRV_Integer>:$operand2
-  );
-
-  let results = (outs
-    SPIRV_AnyStruct:$result
-  );
-
-  let builders = [
-    OpBuilder<(ins "Value":$operand1, "Value":$operand2), [{
-      build($_builder, $_state,
-            ::mlir::spirv::StructType::get({operand1.getType(), operand1.getType()}),
-            operand1, operand2);
-    }]>
-  ];
-
-  let hasVerifier = 1;
 }
 
 // -----
@@ -418,6 +423,75 @@ def SPIRV_IMulOp : SPIRV_ArithmeticBinaryOp<"IMul",
 
 // -----
 
+def SPIRV_SMulExtendedOp : SPIRV_ArithmeticExtendedBinaryOp<"SMulExtended",
+                                                            [Pure, Commutative]> {
+  let summary = [{
+    Result is the full value of the signed integer multiplication of Operand
+    1 and Operand 2.
+  }];
+
+  let description = [{
+    Result Type must be from OpTypeStruct.  The struct must have two
+    members, and the two members must be the same type.  The member type
+    must be a scalar or vector of integer type.
+
+    Operand 1 and Operand 2 must have the same type as the members of Result
+    Type. These are consumed as signed integers.
+
+    Results are computed per component.
+
+    Member 0 of the result gets the low-order bits of the multiplication.
+
+    Member 1 of the result gets the high-order bits of the multiplication.
+
+    <!-- End of AutoGen section -->
+
+    #### Example:
+
+    ```mlir
+    %2 = spirv.SMulExtended %0, %1 : !spirv.struct<(i32, i32)>
+    %2 = spirv.SMulExtended %0, %1 : !spirv.struct<(vector<2xi32>, vector<2xi32>)>
+    ```
+  }];
+}
+
+// -----
+
+def SPIRV_UMulExtendedOp : SPIRV_ArithmeticExtendedBinaryOp<"UMulExtended",
+                                                            [Pure, Commutative]> {
+  let summary = [{
+    Result is the full value of the unsigned integer multiplication of
+    Operand 1 and Operand 2.
+  }];
+
+  let description = [{
+    Result Type must be from OpTypeStruct.  The struct must have two
+    members, and the two members must be the same type.  The member type
+    must be a scalar or vector of integer type, whose Signedness operand is
+    0.
+
+    Operand 1 and Operand 2 must have the same type as the members of Result
+    Type. These are consumed as unsigned integers.
+
+    Results are computed per component.
+
+    Member 0 of the result gets the low-order bits of the multiplication.
+
+    Member 1 of the result gets the high-order bits of the multiplication.
+
+    <!-- End of AutoGen section -->
+
+    #### Example:
+
+    ```mlir
+    %2 = spirv.UMulExtended %0, %1 : !spirv.struct<(i32, i32)>
+    %2 = spirv.UMulExtended %0, %1 : !spirv.struct<(vector<2xi32>, vector<2xi32>)>
+    ```
+  }];
+}
+
+// -----
+
 def SPIRV_ISubOp : SPIRV_ArithmeticBinaryOp<"ISub",
                                         SPIRV_Integer,
                                         [UsableInSpecConstantOp]> {
@@ -458,8 +532,8 @@ def SPIRV_ISubOp : SPIRV_ArithmeticBinaryOp<"ISub",
 
 // -----
 
-def SPIRV_ISubBorrowOp : SPIRV_BinaryOp<"ISubBorrow", SPIRV_AnyStruct, SPIRV_Integer,
-                                    [Pure]> {
+def SPIRV_ISubBorrowOp : SPIRV_ArithmeticExtendedBinaryOp<"ISubBorrow",
+                                                          [Pure]> {
   let summary = [{
     Result is the unsigned integer subtraction of Operand 2 from Operand 1,
     and what it needed to borrow.
@@ -494,25 +568,6 @@ def SPIRV_ISubBorrowOp : SPIRV_BinaryOp<"ISubBorrow", SPIRV_AnyStruct, SPIRV_Int
     %2 = spirv.ISubBorrow %0, %1 : !spirv.struct<(vector<2xi32>, vector<2xi32>)>
     ```
   }];
-
-  let arguments = (ins
-    SPIRV_ScalarOrVectorOf<SPIRV_Integer>:$operand1,
-    SPIRV_ScalarOrVectorOf<SPIRV_Integer>:$operand2
-  );
-
-  let results = (outs
-    SPIRV_AnyStruct:$result
-  );
-
-  let builders = [
-    OpBuilder<(ins "Value":$operand1, "Value":$operand2), [{
-      build($_builder, $_state,
-            ::mlir::spirv::StructType::get({operand1.getType(), operand1.getType()}),
-            operand1, operand2);
-    }]>
-  ];
-
-  let hasVerifier = 1;
 }
 
 // -----

diff  --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
index 77fe5ac9e84e9..400947a4043e3 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
@@ -4243,6 +4243,8 @@ def SPIRV_OC_OpMatrixTimesScalar          : I32EnumAttrCase<"OpMatrixTimesScalar
 def SPIRV_OC_OpMatrixTimesMatrix          : I32EnumAttrCase<"OpMatrixTimesMatrix", 146>;
 def SPIRV_OC_OpIAddCarry                  : I32EnumAttrCase<"OpIAddCarry", 149>;
 def SPIRV_OC_OpISubBorrow                 : I32EnumAttrCase<"OpISubBorrow", 150>;
+def SPIRV_OC_OpUMulExtended               : I32EnumAttrCase<"OpUMulExtended", 151>;
+def SPIRV_OC_OpSMulExtended               : I32EnumAttrCase<"OpSMulExtended", 152>;
 def SPIRV_OC_OpIsNan                      : I32EnumAttrCase<"OpIsNan", 156>;
 def SPIRV_OC_OpIsInf                      : I32EnumAttrCase<"OpIsInf", 157>;
 def SPIRV_OC_OpOrdered                    : I32EnumAttrCase<"OpOrdered", 162>;
@@ -4372,17 +4374,17 @@ def SPIRV_OpcodeAttr :
       SPIRV_OC_OpCompositeInsert, SPIRV_OC_OpTranspose, SPIRV_OC_OpImageDrefGather,
       SPIRV_OC_OpImage, SPIRV_OC_OpImageQuerySize, SPIRV_OC_OpConvertFToU,
       SPIRV_OC_OpConvertFToS, SPIRV_OC_OpConvertSToF, SPIRV_OC_OpConvertUToF,
-      SPIRV_OC_OpUConvert, SPIRV_OC_OpSConvert, SPIRV_OC_OpFConvert, SPIRV_OC_OpPtrCastToGeneric, 
+      SPIRV_OC_OpUConvert, SPIRV_OC_OpSConvert, SPIRV_OC_OpFConvert, SPIRV_OC_OpPtrCastToGeneric,
       SPIRV_OC_OpGenericCastToPtr, SPIRV_OC_OpGenericCastToPtrExplicit, SPIRV_OC_OpBitcast,
       SPIRV_OC_OpSNegate, SPIRV_OC_OpFNegate, SPIRV_OC_OpIAdd, SPIRV_OC_OpFAdd,
       SPIRV_OC_OpISub, SPIRV_OC_OpFSub, SPIRV_OC_OpIMul, SPIRV_OC_OpFMul, SPIRV_OC_OpUDiv,
       SPIRV_OC_OpSDiv, SPIRV_OC_OpFDiv, SPIRV_OC_OpUMod, SPIRV_OC_OpSRem, SPIRV_OC_OpSMod,
       SPIRV_OC_OpFRem, SPIRV_OC_OpFMod, SPIRV_OC_OpVectorTimesScalar,
       SPIRV_OC_OpMatrixTimesScalar, SPIRV_OC_OpMatrixTimesMatrix, SPIRV_OC_OpIAddCarry,
-      SPIRV_OC_OpISubBorrow, SPIRV_OC_OpIsNan, SPIRV_OC_OpIsInf, SPIRV_OC_OpOrdered,
-      SPIRV_OC_OpUnordered, SPIRV_OC_OpLogicalEqual, SPIRV_OC_OpLogicalNotEqual,
-      SPIRV_OC_OpLogicalOr, SPIRV_OC_OpLogicalAnd, SPIRV_OC_OpLogicalNot, SPIRV_OC_OpSelect,
-      SPIRV_OC_OpIEqual, SPIRV_OC_OpINotEqual, SPIRV_OC_OpUGreaterThan,
+      SPIRV_OC_OpISubBorrow, SPIRV_OC_OpUMulExtended, SPIRV_OC_OpSMulExtended, SPIRV_OC_OpIsNan,
+      SPIRV_OC_OpIsInf, SPIRV_OC_OpOrdered, SPIRV_OC_OpUnordered, SPIRV_OC_OpLogicalEqual,
+      SPIRV_OC_OpLogicalNotEqual, SPIRV_OC_OpLogicalOr, SPIRV_OC_OpLogicalAnd, SPIRV_OC_OpLogicalNot,
+      SPIRV_OC_OpSelect, SPIRV_OC_OpIEqual, SPIRV_OC_OpINotEqual, SPIRV_OC_OpUGreaterThan,
       SPIRV_OC_OpSGreaterThan, SPIRV_OC_OpUGreaterThanEqual, SPIRV_OC_OpSGreaterThanEqual,
       SPIRV_OC_OpULessThan, SPIRV_OC_OpSLessThan, SPIRV_OC_OpULessThanEqual,
       SPIRV_OC_OpSLessThanEqual, SPIRV_OC_OpFOrdEqual, SPIRV_OC_OpFUnordEqual,

diff  --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
index c16d3d0b2d8e4..606562292ebcf 100644
--- a/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
@@ -24,6 +24,7 @@
 #include "mlir/IR/FunctionImplementation.h"
 #include "mlir/IR/OpDefinition.h"
 #include "mlir/IR/OpImplementation.h"
+#include "mlir/IR/Operation.h"
 #include "mlir/IR/TypeUtilities.h"
 #include "mlir/Interfaces/CallInterfaces.h"
 #include "llvm/ADT/APFloat.h"
@@ -31,6 +32,7 @@
 #include "llvm/ADT/STLExtras.h"
 #include "llvm/ADT/StringExtras.h"
 #include "llvm/ADT/bit.h"
+#include <cassert>
 #include <numeric>
 
 using namespace mlir;
@@ -763,6 +765,53 @@ static inline bool isMergeBlock(Block &block) {
          isa<spirv::MergeOp>(block.front());
 }
 
+template <typename ExtendedBinaryOp>
+static LogicalResult verifyArithmeticExtendedBinaryOp(ExtendedBinaryOp op) {
+  auto resultType = op.getType().template cast<spirv::StructType>();
+  if (resultType.getNumElements() != 2)
+    return op.emitOpError("expected result struct type containing two members");
+
+  if (!llvm::all_equal({op.getOperand1().getType(), op.getOperand2().getType(),
+                        resultType.getElementType(0),
+                        resultType.getElementType(1)}))
+    return op.emitOpError(
+        "expected all operand types and struct member types are the same");
+
+  return success();
+}
+
+static ParseResult parseArithmeticExtendedBinaryOp(OpAsmParser &parser,
+                                                   OperationState &result) {
+  SmallVector<OpAsmParser::UnresolvedOperand, 2> operands;
+  if (parser.parseOptionalAttrDict(result.attributes) ||
+      parser.parseOperandList(operands) || parser.parseColon())
+    return failure();
+
+  Type resultType;
+  SMLoc loc = parser.getCurrentLocation();
+  if (parser.parseType(resultType))
+    return failure();
+
+  auto structType = resultType.dyn_cast<spirv::StructType>();
+  if (!structType || structType.getNumElements() != 2)
+    return parser.emitError(loc, "expected spirv.struct type with two members");
+
+  SmallVector<Type, 2> operandTypes(2, structType.getElementType(0));
+  if (parser.resolveOperands(operands, operandTypes, loc, result.operands))
+    return failure();
+
+  result.addTypes(resultType);
+  return success();
+}
+
+static void printArithmeticExtendedBinaryOp(Operation *op,
+                                            OpAsmPrinter &printer) {
+  printer << ' ';
+  printer.printOptionalAttrDict(op->getAttrs());
+  printer.printOperands(op->getOperands());
+  printer << " : " << op->getResultTypes().front();
+}
+
 //===----------------------------------------------------------------------===//
 // Common parsers and printers
 //===----------------------------------------------------------------------===//
@@ -2990,48 +3039,16 @@ void spirv::GroupNonUniformUMinOp::print(OpAsmPrinter &p) {
 //===----------------------------------------------------------------------===//
 
 LogicalResult spirv::IAddCarryOp::verify() {
-  auto resultType = getType().cast<spirv::StructType>();
-  if (resultType.getNumElements() != 2)
-    return emitOpError("expected result struct type containing two members");
-
-  if (!llvm::all_equal({getOperand1().getType(), getOperand2().getType(),
-                        resultType.getElementType(0),
-                        resultType.getElementType(1)}))
-    return emitOpError(
-        "expected all operand types and struct member types are the same");
-
-  return success();
+  return ::verifyArithmeticExtendedBinaryOp(*this);
 }
 
 ParseResult spirv::IAddCarryOp::parse(OpAsmParser &parser,
                                       OperationState &result) {
-  SmallVector<OpAsmParser::UnresolvedOperand, 2> operands;
-  if (parser.parseOptionalAttrDict(result.attributes) ||
-      parser.parseOperandList(operands) || parser.parseColon())
-    return failure();
-
-  Type resultType;
-  SMLoc loc = parser.getCurrentLocation();
-  if (parser.parseType(resultType))
-    return failure();
-
-  auto structType = resultType.dyn_cast<spirv::StructType>();
-  if (!structType || structType.getNumElements() != 2)
-    return parser.emitError(loc, "expected spirv.struct type with two members");
-
-  SmallVector<Type, 2> operandTypes(2, structType.getElementType(0));
-  if (parser.resolveOperands(operands, operandTypes, loc, result.operands))
-    return failure();
-
-  result.addTypes(resultType);
-  return success();
+  return ::parseArithmeticExtendedBinaryOp(parser, result);
 }
 
 void spirv::IAddCarryOp::print(OpAsmPrinter &printer) {
-  printer << ' ';
-  printer.printOptionalAttrDict((*this)->getAttrs());
-  printer.printOperands((*this)->getOperands());
-  printer << " : " << getType();
+  ::printArithmeticExtendedBinaryOp(*this, printer);
 }
 
 //===----------------------------------------------------------------------===//
@@ -3039,48 +3056,50 @@ void spirv::IAddCarryOp::print(OpAsmPrinter &printer) {
 //===----------------------------------------------------------------------===//
 
 LogicalResult spirv::ISubBorrowOp::verify() {
-  auto resultType = getType().cast<spirv::StructType>();
-  if (resultType.getNumElements() != 2)
-    return emitOpError("expected result struct type containing two members");
-
-  if (!llvm::all_equal({getOperand1().getType(), getOperand2().getType(),
-                        resultType.getElementType(0),
-                        resultType.getElementType(1)}))
-    return emitOpError(
-        "expected all operand types and struct member types are the same");
-
-  return success();
+  return ::verifyArithmeticExtendedBinaryOp(*this);
 }
 
 ParseResult spirv::ISubBorrowOp::parse(OpAsmParser &parser,
                                        OperationState &result) {
-  SmallVector<OpAsmParser::UnresolvedOperand, 2> operands;
-  if (parser.parseOptionalAttrDict(result.attributes) ||
-      parser.parseOperandList(operands) || parser.parseColon())
-    return failure();
+  return ::parseArithmeticExtendedBinaryOp(parser, result);
+}
 
-  Type resultType;
-  auto loc = parser.getCurrentLocation();
-  if (parser.parseType(resultType))
-    return failure();
+void spirv::ISubBorrowOp::print(OpAsmPrinter &printer) {
+  ::printArithmeticExtendedBinaryOp(*this, printer);
+}
 
-  auto structType = resultType.dyn_cast<spirv::StructType>();
-  if (!structType || structType.getNumElements() != 2)
-    return parser.emitError(loc, "expected spirv.struct type with two members");
+//===----------------------------------------------------------------------===//
+// spirv.SMulExtended
+//===----------------------------------------------------------------------===//
 
-  SmallVector<Type, 2> operandTypes(2, structType.getElementType(0));
-  if (parser.resolveOperands(operands, operandTypes, loc, result.operands))
-    return failure();
+LogicalResult spirv::SMulExtendedOp::verify() {
+  return ::verifyArithmeticExtendedBinaryOp(*this);
+}
 
-  result.addTypes(resultType);
-  return success();
+ParseResult spirv::SMulExtendedOp::parse(OpAsmParser &parser,
+                                         OperationState &result) {
+  return ::parseArithmeticExtendedBinaryOp(parser, result);
 }
 
-void spirv::ISubBorrowOp::print(OpAsmPrinter &printer) {
-  printer << ' ';
-  printer.printOptionalAttrDict((*this)->getAttrs());
-  printer.printOperands((*this)->getOperands());
-  printer << " : " << getType();
+void spirv::SMulExtendedOp::print(OpAsmPrinter &printer) {
+  ::printArithmeticExtendedBinaryOp(*this, printer);
+}
+
+//===----------------------------------------------------------------------===//
+// spirv.UMulExtended
+//===----------------------------------------------------------------------===//
+
+LogicalResult spirv::UMulExtendedOp::verify() {
+  return ::verifyArithmeticExtendedBinaryOp(*this);
+}
+
+ParseResult spirv::UMulExtendedOp::parse(OpAsmParser &parser,
+                                         OperationState &result) {
+  return ::parseArithmeticExtendedBinaryOp(parser, result);
+}
+
+void spirv::UMulExtendedOp::print(OpAsmPrinter &printer) {
+  ::printArithmeticExtendedBinaryOp(*this, printer);
 }
 
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/test/Dialect/SPIRV/IR/arithmetic-ops.mlir b/mlir/test/Dialect/SPIRV/IR/arithmetic-ops.mlir
index c59179eda9fd7..9617204d3419c 100644
--- a/mlir/test/Dialect/SPIRV/IR/arithmetic-ops.mlir
+++ b/mlir/test/Dialect/SPIRV/IR/arithmetic-ops.mlir
@@ -254,6 +254,110 @@ func.func @isub_borrow(%arg: i64) -> !spirv.struct<(i32, i32)> {
 
 // -----
 
+//===----------------------------------------------------------------------===//
+// spirv.SMulExtended
+//===----------------------------------------------------------------------===//
+
+// CHECK-LABEL: @smul_extended_scalar
+func.func @smul_extended_scalar(%arg: i32) -> !spirv.struct<(i32, i32)> {
+  // CHECK: spirv.SMulExtended %{{.+}}, %{{.+}} : !spirv.struct<(i32, i32)>
+  %0 = spirv.SMulExtended %arg, %arg : !spirv.struct<(i32, i32)>
+  return %0 : !spirv.struct<(i32, i32)>
+}
+
+// CHECK-LABEL: @smul_extended_vector
+func.func @smul_extended_vector(%arg: vector<3xi32>) -> !spirv.struct<(vector<3xi32>, vector<3xi32>)> {
+  // CHECK: spirv.SMulExtended %{{.+}}, %{{.+}} : !spirv.struct<(vector<3xi32>, vector<3xi32>)>
+  %0 = spirv.SMulExtended %arg, %arg : !spirv.struct<(vector<3xi32>, vector<3xi32>)>
+  return %0 : !spirv.struct<(vector<3xi32>, vector<3xi32>)>
+}
+
+// -----
+
+func.func @smul_extended(%arg: i32) -> !spirv.struct<(i32, i32, i32)> {
+  // expected-error @+1 {{expected spirv.struct type with two members}}
+  %0 = spirv.SMulExtended %arg, %arg : !spirv.struct<(i32, i32, i32)>
+  return %0 : !spirv.struct<(i32, i32, i32)>
+}
+
+// -----
+
+func.func @smul_extended(%arg: i32) -> !spirv.struct<(i32)> {
+  // expected-error @+1 {{expected result struct type containing two members}}
+  %0 = "spirv.SMulExtended"(%arg, %arg): (i32, i32) -> !spirv.struct<(i32)>
+  return %0 : !spirv.struct<(i32)>
+}
+
+// -----
+
+func.func @smul_extended(%arg: i32) -> !spirv.struct<(i32, i64)> {
+  // expected-error @+1 {{expected all operand types and struct member types are the same}}
+  %0 = "spirv.SMulExtended"(%arg, %arg): (i32, i32) -> !spirv.struct<(i32, i64)>
+  return %0 : !spirv.struct<(i32, i64)>
+}
+
+// -----
+
+func.func @smul_extended(%arg: i64) -> !spirv.struct<(i32, i32)> {
+  // expected-error @+1 {{expected all operand types and struct member types are the same}}
+  %0 = "spirv.SMulExtended"(%arg, %arg): (i64, i64) -> !spirv.struct<(i32, i32)>
+  return %0 : !spirv.struct<(i32, i32)>
+}
+
+// -----
+
+//===----------------------------------------------------------------------===//
+// spirv.UMulExtended
+//===----------------------------------------------------------------------===//
+
+// CHECK-LABEL: @umul_extended_scalar
+func.func @umul_extended_scalar(%arg: i32) -> !spirv.struct<(i32, i32)> {
+  // CHECK: spirv.UMulExtended %{{.+}}, %{{.+}} : !spirv.struct<(i32, i32)>
+  %0 = spirv.UMulExtended %arg, %arg : !spirv.struct<(i32, i32)>
+  return %0 : !spirv.struct<(i32, i32)>
+}
+
+// CHECK-LABEL: @umul_extended_vector
+func.func @umul_extended_vector(%arg: vector<3xi32>) -> !spirv.struct<(vector<3xi32>, vector<3xi32>)> {
+  // CHECK: spirv.UMulExtended %{{.+}}, %{{.+}} : !spirv.struct<(vector<3xi32>, vector<3xi32>)>
+  %0 = spirv.UMulExtended %arg, %arg : !spirv.struct<(vector<3xi32>, vector<3xi32>)>
+  return %0 : !spirv.struct<(vector<3xi32>, vector<3xi32>)>
+}
+
+// -----
+
+func.func @umul_extended(%arg: i32) -> !spirv.struct<(i32, i32, i32)> {
+  // expected-error @+1 {{expected spirv.struct type with two members}}
+  %0 = spirv.UMulExtended %arg, %arg : !spirv.struct<(i32, i32, i32)>
+  return %0 : !spirv.struct<(i32, i32, i32)>
+}
+
+// -----
+
+func.func @umul_extended(%arg: i32) -> !spirv.struct<(i32)> {
+  // expected-error @+1 {{expected result struct type containing two members}}
+  %0 = "spirv.UMulExtended"(%arg, %arg): (i32, i32) -> !spirv.struct<(i32)>
+  return %0 : !spirv.struct<(i32)>
+}
+
+// -----
+
+func.func @umul_extended(%arg: i32) -> !spirv.struct<(i32, i64)> {
+  // expected-error @+1 {{expected all operand types and struct member types are the same}}
+  %0 = "spirv.UMulExtended"(%arg, %arg): (i32, i32) -> !spirv.struct<(i32, i64)>
+  return %0 : !spirv.struct<(i32, i64)>
+}
+
+// -----
+
+func.func @umul_extended(%arg: i64) -> !spirv.struct<(i32, i32)> {
+  // expected-error @+1 {{expected all operand types and struct member types are the same}}
+  %0 = "spirv.UMulExtended"(%arg, %arg): (i64, i64) -> !spirv.struct<(i32, i32)>
+  return %0 : !spirv.struct<(i32, i32)>
+}
+
+// -----
+
 //===----------------------------------------------------------------------===//
 // spirv.SDiv
 //===----------------------------------------------------------------------===//


        


More information about the Mlir-commits mailing list