[Mlir-commits] [mlir] 941005f - [mlir] NFC - Add a builder to vector.transpose
Nicolas Vasilache
llvmlistbot at llvm.org
Thu May 21 02:23:40 PDT 2020
Author: Nicolas Vasilache
Date: 2020-05-21T05:18:58-04:00
New Revision: 941005f51ac4a53ea6cc92dfdf06069c90c02f55
URL: https://github.com/llvm/llvm-project/commit/941005f51ac4a53ea6cc92dfdf06069c90c02f55
DIFF: https://github.com/llvm/llvm-project/commit/941005f51ac4a53ea6cc92dfdf06069c90c02f55.diff
LOG: [mlir] NFC - Add a builder to vector.transpose
Summary: Also expose some more vector ops to EDSCs.
Differential Revision: https://reviews.llvm.org/D80333
Added:
Modified:
mlir/include/mlir/Dialect/Vector/EDSC/Intrinsics.h
mlir/include/mlir/Dialect/Vector/VectorOps.td
mlir/include/mlir/EDSC/Builders.h
mlir/lib/Dialect/Vector/VectorOps.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Vector/EDSC/Intrinsics.h b/mlir/include/mlir/Dialect/Vector/EDSC/Intrinsics.h
index 6b5c4be7b2f4..3a3551ddc3eb 100644
--- a/mlir/include/mlir/Dialect/Vector/EDSC/Intrinsics.h
+++ b/mlir/include/mlir/Dialect/Vector/EDSC/Intrinsics.h
@@ -20,9 +20,11 @@ using vector_insert = ValueBuilder<vector::InsertOp>;
using vector_fma = ValueBuilder<vector::FMAOp>;
using vector_extract = ValueBuilder<vector::ExtractOp>;
using vector_matmul = ValueBuilder<vector::MatmulOp>;
+using vector_outerproduct = ValueBuilder<vector::OuterProductOp>;
using vector_print = OperationBuilder<vector::PrintOp>;
using vector_transfer_read = ValueBuilder<vector::TransferReadOp>;
using vector_transfer_write = OperationBuilder<vector::TransferWriteOp>;
+using vector_transpose = ValueBuilder<vector::TransposeOp>;
using vector_type_cast = ValueBuilder<vector::TypeCastOp>;
using vector_insert = ValueBuilder<vector::InsertOp>;
using vector_fma = ValueBuilder<vector::FMAOp>;
diff --git a/mlir/include/mlir/Dialect/Vector/VectorOps.td b/mlir/include/mlir/Dialect/Vector/VectorOps.td
index 575b99d51c97..264c8ad034c8 100644
--- a/mlir/include/mlir/Dialect/Vector/VectorOps.td
+++ b/mlir/include/mlir/Dialect/Vector/VectorOps.td
@@ -1385,6 +1385,9 @@ def Vector_TransposeOp :
[c, f] ]
```
}];
+ let builders = [OpBuilder<
+ "OpBuilder &builder, OperationState &result, Value vector, "
+ "ArrayRef<int64_t> transp">];
let extraClassDeclaration = [{
VectorType getVectorType() {
return vector().getType().cast<VectorType>();
@@ -1393,6 +1396,7 @@ def Vector_TransposeOp :
return result().getType().cast<VectorType>();
}
void getTransp(SmallVectorImpl<int64_t> &results);
+ static StringRef getTranspAttrName() { return "transp"; }
}];
let assemblyFormat = [{
$vector `,` $transp attr-dict `:` type($vector) `to` type($result)
diff --git a/mlir/include/mlir/EDSC/Builders.h b/mlir/include/mlir/EDSC/Builders.h
index c1437892f6f6..a6045db3d998 100644
--- a/mlir/include/mlir/EDSC/Builders.h
+++ b/mlir/include/mlir/EDSC/Builders.h
@@ -358,8 +358,23 @@ class TemplatedIndexedValue {
/// Emits a `load` when converting to a Value.
operator Value() const { return Load(value, indices); }
+ /// Returns the base memref.
Value getBase() const { return value; }
+ /// Returns the underlying memref.
+ MemRefType getMemRefType() const {
+ return value.getType().template cast<MemRefType>();
+ }
+
+ /// Returns the underlying MemRef elemental type cast as `T`.
+ template <typename T>
+ T getElementalTypeAs() const {
+ return value.getType()
+ .template cast<MemRefType>()
+ .getElementType()
+ .template cast<T>();
+ }
+
/// Arithmetic operator overloadings.
Value operator+(Value e);
Value operator-(Value e);
diff --git a/mlir/lib/Dialect/Vector/VectorOps.cpp b/mlir/lib/Dialect/Vector/VectorOps.cpp
index ca07ee140774..5439233c96b1 100644
--- a/mlir/lib/Dialect/Vector/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/VectorOps.cpp
@@ -1713,6 +1713,18 @@ static LogicalResult verify(TupleOp op) { return success(); }
// TransposeOp
//===----------------------------------------------------------------------===//
+void vector::TransposeOp::build(OpBuilder &builder, OperationState &result,
+ Value vector, ArrayRef<int64_t> transp) {
+ VectorType vt = vector.getType().cast<VectorType>();
+ SmallVector<int64_t, 4> transposedShape(vt.getRank());
+ for (unsigned i = 0; i < transp.size(); ++i)
+ transposedShape[i] = vt.getShape()[transp[i]];
+
+ result.addOperands(vector);
+ result.addTypes(VectorType::get(transposedShape, vt.getElementType()));
+ result.addAttribute(getTranspAttrName(), builder.getI64ArrayAttr(transp));
+}
+
// Eliminates transpose operations, which produce values identical to their
// input values. This happens when the dimensions of the input vector remain in
// their original order after the transpose operation.
More information about the Mlir-commits
mailing list