[Mlir-commits] [mlir] 479ee11 - [mlir] [VectorOps] Introduce vector.transpose
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Fri Mar 20 15:47:53 PDT 2020
Author: aartbik
Date: 2020-03-20T15:47:44-07:00
New Revision: 479ee1106153a52861bcec42a58c2fa23ca6d902
URL: https://github.com/llvm/llvm-project/commit/479ee1106153a52861bcec42a58c2fa23ca6d902
DIFF: https://github.com/llvm/llvm-project/commit/479ee1106153a52861bcec42a58c2fa23ca6d902.diff
LOG: [mlir] [VectorOps] Introduce vector.transpose
Summary: Introduced in order to introduce specialized lowering passes that implement transposition operations efficiently.
Reviewers: nicolasvasilache, andydavis1
Reviewed By: nicolasvasilache
Subscribers: mehdi_amini, rriddle, jpienaar, burmako, shauheen, antiagainst, nicolasvasilache, arpith-jacob, mgester, lucyrfox, liufengdb, Joonsoo, llvm-commits
Tags: #llvm
Differential Revision: https://reviews.llvm.org/D76460
Added:
Modified:
mlir/include/mlir/Dialect/Vector/VectorOps.td
mlir/lib/Dialect/Vector/VectorOps.cpp
mlir/test/Dialect/Vector/invalid.mlir
mlir/test/Dialect/Vector/ops.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Vector/VectorOps.td b/mlir/include/mlir/Dialect/Vector/VectorOps.td
index a3864614d05f..51d7962cdfc4 100644
--- a/mlir/include/mlir/Dialect/Vector/VectorOps.td
+++ b/mlir/include/mlir/Dialect/Vector/VectorOps.td
@@ -1269,6 +1269,45 @@ def Vector_TupleOp :
}];
}
+def Vector_TransposeOp :
+ Vector_Op<"transpose", [NoSideEffect,
+ PredOpTrait<"operand and result have same element type",
+ TCresVTEtIsSameAsOpBase<0, 0>>]>,
+ Arguments<(ins AnyVector:$vector, I64ArrayAttr:$transp)>,
+ Results<(outs AnyVector:$result)> {
+ let summary = "vector transpose operation";
+ let description = [{
+ Takes a n-D vector and returns the transposed n-D vector defined by
+ the permutation of ranks in the n-sized integer array attribute.
+ In the operation
+
+ %1 = vector.tranpose %0, [i_1, .., i_n]
+ : vector<d_1 x .. x d_n x f32>
+ to vector<d_trans[0] x .. x d_trans[n-1] x f32>
+
+ the transp array [i_1, .., i_n] must be a permutation of [0, .., n-1].
+
+ Example:
+
+ %1 = vector.tranpose %0, [1, 0] : vector<2x3xf32> to vector<3x2xf32>
+
+ [ [a, b, c], [ [a, d],
+ [d, e, f] ] -> [b, e],
+ [c, f] ]
+ }];
+ let extraClassDeclaration = [{
+ VectorType getVectorType() {
+ return vector().getType().cast<VectorType>();
+ }
+ VectorType getResultType() {
+ return result().getType().cast<VectorType>();
+ }
+ }];
+ let assemblyFormat = [{
+ $vector `,` $transp attr-dict `:` type($vector) `to` type($result)
+ }];
+}
+
def Vector_TupleGetOp :
Vector_Op<"tuple_get", [NoSideEffect]>,
Arguments<(ins TupleOf<[AnyVector]>:$vectors, APIntAttr:$index)>,
diff --git a/mlir/lib/Dialect/Vector/VectorOps.cpp b/mlir/lib/Dialect/Vector/VectorOps.cpp
index fba0f4af5f26..816aaf9f5948 100644
--- a/mlir/lib/Dialect/Vector/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/VectorOps.cpp
@@ -1520,6 +1520,35 @@ static void print(OpAsmPrinter &p, TupleOp op) {
static LogicalResult verify(TupleOp op) { return success(); }
+//===----------------------------------------------------------------------===//
+// TransposeOp
+//===----------------------------------------------------------------------===//
+
+static LogicalResult verify(TransposeOp op) {
+ VectorType vectorType = op.getVectorType();
+ VectorType resultType = op.getResultType();
+ int64_t rank = resultType.getRank();
+ if (vectorType.getRank() != rank)
+ return op.emitOpError("vector result rank mismatch: ") << rank;
+ // Verify transposition array.
+ auto transpAttr = op.transp().getValue();
+ int64_t size = transpAttr.size();
+ if (rank != size)
+ return op.emitOpError("transposition length mismatch: ") << size;
+ SmallVector<bool, 8> seen(rank, false);
+ for (auto ta : llvm::enumerate(transpAttr)) {
+ int64_t i = ta.value().cast<IntegerAttr>().getInt();
+ if (i < 0 || i >= rank)
+ return op.emitOpError("transposition index out of range: ") << i;
+ if (seen[i])
+ return op.emitOpError("duplicate position index: ") << i;
+ seen[i] = true;
+ if (resultType.getDimSize(ta.index()) != vectorType.getDimSize(i))
+ return op.emitOpError("dimension size mismatch at: ") << i;
+ }
+ return success();
+}
+
//===----------------------------------------------------------------------===//
// TupleGetOp
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/Vector/invalid.mlir b/mlir/test/Dialect/Vector/invalid.mlir
index d9093edb3765..bb5eca0a361b 100644
--- a/mlir/test/Dialect/Vector/invalid.mlir
+++ b/mlir/test/Dialect/Vector/invalid.mlir
@@ -1049,6 +1049,41 @@ func @reduce_unsupported_rank(%arg0: vector<4x16xf32>) -> f32 {
// -----
+func @transpose_rank_mismatch(%arg0: vector<4x16x11xf32>) {
+ // expected-error at +1 {{'vector.transpose' op vector result rank mismatch: 1}}
+ %0 = vector.transpose %arg0, [2, 1, 0] : vector<4x16x11xf32> to vector<100xf32>
+}
+
+// -----
+
+func @transpose_length_mismatch(%arg0: vector<4x4xf32>) {
+ // expected-error at +1 {{'vector.transpose' op transposition length mismatch: 3}}
+ %0 = vector.transpose %arg0, [2, 0, 1] : vector<4x4xf32> to vector<4x4xf32>
+}
+
+// -----
+
+func @transpose_index_oob(%arg0: vector<4x4xf32>) {
+ // expected-error at +1 {{'vector.transpose' op transposition index out of range: 2}}
+ %0 = vector.transpose %arg0, [2, 0] : vector<4x4xf32> to vector<4x4xf32>
+}
+
+// -----
+
+func @transpose_index_dup(%arg0: vector<4x4xf32>) {
+ // expected-error at +1 {{'vector.transpose' op duplicate position index: 0}}
+ %0 = vector.transpose %arg0, [0, 0] : vector<4x4xf32> to vector<4x4xf32>
+}
+
+// -----
+
+func @transpose_dim_size_mismatch(%arg0: vector<11x7x3x2xi32>) {
+ // expected-error at +1 {{'vector.transpose' op dimension size mismatch at: 0}}
+ %0 = vector.transpose %arg0, [3, 0, 1, 2] : vector<11x7x3x2xi32> to vector<2x3x7x11xi32>
+}
+
+// -----
+
func @type_cast_layout(%arg0: memref<4x3xf32, affine_map<(d0, d1)[s0, s1, s2] -> (d0 * s0 + d1 * s1 + s2)>>) {
// expected-error at +1 {{expects operand to be a memref with no layout}}
%0 = vector.type_cast %arg0: memref<4x3xf32, affine_map<(d0, d1)[s0, s1, s2] -> (d0 * s0 + d1 * s1 + s2)>> to memref<vector<4x3xf32>>
diff --git a/mlir/test/Dialect/Vector/ops.mlir b/mlir/test/Dialect/Vector/ops.mlir
index f286b932a472..a3b3fcc9c23c 100644
--- a/mlir/test/Dialect/Vector/ops.mlir
+++ b/mlir/test/Dialect/Vector/ops.mlir
@@ -315,3 +315,15 @@ func @reduce_int(%arg0: vector<16xi32>) -> i32 {
// CHECK: return %[[X]] : i32
return %0 : i32
}
+
+// CHECK-LABEL: transpose_fp
+func @transpose_fp(%arg0: vector<3x7xf32>) -> vector<7x3xf32> {
+ %0 = vector.transpose %arg0, [1, 0] : vector<3x7xf32> to vector<7x3xf32>
+ return %0 : vector<7x3xf32>
+}
+
+// CHECK-LABEL: transpose_int
+func @transpose_int(%arg0: vector<11x7x3x2xi32>) -> vector<2x11x7x3xi32> {
+ %0 = vector.transpose %arg0, [3, 0, 1, 2] : vector<11x7x3x2xi32> to vector<2x11x7x3xi32>
+ return %0 : vector<2x11x7x3xi32>
+}
More information about the Mlir-commits
mailing list