[Mlir-commits] [mlir] 479ee11 - [mlir] [VectorOps] Introduce vector.transpose

llvmlistbot at llvm.org llvmlistbot at llvm.org
Fri Mar 20 15:47:53 PDT 2020


Author: aartbik
Date: 2020-03-20T15:47:44-07:00
New Revision: 479ee1106153a52861bcec42a58c2fa23ca6d902

URL: https://github.com/llvm/llvm-project/commit/479ee1106153a52861bcec42a58c2fa23ca6d902
DIFF: https://github.com/llvm/llvm-project/commit/479ee1106153a52861bcec42a58c2fa23ca6d902.diff

LOG: [mlir] [VectorOps] Introduce vector.transpose

Summary: Introduced in order to introduce specialized lowering passes that implement transposition operations efficiently.

Reviewers: nicolasvasilache, andydavis1

Reviewed By: nicolasvasilache

Subscribers: mehdi_amini, rriddle, jpienaar, burmako, shauheen, antiagainst, nicolasvasilache, arpith-jacob, mgester, lucyrfox, liufengdb, Joonsoo, llvm-commits

Tags: #llvm

Differential Revision: https://reviews.llvm.org/D76460

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Vector/VectorOps.td
    mlir/lib/Dialect/Vector/VectorOps.cpp
    mlir/test/Dialect/Vector/invalid.mlir
    mlir/test/Dialect/Vector/ops.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Vector/VectorOps.td b/mlir/include/mlir/Dialect/Vector/VectorOps.td
index a3864614d05f..51d7962cdfc4 100644
--- a/mlir/include/mlir/Dialect/Vector/VectorOps.td
+++ b/mlir/include/mlir/Dialect/Vector/VectorOps.td
@@ -1269,6 +1269,45 @@ def Vector_TupleOp :
   }];
 }
 
+def Vector_TransposeOp :
+  Vector_Op<"transpose", [NoSideEffect,
+    PredOpTrait<"operand and result have same element type",
+                 TCresVTEtIsSameAsOpBase<0, 0>>]>,
+    Arguments<(ins AnyVector:$vector, I64ArrayAttr:$transp)>,
+    Results<(outs AnyVector:$result)> {
+  let summary = "vector transpose operation";
+  let description = [{
+    Takes a n-D vector and returns the transposed n-D vector defined by
+    the permutation of ranks in the n-sized integer array attribute.
+    In the operation
+
+    %1 = vector.tranpose %0, [i_1, .., i_n]
+      : vector<d_1 x .. x d_n x f32>
+      to vector<d_trans[0] x .. x d_trans[n-1] x f32>
+
+    the transp array [i_1, .., i_n] must be a permutation of [0, .., n-1].
+
+    Example:
+
+    %1 = vector.tranpose %0, [1, 0] : vector<2x3xf32> to vector<3x2xf32>
+
+     [ [a, b, c],       [ [a, d],
+       [d, e, f] ]  ->    [b, e],
+                          [c, f] ]
+  }];
+  let extraClassDeclaration = [{
+    VectorType getVectorType() {
+      return vector().getType().cast<VectorType>();
+    }
+    VectorType getResultType() {
+      return result().getType().cast<VectorType>();
+    }
+  }];
+  let assemblyFormat = [{
+    $vector `,` $transp attr-dict `:` type($vector) `to` type($result)
+  }];
+}
+
 def Vector_TupleGetOp :
   Vector_Op<"tuple_get", [NoSideEffect]>,
     Arguments<(ins TupleOf<[AnyVector]>:$vectors, APIntAttr:$index)>,

diff  --git a/mlir/lib/Dialect/Vector/VectorOps.cpp b/mlir/lib/Dialect/Vector/VectorOps.cpp
index fba0f4af5f26..816aaf9f5948 100644
--- a/mlir/lib/Dialect/Vector/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/VectorOps.cpp
@@ -1520,6 +1520,35 @@ static void print(OpAsmPrinter &p, TupleOp op) {
 
 static LogicalResult verify(TupleOp op) { return success(); }
 
+//===----------------------------------------------------------------------===//
+// TransposeOp
+//===----------------------------------------------------------------------===//
+
+static LogicalResult verify(TransposeOp op) {
+  VectorType vectorType = op.getVectorType();
+  VectorType resultType = op.getResultType();
+  int64_t rank = resultType.getRank();
+  if (vectorType.getRank() != rank)
+    return op.emitOpError("vector result rank mismatch: ") << rank;
+  // Verify transposition array.
+  auto transpAttr = op.transp().getValue();
+  int64_t size = transpAttr.size();
+  if (rank != size)
+    return op.emitOpError("transposition length mismatch: ") << size;
+  SmallVector<bool, 8> seen(rank, false);
+  for (auto ta : llvm::enumerate(transpAttr)) {
+    int64_t i = ta.value().cast<IntegerAttr>().getInt();
+    if (i < 0 || i >= rank)
+      return op.emitOpError("transposition index out of range: ") << i;
+    if (seen[i])
+      return op.emitOpError("duplicate position index: ") << i;
+    seen[i] = true;
+    if (resultType.getDimSize(ta.index()) != vectorType.getDimSize(i))
+      return op.emitOpError("dimension size mismatch at: ") << i;
+  }
+  return success();
+}
+
 //===----------------------------------------------------------------------===//
 // TupleGetOp
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/test/Dialect/Vector/invalid.mlir b/mlir/test/Dialect/Vector/invalid.mlir
index d9093edb3765..bb5eca0a361b 100644
--- a/mlir/test/Dialect/Vector/invalid.mlir
+++ b/mlir/test/Dialect/Vector/invalid.mlir
@@ -1049,6 +1049,41 @@ func @reduce_unsupported_rank(%arg0: vector<4x16xf32>) -> f32 {
 
 // -----
 
+func @transpose_rank_mismatch(%arg0: vector<4x16x11xf32>) {
+  // expected-error at +1 {{'vector.transpose' op vector result rank mismatch: 1}}
+  %0 = vector.transpose %arg0, [2, 1, 0] : vector<4x16x11xf32> to vector<100xf32>
+}
+
+// -----
+
+func @transpose_length_mismatch(%arg0: vector<4x4xf32>) {
+  // expected-error at +1 {{'vector.transpose' op transposition length mismatch: 3}}
+  %0 = vector.transpose %arg0, [2, 0, 1] : vector<4x4xf32> to vector<4x4xf32>
+}
+
+// -----
+
+func @transpose_index_oob(%arg0: vector<4x4xf32>) {
+  // expected-error at +1 {{'vector.transpose' op transposition index out of range: 2}}
+  %0 = vector.transpose %arg0, [2, 0] : vector<4x4xf32> to vector<4x4xf32>
+}
+
+// -----
+
+func @transpose_index_dup(%arg0: vector<4x4xf32>) {
+  // expected-error at +1 {{'vector.transpose' op duplicate position index: 0}}
+  %0 = vector.transpose %arg0, [0, 0] : vector<4x4xf32> to vector<4x4xf32>
+}
+
+// -----
+
+func @transpose_dim_size_mismatch(%arg0: vector<11x7x3x2xi32>) {
+  // expected-error at +1 {{'vector.transpose' op dimension size mismatch at: 0}}
+  %0 = vector.transpose %arg0, [3, 0, 1, 2] : vector<11x7x3x2xi32> to vector<2x3x7x11xi32>
+}
+
+// -----
+
 func @type_cast_layout(%arg0: memref<4x3xf32, affine_map<(d0, d1)[s0, s1, s2] -> (d0 * s0 + d1 * s1 + s2)>>) {
   // expected-error at +1 {{expects operand to be a memref with no layout}}
   %0 = vector.type_cast %arg0: memref<4x3xf32, affine_map<(d0, d1)[s0, s1, s2] -> (d0 * s0 + d1 * s1 + s2)>> to memref<vector<4x3xf32>>

diff  --git a/mlir/test/Dialect/Vector/ops.mlir b/mlir/test/Dialect/Vector/ops.mlir
index f286b932a472..a3b3fcc9c23c 100644
--- a/mlir/test/Dialect/Vector/ops.mlir
+++ b/mlir/test/Dialect/Vector/ops.mlir
@@ -315,3 +315,15 @@ func @reduce_int(%arg0: vector<16xi32>) -> i32 {
   // CHECK:    return %[[X]] : i32
   return %0 : i32
 }
+
+// CHECK-LABEL: transpose_fp
+func @transpose_fp(%arg0: vector<3x7xf32>) -> vector<7x3xf32> {
+  %0 = vector.transpose %arg0, [1, 0] : vector<3x7xf32> to vector<7x3xf32>
+  return %0 : vector<7x3xf32>
+}
+
+// CHECK-LABEL: transpose_int
+func @transpose_int(%arg0: vector<11x7x3x2xi32>) -> vector<2x11x7x3xi32> {
+  %0 = vector.transpose %arg0, [3, 0, 1, 2] : vector<11x7x3x2xi32> to vector<2x11x7x3xi32>
+  return %0 : vector<2x11x7x3xi32>
+}


        


More information about the Mlir-commits mailing list