[Mlir-commits] [mlir] [mlir][linalg] Fix incorrect linalg short form printing (PR #153219)

Boyana Norris llvmlistbot at llvm.org
Wed Aug 13 21:08:44 PDT 2025


https://github.com/brnorris03 updated https://github.com/llvm/llvm-project/pull/153219

>From c24793e4b707d74b727887849e4daad45d513322 Mon Sep 17 00:00:00 2001
From: Boyana Norris <brnorris03 at gmail.com>
Date: Tue, 12 Aug 2025 07:52:12 -0700
Subject: [PATCH 1/4] add yield check

---
 mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp | 58 +++++++++++++++++++-----
 1 file changed, 46 insertions(+), 12 deletions(-)

diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index 9d7fb18f56fef..1e2e6893eac21 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -1570,36 +1570,70 @@ ParseResult MapOp::parse(OpAsmParser &parser, OperationState &result) {
   return success();
 }
 
-// Retrieve the operation from the body, if it is the only one (except
-// yield) and if it gets the same amount of arguments as the body does.
-// If initFirst flag is enabled, we check that init takes the first position in
-// operands of payload.
-static Operation *findPayloadOp(Block *body, bool initFirst = false) {
+// Check if a block contains a single payload operation that can be printed in
+// short form. The block must contain exactly 2 operations: the payload op and a
+// yield.
+static bool canUseShortForm(Block *body) {
   if (body->getOperations().size() != 2)
-    return nullptr;
+    return false;
+
   Operation &payload = body->getOperations().front();
   assert(isa<YieldOp>(body->getOperations().back()));
 
-  if (payload.getNumOperands() == 0 ||
-      payload.getNumOperands() != body->getNumArguments())
+  // Check that the yield has exactly one operand that comes from the payload
+  auto yieldOp = cast<YieldOp>(body->getOperations().back());
+  if (yieldOp.getNumOperands() != 1)
+    return false;
+
+  Value yieldOperand = yieldOp.getOperand(0);
+  if (!yieldOperand.getDefiningOp() || yieldOperand.getDefiningOp() != &payload)
+    return false;
+
+  return true;
+}
+
+// Find a payload operation that can be printed in short form.
+// For MapOp (initFirst=false): operands must match block arguments in order.
+// For ReduceOp (initFirst=true): init operand must be first, then operands must
+// match block arguments.
+static Operation *findPayloadOp(Block *body, bool initFirst = false) {
+  if (!canUseShortForm(body))
     return nullptr;
+
+  Operation &payload = body->getOperations().front();
+
   if (initFirst) {
-    // check init
-    if (payload.getOperands().back() != body->getArgument(0))
+    // For ReduceOp: check that operand count matches block argument count + 1
+    // (for init)
+    if (payload.getNumOperands() == 0 ||
+        payload.getNumOperands() != body->getNumArguments() + 1)
+      return nullptr;
+
+    // Check that init operand is first
+    if (payload.getOperands().front() != body->getArgument(0))
       return nullptr;
-    // check rest
+
+    // Check that remaining operands match block arguments in order
     for (const auto &[operand, bbArg] :
-         llvm::zip(payload.getOperands(), body->getArguments().drop_front())) {
+         llvm::zip(payload.getOperands().drop_front(),
+                   body->getArguments().drop_front())) {
       if (bbArg != operand)
         return nullptr;
     }
   } else {
+    // For MapOp: check that operand count matches block argument count
+    if (payload.getNumOperands() == 0 ||
+        payload.getNumOperands() != body->getNumArguments())
+      return nullptr;
+
+    // Check that operands match block arguments in order
     for (const auto &[operand, bbArg] :
          llvm::zip(payload.getOperands(), body->getArguments())) {
       if (bbArg != operand)
         return nullptr;
     }
   }
+
   return &payload;
 }
 

>From 16080a0246e12ec9527ccae32481de2b989ae853 Mon Sep 17 00:00:00 2001
From: Boyana Norris <brnorris03 at gmail.com>
Date: Tue, 12 Aug 2025 08:43:26 -0700
Subject: [PATCH 2/4] cleanup

---
 mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp | 90 +++++++++---------------
 1 file changed, 33 insertions(+), 57 deletions(-)

diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index 1e2e6893eac21..7af4ea6a2f3a4 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -1570,74 +1570,50 @@ ParseResult MapOp::parse(OpAsmParser &parser, OperationState &result) {
   return success();
 }
 
-// Check if a block contains a single payload operation that can be printed in
-// short form. The block must contain exactly 2 operations: the payload op and a
-// yield.
-static bool canUseShortForm(Block *body) {
+static bool canUseShortForm(Block *body, bool initFirst = false) {
+  // Check if the body can be printed in short form. The following 4 conditions
+  // must be satisfied:
+
+  // 1) The body must contain exactly 2 operations: the payload op and a yield.
   if (body->getOperations().size() != 2)
     return false;
-
   Operation &payload = body->getOperations().front();
-  assert(isa<YieldOp>(body->getOperations().back()));
-
-  // Check that the yield has exactly one operand that comes from the payload
-  auto yieldOp = cast<YieldOp>(body->getOperations().back());
-  if (yieldOp.getNumOperands() != 1)
-    return false;
 
-  Value yieldOperand = yieldOp.getOperand(0);
-  if (!yieldOperand.getDefiningOp() || yieldOperand.getDefiningOp() != &payload)
+  // 2) The payload op must have the same number of operands as the number of
+  //    block arguments.
+  if (payload.getNumOperands() == 0 ||
+      payload.getNumOperands() != body->getNumArguments())
     return false;
 
-  return true;
-}
-
-// Find a payload operation that can be printed in short form.
-// For MapOp (initFirst=false): operands must match block arguments in order.
-// For ReduceOp (initFirst=true): init operand must be first, then operands must
-// match block arguments.
-static Operation *findPayloadOp(Block *body, bool initFirst = false) {
-  if (!canUseShortForm(body))
-    return nullptr;
-
-  Operation &payload = body->getOperations().front();
-
+  // 3) If `initFirst` is true (e.g., for reduction ops), the init block
+  //    must be the first operand of the payload op, otherwise, the operands
+  //    must match the block arguments in order.
   if (initFirst) {
-    // For ReduceOp: check that operand count matches block argument count + 1
-    // (for init)
-    if (payload.getNumOperands() == 0 ||
-        payload.getNumOperands() != body->getNumArguments() + 1)
-      return nullptr;
-
-    // Check that init operand is first
-    if (payload.getOperands().front() != body->getArgument(0))
-      return nullptr;
-
-    // Check that remaining operands match block arguments in order
+    // check init
+    if (payload.getOperands().back() != body->getArgument(0))
+      return false;
+    // check rest
     for (const auto &[operand, bbArg] :
-         llvm::zip(payload.getOperands().drop_front(),
-                   body->getArguments().drop_front())) {
+         llvm::zip(payload.getOperands(), body->getArguments().drop_front())) {
       if (bbArg != operand)
-        return nullptr;
+        return false;
     }
   } else {
-    // For MapOp: check that operand count matches block argument count
-    if (payload.getNumOperands() == 0 ||
-        payload.getNumOperands() != body->getNumArguments())
-      return nullptr;
-
-    // Check that operands match block arguments in order
     for (const auto &[operand, bbArg] :
          llvm::zip(payload.getOperands(), body->getArguments())) {
       if (bbArg != operand)
-        return nullptr;
+        return false;
     }
   }
 
-  return &payload;
+  // 4) The `yield` operand must be the result of the payload op.
+  auto yieldOp = cast<YieldOp>(body->getTerminator());
+  return yieldOp.getNumOperands() == 1 &&
+         yieldOp.getOperand(0).getDefiningOp() &&
+         yieldOp.getOperand(0).getDefiningOp() == &payload;
 }
 
-void printShortForm(OpAsmPrinter &p, Operation *payloadOp) {
+static void printShortForm(OpAsmPrinter &p, Operation *payloadOp) {
   SmallVector<StringRef> elidedAttrs;
   std::string attrToElide;
   p << " { " << payloadOp->getName().getStringRef();
@@ -1656,15 +1632,15 @@ void printShortForm(OpAsmPrinter &p, Operation *payloadOp) {
 
 void MapOp::print(OpAsmPrinter &p) {
   Block *mapper = getBody();
-  Operation *payloadOp = findPayloadOp(mapper);
-  if (payloadOp) {
-    printShortForm(p, payloadOp);
+  bool useShortForm = canUseShortForm(mapper);
+  if (useShortForm) {
+    printShortForm(p, &mapper->getOperations().front());
   }
 
   printCommonStructuredOpParts(p, getDpsInputs(), getDpsInits());
   p.printOptionalAttrDict((*this)->getAttrs());
 
-  if (!payloadOp) {
+  if (!useShortForm) {
     // Print region if the payload op was not detected.
     p.increaseIndent();
     p.printNewline();
@@ -1863,15 +1839,15 @@ static void printDenseI64ArrayAttr(OpAsmPrinter &p, StringRef attributeName,
 
 void ReduceOp::print(OpAsmPrinter &p) {
   Block *mapper = getBody();
-  Operation *payloadOp = findPayloadOp(mapper, /*initFirst=*/true);
-  if (payloadOp) {
-    printShortForm(p, payloadOp);
+  bool useShortForm = canUseShortForm(mapper, /*initFirst=*/true);
+  if (useShortForm) {
+    printShortForm(p, &mapper->getOperations().front());
   }
 
   printCommonStructuredOpParts(p, getDpsInputs(), getDpsInits());
   printDenseI64ArrayAttr(p, getDimensionsAttrName(), getDimensions());
   p.printOptionalAttrDict((*this)->getAttrs(), {getDimensionsAttrName()});
-  if (!payloadOp) {
+  if (!useShortForm) {
     // Print region if the payload op was not detected.
     p.increaseIndent();
     p.printNewline();

>From e50edb76e170a8bdf8a4cedb16df0a39c2aac3e5 Mon Sep 17 00:00:00 2001
From: Boyana Norris <brnorris03 at gmail.com>
Date: Tue, 12 Aug 2025 12:41:47 -0700
Subject: [PATCH 3/4] add lit tests for reduce and map cases that can only use
 long form printing

---
 mlir/test/Dialect/Linalg/roundtrip.mlir | 51 +++++++++++++++++++++++++
 1 file changed, 51 insertions(+)

diff --git a/mlir/test/Dialect/Linalg/roundtrip.mlir b/mlir/test/Dialect/Linalg/roundtrip.mlir
index 4edbc6eda3eae..a09348c69d3a3 100644
--- a/mlir/test/Dialect/Linalg/roundtrip.mlir
+++ b/mlir/test/Dialect/Linalg/roundtrip.mlir
@@ -436,6 +436,34 @@ func.func @reduce(%input: tensor<16x32x64xf32>,
 //  CHECK-SAME:   outs
 //  CHECK-SAME:   dimensions = [1]
 
+
+// -----
+
+
+func.func @reduce_not_short_form_compatible(%input: tensor<16x32x64xf32>,
+                             %init: tensor<16x64xf32>) -> tensor<16x64xf32> {
+  %reduce = linalg.reduce
+    ins(%input:tensor<16x32x64xf32>)
+    outs(%init:tensor<16x64xf32>)
+    dimensions = [1]
+    (%in1: f32, %in2: f32) {
+      %0 = arith.addf %in1, %in2: f32
+      linalg.yield %in1: f32
+    }
+  func.return %reduce : tensor<16x64xf32>
+}
+
+// CHECK-LABEL: func @reduce_not_short_form_compatible
+// CHECK-SAME:  %[[ARG0:.*]]: tensor<16x32x64xf32>
+// CHECK-SAME:  %[[ARG1:.*]]: tensor<16x64xf32>
+// CHECK-NOT: linalg.reduce { arith.addf } ins(%[[ARG0]] : tensor<16x32x64xf32>
+// CHECK:     linalg.reduce ins(%[[ARG0]] : tensor<16x32x64xf32>) outs(%[[ARG1]] : tensor<16x64xf32>) 
+// CHECK-SAME:  dimensions = [1]
+// CHECK:       (%[[IN1:.*]]: f32, %[[IN2:.*]]: f32) {
+// CHECK-NEXT:    %[[ADD_RESULT:.*]] = arith.addf %[[IN1]], %[[IN2]] : f32
+// CHECK-NEXT:    linalg.yield %[[IN1]] : f32
+// CHECK-NEXT:  }
+
 // -----
 
 func.func @reduce_memref(%input: memref<16x32x64xf32>,
@@ -592,6 +620,29 @@ func.func @map_arith_with_attr(%lhs: tensor<64xf32>, %rhs: tensor<64xf32>,
 
 // -----
 
+func.func @map_not_short_form_compatible(%arg0: tensor<1x32xf32>, %arg1: tensor<1x32xf32>) -> tensor<1x32xf32> {
+  %res = tensor.empty() : tensor<1x32xf32>
+  %mapped = linalg.map ins(%arg0, %arg1 : tensor<1x32xf32>, tensor<1x32xf32>) outs(%res : tensor<1x32xf32>)
+    (%in_1: f32, %in_2: f32) {
+      %1 = arith.maximumf %in_1, %in_2 : f32
+      linalg.yield %in_1 : f32
+    }
+  func.return %mapped : tensor<1x32xf32>
+}
+
+// CHECK-LABEL: func @map_not_short_form_compatible
+// CHECK-SAME:        %[[ARG0:.*]]: tensor<1x32xf32>, %[[ARG1:.*]]: tensor<1x32xf32>
+// CHECK:         %[[RES:.*]] = tensor.empty() : tensor<1x32xf32>
+// CHECK-NOT:     linalg.map { arith.maximumf } ins(%[[ARG0]] : tensor<1x32xf32>
+// CHECK:         linalg.map ins(%[[ARG0]], %[[ARG1]] : tensor<1x32xf32>, tensor<1x32xf32>) 
+// CHECK-SAME:               outs(%[[RES]] : tensor<1x32xf32>)
+// CHECK-NEXT:      (%[[IN1:.*]]: f32, %[[IN2:.*]]: f32) {
+// CHECK-NEXT:        %[[MAX_RESULT:.*]] = arith.maximumf %[[IN1]], %[[IN2]] : f32
+// CHECK-NEXT:        linalg.yield %[[IN1]] : f32
+// CHECK-NEXT:    }
+
+// -----
+
 func.func @reduce_arith_with_attr(%input: tensor<16x32x64xf32>,
                   %init: tensor<16x64xf32>) -> tensor<16x64xf32> {
   %reduce = linalg.reduce

>From 5189c8f6e7eea2d3134fb2f29009720d53747ef3 Mon Sep 17 00:00:00 2001
From: Boyana Norris <brnorris03 at gmail.com>
Date: Wed, 13 Aug 2025 21:08:04 -0700
Subject: [PATCH 4/4] address comments

---
 .../Dialect/Linalg/IR/LinalgStructuredOps.td  | 20 ++++++++++-------
 mlir/test/Dialect/Linalg/roundtrip.mlir       | 22 +++++++++----------
 2 files changed, 22 insertions(+), 20 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
index e32d3d01bb182..f3674c3eecfe6 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
@@ -253,10 +253,12 @@ def MapOp : LinalgStructuredBase_Op<"map", [
           }
     ```
 
-    Shortened print form is available. Applies to simple maps with one
-    non-yield operation inside the body.
+    Shortened print form is available for simple maps where the body contains exactly
+    two operations (the payload operation and a yield), the payload operation has
+    the same number of operands as block arguments with operands matching block
+    arguments in order, and the yield operand is the result of the payload operation.
 
-    The example above will be printed as:
+    The example above will be printed using the shortened form as:
     ```mlir
       %add = linalg.map { arith.addf }
           ins(%lhs, %rhs : tensor<64xf32>, tensor<64xf32>)
@@ -340,13 +342,15 @@ def ReduceOp : LinalgStructuredBase_Op<"reduce", [
           }
     ```
 
-    Shortened print form is available. Applies to simple (not variadic) reduces
-    with one non-yield operation inside the body. Applies only if the operation
-    takes `%out` as the first argument.
+    Shortened print form is available for simple reduces where the body contains exactly
+    two operations (the payload operation and a yield), the payload operation has the
+    same number of operands as block arguments, the first block argument (init) is the
+    last operand of the payload operation with remaining operands matching remaining
+    block arguments in order, and the yield operand is the result of the payload operation.
 
-    The example above will be printed as:
+    The example above will be printed using the shortened form as:
     ```mlir
-          %reduce = linalg.reduce { arith.addf }
+      %reduce = linalg.reduce { arith.addf }
           ins(%input:tensor<16x32x64xf32>)
           outs(%init:tensor<16x64xf32>)
           dimensions = [1]
diff --git a/mlir/test/Dialect/Linalg/roundtrip.mlir b/mlir/test/Dialect/Linalg/roundtrip.mlir
index a09348c69d3a3..563013d4083af 100644
--- a/mlir/test/Dialect/Linalg/roundtrip.mlir
+++ b/mlir/test/Dialect/Linalg/roundtrip.mlir
@@ -454,10 +454,10 @@ func.func @reduce_not_short_form_compatible(%input: tensor<16x32x64xf32>,
 }
 
 // CHECK-LABEL: func @reduce_not_short_form_compatible
-// CHECK-SAME:  %[[ARG0:.*]]: tensor<16x32x64xf32>
-// CHECK-SAME:  %[[ARG1:.*]]: tensor<16x64xf32>
-// CHECK-NOT: linalg.reduce { arith.addf } ins(%[[ARG0]] : tensor<16x32x64xf32>
-// CHECK:     linalg.reduce ins(%[[ARG0]] : tensor<16x32x64xf32>) outs(%[[ARG1]] : tensor<16x64xf32>) 
+// CHECK-SAME:  %[[INPUT:.*]]: tensor<16x32x64xf32>
+// CHECK-SAME:  %[[INIT:.*]]: tensor<16x64xf32>
+// CHECK-NOT: linalg.reduce { arith.addf } ins(%[[INPUT]] : tensor<16x32x64xf32>
+// CHECK:     linalg.reduce ins(%[[INPUT]] : tensor<16x32x64xf32>) outs(%[[INIT]] : tensor<16x64xf32>) 
 // CHECK-SAME:  dimensions = [1]
 // CHECK:       (%[[IN1:.*]]: f32, %[[IN2:.*]]: f32) {
 // CHECK-NEXT:    %[[ADD_RESULT:.*]] = arith.addf %[[IN1]], %[[IN2]] : f32
@@ -620,9 +620,8 @@ func.func @map_arith_with_attr(%lhs: tensor<64xf32>, %rhs: tensor<64xf32>,
 
 // -----
 
-func.func @map_not_short_form_compatible(%arg0: tensor<1x32xf32>, %arg1: tensor<1x32xf32>) -> tensor<1x32xf32> {
-  %res = tensor.empty() : tensor<1x32xf32>
-  %mapped = linalg.map ins(%arg0, %arg1 : tensor<1x32xf32>, tensor<1x32xf32>) outs(%res : tensor<1x32xf32>)
+func.func @map_not_short_form_compatible(%lhs: tensor<1x32xf32>, %rhs: tensor<1x32xf32>, %init: tensor<1x32xf32>) -> tensor<1x32xf32> {
+  %mapped = linalg.map ins(%lhs, %rhs : tensor<1x32xf32>, tensor<1x32xf32>) outs(%init : tensor<1x32xf32>)
     (%in_1: f32, %in_2: f32) {
       %1 = arith.maximumf %in_1, %in_2 : f32
       linalg.yield %in_1 : f32
@@ -631,11 +630,10 @@ func.func @map_not_short_form_compatible(%arg0: tensor<1x32xf32>, %arg1: tensor<
 }
 
 // CHECK-LABEL: func @map_not_short_form_compatible
-// CHECK-SAME:        %[[ARG0:.*]]: tensor<1x32xf32>, %[[ARG1:.*]]: tensor<1x32xf32>
-// CHECK:         %[[RES:.*]] = tensor.empty() : tensor<1x32xf32>
-// CHECK-NOT:     linalg.map { arith.maximumf } ins(%[[ARG0]] : tensor<1x32xf32>
-// CHECK:         linalg.map ins(%[[ARG0]], %[[ARG1]] : tensor<1x32xf32>, tensor<1x32xf32>) 
-// CHECK-SAME:               outs(%[[RES]] : tensor<1x32xf32>)
+// CHECK-SAME:        %[[LHS:.*]]: tensor<1x32xf32>, %[[RHS:.*]]: tensor<1x32xf32>, %[[INIT:.*]]: tensor<1x32xf32>
+// CHECK-NOT:     linalg.map { arith.maximumf } ins(%[[LHS]] : tensor<1x32xf32>
+// CHECK:         linalg.map ins(%[[LHS]], %[[RHS]] : tensor<1x32xf32>, tensor<1x32xf32>) 
+// CHECK-SAME:               outs(%[[INIT]] : tensor<1x32xf32>)
 // CHECK-NEXT:      (%[[IN1:.*]]: f32, %[[IN2:.*]]: f32) {
 // CHECK-NEXT:        %[[MAX_RESULT:.*]] = arith.maximumf %[[IN1]], %[[IN2]] : f32
 // CHECK-NEXT:        linalg.yield %[[IN1]] : f32



More information about the Mlir-commits mailing list