[Mlir-commits] [mlir] e901188 - [mlir][spirv] Define sp.VectorShuffle

Lei Zhang llvmlistbot at llvm.org
Tue Feb 2 08:12:09 PST 2021


Author: Lei Zhang
Date: 2021-02-02T11:08:56-05:00
New Revision: e901188cf9e34b54db7b1c5359264549d5f5be4f

URL: https://github.com/llvm/llvm-project/commit/e901188cf9e34b54db7b1c5359264549d5f5be4f
DIFF: https://github.com/llvm/llvm-project/commit/e901188cf9e34b54db7b1c5359264549d5f5be4f.diff

LOG: [mlir][spirv] Define sp.VectorShuffle

This patch adds basic op definition, parser/printer, and verifier.

Reviewed By: ThomasRaoux

Differential Revision: https://reviews.llvm.org/D95825

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
    mlir/include/mlir/Dialect/SPIRV/IR/SPIRVCompositeOps.td
    mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
    mlir/test/Dialect/SPIRV/IR/composite-ops.mlir
    mlir/test/Target/SPIRV/composite-op.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
index 0c2d91133b27..afeca5532e07 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
@@ -3187,6 +3187,7 @@ def SPV_OC_OpDecorate                  : I32EnumAttrCase<"OpDecorate", 71>;
 def SPV_OC_OpMemberDecorate            : I32EnumAttrCase<"OpMemberDecorate", 72>;
 def SPV_OC_OpVectorExtractDynamic      : I32EnumAttrCase<"OpVectorExtractDynamic", 77>;
 def SPV_OC_OpVectorInsertDynamic       : I32EnumAttrCase<"OpVectorInsertDynamic", 78>;
+def SPV_OC_OpVectorShuffle             : I32EnumAttrCase<"OpVectorShuffle", 79>;
 def SPV_OC_OpCompositeConstruct        : I32EnumAttrCase<"OpCompositeConstruct", 80>;
 def SPV_OC_OpCompositeExtract          : I32EnumAttrCase<"OpCompositeExtract", 81>;
 def SPV_OC_OpCompositeInsert           : I32EnumAttrCase<"OpCompositeInsert", 82>;
@@ -3327,7 +3328,7 @@ def SPV_OpcodeAttr :
       SPV_OC_OpVariable, SPV_OC_OpLoad, SPV_OC_OpStore, SPV_OC_OpCopyMemory,
       SPV_OC_OpAccessChain, SPV_OC_OpDecorate, SPV_OC_OpMemberDecorate,
       SPV_OC_OpVectorExtractDynamic, SPV_OC_OpVectorInsertDynamic,
-      SPV_OC_OpCompositeConstruct, SPV_OC_OpCompositeExtract,
+      SPV_OC_OpVectorShuffle, SPV_OC_OpCompositeConstruct, SPV_OC_OpCompositeExtract,
       SPV_OC_OpCompositeInsert, SPV_OC_OpTranspose, SPV_OC_OpConvertFToU,
       SPV_OC_OpConvertFToS, SPV_OC_OpConvertSToF, SPV_OC_OpConvertUToF,
       SPV_OC_OpUConvert, SPV_OC_OpSConvert, SPV_OC_OpFConvert, SPV_OC_OpBitcast,

diff  --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVCompositeOps.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVCompositeOps.td
index d73217117027..2ac4fb0cc5f6 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVCompositeOps.td
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVCompositeOps.td
@@ -165,17 +165,17 @@ def SPV_CompositeInsertOp : SPV_Op<"CompositeInsert",
 
   let builders = [
     OpBuilderDAG<(ins "Value":$object, "Value":$composite,
-      "ArrayRef<int32_t>":$indices)>
+                      "ArrayRef<int32_t>":$indices)>
   ];
 }
 
 // -----
 
-def SPV_VectorExtractDynamicOp : SPV_Op<"VectorExtractDynamic",
-            [NoSideEffect,
-            TypesMatchWith<"type of 'result' matches element type of 'vector'",
-                     "vector", "result",
-                     "$_self.cast<mlir::VectorType>().getElementType()">]> {
+def SPV_VectorExtractDynamicOp : SPV_Op<"VectorExtractDynamic", [
+    NoSideEffect,
+    TypesMatchWith<"type of 'result' matches element type of 'vector'",
+                   "vector", "result",
+                   "$_self.cast<mlir::VectorType>().getElementType()">]> {
   let summary = [{
     Extract a single, dynamically selected, component of a vector.
   }];
@@ -194,13 +194,6 @@ def SPV_VectorExtractDynamicOp : SPV_Op<"VectorExtractDynamic",
 
     <!-- End of AutoGen section -->
 
-    ```
-	scalar-type ::= integer-type | float-type | boolean-type 
-    vector-extract-dynamic-op ::= `spv.VectorExtractDynamic ` ssa-use `[` ssa-use `]`
-                                    `:` `vector<` integer-literal `x` scalar-type `>` `,`
-                                    integer-type
-    ```mlir
-
     #### Example:
 
     ```
@@ -226,12 +219,13 @@ def SPV_VectorExtractDynamicOp : SPV_Op<"VectorExtractDynamic",
 
 // -----
 
-def SPV_VectorInsertDynamicOp : SPV_Op<"VectorInsertDynamic",
-        [NoSideEffect,
-        TypesMatchWith<"type of 'component' matches element type of 'vector'",
-                "vector", "component",
-                "$_self.cast<mlir::VectorType>().getElementType()">,
-                AllTypesMatch<["vector", "result"]>]> {
+def SPV_VectorInsertDynamicOp : SPV_Op<"VectorInsertDynamic", [
+    NoSideEffect,
+    TypesMatchWith<
+      "type of 'component' matches element type of 'vector'",
+      "vector", "component",
+      "$_self.cast<mlir::VectorType>().getElementType()">,
+    AllTypesMatch<["vector", "result"]>]> {
   let summary = [{
     Make a copy of a vector, with a single, variably selected, component
     modified.
@@ -289,4 +283,64 @@ def SPV_VectorInsertDynamicOp : SPV_Op<"VectorInsertDynamic",
 
 // -----
 
+def SPV_VectorShuffleOp : SPV_Op<"VectorShuffle", [
+    NoSideEffect, AllElementTypesMatch<["vector1", "vector2", "result"]>]> {
+  let summary = [{
+    Select arbitrary components from two vectors to make a new vector.
+  }];
+
+  let description = [{
+    Result Type must be an OpTypeVector. The number of components in Result
+    Type must be the same as the number of Component operands.
+
+    Vector 1 and Vector 2 must both have vector types, with the same
+    Component Type as Result Type. They do not have to have the same number
+    of components as Result Type or with each other. They are logically
+    concatenated, forming a single vector with Vector 1’s components
+    appearing before Vector 2’s. The components of this logical vector are
+    logically numbered with a single consecutive set of numbers from 0 to N
+    - 1, where N is the total number of components.
+
+    Components are these logical numbers (see above), selecting which of the
+    logically numbered components form the result. Each component is an
+    unsigned 32-bit integer.  They can select the components in any order
+    and can repeat components. The first component of the result is selected
+    by the first Component operand,  the second component of the result is
+    selected by the second Component operand, etc. A Component literal may
+    also be FFFFFFFF, which means the corresponding result component has no
+    source and is undefined. All Component literals must either be FFFFFFFF
+    or in [0, N - 1] (inclusive).
+
+    Note: A vector “swizzle” can be done by using the vector for both Vector
+    operands, or using an OpUndef for one of the Vector operands.
+
+    <!-- End of AutoGen section -->
+
+    #### Example:
+
+    ```mlir
+    %0 = spv.VectorShuffle [1: i32, 3: i32, 5: i32]
+                           %vector1: vector<4xf32>, %vector2: vector<2xf32>
+                        -> vector<3xf32>
+    ```
+  }];
+
+  let arguments = (ins
+    SPV_Vector:$vector1,
+    SPV_Vector:$vector2,
+    I32ArrayAttr:$components
+  );
+
+  let results = (outs
+    SPV_Vector:$result
+  );
+
+  let assemblyFormat = [{
+    attr-dict $components $vector1 `:` type($vector1) `,`
+                          $vector2 `:` type($vector2) `->` type($result)
+  }];
+}
+
+// -----
+
 #endif // MLIR_DIALECT_SPIRV_IR_COMPOSITE_OPS

diff  --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
index 4506447b0503..fc1a705107ee 100644
--- a/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
@@ -24,6 +24,7 @@
 #include "mlir/IR/FunctionImplementation.h"
 #include "mlir/IR/OpDefinition.h"
 #include "mlir/IR/OpImplementation.h"
+#include "mlir/IR/TypeUtilities.h"
 #include "mlir/Interfaces/CallInterfaces.h"
 #include "llvm/ADT/APFloat.h"
 #include "llvm/ADT/APInt.h"
@@ -3036,6 +3037,36 @@ static LogicalResult verify(spirv::VariableOp varOp) {
   return success();
 }
 
+//===----------------------------------------------------------------------===//
+// spv.VectorShuffle
+//===----------------------------------------------------------------------===//
+
+static LogicalResult verify(spirv::VectorShuffleOp shuffleOp) {
+  VectorType resultType = shuffleOp.getType().cast<VectorType>();
+
+  size_t numResultElements = resultType.getNumElements();
+  if (numResultElements != shuffleOp.components().size())
+    return shuffleOp.emitOpError("result type element count (")
+           << numResultElements
+           << ") mismatch with the number of component selectors ("
+           << shuffleOp.components().size() << ")";
+
+  size_t totalSrcElements =
+      shuffleOp.vector1().getType().cast<VectorType>().getNumElements() +
+      shuffleOp.vector2().getType().cast<VectorType>().getNumElements();
+
+  for (const auto &selector :
+       shuffleOp.components().getAsValueRange<IntegerAttr>()) {
+    uint32_t index = selector.getZExtValue();
+    if (index >= totalSrcElements &&
+        index != std::numeric_limits<uint32_t>().max())
+      return shuffleOp.emitOpError("component selector ")
+             << index << " out of range: expected to be in [0, "
+             << totalSrcElements << ") or 0xffffffff";
+  }
+  return success();
+}
+
 //===----------------------------------------------------------------------===//
 // spv.CooperativeMatrixLoadNV
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/test/Dialect/SPIRV/IR/composite-ops.mlir b/mlir/test/Dialect/SPIRV/IR/composite-ops.mlir
index 77d091fe1107..08b7fd26e1db 100644
--- a/mlir/test/Dialect/SPIRV/IR/composite-ops.mlir
+++ b/mlir/test/Dialect/SPIRV/IR/composite-ops.mlir
@@ -283,3 +283,31 @@ func @vector_dynamic_insert(%val: f32, %vec: vector<4xf32>, %id : i32) -> vector
   %0 = spv.VectorInsertDynamic %val, %vec[%id] : vector<4xf32>, i32
   return %0 : vector<4xf32>
 }
+
+// -----
+
+//===----------------------------------------------------------------------===//
+// spv.VectorShuffle
+//===----------------------------------------------------------------------===//
+
+func @vector_shuffle(%vector1: vector<4xf32>, %vector2: vector<2xf32>) -> vector<3xf32> {
+  // CHECK: %{{.+}} = spv.VectorShuffle [1 : i32, 3 : i32, -1 : i32] %{{.+}} : vector<4xf32>, %arg1 : vector<2xf32> -> vector<3xf32>
+  %0 = spv.VectorShuffle [1: i32, 3: i32, 0xffffffff: i32] %vector1: vector<4xf32>, %vector2: vector<2xf32> -> vector<3xf32>
+  return %0: vector<3xf32>
+}
+
+// -----
+
+func @vector_shuffle_extra_selector(%vector1: vector<4xf32>, %vector2: vector<2xf32>) -> vector<3xf32> {
+  // expected-error @+1 {{result type element count (3) mismatch with the number of component selectors (4)}}
+  %0 = spv.VectorShuffle [1: i32, 3: i32, 5: i32, 2: i32] %vector1: vector<4xf32>, %vector2: vector<2xf32> -> vector<3xf32>
+  return %0: vector<3xf32>
+}
+
+// -----
+
+func @vector_shuffle_extra_selector(%vector1: vector<4xf32>, %vector2: vector<2xf32>) -> vector<3xf32> {
+  // expected-error @+1 {{component selector 7 out of range: expected to be in [0, 6) or 0xffffffff}}
+  %0 = spv.VectorShuffle [1: i32, 7: i32, 5: i32] %vector1: vector<4xf32>, %vector2: vector<2xf32> -> vector<3xf32>
+  return %0: vector<3xf32>
+}

diff  --git a/mlir/test/Target/SPIRV/composite-op.mlir b/mlir/test/Target/SPIRV/composite-op.mlir
index 468d5419081a..9b0462b5332f 100644
--- a/mlir/test/Target/SPIRV/composite-op.mlir
+++ b/mlir/test/Target/SPIRV/composite-op.mlir
@@ -21,4 +21,9 @@ spv.module Logical GLSL450 requires #spv.vce<v1.0, [Shader], []> {
     %0 = spv.VectorInsertDynamic %val, %vec[%id] : vector<4xf32>, i32
     spv.ReturnValue %0: vector<4xf32>
   }
+  spv.func @vector_shuffle(%vector1: vector<4xf32>, %vector2: vector<2xf32>) -> vector<3xf32> "None" {
+    // CHECK: %{{.+}} = spv.VectorShuffle [1 : i32, 3 : i32, -1 : i32] %{{.+}} : vector<4xf32>, %arg1 : vector<2xf32> -> vector<3xf32>
+    %0 = spv.VectorShuffle [1: i32, 3: i32, 0xffffffff: i32] %vector1: vector<4xf32>, %vector2: vector<2xf32> -> vector<3xf32>
+    spv.ReturnValue %0: vector<3xf32>
+  }
 }


        


More information about the Mlir-commits mailing list