[Mlir-commits] [mlir] [mlir][linalg] Genericize MapOp (PR #162742)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Fri Oct 10 07:28:15 PDT 2025


https://github.com/srcarroll updated https://github.com/llvm/llvm-project/pull/162742

>From 7d4273a64897f54f67a3b55f086fe6613dba659e Mon Sep 17 00:00:00 2001
From: Sam <srcarroll314 at gmail.com>
Date: Thu, 9 Oct 2025 16:41:03 -0500
Subject: [PATCH 1/2] [mlir][linalg] Genericize MapOp

---
 .../Dialect/Linalg/IR/LinalgStructuredOps.td  |  4 ----
 mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp      | 21 +++++++++---------
 .../Linalg/Transforms/Generalization.cpp      |  6 ++---
 .../Dialect/Linalg/generalize-named-ops.mlir  | 22 ++++++++++++-------
 .../lower-to-loops-using-interface.mlir       |  6 ++---
 5 files changed, 29 insertions(+), 30 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
index f3674c3eecfe6..ecd036d452b27 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
@@ -293,10 +293,6 @@ def MapOp : LinalgStructuredBase_Op<"map", [
     // Implement functions necessary for DestinationStyleOpInterface.
     MutableOperandRange getDpsInitsMutable() { return getInitMutable(); }
 
-    SmallVector<OpOperand *> getOpOperandsMatchingBBargs() {
-      return getDpsInputOperands();
-    }
-
     bool payloadUsesValueFromOperand(OpOperand * opOperand) {
       if (isDpsInit(opOperand)) return false;
       return !getMatchingBlockArgument(opOperand).use_empty();
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index 59013a23b3e3b..6134a7535808d 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -1474,6 +1474,8 @@ void MapOp::getAsmBlockArgumentNames(Region &region,
                                      OpAsmSetValueNameFn setNameFn) {
   for (Value v : getRegionInputArgs())
     setNameFn(v, "in");
+  for (Value v : getRegionOutputArgs())
+    setNameFn(v, "init");
 }
 
 void MapOp::getAsmResultNames(function_ref<void(Value, StringRef)> setNameFn) {
@@ -1495,14 +1497,14 @@ void MapOp::build(
 
   if (bodyBuild)
     buildGenericRegion(builder, result.location, *result.regions.front(),
-                       inputs, /*outputs=*/{}, bodyBuild);
+                       inputs, /*outputs=*/{init}, bodyBuild);
 }
 
 static void addBodyWithPayloadOp(OpAsmParser &parser, OperationState &result,
                                  const OperationName &payloadOpName,
                                  const NamedAttrList &payloadOpAttrs,
                                  ArrayRef<Value> operands,
-                                 bool initFirst = false) {
+                                 bool initFirst = false, bool mapInit = true) {
   OpBuilder b(parser.getContext());
   Region *body = result.addRegion();
   Block &block = body->emplaceBlock();
@@ -1516,12 +1518,13 @@ static void addBodyWithPayloadOp(OpAsmParser &parser, OperationState &result,
   // If initFirst flag is enabled, we consider init as the first position of
   // payload operands.
   if (initFirst) {
-    payloadOpOperands.push_back(block.getArguments().back());
+    if (mapInit)
+      payloadOpOperands.push_back(block.getArguments().back());
     for (const auto &arg : block.getArguments().drop_back())
       payloadOpOperands.push_back(arg);
   } else {
     payloadOpOperands = {block.getArguments().begin(),
-                         block.getArguments().end()};
+                         block.getArguments().end() - int(!mapInit)};
   }
 
   Operation *payloadOp = b.create(
@@ -1551,12 +1554,8 @@ ParseResult MapOp::parse(OpAsmParser &parser, OperationState &result) {
     return failure();
 
   if (payloadOpName.has_value()) {
-    if (!result.operands.empty())
-      addBodyWithPayloadOp(parser, result, payloadOpName.value(),
-                           payloadOpAttrs,
-                           ArrayRef(result.operands).drop_back());
-    else
-      result.addRegion();
+    addBodyWithPayloadOp(parser, result, payloadOpName.value(), payloadOpAttrs,
+                         ArrayRef(result.operands), false, false);
   } else {
     SmallVector<OpAsmParser::Argument> regionArgs;
     if (parser.parseArgumentList(regionArgs, OpAsmParser::Delimiter::Paren,
@@ -1659,7 +1658,7 @@ LogicalResult MapOp::verify() {
   auto blockArgs = bodyBlock->getArguments();
 
   // Checks if the number of `inputs` match the arity of the `mapper` region.
-  if (getInputs().size() != blockArgs.size())
+  if (getInputs().size() + 1 != blockArgs.size())
     return emitOpError() << "expects number of operands to match the arity of "
                             "mapper, but got: "
                          << getInputs().size() << " and " << blockArgs.size();
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Generalization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Generalization.cpp
index 3e31393fd51ed..75bb1757a55f5 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Generalization.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Generalization.cpp
@@ -31,10 +31,8 @@ using namespace mlir;
 using namespace mlir::linalg;
 
 static LogicalResult generalizeNamedOpPrecondition(LinalgOp linalgOp) {
-  // Bailout if `linalgOp` is already a generic or a linalg.map. We cannot
-  // trivially generalize a `linalg.map`, as it does not use the output as
-  // region arguments in the block.
-  if (isa<GenericOp>(linalgOp) || isa<MapOp>(linalgOp))
+  // Bailout if `linalgOp` is already a generic.
+  if (isa<GenericOp>(linalgOp))
     return failure();
   // Check if the operation has exactly one region.
   if (linalgOp->getNumRegions() != 1) {
diff --git a/mlir/test/Dialect/Linalg/generalize-named-ops.mlir b/mlir/test/Dialect/Linalg/generalize-named-ops.mlir
index ae07b1b82228c..dcdd6c8db4b21 100644
--- a/mlir/test/Dialect/Linalg/generalize-named-ops.mlir
+++ b/mlir/test/Dialect/Linalg/generalize-named-ops.mlir
@@ -386,18 +386,24 @@ func.func @generalize_batch_reduce_gemm_bf16(%lhs: memref<7x8x9xbf16>, %rhs: mem
 
 // -----
 
-// CHECK-LABEL: generalize_linalg_map
-func.func @generalize_linalg_map(%arg0: memref<1x8x8x8xf32>) {
+func.func @generalize_linalg_map(%arg0: memref<1x8x8x8xf32>, %arg1: memref<1x8x8x8xf32>, %arg2: memref<1x8x8x8xf32>) {
   %cst = arith.constant 0.000000e+00 : f32
-  // CHECK: linalg.map
-  // CHECK-NOT: linalg.generic
-  linalg.map outs(%arg0 : memref<1x8x8x8xf32>)
-    () {
-      linalg.yield %cst : f32
-    }
+  linalg.map {arith.addf} ins(%arg0, %arg1: memref<1x8x8x8xf32>, memref<1x8x8x8xf32>) outs(%arg2 : memref<1x8x8x8xf32>)
   return
 }
 
+// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
+
+// CHECK: @generalize_linalg_map
+
+// CHECK: linalg.generic
+// CHECK-SAME: indexing_maps = [#[[MAP0]], #[[MAP0]], #[[MAP0]]]
+// CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel", "parallel"]}
+// CHECK-SAME: ins(%{{.+}}, %{{.+}} : memref<1x8x8x8xf32>, memref<1x8x8x8xf32>) outs(%{{.+}} : memref<1x8x8x8xf32>
+// CHECK:         ^{{.+}}(%[[BBARG0:.+]]: f32, %[[BBARG1:.+]]: f32, %[[BBARG2:.+]]: f32)
+// CHECK:         %[[ADD:.+]] = arith.addf %[[BBARG0]], %[[BBARG1]] : f32
+// CHECK:         linalg.yield %[[ADD]] : f32
+
 // -----
 
 func.func @generalize_add(%lhs: memref<7x14x21xf32>, %rhs: memref<7x14x21xf32>,
diff --git a/mlir/test/Interfaces/TilingInterface/lower-to-loops-using-interface.mlir b/mlir/test/Interfaces/TilingInterface/lower-to-loops-using-interface.mlir
index 8cbee3cbb758b..aa8882d21698c 100644
--- a/mlir/test/Interfaces/TilingInterface/lower-to-loops-using-interface.mlir
+++ b/mlir/test/Interfaces/TilingInterface/lower-to-loops-using-interface.mlir
@@ -257,10 +257,10 @@ module attributes {transform.with_named_sequence} {
 // -----
 
 func.func @map(%lhs: memref<64xf32>,
-    %rhs: memref<64xf32>, %out: memref<64xf32>) {
+    %rhs: memref<64xf32>, %init: memref<64xf32>) {
   linalg.map ins(%lhs, %rhs : memref<64xf32>, memref<64xf32>)
-             outs(%out : memref<64xf32>)
-    (%in: f32, %in_0: f32) {
+             outs(%init : memref<64xf32>)
+    (%in: f32, %in_0: f32, %out: f32) {
       %0 = arith.addf %in, %in_0 : f32
       linalg.yield %0 : f32
     }

>From 7a1d5f3cfdb7714e425da84e97a0855982b34831 Mon Sep 17 00:00:00 2001
From: Sam <srcarroll314 at gmail.com>
Date: Fri, 10 Oct 2025 09:27:56 -0500
Subject: [PATCH 2/2] revert unnecessary change

---
 mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp | 14 ++++++++++----
 1 file changed, 10 insertions(+), 4 deletions(-)

diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index 6134a7535808d..4a3b7ad49eb59 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -1554,8 +1554,12 @@ ParseResult MapOp::parse(OpAsmParser &parser, OperationState &result) {
     return failure();
 
   if (payloadOpName.has_value()) {
-    addBodyWithPayloadOp(parser, result, payloadOpName.value(), payloadOpAttrs,
-                         ArrayRef(result.operands), false, false);
+    if (!result.operands.empty())
+      addBodyWithPayloadOp(parser, result, payloadOpName.value(),
+                           payloadOpAttrs, ArrayRef(result.operands), false,
+                           false);
+    else
+      result.addRegion();
   } else {
     SmallVector<OpAsmParser::Argument> regionArgs;
     if (parser.parseArgumentList(regionArgs, OpAsmParser::Delimiter::Paren,
@@ -1657,11 +1661,13 @@ LogicalResult MapOp::verify() {
   auto *bodyBlock = getBody();
   auto blockArgs = bodyBlock->getArguments();
 
-  // Checks if the number of `inputs` match the arity of the `mapper` region.
+  // Checks if the number of `inputs` + `init` match the arity of the `mapper`
+  // region.
   if (getInputs().size() + 1 != blockArgs.size())
     return emitOpError() << "expects number of operands to match the arity of "
                             "mapper, but got: "
-                         << getInputs().size() << " and " << blockArgs.size();
+                         << getInputs().size() + 1 << " and "
+                         << blockArgs.size();
 
   // The parameters of mapper should all match the element type of inputs.
   for (const auto &[bbArgType, inputArg] :



More information about the Mlir-commits mailing list