[Mlir-commits] [mlir] bf3f701 - [mlir][NFC] Generalize `getPermutedPosition`

Diego Caballero llvmlistbot at llvm.org
Thu Dec 1 11:00:48 PST 2022


Author: Diego Caballero
Date: 2022-12-01T18:58:25Z
New Revision: bf3f7016b17970478d1b8af481318c62d0a9004e

URL: https://github.com/llvm/llvm-project/commit/bf3f7016b17970478d1b8af481318c62d0a9004e
DIFF: https://github.com/llvm/llvm-project/commit/bf3f7016b17970478d1b8af481318c62d0a9004e.diff

LOG: [mlir][NFC] Generalize `getPermutedPosition`

Small change to support projected permutations in the
`getPermutedPosition` utility. Renamed to `getResultPosition`.

Reviewed By: nicolasvasilache

Differential Revision: https://reviews.llvm.org/D138946

Added: 
    

Modified: 
    mlir/include/mlir/IR/AffineMap.h
    mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
    mlir/lib/IR/AffineMap.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/IR/AffineMap.h b/mlir/include/mlir/IR/AffineMap.h
index ccd167dfa5ab..74cc3b1eb17a 100644
--- a/mlir/include/mlir/IR/AffineMap.h
+++ b/mlir/include/mlir/IR/AffineMap.h
@@ -166,9 +166,10 @@ class AffineMap {
   /// when the caller knows it is safe to do so.
   unsigned getDimPosition(unsigned idx) const;
 
-  /// Extracts the permuted position where given input index resides.
-  /// Fails when called on a non-permutation.
-  unsigned getPermutedPosition(unsigned input) const;
+  /// Extracts the first result position where `input` dimension resides.
+  /// Returns `llvm::None` if `input` is not a dimension expression or cannot be
+  /// found in results.
+  Optional<unsigned> getResultPosition(AffineExpr input) const;
 
   /// Return true if any affine expression involves AffineDimExpr `position`.
   bool isFunctionOfDim(unsigned position) const {

diff  --git a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
index f4649b3876cc..652e0504c5ee 100644
--- a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
+++ b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
@@ -300,7 +300,10 @@ uint64_t mlir::sparse_tensor::toStoredDim(const SparseTensorEncodingAttr &enc,
     auto order = enc.getDimOrdering();
     if (order) {
       assert(order.isPermutation());
-      return order.getPermutedPosition(d);
+      auto maybePos =
+          order.getResultPosition(getAffineDimExpr(d, enc.getContext()));
+      assert(maybePos.has_value());
+      return *maybePos;
     }
   }
   return d;

diff  --git a/mlir/lib/IR/AffineMap.cpp b/mlir/lib/IR/AffineMap.cpp
index d275fc88cc90..e6186834cf5e 100644
--- a/mlir/lib/IR/AffineMap.cpp
+++ b/mlir/lib/IR/AffineMap.cpp
@@ -328,12 +328,16 @@ unsigned AffineMap::getDimPosition(unsigned idx) const {
   return getResult(idx).cast<AffineDimExpr>().getPosition();
 }
 
-unsigned AffineMap::getPermutedPosition(unsigned input) const {
-  assert(isPermutation() && "invalid permutation request");
-  for (unsigned i = 0, numResults = getNumResults(); i < numResults; i++)
-    if (getDimPosition(i) == input)
+Optional<unsigned> AffineMap::getResultPosition(AffineExpr input) const {
+  if (!input.isa<AffineDimExpr>())
+    return llvm::None;
+
+  for (unsigned i = 0, numResults = getNumResults(); i < numResults; i++) {
+    if (getResult(i) == input)
       return i;
-  llvm_unreachable("incorrect permutation request");
+  }
+
+  return llvm::None;
 }
 
 /// Folds the results of the application of an affine map on the provided


        


More information about the Mlir-commits mailing list