[Mlir-commits] [mlir] [mlir][linalg] raise generic to named ops. (PR #110421)
Javed Absar
llvmlistbot at llvm.org
Thu Oct 10 04:45:12 PDT 2024
================
@@ -87,16 +111,78 @@ std::optional<Value> linalg::isaFillOpInterface(GenericOp genericOp) {
OpOperand *value = genericOp.getDpsInputOperand(0);
if (!genericOp.isScalar(value))
return std::nullopt;
+ return value->get();
+}
- Block *body = genericOp.getBody();
- if (body->getOperations().size() != 1)
+//===----------------------------------------------------------------------===//
+// BroadcastOpInterface implementation
+//===----------------------------------------------------------------------===//
+std::optional<SmallVector<int64_t>>
+linalg::isaBroadcastOpInterface(GenericOp genericOp) {
+ // Structural.
+ if (!isAllParallel(genericOp) || !isSingleInputOutput(genericOp) ||
+ !isSingleYieldOp(genericOp))
return std::nullopt;
- auto yieldOp = dyn_cast<linalg::YieldOp>(body->back());
- if (!yieldOp || yieldOp.getNumOperands() != 1 ||
- yieldOp->getOperand(0) != body->getArgument(0))
+ auto t0 = genericOp.getDpsInputOperand(0)->get().getType();
+ auto t1 = genericOp.getDpsInitOperand(0)->get().getType();
+ if (!isa<MemRefType, RankedTensorType>(t0) ||
+ !isa<MemRefType, RankedTensorType>(t1))
return std::nullopt;
- return value->get();
+
+ // Check output is identity map. Injective function could also be
+ // a permutation of indices and expressible in linalg.generic but
+ // is not expressible for named broadcast op.
+ auto dstMap = genericOp.getIndexingMapsArray()[1];
+ if (!dstMap.isIdentity())
+ return std::nullopt;
+
+ SmallVector<int64_t> position;
+ auto srcMap = genericOp.getIndexingMapsArray()[0];
----------------
javedabsar1 wrote:
done.
https://github.com/llvm/llvm-project/pull/110421
More information about the Mlir-commits
mailing list