[Mlir-commits] [mlir] cfb9e47 - [mlir][spirv] Define spv.VectorTimesScalar op

Lei Zhang llvmlistbot at llvm.org
Tue Mar 8 12:59:44 PST 2022


Author: Lei Zhang
Date: 2022-03-08T15:58:31-05:00
New Revision: cfb9e474ae360ce59ba9bf05167ba4922d58be5f

URL: https://github.com/llvm/llvm-project/commit/cfb9e474ae360ce59ba9bf05167ba4922d58be5f
DIFF: https://github.com/llvm/llvm-project/commit/cfb9e474ae360ce59ba9bf05167ba4922d58be5f.diff

LOG: [mlir][spirv] Define spv.VectorTimesScalar op

Reviewed By: hanchung

Differential Revision: https://reviews.llvm.org/D121247

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/SPIRV/IR/SPIRVArithmeticOps.td
    mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
    mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
    mlir/test/Dialect/SPIRV/IR/arithmetic-ops.mlir
    mlir/test/Target/SPIRV/arithmetic-ops.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVArithmeticOps.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVArithmeticOps.td
index c600b31676b3b..468934190d606 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVArithmeticOps.td
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVArithmeticOps.td
@@ -564,6 +564,40 @@ def SPV_UDivOp : SPV_ArithmeticBinaryOp<"UDiv",
 
 // -----
 
+def SPV_VectorTimesScalarOp : SPV_Op<"VectorTimesScalar", [NoSideEffect]> {
+  let summary = "Scale a floating-point vector.";
+
+  let description = [{
+    Result Type must be a vector of floating-point type.
+
+     The type of Vector must be the same as Result Type. Each component of
+    Vector is multiplied by Scalar.
+
+    Scalar must have the same type as the Component Type in Result Type.
+
+    <!-- End of AutoGen section -->
+
+    #### Example:
+
+    ```mlir
+    %0 = spv.VectorTimesScalar %vector, %scalar : vector<4xf32>
+    ```
+  }];
+
+  let arguments = (ins
+    VectorOfLengthAndType<[2, 3, 4], [SPV_Float]>:$vector,
+    SPV_Float:$scalar
+  );
+
+  let results = (outs
+    VectorOfLengthAndType<[2, 3, 4], [SPV_Float]>:$result
+  );
+
+  let assemblyFormat = "operands attr-dict `:` `(` type(operands) `)` `->` type($result)";
+}
+
+// -----
+
 def SPV_UModOp : SPV_ArithmeticBinaryOp<"UMod",
                                         SPV_Integer,
                                         [UnsignedOp, UsableInSpecConstantOp]> {

diff  --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
index 6d494859c03f0..c342331873b37 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
@@ -4078,6 +4078,7 @@ def SPV_OC_OpSRem                      : I32EnumAttrCase<"OpSRem", 138>;
 def SPV_OC_OpSMod                      : I32EnumAttrCase<"OpSMod", 139>;
 def SPV_OC_OpFRem                      : I32EnumAttrCase<"OpFRem", 140>;
 def SPV_OC_OpFMod                      : I32EnumAttrCase<"OpFMod", 141>;
+def SPV_OC_OpVectorTimesScalar         : I32EnumAttrCase<"OpVectorTimesScalar", 142>;
 def SPV_OC_OpMatrixTimesScalar         : I32EnumAttrCase<"OpMatrixTimesScalar", 143>;
 def SPV_OC_OpMatrixTimesMatrix         : I32EnumAttrCase<"OpMatrixTimesMatrix", 146>;
 def SPV_OC_OpIsNan                     : I32EnumAttrCase<"OpIsNan", 156>;
@@ -4202,32 +4203,33 @@ def SPV_OpcodeAttr :
       SPV_OC_OpSNegate, SPV_OC_OpFNegate, SPV_OC_OpIAdd, SPV_OC_OpFAdd,
       SPV_OC_OpISub, SPV_OC_OpFSub, SPV_OC_OpIMul, SPV_OC_OpFMul, SPV_OC_OpUDiv,
       SPV_OC_OpSDiv, SPV_OC_OpFDiv, SPV_OC_OpUMod, SPV_OC_OpSRem, SPV_OC_OpSMod,
-      SPV_OC_OpFRem, SPV_OC_OpFMod, SPV_OC_OpMatrixTimesScalar,
-      SPV_OC_OpMatrixTimesMatrix, SPV_OC_OpIsNan, SPV_OC_OpIsInf, SPV_OC_OpOrdered,
-      SPV_OC_OpUnordered, SPV_OC_OpLogicalEqual, SPV_OC_OpLogicalNotEqual,
-      SPV_OC_OpLogicalOr, SPV_OC_OpLogicalAnd, SPV_OC_OpLogicalNot, SPV_OC_OpSelect,
-      SPV_OC_OpIEqual, SPV_OC_OpINotEqual, SPV_OC_OpUGreaterThan,
-      SPV_OC_OpSGreaterThan, SPV_OC_OpUGreaterThanEqual, SPV_OC_OpSGreaterThanEqual,
-      SPV_OC_OpULessThan, SPV_OC_OpSLessThan, SPV_OC_OpULessThanEqual,
-      SPV_OC_OpSLessThanEqual, SPV_OC_OpFOrdEqual, SPV_OC_OpFUnordEqual,
-      SPV_OC_OpFOrdNotEqual, SPV_OC_OpFUnordNotEqual, SPV_OC_OpFOrdLessThan,
-      SPV_OC_OpFUnordLessThan, SPV_OC_OpFOrdGreaterThan, SPV_OC_OpFUnordGreaterThan,
-      SPV_OC_OpFOrdLessThanEqual, SPV_OC_OpFUnordLessThanEqual,
-      SPV_OC_OpFOrdGreaterThanEqual, SPV_OC_OpFUnordGreaterThanEqual,
-      SPV_OC_OpShiftRightLogical, SPV_OC_OpShiftRightArithmetic,
-      SPV_OC_OpShiftLeftLogical, SPV_OC_OpBitwiseOr, SPV_OC_OpBitwiseXor,
-      SPV_OC_OpBitwiseAnd, SPV_OC_OpNot, SPV_OC_OpBitFieldInsert,
-      SPV_OC_OpBitFieldSExtract, SPV_OC_OpBitFieldUExtract, SPV_OC_OpBitReverse,
-      SPV_OC_OpBitCount, SPV_OC_OpControlBarrier, SPV_OC_OpMemoryBarrier,
-      SPV_OC_OpAtomicExchange, SPV_OC_OpAtomicCompareExchange,
-      SPV_OC_OpAtomicCompareExchangeWeak, SPV_OC_OpAtomicIIncrement,
-      SPV_OC_OpAtomicIDecrement, SPV_OC_OpAtomicIAdd, SPV_OC_OpAtomicISub,
-      SPV_OC_OpAtomicSMin, SPV_OC_OpAtomicUMin, SPV_OC_OpAtomicSMax,
-      SPV_OC_OpAtomicUMax, SPV_OC_OpAtomicAnd, SPV_OC_OpAtomicOr, SPV_OC_OpAtomicXor,
-      SPV_OC_OpPhi, SPV_OC_OpLoopMerge, SPV_OC_OpSelectionMerge, SPV_OC_OpLabel,
-      SPV_OC_OpBranch, SPV_OC_OpBranchConditional, SPV_OC_OpReturn,
-      SPV_OC_OpReturnValue, SPV_OC_OpUnreachable, SPV_OC_OpGroupBroadcast,
-      SPV_OC_OpNoLine, SPV_OC_OpModuleProcessed, SPV_OC_OpGroupNonUniformElect,
+      SPV_OC_OpFRem, SPV_OC_OpFMod, SPV_OC_OpVectorTimesScalar,
+      SPV_OC_OpMatrixTimesScalar, SPV_OC_OpMatrixTimesMatrix, SPV_OC_OpIsNan,
+      SPV_OC_OpIsInf, SPV_OC_OpOrdered, SPV_OC_OpUnordered, SPV_OC_OpLogicalEqual,
+      SPV_OC_OpLogicalNotEqual, SPV_OC_OpLogicalOr, SPV_OC_OpLogicalAnd,
+      SPV_OC_OpLogicalNot, SPV_OC_OpSelect, SPV_OC_OpIEqual, SPV_OC_OpINotEqual,
+      SPV_OC_OpUGreaterThan, SPV_OC_OpSGreaterThan, SPV_OC_OpUGreaterThanEqual,
+      SPV_OC_OpSGreaterThanEqual, SPV_OC_OpULessThan, SPV_OC_OpSLessThan,
+      SPV_OC_OpULessThanEqual, SPV_OC_OpSLessThanEqual, SPV_OC_OpFOrdEqual,
+      SPV_OC_OpFUnordEqual, SPV_OC_OpFOrdNotEqual, SPV_OC_OpFUnordNotEqual,
+      SPV_OC_OpFOrdLessThan, SPV_OC_OpFUnordLessThan, SPV_OC_OpFOrdGreaterThan,
+      SPV_OC_OpFUnordGreaterThan, SPV_OC_OpFOrdLessThanEqual,
+      SPV_OC_OpFUnordLessThanEqual, SPV_OC_OpFOrdGreaterThanEqual,
+      SPV_OC_OpFUnordGreaterThanEqual, SPV_OC_OpShiftRightLogical,
+      SPV_OC_OpShiftRightArithmetic, SPV_OC_OpShiftLeftLogical, SPV_OC_OpBitwiseOr,
+      SPV_OC_OpBitwiseXor, SPV_OC_OpBitwiseAnd, SPV_OC_OpNot,
+      SPV_OC_OpBitFieldInsert, SPV_OC_OpBitFieldSExtract, SPV_OC_OpBitFieldUExtract,
+      SPV_OC_OpBitReverse, SPV_OC_OpBitCount, SPV_OC_OpControlBarrier,
+      SPV_OC_OpMemoryBarrier, SPV_OC_OpAtomicExchange,
+      SPV_OC_OpAtomicCompareExchange, SPV_OC_OpAtomicCompareExchangeWeak,
+      SPV_OC_OpAtomicIIncrement, SPV_OC_OpAtomicIDecrement, SPV_OC_OpAtomicIAdd,
+      SPV_OC_OpAtomicISub, SPV_OC_OpAtomicSMin, SPV_OC_OpAtomicUMin,
+      SPV_OC_OpAtomicSMax, SPV_OC_OpAtomicUMax, SPV_OC_OpAtomicAnd,
+      SPV_OC_OpAtomicOr, SPV_OC_OpAtomicXor, SPV_OC_OpPhi, SPV_OC_OpLoopMerge,
+      SPV_OC_OpSelectionMerge, SPV_OC_OpLabel, SPV_OC_OpBranch,
+      SPV_OC_OpBranchConditional, SPV_OC_OpReturn, SPV_OC_OpReturnValue,
+      SPV_OC_OpUnreachable, SPV_OC_OpGroupBroadcast, SPV_OC_OpNoLine,
+      SPV_OC_OpModuleProcessed, SPV_OC_OpGroupNonUniformElect,
       SPV_OC_OpGroupNonUniformBroadcast, SPV_OC_OpGroupNonUniformBallot,
       SPV_OC_OpGroupNonUniformIAdd, SPV_OC_OpGroupNonUniformFAdd,
       SPV_OC_OpGroupNonUniformIMul, SPV_OC_OpGroupNonUniformFMul,

diff  --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
index 1b4b4de3bb170..6a5d7c7c194d0 100644
--- a/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
@@ -4424,6 +4424,19 @@ LogicalResult spirv::PtrAccessChainOp::verify() {
   return verifyAccessChain(*this, indices());
 }
 
+//===----------------------------------------------------------------------===//
+// spv.VectorTimesScalarOp
+//===----------------------------------------------------------------------===//
+
+LogicalResult spirv::VectorTimesScalarOp::verify() {
+  if (vector().getType() != getType())
+    return emitOpError("vector operand and result type mismatch");
+  auto scalarType = getType().cast<VectorType>().getElementType();
+  if (scalar().getType() != scalarType)
+    return emitOpError("scalar operand and result element type match");
+  return success();
+}
+
 // TableGen'erated operation interfaces for querying versions, extensions, and
 // capabilities.
 #include "mlir/Dialect/SPIRV/IR/SPIRVAvailability.cpp.inc"

diff  --git a/mlir/test/Dialect/SPIRV/IR/arithmetic-ops.mlir b/mlir/test/Dialect/SPIRV/IR/arithmetic-ops.mlir
index de574b1510c9c..00481828ec634 100644
--- a/mlir/test/Dialect/SPIRV/IR/arithmetic-ops.mlir
+++ b/mlir/test/Dialect/SPIRV/IR/arithmetic-ops.mlir
@@ -219,3 +219,29 @@ func @umod_scalar(%arg: i32) -> i32 {
   return %0 : i32
 }
 
+// -----
+//===----------------------------------------------------------------------===//
+// spv.VectorTimesScalar
+//===----------------------------------------------------------------------===//
+
+func @vector_times_scalar(%vector: vector<4xf32>, %scalar: f32) -> vector<4xf32> {
+  // CHECK: spv.VectorTimesScalar %{{.+}}, %{{.+}} : (vector<4xf32>, f32) -> vector<4xf32>
+  %0 = spv.VectorTimesScalar %vector, %scalar : (vector<4xf32>, f32) -> vector<4xf32>
+  return %0 : vector<4xf32>
+}
+
+// -----
+
+func @vector_times_scalar(%vector: vector<4xf32>, %scalar: f16) -> vector<4xf32> {
+  // expected-error @+1 {{scalar operand and result element type match}}
+  %0 = spv.VectorTimesScalar %vector, %scalar : (vector<4xf32>, f16) -> vector<4xf32>
+  return %0 : vector<4xf32>
+}
+
+// -----
+
+func @vector_times_scalar(%vector: vector<4xf32>, %scalar: f32) -> vector<3xf32> {
+  // expected-error @+1 {{vector operand and result type mismatch}}
+  %0 = spv.VectorTimesScalar %vector, %scalar : (vector<4xf32>, f32) -> vector<3xf32>
+  return %0 : vector<3xf32>
+}

diff  --git a/mlir/test/Target/SPIRV/arithmetic-ops.mlir b/mlir/test/Target/SPIRV/arithmetic-ops.mlir
index 9752c0d0e5799..f332e6672d049 100644
--- a/mlir/test/Target/SPIRV/arithmetic-ops.mlir
+++ b/mlir/test/Target/SPIRV/arithmetic-ops.mlir
@@ -81,4 +81,9 @@ spv.module Logical GLSL450 requires #spv.vce<v1.0, [Shader], []> {
     %0 = spv.SRem %arg0, %arg1 : vector<4xi32>
     spv.Return
   }
+  spv.func @vector_times_scalar(%arg0 : vector<4xf32>, %arg1 : f32) "None" {
+    // CHECK: {{%.*}} = spv.VectorTimesScalar {{%.*}}, {{%.*}} : (vector<4xf32>, f32) -> vector<4xf32>
+    %0 = spv.VectorTimesScalar %arg0, %arg1 : (vector<4xf32>, f32) -> vector<4xf32>
+    spv.Return
+  }
 }


        


More information about the Mlir-commits mailing list