[Mlir-commits] [mlir] b1d3afc - [mlir] Factor more common utils to IndexingUtils

Hanhan Wang llvmlistbot at llvm.org
Fri Dec 2 13:27:13 PST 2022


Author: Hanhan Wang
Date: 2022-12-02T13:27:01-08:00
New Revision: b1d3afc93e0e6bdfbc1105b48bc4caed0c880be2

URL: https://github.com/llvm/llvm-project/commit/b1d3afc93e0e6bdfbc1105b48bc4caed0c880be2
DIFF: https://github.com/llvm/llvm-project/commit/b1d3afc93e0e6bdfbc1105b48bc4caed0c880be2.diff

LOG: [mlir] Factor more common utils to IndexingUtils

Reviewed By: mravishankar

Differential Revision: https://reviews.llvm.org/D139159

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Linalg/IR/Linalg.h
    mlir/include/mlir/Dialect/Utils/IndexingUtils.h
    mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
    mlir/lib/Dialect/Linalg/Transforms/HoistPadding.cpp
    mlir/lib/Dialect/Linalg/Utils/Utils.cpp
    mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
    mlir/lib/Dialect/Utils/IndexingUtils.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Linalg/IR/Linalg.h b/mlir/include/mlir/Dialect/Linalg/IR/Linalg.h
index e231bddfcc414..28c75fcfa6530 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/Linalg.h
+++ b/mlir/include/mlir/Dialect/Linalg/IR/Linalg.h
@@ -70,10 +70,6 @@ AffineMap extractOrIdentityMap(Optional<AffineMap> maybeMap, unsigned rank,
 SmallVector<AffineExpr, 4> concat(ArrayRef<AffineExpr> a,
                                   ArrayRef<AffineExpr> b);
 
-/// Check if `permutation` is a permutation of the range
-/// `[0, permutation.size())`.
-bool isPermutation(ArrayRef<int64_t> permutation);
-
 } // namespace linalg
 } // namespace mlir
 

diff  --git a/mlir/include/mlir/Dialect/Utils/IndexingUtils.h b/mlir/include/mlir/Dialect/Utils/IndexingUtils.h
index ee1d4550e953c..bc58c127e1624 100644
--- a/mlir/include/mlir/Dialect/Utils/IndexingUtils.h
+++ b/mlir/include/mlir/Dialect/Utils/IndexingUtils.h
@@ -75,6 +75,12 @@ void applyPermutationToVector(SmallVector<T, N> &inVec,
   inVec = auxVec;
 }
 
+/// Helper method to apply to inverse a permutation.
+SmallVector<int64_t> invertPermutationVector(ArrayRef<int64_t> permutation);
+
+/// Method to check if an interchange vector is a permutation.
+bool isPermutationVector(ArrayRef<int64_t> interchange);
+
 /// Helper that returns a subset of `arrayAttr` as a vector of int64_t.
 SmallVector<int64_t> getI64SubArray(ArrayAttr arrayAttr, unsigned dropFront = 0,
                                     unsigned dropBack = 0);

diff  --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index 1f6a8fd91efb5..a6f42c9577ef7 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -22,6 +22,7 @@
 #include "mlir/Dialect/SCF/IR/SCF.h"
 #include "mlir/Dialect/SparseTensor/IR/SparseTensor.h"
 #include "mlir/Dialect/Tensor/IR/Tensor.h"
+#include "mlir/Dialect/Utils/IndexingUtils.h"
 #include "mlir/Dialect/Utils/ReshapeOpsUtils.h"
 #include "mlir/Dialect/Utils/StaticValueUtils.h"
 #include "mlir/IR/AffineExprVisitor.h"
@@ -1392,7 +1393,7 @@ void TransposeOp::print(OpAsmPrinter &p) {
 LogicalResult TransposeOp::verify() {
   ArrayRef<int64_t> permutationRef = getPermutation();
 
-  if (!isPermutation(permutationRef))
+  if (!isPermutationVector(permutationRef))
     return emitOpError("permutation is not valid");
 
   auto inputType = getInput().getType();
@@ -1683,19 +1684,6 @@ SmallVector<AffineExpr, 4> mlir::linalg::concat(ArrayRef<AffineExpr> a,
   return llvm::to_vector<4>(concatRanges);
 }
 
-bool mlir::linalg::isPermutation(ArrayRef<int64_t> permutation) {
-  // Count the number of appearances for all indices.
-  SmallVector<int64_t> indexCounts(permutation.size(), 0);
-  for (auto index : permutation) {
-    // Exit if the index is out-of-range.
-    if (index < 0 || index >= static_cast<int64_t>(permutation.size()))
-      return false;
-    ++indexCounts[index];
-  }
-  // Return true if all indices appear once.
-  return count(indexCounts, 1) == static_cast<int64_t>(permutation.size());
-}
-
 static void appendMangledType(llvm::raw_string_ostream &ss, Type t) {
   if (auto memref = t.dyn_cast<MemRefType>()) {
     ss << "view";

diff  --git a/mlir/lib/Dialect/Linalg/Transforms/HoistPadding.cpp b/mlir/lib/Dialect/Linalg/Transforms/HoistPadding.cpp
index bff2d5406f5c2..a83305d013bcb 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/HoistPadding.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/HoistPadding.cpp
@@ -151,7 +151,7 @@ computeTransposedType(RankedTensorType rankedTensorType,
                       ArrayRef<int64_t> transposeVector) {
   if (transposeVector.empty())
     return rankedTensorType;
-  if (!isPermutation(transposeVector) ||
+  if (!isPermutationVector(transposeVector) ||
       transposeVector.size() != static_cast<size_t>(rankedTensorType.getRank()))
     return failure();
 

diff  --git a/mlir/lib/Dialect/Linalg/Utils/Utils.cpp b/mlir/lib/Dialect/Linalg/Utils/Utils.cpp
index 712d4c2de32f1..3267b2aff075f 100644
--- a/mlir/lib/Dialect/Linalg/Utils/Utils.cpp
+++ b/mlir/lib/Dialect/Linalg/Utils/Utils.cpp
@@ -25,6 +25,7 @@
 #include "mlir/Dialect/SCF/IR/SCF.h"
 #include "mlir/Dialect/Tensor/IR/Tensor.h"
 #include "mlir/Dialect/Tensor/Utils/Utils.h"
+#include "mlir/Dialect/Utils/IndexingUtils.h"
 #include "mlir/Dialect/Utils/StaticValueUtils.h"
 #include "mlir/IR/AffineExpr.h"
 #include "mlir/IR/AffineExprVisitor.h"
@@ -409,7 +410,7 @@ GenericOp makeTransposeOp(OpBuilder &b, Location loc, Value inputTensor,
   auto resultTensorType = outputTensor.getType().cast<RankedTensorType>();
   Type elementType = resultTensorType.getElementType();
 
-  assert(isPermutation(transposeVector) &&
+  assert(isPermutationVector(transposeVector) &&
          "expect transpose vector to be a permutation");
   assert(transposeVector.size() ==
              static_cast<size_t>(resultTensorType.getRank()) &&

diff  --git a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
index aea0f07a2272b..092f853b11d10 100644
--- a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
@@ -18,6 +18,7 @@
 #include "mlir/Dialect/Func/IR/FuncOps.h"
 #include "mlir/Dialect/SCF/Utils/Utils.h"
 #include "mlir/Dialect/Tensor/IR/Tensor.h"
+#include "mlir/Dialect/Utils/IndexingUtils.h"
 #include "mlir/IR/Matchers.h"
 #include "mlir/IR/PatternMatch.h"
 #include "mlir/Interfaces/DestinationStyleOpInterface.h"
@@ -59,34 +60,6 @@ fillInterchangeVector(ArrayRef<int64_t> interchangeVector,
   return filledVector;
 }
 
-/// Helper method to apply permutation to a vector
-template <typename T>
-static SmallVector<T> applyPermutationToVector(const SmallVector<T> &vector,
-                                               ArrayRef<int64_t> interchange) {
-  assert(interchange.size() == vector.size());
-  return llvm::to_vector(
-      llvm::map_range(interchange, [&](int64_t val) { return vector[val]; }));
-}
-/// Helper method to apply to invert a permutation.
-static SmallVector<int64_t>
-invertPermutationVector(ArrayRef<int64_t> interchange) {
-  SmallVector<int64_t> inversion(interchange.size());
-  for (const auto &pos : llvm::enumerate(interchange)) {
-    inversion[pos.value()] = pos.index();
-  }
-  return inversion;
-}
-/// Method to check if an interchange vector is a permutation.
-static bool isPermutation(ArrayRef<int64_t> interchange) {
-  llvm::SmallDenseSet<int64_t, 4> seenVals;
-  for (auto val : interchange) {
-    if (seenVals.count(val))
-      return false;
-    seenVals.insert(val);
-  }
-  return seenVals.size() == interchange.size();
-}
-
 //===----------------------------------------------------------------------===//
 // tileUsingSCFForOp implementation.
 //===----------------------------------------------------------------------===//
@@ -321,16 +294,14 @@ mlir::scf::tileUsingSCFForOp(RewriterBase &rewriter, TilingInterface op,
                                                 iterationDomain.size());
     }
     if (!interchangeVector.empty()) {
-      if (!isPermutation(interchangeVector)) {
+      if (!isPermutationVector(interchangeVector)) {
         return rewriter.notifyMatchFailure(
             op, "invalid intechange vector, not a permutation of the entire "
                 "iteration space");
       }
 
-      iterationDomain =
-          applyPermutationToVector(iterationDomain, interchangeVector);
-      tileSizeVector =
-          applyPermutationToVector(tileSizeVector, interchangeVector);
+      applyPermutationToVector(iterationDomain, interchangeVector);
+      applyPermutationToVector(tileSizeVector, interchangeVector);
     }
 
     // 3. Materialize an empty loop nest that iterates over the tiles. These
@@ -341,8 +312,8 @@ mlir::scf::tileUsingSCFForOp(RewriterBase &rewriter, TilingInterface op,
 
     if (!interchangeVector.empty()) {
       auto inversePermutation = invertPermutationVector(interchangeVector);
-      offsets = applyPermutationToVector(offsets, inversePermutation);
-      sizes = applyPermutationToVector(sizes, inversePermutation);
+      applyPermutationToVector(offsets, inversePermutation);
+      applyPermutationToVector(sizes, inversePermutation);
     }
   }
 

diff  --git a/mlir/lib/Dialect/Utils/IndexingUtils.cpp b/mlir/lib/Dialect/Utils/IndexingUtils.cpp
index 1c7a89cd755f3..ed92ade8ab092 100644
--- a/mlir/lib/Dialect/Utils/IndexingUtils.cpp
+++ b/mlir/lib/Dialect/Utils/IndexingUtils.cpp
@@ -86,6 +86,25 @@ int64_t mlir::computeMaxLinearIndex(ArrayRef<int64_t> basis) {
                          std::multiplies<int64_t>());
 }
 
+llvm::SmallVector<int64_t>
+mlir::invertPermutationVector(ArrayRef<int64_t> permutation) {
+  SmallVector<int64_t> inversion(permutation.size());
+  for (const auto &pos : llvm::enumerate(permutation)) {
+    inversion[pos.value()] = pos.index();
+  }
+  return inversion;
+}
+
+bool mlir::isPermutationVector(ArrayRef<int64_t> interchange) {
+  llvm::SmallDenseSet<int64_t, 4> seenVals;
+  for (auto val : interchange) {
+    if (seenVals.count(val))
+      return false;
+    seenVals.insert(val);
+  }
+  return seenVals.size() == interchange.size();
+}
+
 llvm::SmallVector<int64_t> mlir::getI64SubArray(ArrayAttr arrayAttr,
                                                 unsigned dropFront,
                                                 unsigned dropBack) {


        


More information about the Mlir-commits mailing list