[Mlir-commits] [mlir] e901188 - [mlir][spirv] Define sp.VectorShuffle
Lei Zhang
llvmlistbot at llvm.org
Tue Feb 2 08:12:09 PST 2021
Author: Lei Zhang
Date: 2021-02-02T11:08:56-05:00
New Revision: e901188cf9e34b54db7b1c5359264549d5f5be4f
URL: https://github.com/llvm/llvm-project/commit/e901188cf9e34b54db7b1c5359264549d5f5be4f
DIFF: https://github.com/llvm/llvm-project/commit/e901188cf9e34b54db7b1c5359264549d5f5be4f.diff
LOG: [mlir][spirv] Define sp.VectorShuffle
This patch adds basic op definition, parser/printer, and verifier.
Reviewed By: ThomasRaoux
Differential Revision: https://reviews.llvm.org/D95825
Added:
Modified:
mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
mlir/include/mlir/Dialect/SPIRV/IR/SPIRVCompositeOps.td
mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
mlir/test/Dialect/SPIRV/IR/composite-ops.mlir
mlir/test/Target/SPIRV/composite-op.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
index 0c2d91133b27..afeca5532e07 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
@@ -3187,6 +3187,7 @@ def SPV_OC_OpDecorate : I32EnumAttrCase<"OpDecorate", 71>;
def SPV_OC_OpMemberDecorate : I32EnumAttrCase<"OpMemberDecorate", 72>;
def SPV_OC_OpVectorExtractDynamic : I32EnumAttrCase<"OpVectorExtractDynamic", 77>;
def SPV_OC_OpVectorInsertDynamic : I32EnumAttrCase<"OpVectorInsertDynamic", 78>;
+def SPV_OC_OpVectorShuffle : I32EnumAttrCase<"OpVectorShuffle", 79>;
def SPV_OC_OpCompositeConstruct : I32EnumAttrCase<"OpCompositeConstruct", 80>;
def SPV_OC_OpCompositeExtract : I32EnumAttrCase<"OpCompositeExtract", 81>;
def SPV_OC_OpCompositeInsert : I32EnumAttrCase<"OpCompositeInsert", 82>;
@@ -3327,7 +3328,7 @@ def SPV_OpcodeAttr :
SPV_OC_OpVariable, SPV_OC_OpLoad, SPV_OC_OpStore, SPV_OC_OpCopyMemory,
SPV_OC_OpAccessChain, SPV_OC_OpDecorate, SPV_OC_OpMemberDecorate,
SPV_OC_OpVectorExtractDynamic, SPV_OC_OpVectorInsertDynamic,
- SPV_OC_OpCompositeConstruct, SPV_OC_OpCompositeExtract,
+ SPV_OC_OpVectorShuffle, SPV_OC_OpCompositeConstruct, SPV_OC_OpCompositeExtract,
SPV_OC_OpCompositeInsert, SPV_OC_OpTranspose, SPV_OC_OpConvertFToU,
SPV_OC_OpConvertFToS, SPV_OC_OpConvertSToF, SPV_OC_OpConvertUToF,
SPV_OC_OpUConvert, SPV_OC_OpSConvert, SPV_OC_OpFConvert, SPV_OC_OpBitcast,
diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVCompositeOps.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVCompositeOps.td
index d73217117027..2ac4fb0cc5f6 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVCompositeOps.td
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVCompositeOps.td
@@ -165,17 +165,17 @@ def SPV_CompositeInsertOp : SPV_Op<"CompositeInsert",
let builders = [
OpBuilderDAG<(ins "Value":$object, "Value":$composite,
- "ArrayRef<int32_t>":$indices)>
+ "ArrayRef<int32_t>":$indices)>
];
}
// -----
-def SPV_VectorExtractDynamicOp : SPV_Op<"VectorExtractDynamic",
- [NoSideEffect,
- TypesMatchWith<"type of 'result' matches element type of 'vector'",
- "vector", "result",
- "$_self.cast<mlir::VectorType>().getElementType()">]> {
+def SPV_VectorExtractDynamicOp : SPV_Op<"VectorExtractDynamic", [
+ NoSideEffect,
+ TypesMatchWith<"type of 'result' matches element type of 'vector'",
+ "vector", "result",
+ "$_self.cast<mlir::VectorType>().getElementType()">]> {
let summary = [{
Extract a single, dynamically selected, component of a vector.
}];
@@ -194,13 +194,6 @@ def SPV_VectorExtractDynamicOp : SPV_Op<"VectorExtractDynamic",
<!-- End of AutoGen section -->
- ```
- scalar-type ::= integer-type | float-type | boolean-type
- vector-extract-dynamic-op ::= `spv.VectorExtractDynamic ` ssa-use `[` ssa-use `]`
- `:` `vector<` integer-literal `x` scalar-type `>` `,`
- integer-type
- ```mlir
-
#### Example:
```
@@ -226,12 +219,13 @@ def SPV_VectorExtractDynamicOp : SPV_Op<"VectorExtractDynamic",
// -----
-def SPV_VectorInsertDynamicOp : SPV_Op<"VectorInsertDynamic",
- [NoSideEffect,
- TypesMatchWith<"type of 'component' matches element type of 'vector'",
- "vector", "component",
- "$_self.cast<mlir::VectorType>().getElementType()">,
- AllTypesMatch<["vector", "result"]>]> {
+def SPV_VectorInsertDynamicOp : SPV_Op<"VectorInsertDynamic", [
+ NoSideEffect,
+ TypesMatchWith<
+ "type of 'component' matches element type of 'vector'",
+ "vector", "component",
+ "$_self.cast<mlir::VectorType>().getElementType()">,
+ AllTypesMatch<["vector", "result"]>]> {
let summary = [{
Make a copy of a vector, with a single, variably selected, component
modified.
@@ -289,4 +283,64 @@ def SPV_VectorInsertDynamicOp : SPV_Op<"VectorInsertDynamic",
// -----
+def SPV_VectorShuffleOp : SPV_Op<"VectorShuffle", [
+ NoSideEffect, AllElementTypesMatch<["vector1", "vector2", "result"]>]> {
+ let summary = [{
+ Select arbitrary components from two vectors to make a new vector.
+ }];
+
+ let description = [{
+ Result Type must be an OpTypeVector. The number of components in Result
+ Type must be the same as the number of Component operands.
+
+ Vector 1 and Vector 2 must both have vector types, with the same
+ Component Type as Result Type. They do not have to have the same number
+ of components as Result Type or with each other. They are logically
+ concatenated, forming a single vector with Vector 1’s components
+ appearing before Vector 2’s. The components of this logical vector are
+ logically numbered with a single consecutive set of numbers from 0 to N
+ - 1, where N is the total number of components.
+
+ Components are these logical numbers (see above), selecting which of the
+ logically numbered components form the result. Each component is an
+ unsigned 32-bit integer. They can select the components in any order
+ and can repeat components. The first component of the result is selected
+ by the first Component operand, the second component of the result is
+ selected by the second Component operand, etc. A Component literal may
+ also be FFFFFFFF, which means the corresponding result component has no
+ source and is undefined. All Component literals must either be FFFFFFFF
+ or in [0, N - 1] (inclusive).
+
+ Note: A vector “swizzle” can be done by using the vector for both Vector
+ operands, or using an OpUndef for one of the Vector operands.
+
+ <!-- End of AutoGen section -->
+
+ #### Example:
+
+ ```mlir
+ %0 = spv.VectorShuffle [1: i32, 3: i32, 5: i32]
+ %vector1: vector<4xf32>, %vector2: vector<2xf32>
+ -> vector<3xf32>
+ ```
+ }];
+
+ let arguments = (ins
+ SPV_Vector:$vector1,
+ SPV_Vector:$vector2,
+ I32ArrayAttr:$components
+ );
+
+ let results = (outs
+ SPV_Vector:$result
+ );
+
+ let assemblyFormat = [{
+ attr-dict $components $vector1 `:` type($vector1) `,`
+ $vector2 `:` type($vector2) `->` type($result)
+ }];
+}
+
+// -----
+
#endif // MLIR_DIALECT_SPIRV_IR_COMPOSITE_OPS
diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
index 4506447b0503..fc1a705107ee 100644
--- a/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
@@ -24,6 +24,7 @@
#include "mlir/IR/FunctionImplementation.h"
#include "mlir/IR/OpDefinition.h"
#include "mlir/IR/OpImplementation.h"
+#include "mlir/IR/TypeUtilities.h"
#include "mlir/Interfaces/CallInterfaces.h"
#include "llvm/ADT/APFloat.h"
#include "llvm/ADT/APInt.h"
@@ -3036,6 +3037,36 @@ static LogicalResult verify(spirv::VariableOp varOp) {
return success();
}
+//===----------------------------------------------------------------------===//
+// spv.VectorShuffle
+//===----------------------------------------------------------------------===//
+
+static LogicalResult verify(spirv::VectorShuffleOp shuffleOp) {
+ VectorType resultType = shuffleOp.getType().cast<VectorType>();
+
+ size_t numResultElements = resultType.getNumElements();
+ if (numResultElements != shuffleOp.components().size())
+ return shuffleOp.emitOpError("result type element count (")
+ << numResultElements
+ << ") mismatch with the number of component selectors ("
+ << shuffleOp.components().size() << ")";
+
+ size_t totalSrcElements =
+ shuffleOp.vector1().getType().cast<VectorType>().getNumElements() +
+ shuffleOp.vector2().getType().cast<VectorType>().getNumElements();
+
+ for (const auto &selector :
+ shuffleOp.components().getAsValueRange<IntegerAttr>()) {
+ uint32_t index = selector.getZExtValue();
+ if (index >= totalSrcElements &&
+ index != std::numeric_limits<uint32_t>().max())
+ return shuffleOp.emitOpError("component selector ")
+ << index << " out of range: expected to be in [0, "
+ << totalSrcElements << ") or 0xffffffff";
+ }
+ return success();
+}
+
//===----------------------------------------------------------------------===//
// spv.CooperativeMatrixLoadNV
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/SPIRV/IR/composite-ops.mlir b/mlir/test/Dialect/SPIRV/IR/composite-ops.mlir
index 77d091fe1107..08b7fd26e1db 100644
--- a/mlir/test/Dialect/SPIRV/IR/composite-ops.mlir
+++ b/mlir/test/Dialect/SPIRV/IR/composite-ops.mlir
@@ -283,3 +283,31 @@ func @vector_dynamic_insert(%val: f32, %vec: vector<4xf32>, %id : i32) -> vector
%0 = spv.VectorInsertDynamic %val, %vec[%id] : vector<4xf32>, i32
return %0 : vector<4xf32>
}
+
+// -----
+
+//===----------------------------------------------------------------------===//
+// spv.VectorShuffle
+//===----------------------------------------------------------------------===//
+
+func @vector_shuffle(%vector1: vector<4xf32>, %vector2: vector<2xf32>) -> vector<3xf32> {
+ // CHECK: %{{.+}} = spv.VectorShuffle [1 : i32, 3 : i32, -1 : i32] %{{.+}} : vector<4xf32>, %arg1 : vector<2xf32> -> vector<3xf32>
+ %0 = spv.VectorShuffle [1: i32, 3: i32, 0xffffffff: i32] %vector1: vector<4xf32>, %vector2: vector<2xf32> -> vector<3xf32>
+ return %0: vector<3xf32>
+}
+
+// -----
+
+func @vector_shuffle_extra_selector(%vector1: vector<4xf32>, %vector2: vector<2xf32>) -> vector<3xf32> {
+ // expected-error @+1 {{result type element count (3) mismatch with the number of component selectors (4)}}
+ %0 = spv.VectorShuffle [1: i32, 3: i32, 5: i32, 2: i32] %vector1: vector<4xf32>, %vector2: vector<2xf32> -> vector<3xf32>
+ return %0: vector<3xf32>
+}
+
+// -----
+
+func @vector_shuffle_extra_selector(%vector1: vector<4xf32>, %vector2: vector<2xf32>) -> vector<3xf32> {
+ // expected-error @+1 {{component selector 7 out of range: expected to be in [0, 6) or 0xffffffff}}
+ %0 = spv.VectorShuffle [1: i32, 7: i32, 5: i32] %vector1: vector<4xf32>, %vector2: vector<2xf32> -> vector<3xf32>
+ return %0: vector<3xf32>
+}
diff --git a/mlir/test/Target/SPIRV/composite-op.mlir b/mlir/test/Target/SPIRV/composite-op.mlir
index 468d5419081a..9b0462b5332f 100644
--- a/mlir/test/Target/SPIRV/composite-op.mlir
+++ b/mlir/test/Target/SPIRV/composite-op.mlir
@@ -21,4 +21,9 @@ spv.module Logical GLSL450 requires #spv.vce<v1.0, [Shader], []> {
%0 = spv.VectorInsertDynamic %val, %vec[%id] : vector<4xf32>, i32
spv.ReturnValue %0: vector<4xf32>
}
+ spv.func @vector_shuffle(%vector1: vector<4xf32>, %vector2: vector<2xf32>) -> vector<3xf32> "None" {
+ // CHECK: %{{.+}} = spv.VectorShuffle [1 : i32, 3 : i32, -1 : i32] %{{.+}} : vector<4xf32>, %arg1 : vector<2xf32> -> vector<3xf32>
+ %0 = spv.VectorShuffle [1: i32, 3: i32, 0xffffffff: i32] %vector1: vector<4xf32>, %vector2: vector<2xf32> -> vector<3xf32>
+ spv.ReturnValue %0: vector<3xf32>
+ }
}
More information about the Mlir-commits
mailing list