[Mlir-commits] [mlir] 8b525c9 - [mlir][Linalg] Add utility function that return static loop bounds of Linalg ops

llvmlistbot at llvm.org llvmlistbot at llvm.org
Thu Nov 19 19:01:04 PST 2020


Author: MaheshRavishankar
Date: 2020-11-19T19:00:44-08:00
New Revision: 8b525c9c19f8c4cf3d7df0ec93e4935fae087e7a

URL: https://github.com/llvm/llvm-project/commit/8b525c9c19f8c4cf3d7df0ec93e4935fae087e7a
DIFF: https://github.com/llvm/llvm-project/commit/8b525c9c19f8c4cf3d7df0ec93e4935fae087e7a.diff

LOG: [mlir][Linalg] Add utility function that return static loop bounds of Linalg ops

Differential Revision: https://reviews.llvm.org/D91749

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Linalg/Utils/Utils.h
    mlir/lib/Dialect/Linalg/Utils/Utils.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h b/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h
index d8c595fb91fd..f5669e383368 100644
--- a/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h
+++ b/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h
@@ -114,12 +114,23 @@ SmallVector<Value, 8> getShape(OpBuilder &builder, ConcreteOpTy linalgOp) {
   return getShape(builder, cast<linalg::LinalgOp>(linalgOp.getOperation()));
 }
 
+/// Like `getShape`, but only returns statically-known information, without
+/// generating any new IR. For each shape dimension, returns >=0 if that
+/// dimension is statically known, or -1 otherwise.
+SmallVector<int64_t, 8> getStaticShape(LinalgOp linalgOp);
+
 /// Returns the loop ranges of the `linalgOp`. Applies the inverse of the
 /// concatenated indexing maps to the result of `getShape`. Returns None if
 /// the bounds computation fails.
 Optional<SmallVector<Value, 4>> getLoopRanges(OpBuilder &builder,
                                               LinalgOp linalgOp);
 
+/// Returns the statically-known loop ranges of the `linalgOp`. Applies the
+/// inverse of the concatenated indexing maps to the result of `getStaticShape`.
+/// Returns None if inverting the concatenated indexing map fails. Returns -1
+/// for non-statically-known loop ranges.
+Optional<SmallVector<int64_t, 4>> getStaticLoopRanges(LinalgOp linalgOp);
+
 /// Returns the values obtained by applying `map` to the list of values.
 SmallVector<Value, 4> applyMapToValues(OpBuilder &b, Location loc,
                                        AffineMap map, ValueRange values);

diff  --git a/mlir/lib/Dialect/Linalg/Utils/Utils.cpp b/mlir/lib/Dialect/Linalg/Utils/Utils.cpp
index e5f0ba013e01..c9769476baec 100644
--- a/mlir/lib/Dialect/Linalg/Utils/Utils.cpp
+++ b/mlir/lib/Dialect/Linalg/Utils/Utils.cpp
@@ -156,6 +156,15 @@ SmallVector<Value, 8> getShape(OpBuilder &builder, LinalgOp linalgOp) {
   return res;
 }
 
+SmallVector<int64_t, 8> getStaticShape(LinalgOp linalgOp) {
+  SmallVector<int64_t, 8> res;
+  for (Value v : linalgOp.getShapedOperands()) {
+    auto shape = v.getType().cast<ShapedType>().getShape();
+    res.append(shape.begin(), shape.end());
+  }
+  return res;
+}
+
 Optional<SmallVector<Value, 4>> getLoopRanges(OpBuilder &builder,
                                               LinalgOp linalgOp) {
   SmallVector<Value, 8> viewSizes = getShape(builder, linalgOp);
@@ -166,6 +175,15 @@ Optional<SmallVector<Value, 4>> getLoopRanges(OpBuilder &builder,
   return applyMapToValues(builder, linalgOp.getLoc(), invertedMap, viewSizes);
 }
 
+Optional<SmallVector<int64_t, 4>> getStaticLoopRanges(LinalgOp linalgOp) {
+  SmallVector<int64_t, 8> viewSizes = getStaticShape(linalgOp);
+  AffineMap invertedMap =
+      inversePermutation(concatAffineMaps(linalgOp.getIndexingMaps()));
+  if (!invertedMap)
+    return {};
+  return invertedMap.compose(viewSizes);
+}
+
 /// Specialization to build an scf "for" nest.
 template <>
 void GenerateLoopNest<scf::ForOp>::doit(


        


More information about the Mlir-commits mailing list