[Mlir-commits] [mlir] 9072f1b - [mlir][linalg] Add isPermutation helper (NFC).

Tobias Gysi llvmlistbot at llvm.org
Tue Sep 21 08:07:57 PDT 2021


Author: Tobias Gysi
Date: 2021-09-21T15:07:39Z
New Revision: 9072f1b5f81347b36f0668e8cc10802fedbc6cfd

URL: https://github.com/llvm/llvm-project/commit/9072f1b5f81347b36f0668e8cc10802fedbc6cfd
DIFF: https://github.com/llvm/llvm-project/commit/9072f1b5f81347b36f0668e8cc10802fedbc6cfd.diff

LOG: [mlir][linalg] Add isPermutation helper (NFC).

Add a helper method to check if an index vector contains a permutation of its indices. Additionally, refactor applyPermutationToVector to take int64_t.

Reviewed By: nicolasvasilache

Differential Revision: https://reviews.llvm.org/D110135

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Linalg/Utils/Utils.h
    mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp
    mlir/lib/Dialect/Linalg/Transforms/Interchange.cpp
    mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp
    mlir/lib/Dialect/Linalg/Utils/Utils.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h b/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h
index a8d4edf6c072d..67e2fc35367af 100644
--- a/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h
+++ b/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h
@@ -29,16 +29,20 @@ class LinalgDependenceGraph;
 // General utilities
 //===----------------------------------------------------------------------===//
 
+/// Check if `permutation` is a permutation of the range
+/// `[0, permutation.size())`.
+bool isPermutation(ArrayRef<int64_t> permutation);
+
 /// Apply the permutation defined by `permutation` to `inVec`.
 /// Element `i` in `inVec` is mapped to location `j = permutation[i]`.
 /// E.g.: for an input vector `inVec = ['a', 'b', 'c']` and a permutation vector
 /// `permutation = [2, 0, 1]`, this function leaves `inVec = ['c', 'a', 'b']`.
 template <typename T, unsigned N>
 void applyPermutationToVector(SmallVector<T, N> &inVec,
-                              ArrayRef<unsigned> permutation) {
+                              ArrayRef<int64_t> permutation) {
   SmallVector<T, N> auxVec(inVec.size());
-  for (unsigned i = 0; i < permutation.size(); ++i)
-    auxVec[i] = inVec[permutation[i]];
+  for (auto en : enumerate(permutation))
+    auxVec[en.index()] = inVec[en.value()];
   inVec = auxVec;
 }
 

diff  --git a/mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp b/mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp
index 0e65cb23964ca..13a6fe4fcc903 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp
@@ -367,6 +367,8 @@ mlir::linalg::tileConsumerAndFuseProducers(OpBuilder &b, LinalgOp consumerOp,
                                            ArrayRef<int64_t> tileInterchange) {
   assert(tileSizes.size() == tileInterchange.size() &&
          "expect the number of tile sizes and interchange dims to match");
+  assert(isPermutation(tileInterchange) &&
+         "expect tile interchange is a permutation");
 
   // Create an empty tile loop nest.
   TileLoopNest tileLoopNest(consumerOp);
@@ -375,9 +377,7 @@ mlir::linalg::tileConsumerAndFuseProducers(OpBuilder &b, LinalgOp consumerOp,
   // inner reduction dimensions.
   SmallVector<StringAttr> iterTypes =
       llvm::to_vector<6>(consumerOp.iterator_types().getAsRange<StringAttr>());
-  applyPermutationToVector(
-      iterTypes,
-      SmallVector<unsigned>(tileInterchange.begin(), tileInterchange.end()));
+  applyPermutationToVector(iterTypes, tileInterchange);
   auto *it = find_if(iterTypes, [&](StringAttr iterType) {
     return !isParallelIterator(iterType);
   });
@@ -459,14 +459,10 @@ struct LinalgTileAndFuseTensorOps
                                    tileInterchange.begin() +
                                        rootOp.getNumLoops());
 
-    // As a tiling can only tile a loop dimension once, `rootInterchange` has to
-    // be a permutation of the `rootOp` loop dimensions.
-    SmallVector<AffineExpr> rootInterchangeExprs;
-    transform(rootInterchange, std::back_inserter(rootInterchangeExprs),
-              [&](int64_t dim) { return b.getAffineDimExpr(dim); });
-    AffineMap rootInterchangeMap = AffineMap::get(
-        rootOp.getNumLoops(), 0, rootInterchangeExprs, funcOp.getContext());
-    if (!rootInterchangeMap.isPermutation())
+    // Check `rootInterchange` is a permutation of the `rootOp` loop dimensions.
+    // It has to be a permutation since the tiling cannot tile the same loop
+    // dimension multiple times.
+    if (!isPermutation(rootInterchange))
       return notifyFailure(
           "expect the tile interchange permutes the root loops");
 

diff  --git a/mlir/lib/Dialect/Linalg/Transforms/Interchange.cpp b/mlir/lib/Dialect/Linalg/Transforms/Interchange.cpp
index d94fa30500b75..a42c204a389ee 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Interchange.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Interchange.cpp
@@ -69,7 +69,9 @@ void mlir::linalg::interchangeGenericOp(PatternRewriter &rewriter,
   ArrayRef<Attribute> itTypes = genericOp.iterator_types().getValue();
   SmallVector<Attribute, 4> itTypesVector;
   llvm::append_range(itTypesVector, itTypes);
-  applyPermutationToVector(itTypesVector, interchangeVector);
+  SmallVector<int64_t> permutation(interchangeVector.begin(),
+                                   interchangeVector.end());
+  applyPermutationToVector(itTypesVector, permutation);
   genericOp->setAttr(getIteratorTypesAttrName(),
                      ArrayAttr::get(context, itTypesVector));
 

diff  --git a/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp b/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp
index f3b9516bacf00..e36915d60108f 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp
@@ -206,8 +206,10 @@ tileLinalgOpImpl(OpBuilder &b, LinalgOp op, ValueRange tileSizes,
     invPermutationMap = inversePermutation(
         AffineMap::getPermutationMap(interchangeVector, b.getContext()));
     assert(invPermutationMap);
-    applyPermutationToVector(loopRanges, interchangeVector);
-    applyPermutationToVector(iteratorTypes, interchangeVector);
+    SmallVector<int64_t> permutation(interchangeVector.begin(),
+                                     interchangeVector.end());
+    applyPermutationToVector(loopRanges, permutation);
+    applyPermutationToVector(iteratorTypes, permutation);
   }
 
   // 2. Create the tiled loops.

diff  --git a/mlir/lib/Dialect/Linalg/Utils/Utils.cpp b/mlir/lib/Dialect/Linalg/Utils/Utils.cpp
index e0b1a4840b41f..b7a2becefa77f 100644
--- a/mlir/lib/Dialect/Linalg/Utils/Utils.cpp
+++ b/mlir/lib/Dialect/Linalg/Utils/Utils.cpp
@@ -138,6 +138,19 @@ static void unpackRanges(ArrayRef<Range> ranges, SmallVectorImpl<Value> &lbs,
 namespace mlir {
 namespace linalg {
 
+bool isPermutation(ArrayRef<int64_t> permutation) {
+  // Count the number of appearances for all indices.
+  SmallVector<int64_t> indexCounts(permutation.size(), 0);
+  for (auto index : permutation) {
+    // Exit if the index is out-of-range.
+    if (index < 0 || index >= static_cast<int64_t>(permutation.size()))
+      return false;
+    indexCounts[index]++;
+  }
+  // Return true if all indices appear once.
+  return count(indexCounts, 1) == static_cast<int64_t>(permutation.size());
+}
+
 /// Helper function that creates a memref::DimOp or tensor::DimOp depending on
 /// the type of `source`.
 Value createOrFoldDimOp(OpBuilder &b, Location loc, Value source, int64_t dim) {


        


More information about the Mlir-commits mailing list