[Mlir-commits] [mlir] [mlir] Add helper to check elementwise-mappable ops with tensors and scalars (PR #154872)
Adam Siemieniuk
llvmlistbot at llvm.org
Mon Aug 25 03:12:20 PDT 2025
================
@@ -81,13 +105,39 @@ struct ConvertAnyElementwiseMappableOpOnRankedTensors : public RewritePattern {
return rewriter.notifyMatchFailure(
op, "requires elementwise op on ranked tensors");
- auto rank = cast<RankedTensorType>(op->getResult(0).getType()).getRank();
- SmallVector<AffineMap, 3> indexingMaps(
- op->getNumResults() + op->getNumOperands(),
- rewriter.getMultiDimIdentityMap(rank));
- SmallVector<utils::IteratorType, 6> iteratorTypes(
+ auto resTy = cast<RankedTensorType>(op->getResult(0).getType());
+ auto rank = resTy.getRank();
+
+ // Maps: identity for tensors (rank > 0), scalar map for scalars/rank-0.
+ AffineMap scalarMap = AffineMap::get(/*dimCount=*/rank, /*symbolCount=*/0,
+ /*results=*/{}, rewriter.getContext());
+ AffineMap idMap = rewriter.getMultiDimIdentityMap(rank);
+
+ // Create indexing maps: one per operand, one per result.
+ SmallVector<AffineMap, 6> indexingMaps;
+ indexingMaps.reserve(op->getNumOperands() + op->getNumResults());
+
+ for (Value v : op->getOperands()) {
+ Type ty = v.getType();
----------------
adam-smnk wrote:
Looks like you could use `getOperandTypes()` directly.
nit: I'd suggest a bit more expressive name than `v` (even just `val` is better)
https://github.com/llvm/llvm-project/pull/154872
More information about the Mlir-commits
mailing list