[Mlir-commits] [mlir] 0b781db - [mlir] Add new builders to linalg.reshape.

Alexander Belyaev llvmlistbot at llvm.org
Thu Jun 11 03:48:12 PDT 2020


Author: Alexander Belyaev
Date: 2020-06-11T12:47:35+02:00
New Revision: 0b781db9087977b758c27bb02be1f80cd00bf0d7

URL: https://github.com/llvm/llvm-project/commit/0b781db9087977b758c27bb02be1f80cd00bf0d7
DIFF: https://github.com/llvm/llvm-project/commit/0b781db9087977b758c27bb02be1f80cd00bf0d7.diff

LOG: [mlir] Add new builders to linalg.reshape.

Differential Revision: https://reviews.llvm.org/D81640

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.h
    mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td
    mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
    mlir/test/EDSC/builder-api-test.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.h b/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.h
index ce3a1cef3b35..fbb3f062a494 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.h
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.h
@@ -34,6 +34,9 @@ class PoolingMaxOp;
 class PoolingMinOp;
 class PoolingSumOp;
 
+using ReassociationIndicies = SmallVector<int64_t, 2>;
+using ReassociationExprs = SmallVector<AffineExpr, 2>;
+
 /// Returns the name mangled library call name to disambiguate between 
diff erent
 /// overloads at the C level. The name mangling scheme is basic and uses MLIR
 /// type names:

diff  --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td
index 1615957ff0c3..64f92ee5d406 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td
@@ -64,16 +64,32 @@ def Linalg_RangeOp :
 class Linalg_ReshapeLikeOp<string mnemonic> :
     Linalg_Op<mnemonic, [NoSideEffect]> {
   let builders = [
-    // Builder for a contracting reshape whose result type is computed from
+    // Builders for a contracting reshape whose result type is computed from
     // `src` and `reassociation`.
     OpBuilder<"OpBuilder &b, OperationState &result, Value src, "
-    "ArrayRef<ArrayRef<AffineExpr>> reassociation, "
-    "ArrayRef<NamedAttribute> attrs = {}">,
-    // Builder for a reshape whose result type is passed explicitly. This may be
-    // either a contracting or expanding reshape.
-    OpBuilder<"OpBuilder &b, OperationState &result, Type resultType, Value src,"
-    "ArrayRef<ArrayRef<AffineExpr>> reassociation, "
-    "ArrayRef<NamedAttribute> attrs = {}">];
+              "ArrayRef<ReassociationExprs> reassociation, "
+              "ArrayRef<NamedAttribute> attrs = {}">,
+    OpBuilder<"OpBuilder &b, OperationState &result, Value src, "
+              "ArrayRef<ReassociationIndicies> reassociation, "
+              "ArrayRef<NamedAttribute> attrs = {}", [{
+      auto reassociationMaps =
+          convertReassociationIndiciesToMaps(b, reassociation);
+      build(b, result, src, reassociationMaps, attrs);
+    }]>,
+
+    // Builders for a reshape whose result type is passed explicitly. This may
+    // be either a contracting or expanding reshape.
+    OpBuilder<"OpBuilder &b, OperationState &result, Type resultType, "
+              "Value src, ArrayRef<ReassociationExprs> reassociation, "
+              "ArrayRef<NamedAttribute> attrs = {}">,
+    OpBuilder<"OpBuilder &b, OperationState &result, Type resultType, "
+              "Value src, ArrayRef<ReassociationIndicies> reassociation, "
+              "ArrayRef<NamedAttribute> attrs = {}", [{
+      auto reassociationMaps =
+          convertReassociationIndiciesToMaps(b, reassociation);
+      build(b, result, src, reassociationMaps, attrs);
+    }]>
+  ];
 
   code commonExtraClassDeclaration = [{
     static StringRef getReassociationAttrName() { return "reassociation"; }

diff  --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index db4587fce014..717baa02de61 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -476,9 +476,9 @@ static SmallVector<AffineMap, 4> getAffineMaps(ArrayAttr attrs) {
 }
 
 template <typename AffineExprTy>
-unsigned getMaxPosOfType(ArrayRef<ArrayRef<AffineExpr>> exprArrays) {
+unsigned getMaxPosOfType(ArrayRef<ReassociationExprs> exprArrays) {
   unsigned pos = 0;
-  for (auto exprs : exprArrays) {
+  for (const auto &exprs : exprArrays) {
     for (auto expr : exprs) {
       expr.walk([&pos](AffineExpr e) {
         if (auto d = e.dyn_cast<AffineExprTy>())
@@ -490,23 +490,37 @@ unsigned getMaxPosOfType(ArrayRef<ArrayRef<AffineExpr>> exprArrays) {
 }
 
 static SmallVector<AffineMap, 4>
-getSymbolLessAffineMaps(ArrayRef<ArrayRef<AffineExpr>> reassociation) {
+getSymbolLessAffineMaps(ArrayRef<ReassociationExprs> reassociation) {
   unsigned maxDim = getMaxPosOfType<AffineDimExpr>(reassociation);
   assert(getMaxPosOfType<AffineSymbolExpr>(reassociation) == 0 &&
          "Expected symbol-less expressions");
   SmallVector<AffineMap, 4> maps;
   maps.reserve(reassociation.size());
-  for (auto exprs : reassociation) {
-    assert(exprs.size() != 0);
+  for (const auto &exprs : reassociation) {
+    assert(!exprs.empty());
     maps.push_back(AffineMap::get(maxDim + 1, 0, exprs, exprs[0].getContext()));
   }
   return maps;
 }
 
-void mlir::linalg::ReshapeOp::build(
-    OpBuilder &b, OperationState &result, Value src,
-    ArrayRef<ArrayRef<AffineExpr>> reassociation,
-    ArrayRef<NamedAttribute> attrs) {
+static SmallVector<SmallVector<AffineExpr, 2>, 2>
+convertReassociationIndiciesToMaps(
+    OpBuilder &b, ArrayRef<ReassociationIndicies> reassociationIndicies) {
+  SmallVector<SmallVector<AffineExpr, 2>, 2> reassociationMaps;
+  for (const auto &indicies : reassociationIndicies) {
+    SmallVector<AffineExpr, 2> reassociationMap;
+    reassociationMap.reserve(indicies.size());
+    for (int64_t index : indicies)
+      reassociationMap.push_back(b.getAffineDimExpr(index));
+    reassociationMaps.push_back(std::move(reassociationMap));
+  }
+  return reassociationMaps;
+}
+
+void mlir::linalg::ReshapeOp::build(OpBuilder &b, OperationState &result,
+                                    Value src,
+                                    ArrayRef<ReassociationExprs> reassociation,
+                                    ArrayRef<NamedAttribute> attrs) {
   auto maps = getSymbolLessAffineMaps(reassociation);
   auto memRefType = src.getType().cast<MemRefType>();
   auto resultType = computeReshapeCollapsedType(memRefType, maps);
@@ -515,10 +529,10 @@ void mlir::linalg::ReshapeOp::build(
                       b.getAffineMapArrayAttr(maps));
 }
 
-void mlir::linalg::ReshapeOp::build(
-    OpBuilder &b, OperationState &result, Type resultType, Value src,
-    ArrayRef<ArrayRef<AffineExpr>> reassociation,
-    ArrayRef<NamedAttribute> attrs) {
+void mlir::linalg::ReshapeOp::build(OpBuilder &b, OperationState &result,
+                                    Type resultType, Value src,
+                                    ArrayRef<ReassociationExprs> reassociation,
+                                    ArrayRef<NamedAttribute> attrs) {
   auto maps = getSymbolLessAffineMaps(reassociation);
   build(b, result, resultType, src, attrs);
   result.addAttribute(ReshapeOp::getReassociationAttrName(),
@@ -622,7 +636,7 @@ computeTensorReshapeCollapsedType(RankedTensorType type,
 
 void mlir::linalg::TensorReshapeOp::build(
     OpBuilder &b, OperationState &result, Value src,
-    ArrayRef<ArrayRef<AffineExpr>> reassociation,
+    ArrayRef<ReassociationExprs> reassociation,
     ArrayRef<NamedAttribute> attrs) {
   auto maps = getSymbolLessAffineMaps(reassociation);
   auto resultType = computeTensorReshapeCollapsedType(
@@ -634,7 +648,7 @@ void mlir::linalg::TensorReshapeOp::build(
 
 void mlir::linalg::TensorReshapeOp::build(
     OpBuilder &b, OperationState &result, Type resultType, Value src,
-    ArrayRef<ArrayRef<AffineExpr>> reassociation,
+    ArrayRef<ReassociationExprs> reassociation,
     ArrayRef<NamedAttribute> attrs) {
   auto maps = getSymbolLessAffineMaps(reassociation);
   build(b, result, resultType, src, attrs);

diff  --git a/mlir/test/EDSC/builder-api-test.cpp b/mlir/test/EDSC/builder-api-test.cpp
index 4b01f7110532..3435926f867e 100644
--- a/mlir/test/EDSC/builder-api-test.cpp
+++ b/mlir/test/EDSC/builder-api-test.cpp
@@ -941,6 +941,8 @@ TEST_FUNC(linalg_generic_dilated_conv_nhwc) {
 //       CHECK: linalg.reshape {{.*}} [affine_map<(d0, d1, d2) -> (d0, d1)>, affine_map<(d0, d1, d2) -> (d2)>] : memref<32x16xf32> into memref<4x8x16xf32>
 // clang-format on
 TEST_FUNC(linalg_metadata_ops) {
+  using linalg::ReassociationExprs;
+
   auto f32Type = FloatType::getF32(&globalContext());
   auto memrefType = MemRefType::get({4, 8, 16}, f32Type, {}, 0);
   auto f = makeFunction("linalg_metadata_ops", {}, {memrefType});
@@ -950,9 +952,10 @@ TEST_FUNC(linalg_metadata_ops) {
   AffineExpr i, j, k;
   bindDims(&globalContext(), i, j, k);
   Value v(f.getArgument(0));
-  auto reshaped = linalg_reshape(v, ArrayRef<ArrayRef<AffineExpr>>{{i, j}, k});
-  linalg_reshape(memrefType, reshaped,
-                 ArrayRef<ArrayRef<AffineExpr>>{{i, j}, k});
+  SmallVector<ReassociationExprs, 2> maps = {ReassociationExprs({i, j}),
+                                             ReassociationExprs({k})};
+  auto reshaped = linalg_reshape(v, maps);
+  linalg_reshape(memrefType, reshaped, maps);
 
   f.print(llvm::outs());
   f.erase();


        


More information about the Mlir-commits mailing list