[Mlir-commits] [mlir] 2bcb620 - [mlir][spirv] Add TransposeOp

Lei Zhang llvmlistbot at llvm.org
Wed Jun 24 17:42:01 PDT 2020


Author: HazemAbdelhafez
Date: 2020-06-24T20:41:54-04:00
New Revision: 2bcb62086884fdb5248a8fe9095c1ad08e2ecd50

URL: https://github.com/llvm/llvm-project/commit/2bcb62086884fdb5248a8fe9095c1ad08e2ecd50
DIFF: https://github.com/llvm/llvm-project/commit/2bcb62086884fdb5248a8fe9095c1ad08e2ecd50.diff

LOG: [mlir][spirv] Add TransposeOp

Add Transpose operation to SPIRV dialect.

Differential Revision: https://reviews.llvm.org/D82308

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/SPIRV/SPIRVBase.td
    mlir/include/mlir/Dialect/SPIRV/SPIRVMatrixOps.td
    mlir/lib/Dialect/SPIRV/SPIRVOps.cpp
    mlir/test/Dialect/SPIRV/Serialization/matrix.mlir
    mlir/test/Dialect/SPIRV/matrix-ops.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/SPIRV/SPIRVBase.td b/mlir/include/mlir/Dialect/SPIRV/SPIRVBase.td
index e64abd70333d..6bff480ab83b 100644
--- a/mlir/include/mlir/Dialect/SPIRV/SPIRVBase.td
+++ b/mlir/include/mlir/Dialect/SPIRV/SPIRVBase.td
@@ -3141,6 +3141,7 @@ def SPV_OC_OpMemberDecorate            : I32EnumAttrCase<"OpMemberDecorate", 72>
 def SPV_OC_OpCompositeConstruct        : I32EnumAttrCase<"OpCompositeConstruct", 80>;
 def SPV_OC_OpCompositeExtract          : I32EnumAttrCase<"OpCompositeExtract", 81>;
 def SPV_OC_OpCompositeInsert           : I32EnumAttrCase<"OpCompositeInsert", 82>;
+def SPV_OC_OpTranspose                 : I32EnumAttrCase<"OpTranspose", 84>;
 def SPV_OC_OpConvertFToU               : I32EnumAttrCase<"OpConvertFToU", 109>;
 def SPV_OC_OpConvertFToS               : I32EnumAttrCase<"OpConvertFToS", 110>;
 def SPV_OC_OpConvertSToF               : I32EnumAttrCase<"OpConvertSToF", 111>;
@@ -3265,20 +3266,21 @@ def SPV_OpcodeAttr :
       SPV_OC_OpFunctionParameter, SPV_OC_OpFunctionEnd, SPV_OC_OpFunctionCall,
       SPV_OC_OpVariable, SPV_OC_OpLoad, SPV_OC_OpStore, SPV_OC_OpAccessChain,
       SPV_OC_OpDecorate, SPV_OC_OpMemberDecorate, SPV_OC_OpCompositeConstruct,
-      SPV_OC_OpCompositeExtract, SPV_OC_OpCompositeInsert, SPV_OC_OpConvertFToU,
-      SPV_OC_OpConvertFToS, SPV_OC_OpConvertSToF, SPV_OC_OpConvertUToF,
-      SPV_OC_OpUConvert, SPV_OC_OpSConvert, SPV_OC_OpFConvert, SPV_OC_OpBitcast,
-      SPV_OC_OpFNegate, SPV_OC_OpIAdd, SPV_OC_OpFAdd, SPV_OC_OpISub, SPV_OC_OpFSub,
-      SPV_OC_OpIMul, SPV_OC_OpFMul, SPV_OC_OpUDiv, SPV_OC_OpSDiv, SPV_OC_OpFDiv,
-      SPV_OC_OpUMod, SPV_OC_OpSRem, SPV_OC_OpSMod, SPV_OC_OpFRem, SPV_OC_OpFMod,
-      SPV_OC_OpMatrixTimesScalar, SPV_OC_OpLogicalEqual, SPV_OC_OpLogicalNotEqual,
-      SPV_OC_OpLogicalOr, SPV_OC_OpLogicalAnd, SPV_OC_OpLogicalNot, SPV_OC_OpSelect,
-      SPV_OC_OpIEqual, SPV_OC_OpINotEqual, SPV_OC_OpUGreaterThan,
-      SPV_OC_OpSGreaterThan, SPV_OC_OpUGreaterThanEqual, SPV_OC_OpSGreaterThanEqual,
-      SPV_OC_OpULessThan, SPV_OC_OpSLessThan, SPV_OC_OpULessThanEqual,
-      SPV_OC_OpSLessThanEqual, SPV_OC_OpFOrdEqual, SPV_OC_OpFUnordEqual,
-      SPV_OC_OpFOrdNotEqual, SPV_OC_OpFUnordNotEqual, SPV_OC_OpFOrdLessThan,
-      SPV_OC_OpFUnordLessThan, SPV_OC_OpFOrdGreaterThan, SPV_OC_OpFUnordGreaterThan,
+      SPV_OC_OpCompositeExtract, SPV_OC_OpCompositeInsert, SPV_OC_OpTranspose,
+      SPV_OC_OpConvertFToU, SPV_OC_OpConvertFToS, SPV_OC_OpConvertSToF,
+      SPV_OC_OpConvertUToF, SPV_OC_OpUConvert, SPV_OC_OpSConvert, SPV_OC_OpFConvert,
+      SPV_OC_OpBitcast, SPV_OC_OpFNegate, SPV_OC_OpIAdd, SPV_OC_OpFAdd,
+      SPV_OC_OpISub, SPV_OC_OpFSub, SPV_OC_OpIMul, SPV_OC_OpFMul, SPV_OC_OpUDiv,
+      SPV_OC_OpSDiv, SPV_OC_OpFDiv, SPV_OC_OpUMod, SPV_OC_OpSRem, SPV_OC_OpSMod,
+      SPV_OC_OpFRem, SPV_OC_OpFMod, SPV_OC_OpMatrixTimesScalar,
+      SPV_OC_OpLogicalEqual, SPV_OC_OpLogicalNotEqual, SPV_OC_OpLogicalOr,
+      SPV_OC_OpLogicalAnd, SPV_OC_OpLogicalNot, SPV_OC_OpSelect, SPV_OC_OpIEqual,
+      SPV_OC_OpINotEqual, SPV_OC_OpUGreaterThan, SPV_OC_OpSGreaterThan,
+      SPV_OC_OpUGreaterThanEqual, SPV_OC_OpSGreaterThanEqual, SPV_OC_OpULessThan,
+      SPV_OC_OpSLessThan, SPV_OC_OpULessThanEqual, SPV_OC_OpSLessThanEqual,
+      SPV_OC_OpFOrdEqual, SPV_OC_OpFUnordEqual, SPV_OC_OpFOrdNotEqual,
+      SPV_OC_OpFUnordNotEqual, SPV_OC_OpFOrdLessThan, SPV_OC_OpFUnordLessThan,
+      SPV_OC_OpFOrdGreaterThan, SPV_OC_OpFUnordGreaterThan,
       SPV_OC_OpFOrdLessThanEqual, SPV_OC_OpFUnordLessThanEqual,
       SPV_OC_OpFOrdGreaterThanEqual, SPV_OC_OpFUnordGreaterThanEqual,
       SPV_OC_OpShiftRightLogical, SPV_OC_OpShiftRightArithmetic,

diff  --git a/mlir/include/mlir/Dialect/SPIRV/SPIRVMatrixOps.td b/mlir/include/mlir/Dialect/SPIRV/SPIRVMatrixOps.td
index c71e0cd54e99..07d7fd1093c2 100644
--- a/mlir/include/mlir/Dialect/SPIRV/SPIRVMatrixOps.td
+++ b/mlir/include/mlir/Dialect/SPIRV/SPIRVMatrixOps.td
@@ -45,6 +45,13 @@ def SPV_MatrixTimesScalarOp : SPV_Op<"MatrixTimesScalar", []> {
     ```
   }];
 
+  let availability = [
+    MinVersion<SPV_V_1_0>,
+    MaxVersion<SPV_V_1_5>,
+    Extension<[]>,
+    Capability<[SPV_C_Matrix]>
+  ];
+
   let arguments = (ins
     SPV_AnyMatrix:$matrix,
     SPV_Float:$scalar
@@ -72,4 +79,58 @@ def SPV_MatrixTimesScalarOp : SPV_Op<"MatrixTimesScalar", []> {
 
 // -----
 
+def SPV_TransposeOp : SPV_Op<"Transpose", []> {
+  let summary = "Transpose a matrix.";
+
+  let description = [{
+    Result Type must be an OpTypeMatrix.
+
+    Matrix must be an object of type OpTypeMatrix. The number of columns and
+    the column size of Matrix must be the reverse of those in Result Type.
+    The types of the scalar components in Matrix and Result Type must be the
+    same.
+
+    Matrix must have of type of OpTypeMatrix.
+
+    <!-- End of AutoGen section -->
+
+    ```
+    transpose-op ::= ssa-id `=` `spv.Transpose` ssa-use `:` matrix-type `->`
+    matrix-type
+
+    ```mlir
+
+    #### Example:
+
+    ```
+    %0 = spv.Transpose %matrix: !spv.matrix<2 x vector<3xf32>> ->
+    !spv.matrix<3 x vector<2xf32>>
+
+    ```
+  }];
+
+  let availability = [
+    MinVersion<SPV_V_1_0>,
+    MaxVersion<SPV_V_1_5>,
+    Extension<[]>,
+    Capability<[SPV_C_Matrix]>
+  ];
+
+  let arguments = (ins
+    SPV_AnyMatrix:$matrix
+  );
+
+  let results = (outs
+    SPV_AnyMatrix:$result
+  );
+
+  let assemblyFormat = [{
+    operands attr-dict `:` type($matrix) `->` type($result)
+  }];
+
+  let verifier = [{ return verifyTranspose(*this); }];
+}
+
+// -----
+
 #endif // SPIRV_MATRIX_OPS
\ No newline at end of file

diff  --git a/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp b/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp
index 3742bab414ec..6415218f74c2 100644
--- a/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp
+++ b/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp
@@ -2815,6 +2815,36 @@ static LogicalResult verifyMatrixTimesScalar(spirv::MatrixTimesScalarOp op) {
   return success();
 }
 
+//===----------------------------------------------------------------------===//
+// spv.Transpose
+//===----------------------------------------------------------------------===//
+
+static LogicalResult verifyTranspose(spirv::TransposeOp op) {
+  auto inputMatrix = op.matrix().getType().cast<spirv::MatrixType>();
+  auto resultMatrix = op.result().getType().cast<spirv::MatrixType>();
+
+  // Verify that the input and output matrices have correct shapes.
+  if (auto inputMatrixColumns =
+          inputMatrix.getElementType().dyn_cast<VectorType>()) {
+    if (inputMatrixColumns.getNumElements() != resultMatrix.getNumElements())
+      return op.emitError("input matrix rows count must be equal to "
+                          "output matrix columns count");
+    if (auto resultMatrixColumns =
+            resultMatrix.getElementType().dyn_cast<VectorType>()) {
+      if (resultMatrixColumns.getNumElements() != inputMatrix.getNumElements())
+        return op.emitError("input matrix columns count must be equal "
+                            "to output matrix rows count");
+
+      // Verify that the input and output matrices have the same component type
+      if (inputMatrixColumns.getElementType() !=
+          resultMatrixColumns.getElementType())
+        return op.emitError("input and output matrices must have the "
+                            "same component type");
+    }
+  }
+  return success();
+}
+
 namespace mlir {
 namespace spirv {
 

diff  --git a/mlir/test/Dialect/SPIRV/Serialization/matrix.mlir b/mlir/test/Dialect/SPIRV/Serialization/matrix.mlir
index e10bfc88afb0..6db8a7666768 100644
--- a/mlir/test/Dialect/SPIRV/Serialization/matrix.mlir
+++ b/mlir/test/Dialect/SPIRV/Serialization/matrix.mlir
@@ -22,6 +22,13 @@ spv.module Logical GLSL450 requires #spv.vce<v1.0, [Shader], []> {
     spv.ReturnValue %result : !spv.matrix<3 x vector<3xf16>>
 
   }
+
+  // CHECK-LABEL: @matrix_transpose_1
+  spv.func @matrix_transpose_1(%arg0 : !spv.matrix<3 x vector<2xf32>>) -> !spv.matrix<2 x vector<3xf32>> "None" {
+    // CHECK: {{%.*}} = spv.Transpose {{%.*}} : !spv.matrix<3 x vector<2xf32>> -> !spv.matrix<2 x vector<3xf32>>
+    %result = spv.Transpose %arg0 : !spv.matrix<3 x vector<2xf32>> -> !spv.matrix<2 x vector<3xf32>>
+    spv.ReturnValue %result : !spv.matrix<2 x vector<3xf32>>
+  }
 }
 
 // -----

diff  --git a/mlir/test/Dialect/SPIRV/matrix-ops.mlir b/mlir/test/Dialect/SPIRV/matrix-ops.mlir
index 8079b4ba88f6..09bdf3983005 100644
--- a/mlir/test/Dialect/SPIRV/matrix-ops.mlir
+++ b/mlir/test/Dialect/SPIRV/matrix-ops.mlir
@@ -2,11 +2,25 @@
 
 spv.module Logical GLSL450 requires #spv.vce<v1.0, [Shader], []> {
   // CHECK-LABEL: @matrix_times_scalar
-  spv.func @matrix_times_scalar_1(%arg0 : !spv.matrix<3 x vector<3xf32>>, %arg1 : f32) -> !spv.matrix<3 x vector<3xf32>> "None" {
+  spv.func @matrix_times_scalar(%arg0 : !spv.matrix<3 x vector<3xf32>>, %arg1 : f32) -> !spv.matrix<3 x vector<3xf32>> "None" {
     // CHECK: {{%.*}} = spv.MatrixTimesScalar {{%.*}}, {{%.*}} : !spv.matrix<3 x vector<3xf32>>, f32 -> !spv.matrix<3 x vector<3xf32>>
     %result = spv.MatrixTimesScalar %arg0, %arg1 : !spv.matrix<3 x vector<3xf32>>, f32 -> !spv.matrix<3 x vector<3xf32>>
     spv.ReturnValue %result : !spv.matrix<3 x vector<3xf32>>
   }
+
+  // CHECK-LABEL: @matrix_transpose_1
+  spv.func @matrix_transpose_1(%arg0 : !spv.matrix<3 x vector<2xf32>>) -> !spv.matrix<2 x vector<3xf32>> "None" {
+    // CHECK: {{%.*}} = spv.Transpose {{%.*}} : !spv.matrix<3 x vector<2xf32>> -> !spv.matrix<2 x vector<3xf32>>
+    %result = spv.Transpose %arg0 : !spv.matrix<3 x vector<2xf32>> -> !spv.matrix<2 x vector<3xf32>>
+    spv.ReturnValue %result : !spv.matrix<2 x vector<3xf32>>
+  }
+
+  // CHECK-LABEL: @matrix_transpose_2
+  spv.func @matrix_transpose_2(%arg0 : !spv.matrix<3 x vector<3xf32>>) -> !spv.matrix<3 x vector<3xf32>> "None" {
+    // CHECK: {{%.*}} = spv.Transpose {{%.*}} : !spv.matrix<3 x vector<3xf32>> -> !spv.matrix<3 x vector<3xf32>>
+    %result = spv.Transpose %arg0 : !spv.matrix<3 x vector<3xf32>> -> !spv.matrix<3 x vector<3xf32>>
+    spv.ReturnValue %result : !spv.matrix<3 x vector<3xf32>>
+  }
 }
 
 // -----
@@ -37,5 +51,26 @@ func @input_output_size_mismatch(%arg0 : !spv.matrix<3 x vector<3xf32>>, %arg1 :
    %result = spv.MatrixTimesScalar %arg0, %arg1 : !spv.matrix<3 x vector<3xf32>>, f32 -> !spv.matrix<4 x vector<3xf32>>
 }
 
+// -----
+
+func @transpose_op_shape_mismatch_1(%arg0 : !spv.matrix<3 x vector<4xf32>>) -> () {
+   // expected-error @+1 {{input matrix rows count must be equal to output matrix columns count}}
+   %result = spv.Transpose %arg0 : !spv.matrix<3 x vector<4xf32>> -> !spv.matrix<3 x vector<3xf32>>
+   spv.Return
+}
+
+// -----
 
+func @transpose_op_shape_mismatch_2(%arg0 : !spv.matrix<3 x vector<4xf32>>) -> () {
+   // expected-error @+1 {{input matrix rows count must be equal to output matrix columns count}}
+   %result = spv.Transpose %arg0 : !spv.matrix<3 x vector<4xf32>> -> !spv.matrix<2 x vector<4xf32>>
+   spv.Return
+}
 
+// -----
+
+func @transpose_op_type_mismatch(%arg0 : !spv.matrix<3 x vector<4xf32>>) -> () {
+   // expected-error @+1 {{input and output matrices must have the same component type}}
+   %result = spv.Transpose %arg0 : !spv.matrix<3 x vector<4xf32>> -> !spv.matrix<4 x vector<3xf16>>
+   spv.Return
+}


        


More information about the Mlir-commits mailing list