[Mlir-commits] [mlir] 941005f - [mlir] NFC - Add a builder to vector.transpose

Nicolas Vasilache llvmlistbot at llvm.org
Thu May 21 02:23:40 PDT 2020


Author: Nicolas Vasilache
Date: 2020-05-21T05:18:58-04:00
New Revision: 941005f51ac4a53ea6cc92dfdf06069c90c02f55

URL: https://github.com/llvm/llvm-project/commit/941005f51ac4a53ea6cc92dfdf06069c90c02f55
DIFF: https://github.com/llvm/llvm-project/commit/941005f51ac4a53ea6cc92dfdf06069c90c02f55.diff

LOG: [mlir] NFC - Add a builder to vector.transpose

Summary: Also expose some more vector ops to EDSCs.

Differential Revision: https://reviews.llvm.org/D80333

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Vector/EDSC/Intrinsics.h
    mlir/include/mlir/Dialect/Vector/VectorOps.td
    mlir/include/mlir/EDSC/Builders.h
    mlir/lib/Dialect/Vector/VectorOps.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Vector/EDSC/Intrinsics.h b/mlir/include/mlir/Dialect/Vector/EDSC/Intrinsics.h
index 6b5c4be7b2f4..3a3551ddc3eb 100644
--- a/mlir/include/mlir/Dialect/Vector/EDSC/Intrinsics.h
+++ b/mlir/include/mlir/Dialect/Vector/EDSC/Intrinsics.h
@@ -20,9 +20,11 @@ using vector_insert = ValueBuilder<vector::InsertOp>;
 using vector_fma = ValueBuilder<vector::FMAOp>;
 using vector_extract = ValueBuilder<vector::ExtractOp>;
 using vector_matmul = ValueBuilder<vector::MatmulOp>;
+using vector_outerproduct = ValueBuilder<vector::OuterProductOp>;
 using vector_print = OperationBuilder<vector::PrintOp>;
 using vector_transfer_read = ValueBuilder<vector::TransferReadOp>;
 using vector_transfer_write = OperationBuilder<vector::TransferWriteOp>;
+using vector_transpose = ValueBuilder<vector::TransposeOp>;
 using vector_type_cast = ValueBuilder<vector::TypeCastOp>;
 using vector_insert = ValueBuilder<vector::InsertOp>;
 using vector_fma = ValueBuilder<vector::FMAOp>;

diff  --git a/mlir/include/mlir/Dialect/Vector/VectorOps.td b/mlir/include/mlir/Dialect/Vector/VectorOps.td
index 575b99d51c97..264c8ad034c8 100644
--- a/mlir/include/mlir/Dialect/Vector/VectorOps.td
+++ b/mlir/include/mlir/Dialect/Vector/VectorOps.td
@@ -1385,6 +1385,9 @@ def Vector_TransposeOp :
                           [c, f] ]
     ```
   }];
+  let builders = [OpBuilder<
+    "OpBuilder &builder, OperationState &result, Value vector, "
+    "ArrayRef<int64_t> transp">];
   let extraClassDeclaration = [{
     VectorType getVectorType() {
       return vector().getType().cast<VectorType>();
@@ -1393,6 +1396,7 @@ def Vector_TransposeOp :
       return result().getType().cast<VectorType>();
     }
     void getTransp(SmallVectorImpl<int64_t> &results);
+    static StringRef getTranspAttrName() { return "transp"; }
   }];
   let assemblyFormat = [{
     $vector `,` $transp attr-dict `:` type($vector) `to` type($result)

diff  --git a/mlir/include/mlir/EDSC/Builders.h b/mlir/include/mlir/EDSC/Builders.h
index c1437892f6f6..a6045db3d998 100644
--- a/mlir/include/mlir/EDSC/Builders.h
+++ b/mlir/include/mlir/EDSC/Builders.h
@@ -358,8 +358,23 @@ class TemplatedIndexedValue {
   /// Emits a `load` when converting to a Value.
   operator Value() const { return Load(value, indices); }
 
+  /// Returns the base memref.
   Value getBase() const { return value; }
 
+  /// Returns the underlying memref.
+  MemRefType getMemRefType() const {
+    return value.getType().template cast<MemRefType>();
+  }
+
+  /// Returns the underlying MemRef elemental type cast as `T`.
+  template <typename T>
+  T getElementalTypeAs() const {
+    return value.getType()
+        .template cast<MemRefType>()
+        .getElementType()
+        .template cast<T>();
+  }
+
   /// Arithmetic operator overloadings.
   Value operator+(Value e);
   Value operator-(Value e);

diff  --git a/mlir/lib/Dialect/Vector/VectorOps.cpp b/mlir/lib/Dialect/Vector/VectorOps.cpp
index ca07ee140774..5439233c96b1 100644
--- a/mlir/lib/Dialect/Vector/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/VectorOps.cpp
@@ -1713,6 +1713,18 @@ static LogicalResult verify(TupleOp op) { return success(); }
 // TransposeOp
 //===----------------------------------------------------------------------===//
 
+void vector::TransposeOp::build(OpBuilder &builder, OperationState &result,
+                                Value vector, ArrayRef<int64_t> transp) {
+  VectorType vt = vector.getType().cast<VectorType>();
+  SmallVector<int64_t, 4> transposedShape(vt.getRank());
+  for (unsigned i = 0; i < transp.size(); ++i)
+    transposedShape[i] = vt.getShape()[transp[i]];
+
+  result.addOperands(vector);
+  result.addTypes(VectorType::get(transposedShape, vt.getElementType()));
+  result.addAttribute(getTranspAttrName(), builder.getI64ArrayAttr(transp));
+}
+
 // Eliminates transpose operations, which produce values identical to their
 // input values. This happens when the dimensions of the input vector remain in
 // their original order after the transpose operation.


        


More information about the Mlir-commits mailing list