[Mlir-commits] [mlir] [mlir] Add inferContractionDims util for indexing map inputs (PR #76081)

Quinn Dawkins llvmlistbot at llvm.org
Thu Dec 21 07:35:19 PST 2023


================
@@ -201,11 +202,47 @@ findPermutationsIndexingOperand(LinalgOp linalgOp, OpOperand *opOperand,
   return res;
 }
 
+static llvm::SmallDenseSet<int64_t>
+findPermutationsIndexingOperand(LinalgOp linalgOp, OpOperand *opOperand,
+                                utils::IteratorType iter) {
+  assert(linalgOp == opOperand->getOwner() && "expected linalgOp owner");
+  return findPermutationsIndexingOperandImpl(
+      linalgOp.getMatchingIndexingMap(opOperand),
+      linalgOp.getIteratorTypesArray(), iter);
+}
+
+static llvm::SmallDenseSet<int64_t>
+findPermutationsIndexingOperand(AffineMap indexingMap,
+                                ArrayRef<utils::IteratorType> iterators,
+                                utils::IteratorType iter) {
+  assert(iterators.size() == indexingMap.getNumDims());
+  return findPermutationsIndexingOperandImpl(indexingMap, iterators, iter);
+}
+
 namespace {
 auto par = utils::IteratorType::parallel;
 auto red = utils::IteratorType::reduction;
 } // namespace
 
+/// Infer the iterator types from the init affine map. This looks at which dims
+/// are present in the map results, and returns an iterator types array with
+/// parallel types for dims that are present, and reduction types for dims that
+/// are not present.
+static FailureOr<ArrayRef<utils::IteratorType>>
+inferIteratorsFromOutMap(AffineMap map) {
+  if (!map.isProjectedPermutation())
+    return failure();
+  SmallVector<utils::IteratorType> iterators(map.getNumDims(), red);
+  for (auto expr : map.getResults()) {
+    if (auto dim = dyn_cast<AffineDimExpr>(expr)) {
+      iterators[dim.getPosition()] = par;
+    }
+  }
+  if (iterators.size() != map.getNumDims())
+    return failure();
----------------
qedawkins wrote:

nit: This check does not seem possible given you aren't changing the size of the vector.

https://github.com/llvm/llvm-project/pull/76081


More information about the Mlir-commits mailing list