[Mlir-commits] [mlir] a7b2977 - [mlir][Linalg] Add Utility method to get loop ranges for a LinalgOp.

llvmlistbot at llvm.org llvmlistbot at llvm.org
Wed Sep 9 22:58:07 PDT 2020


Author: MaheshRavishankar
Date: 2020-09-09T22:55:39-07:00
New Revision: a7b2977aa613b5e9b9d9e6e8232f89012404c52c

URL: https://github.com/llvm/llvm-project/commit/a7b2977aa613b5e9b9d9e6e8232f89012404c52c
DIFF: https://github.com/llvm/llvm-project/commit/a7b2977aa613b5e9b9d9e6e8232f89012404c52c.diff

LOG: [mlir][Linalg] Add Utility method to get loop ranges for a LinalgOp.

Also refactor the getViewSizes method to work on LinalgOp instead of
being a templated version. Keeping the templated version for
compatibility.

Differential Revision: https://reviews.llvm.org/D87303

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Linalg/Utils/Utils.h
    mlir/lib/Dialect/Linalg/Utils/Utils.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h b/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h
index beef1a70096e6..c0c59bda1894f 100644
--- a/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h
+++ b/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h
@@ -94,42 +94,22 @@ Operation *fuseTensorOps(PatternRewriter &rewriter, Operation *consumer,
                          unsigned consumerIdx,
                          OperationFolder *folder = nullptr);
 
-/// Returns the linearized list of all view dimensions in a linalgOp. Applying
+/// Returns the linearized list of all view dimensions in a `linalgOp`. Applying
 /// the inverse, concatenated loopToOperandRangeMaps to this list allows the
 /// derivation of loop ranges for any linalgOp.
-template <typename ConcreteOp>
-SmallVector<Value, 8> getViewSizes(OpBuilder &builder, ConcreteOp linalgOp) {
-  auto loc = linalgOp.getLoc();
-  SmallVector<Value, 8> res;
-  SmallVector<unsigned, 4> ranks;
-  for (auto v : linalgOp.getInputsAndOutputBuffers()) {
-    MemRefType t = v.getType().template cast<MemRefType>();
-    ranks.push_back(t.getRank());
-    for (unsigned i = 0; i < t.getRank(); ++i)
-      res.push_back(builder.create<DimOp>(loc, v, i));
-  }
-
-  auto attr = linalgOp.template getAttrOfType<IntegerAttr>("symbol_source");
-  if (attr) {
-    // Find the correct position for inserting values for symbols.
-    unsigned numSymb = ranks[attr.getInt()], symbolsPos = 0;
-    for (unsigned idx = 0; idx < attr.getInt(); idx++)
-      symbolsPos += ranks[idx];
-
-    // Append the end of the value list that corresponds to the
-    // values mapping to symbols. Since inside concatinated map symbols are
-    // repeated we have to repeat the sizes as well.
-
-    // Reserve is mandatory to avoid a potential undefined behavior with
-    // pushing back to smallvector from itself.
-    res.reserve(res.size() + ranks.size() * numSymb);
-    for (unsigned idx = 0, s = ranks.size(); idx < s; ++idx)
-      for (unsigned idx2 = 0; idx2 < numSymb; ++idx2)
-        res.push_back(res[symbolsPos + idx2]);
-  }
-  return res;
+SmallVector<Value, 8> getViewSizes(OpBuilder &builder, LinalgOp linalgOp);
+template <typename ConcreteOpTy>
+SmallVector<Value, 8> getViewSizes(OpBuilder &builder, ConcreteOpTy linalgOp) {
+  return getViewSizes(builder, cast<linalg::LinalgOp>(linalgOp.getOperation()));
 }
 
+/// Returns the loop ranges of the `linalgOp`. Applies the inverse of the
+/// concatenated indexing maps to the result of `getViewSizes`. Returns None if
+/// the bounds computation fails.
+Optional<SmallVector<Value, 4>>
+getLoopRanges(OpBuilder &builder, LinalgOp linalgOp,
+              OperationFolder *folder = nullptr);
+
 /// Returns the values obtained by applying `map` to the list of values.
 /// When non-null, the optional pointer `folder` is used to call into the
 /// `createAndFold` builder method. If `folder` is null, the regular `create`

diff  --git a/mlir/lib/Dialect/Linalg/Utils/Utils.cpp b/mlir/lib/Dialect/Linalg/Utils/Utils.cpp
index cf14555aa63fc..585b00189964d 100644
--- a/mlir/lib/Dialect/Linalg/Utils/Utils.cpp
+++ b/mlir/lib/Dialect/Linalg/Utils/Utils.cpp
@@ -147,6 +147,50 @@ static void unpackRanges(ArrayRef<SubViewOp::Range> ranges,
 namespace mlir {
 namespace linalg {
 
+/// Return the linearized list of all view dimensions in a linalgOp.
+SmallVector<Value, 8> getViewSizes(OpBuilder &builder, LinalgOp linalgOp) {
+  auto loc = linalgOp.getLoc();
+  SmallVector<Value, 8> res;
+  SmallVector<unsigned, 4> ranks;
+  for (auto v : linalgOp.getInputsAndOutputBuffers()) {
+    MemRefType t = v.getType().template cast<MemRefType>();
+    ranks.push_back(t.getRank());
+    for (unsigned i = 0; i < t.getRank(); ++i)
+      res.push_back(builder.create<DimOp>(loc, v, i));
+  }
+
+  auto attr = linalgOp.template getAttrOfType<IntegerAttr>("symbol_source");
+  if (attr) {
+    // Find the correct position for inserting values for symbols.
+    unsigned numSymb = ranks[attr.getInt()], symbolsPos = 0;
+    for (unsigned idx = 0; idx < attr.getInt(); idx++)
+      symbolsPos += ranks[idx];
+
+    // Append the end of the value list that corresponds to the
+    // values mapping to symbols. Since inside concatinated map symbols are
+    // repeated we have to repeat the sizes as well.
+
+    // Reserve is mandatory to avoid a potential undefined behavior with
+    // pushing back to smallvector from itself.
+    res.reserve(res.size() + ranks.size() * numSymb);
+    for (unsigned idx = 0, s = ranks.size(); idx < s; ++idx)
+      for (unsigned idx2 = 0; idx2 < numSymb; ++idx2)
+        res.push_back(res[symbolsPos + idx2]);
+  }
+  return res;
+}
+
+Optional<SmallVector<Value, 4>>
+getLoopRanges(OpBuilder &builder, LinalgOp linalgOp, OperationFolder *folder) {
+  SmallVector<Value, 8> viewSizes = getViewSizes(builder, linalgOp);
+  AffineMap invertedMap =
+      inversePermutation(concatAffineMaps(linalgOp.getIndexingMaps()));
+  if (!invertedMap)
+    return {};
+  return applyMapToValues(builder, linalgOp.getLoc(), invertedMap, viewSizes,
+                          folder);
+}
+
 /// Specialization to build an scf "for" nest.
 template <>
 void GenerateLoopNest<scf::ForOp>::doit(


        


More information about the Mlir-commits mailing list