[Mlir-commits] [mlir] [mlir][linalg] raise generic to named ops. (PR #110421)
Renato Golin
llvmlistbot at llvm.org
Mon Oct 7 04:52:04 PDT 2024
================
@@ -87,16 +111,78 @@ std::optional<Value> linalg::isaFillOpInterface(GenericOp genericOp) {
OpOperand *value = genericOp.getDpsInputOperand(0);
if (!genericOp.isScalar(value))
return std::nullopt;
+ return value->get();
+}
- Block *body = genericOp.getBody();
- if (body->getOperations().size() != 1)
+//===----------------------------------------------------------------------===//
+// BroadcastOpInterface implementation
+//===----------------------------------------------------------------------===//
+std::optional<SmallVector<int64_t>>
+linalg::isaBroadcastOpInterface(GenericOp genericOp) {
+ // Structural.
+ if (!isAllParallel(genericOp) || !isSingleInputOutput(genericOp) ||
+ !isSingleYieldOp(genericOp))
return std::nullopt;
- auto yieldOp = dyn_cast<linalg::YieldOp>(body->back());
- if (!yieldOp || yieldOp.getNumOperands() != 1 ||
- yieldOp->getOperand(0) != body->getArgument(0))
+ auto t0 = genericOp.getDpsInputOperand(0)->get().getType();
+ auto t1 = genericOp.getDpsInitOperand(0)->get().getType();
+ if (!isa<MemRefType, RankedTensorType>(t0) ||
+ !isa<MemRefType, RankedTensorType>(t1))
return std::nullopt;
- return value->get();
+
+ // Check output is identity map. Injective function could also be
+ // a permutation of indices and expressible in linalg.generic but
+ // is not expressible for named broadcast op.
+ auto dstMap = genericOp.getIndexingMapsArray()[1];
+ if (!dstMap.isIdentity())
+ return std::nullopt;
+
+ SmallVector<int64_t> position;
+ auto srcMap = genericOp.getIndexingMapsArray()[0];
+
+ // Check input map is monotonically increasing DimIds.
+ for (unsigned i = 0; i < srcMap.getNumResults(); ++i) {
+ auto expr = llvm::dyn_cast<AffineDimExpr>(srcMap.getResults()[i]);
+ if (!expr)
+ return std::nullopt;
+ int64_t pos = expr.getPosition();
+ if (i > 0 && pos <= position[i - 1])
+ return std::nullopt;
+ position.push_back(expr.getPosition());
+ }
+
+ SmallVector<int64_t> broadcastedDims;
+ auto numDims = srcMap.getNumDims();
+ for (auto dim : llvm::seq<int64_t>(0, numDims)) {
----------------
rengolin wrote:
This is quadratic, but it should be fine, since the number of items is generally really small.
https://github.com/llvm/llvm-project/pull/110421
More information about the Mlir-commits
mailing list