[Mlir-commits] [mlir] [mlir][linalg] raise generic to named ops. (PR #110421)

Javed Absar llvmlistbot at llvm.org
Thu Oct 10 04:45:11 PDT 2024


================
@@ -87,16 +111,78 @@ std::optional<Value> linalg::isaFillOpInterface(GenericOp genericOp) {
   OpOperand *value = genericOp.getDpsInputOperand(0);
   if (!genericOp.isScalar(value))
     return std::nullopt;
+  return value->get();
+}
 
-  Block *body = genericOp.getBody();
-  if (body->getOperations().size() != 1)
+//===----------------------------------------------------------------------===//
+// BroadcastOpInterface implementation
+//===----------------------------------------------------------------------===//
+std::optional<SmallVector<int64_t>>
+linalg::isaBroadcastOpInterface(GenericOp genericOp) {
+  // Structural.
+  if (!isAllParallel(genericOp) || !isSingleInputOutput(genericOp) ||
+      !isSingleYieldOp(genericOp))
     return std::nullopt;
 
-  auto yieldOp = dyn_cast<linalg::YieldOp>(body->back());
-  if (!yieldOp || yieldOp.getNumOperands() != 1 ||
-      yieldOp->getOperand(0) != body->getArgument(0))
+  auto t0 = genericOp.getDpsInputOperand(0)->get().getType();
+  auto t1 = genericOp.getDpsInitOperand(0)->get().getType();
+  if (!isa<MemRefType, RankedTensorType>(t0) ||
+      !isa<MemRefType, RankedTensorType>(t1))
     return std::nullopt;
-  return value->get();
+
+  // Check output is identity map. Injective function could also be
+  // a permutation of indices and expressible in linalg.generic but
+  // is not expressible for named broadcast op.
+  auto dstMap = genericOp.getIndexingMapsArray()[1];
+  if (!dstMap.isIdentity())
+    return std::nullopt;
+
+  SmallVector<int64_t> position;
+  auto srcMap = genericOp.getIndexingMapsArray()[0];
+
+  // Check input map is monotonically increasing DimIds.
+  for (unsigned i = 0; i < srcMap.getNumResults(); ++i) {
+    auto expr = llvm::dyn_cast<AffineDimExpr>(srcMap.getResults()[i]);
+    if (!expr)
+      return std::nullopt;
+    int64_t pos = expr.getPosition();
+    if (i > 0 && pos <= position[i - 1])
+      return std::nullopt;
+    position.push_back(expr.getPosition());
+  }
+
+  SmallVector<int64_t> broadcastedDims;
+  auto numDims = srcMap.getNumDims();
+  for (auto dim : llvm::seq<int64_t>(0, numDims)) {
+    if (!llvm::is_contained(position, dim))
+      broadcastedDims.push_back(dim);
+  }
+  return broadcastedDims;
+}
+
+//===----------------------------------------------------------------------===//
+// TranposeOpInterface implementation
+//===----------------------------------------------------------------------===//
+std::optional<SmallVector<int64_t>>
+linalg::isaTransposeOpInterface(GenericOp genericOp) {
+  // Structural.
+  if (!isAllParallel(genericOp) || !isSingleInputOutput(genericOp) ||
+      !isSingleYieldOp(genericOp))
+    return std::nullopt;
+
+  // mapping checks.
+  auto mapRange = genericOp.getIndexingMapsArray();
+  if (mapRange.size() != 2 || !mapRange.back().isIdentity() ||
+      !mapRange.front().isPermutation())
+    return std::nullopt;
+
+  SmallVector<int64_t> permutation;
+  auto map = mapRange.front();
----------------
javedabsar1 wrote:

I spotted an error and fixed it. But now its harder to use range-for as the index value is needed.

https://github.com/llvm/llvm-project/pull/110421


More information about the Mlir-commits mailing list