[Mlir-commits] [mlir] [mlir][linalg] Fix incorrect linalg short form printing (PR #153219)
Boyana Norris
llvmlistbot at llvm.org
Wed Aug 13 21:08:44 PDT 2025
https://github.com/brnorris03 updated https://github.com/llvm/llvm-project/pull/153219
>From c24793e4b707d74b727887849e4daad45d513322 Mon Sep 17 00:00:00 2001
From: Boyana Norris <brnorris03 at gmail.com>
Date: Tue, 12 Aug 2025 07:52:12 -0700
Subject: [PATCH 1/4] add yield check
---
mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp | 58 +++++++++++++++++++-----
1 file changed, 46 insertions(+), 12 deletions(-)
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index 9d7fb18f56fef..1e2e6893eac21 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -1570,36 +1570,70 @@ ParseResult MapOp::parse(OpAsmParser &parser, OperationState &result) {
return success();
}
-// Retrieve the operation from the body, if it is the only one (except
-// yield) and if it gets the same amount of arguments as the body does.
-// If initFirst flag is enabled, we check that init takes the first position in
-// operands of payload.
-static Operation *findPayloadOp(Block *body, bool initFirst = false) {
+// Check if a block contains a single payload operation that can be printed in
+// short form. The block must contain exactly 2 operations: the payload op and a
+// yield.
+static bool canUseShortForm(Block *body) {
if (body->getOperations().size() != 2)
- return nullptr;
+ return false;
+
Operation &payload = body->getOperations().front();
assert(isa<YieldOp>(body->getOperations().back()));
- if (payload.getNumOperands() == 0 ||
- payload.getNumOperands() != body->getNumArguments())
+ // Check that the yield has exactly one operand that comes from the payload
+ auto yieldOp = cast<YieldOp>(body->getOperations().back());
+ if (yieldOp.getNumOperands() != 1)
+ return false;
+
+ Value yieldOperand = yieldOp.getOperand(0);
+ if (!yieldOperand.getDefiningOp() || yieldOperand.getDefiningOp() != &payload)
+ return false;
+
+ return true;
+}
+
+// Find a payload operation that can be printed in short form.
+// For MapOp (initFirst=false): operands must match block arguments in order.
+// For ReduceOp (initFirst=true): init operand must be first, then operands must
+// match block arguments.
+static Operation *findPayloadOp(Block *body, bool initFirst = false) {
+ if (!canUseShortForm(body))
return nullptr;
+
+ Operation &payload = body->getOperations().front();
+
if (initFirst) {
- // check init
- if (payload.getOperands().back() != body->getArgument(0))
+ // For ReduceOp: check that operand count matches block argument count + 1
+ // (for init)
+ if (payload.getNumOperands() == 0 ||
+ payload.getNumOperands() != body->getNumArguments() + 1)
+ return nullptr;
+
+ // Check that init operand is first
+ if (payload.getOperands().front() != body->getArgument(0))
return nullptr;
- // check rest
+
+ // Check that remaining operands match block arguments in order
for (const auto &[operand, bbArg] :
- llvm::zip(payload.getOperands(), body->getArguments().drop_front())) {
+ llvm::zip(payload.getOperands().drop_front(),
+ body->getArguments().drop_front())) {
if (bbArg != operand)
return nullptr;
}
} else {
+ // For MapOp: check that operand count matches block argument count
+ if (payload.getNumOperands() == 0 ||
+ payload.getNumOperands() != body->getNumArguments())
+ return nullptr;
+
+ // Check that operands match block arguments in order
for (const auto &[operand, bbArg] :
llvm::zip(payload.getOperands(), body->getArguments())) {
if (bbArg != operand)
return nullptr;
}
}
+
return &payload;
}
>From 16080a0246e12ec9527ccae32481de2b989ae853 Mon Sep 17 00:00:00 2001
From: Boyana Norris <brnorris03 at gmail.com>
Date: Tue, 12 Aug 2025 08:43:26 -0700
Subject: [PATCH 2/4] cleanup
---
mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp | 90 +++++++++---------------
1 file changed, 33 insertions(+), 57 deletions(-)
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index 1e2e6893eac21..7af4ea6a2f3a4 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -1570,74 +1570,50 @@ ParseResult MapOp::parse(OpAsmParser &parser, OperationState &result) {
return success();
}
-// Check if a block contains a single payload operation that can be printed in
-// short form. The block must contain exactly 2 operations: the payload op and a
-// yield.
-static bool canUseShortForm(Block *body) {
+static bool canUseShortForm(Block *body, bool initFirst = false) {
+ // Check if the body can be printed in short form. The following 4 conditions
+ // must be satisfied:
+
+ // 1) The body must contain exactly 2 operations: the payload op and a yield.
if (body->getOperations().size() != 2)
return false;
-
Operation &payload = body->getOperations().front();
- assert(isa<YieldOp>(body->getOperations().back()));
-
- // Check that the yield has exactly one operand that comes from the payload
- auto yieldOp = cast<YieldOp>(body->getOperations().back());
- if (yieldOp.getNumOperands() != 1)
- return false;
- Value yieldOperand = yieldOp.getOperand(0);
- if (!yieldOperand.getDefiningOp() || yieldOperand.getDefiningOp() != &payload)
+ // 2) The payload op must have the same number of operands as the number of
+ // block arguments.
+ if (payload.getNumOperands() == 0 ||
+ payload.getNumOperands() != body->getNumArguments())
return false;
- return true;
-}
-
-// Find a payload operation that can be printed in short form.
-// For MapOp (initFirst=false): operands must match block arguments in order.
-// For ReduceOp (initFirst=true): init operand must be first, then operands must
-// match block arguments.
-static Operation *findPayloadOp(Block *body, bool initFirst = false) {
- if (!canUseShortForm(body))
- return nullptr;
-
- Operation &payload = body->getOperations().front();
-
+ // 3) If `initFirst` is true (e.g., for reduction ops), the init block
+ // must be the first operand of the payload op, otherwise, the operands
+ // must match the block arguments in order.
if (initFirst) {
- // For ReduceOp: check that operand count matches block argument count + 1
- // (for init)
- if (payload.getNumOperands() == 0 ||
- payload.getNumOperands() != body->getNumArguments() + 1)
- return nullptr;
-
- // Check that init operand is first
- if (payload.getOperands().front() != body->getArgument(0))
- return nullptr;
-
- // Check that remaining operands match block arguments in order
+ // check init
+ if (payload.getOperands().back() != body->getArgument(0))
+ return false;
+ // check rest
for (const auto &[operand, bbArg] :
- llvm::zip(payload.getOperands().drop_front(),
- body->getArguments().drop_front())) {
+ llvm::zip(payload.getOperands(), body->getArguments().drop_front())) {
if (bbArg != operand)
- return nullptr;
+ return false;
}
} else {
- // For MapOp: check that operand count matches block argument count
- if (payload.getNumOperands() == 0 ||
- payload.getNumOperands() != body->getNumArguments())
- return nullptr;
-
- // Check that operands match block arguments in order
for (const auto &[operand, bbArg] :
llvm::zip(payload.getOperands(), body->getArguments())) {
if (bbArg != operand)
- return nullptr;
+ return false;
}
}
- return &payload;
+ // 4) The `yield` operand must be the result of the payload op.
+ auto yieldOp = cast<YieldOp>(body->getTerminator());
+ return yieldOp.getNumOperands() == 1 &&
+ yieldOp.getOperand(0).getDefiningOp() &&
+ yieldOp.getOperand(0).getDefiningOp() == &payload;
}
-void printShortForm(OpAsmPrinter &p, Operation *payloadOp) {
+static void printShortForm(OpAsmPrinter &p, Operation *payloadOp) {
SmallVector<StringRef> elidedAttrs;
std::string attrToElide;
p << " { " << payloadOp->getName().getStringRef();
@@ -1656,15 +1632,15 @@ void printShortForm(OpAsmPrinter &p, Operation *payloadOp) {
void MapOp::print(OpAsmPrinter &p) {
Block *mapper = getBody();
- Operation *payloadOp = findPayloadOp(mapper);
- if (payloadOp) {
- printShortForm(p, payloadOp);
+ bool useShortForm = canUseShortForm(mapper);
+ if (useShortForm) {
+ printShortForm(p, &mapper->getOperations().front());
}
printCommonStructuredOpParts(p, getDpsInputs(), getDpsInits());
p.printOptionalAttrDict((*this)->getAttrs());
- if (!payloadOp) {
+ if (!useShortForm) {
// Print region if the payload op was not detected.
p.increaseIndent();
p.printNewline();
@@ -1863,15 +1839,15 @@ static void printDenseI64ArrayAttr(OpAsmPrinter &p, StringRef attributeName,
void ReduceOp::print(OpAsmPrinter &p) {
Block *mapper = getBody();
- Operation *payloadOp = findPayloadOp(mapper, /*initFirst=*/true);
- if (payloadOp) {
- printShortForm(p, payloadOp);
+ bool useShortForm = canUseShortForm(mapper, /*initFirst=*/true);
+ if (useShortForm) {
+ printShortForm(p, &mapper->getOperations().front());
}
printCommonStructuredOpParts(p, getDpsInputs(), getDpsInits());
printDenseI64ArrayAttr(p, getDimensionsAttrName(), getDimensions());
p.printOptionalAttrDict((*this)->getAttrs(), {getDimensionsAttrName()});
- if (!payloadOp) {
+ if (!useShortForm) {
// Print region if the payload op was not detected.
p.increaseIndent();
p.printNewline();
>From e50edb76e170a8bdf8a4cedb16df0a39c2aac3e5 Mon Sep 17 00:00:00 2001
From: Boyana Norris <brnorris03 at gmail.com>
Date: Tue, 12 Aug 2025 12:41:47 -0700
Subject: [PATCH 3/4] add lit tests for reduce and map cases that can only use
long form printing
---
mlir/test/Dialect/Linalg/roundtrip.mlir | 51 +++++++++++++++++++++++++
1 file changed, 51 insertions(+)
diff --git a/mlir/test/Dialect/Linalg/roundtrip.mlir b/mlir/test/Dialect/Linalg/roundtrip.mlir
index 4edbc6eda3eae..a09348c69d3a3 100644
--- a/mlir/test/Dialect/Linalg/roundtrip.mlir
+++ b/mlir/test/Dialect/Linalg/roundtrip.mlir
@@ -436,6 +436,34 @@ func.func @reduce(%input: tensor<16x32x64xf32>,
// CHECK-SAME: outs
// CHECK-SAME: dimensions = [1]
+
+// -----
+
+
+func.func @reduce_not_short_form_compatible(%input: tensor<16x32x64xf32>,
+ %init: tensor<16x64xf32>) -> tensor<16x64xf32> {
+ %reduce = linalg.reduce
+ ins(%input:tensor<16x32x64xf32>)
+ outs(%init:tensor<16x64xf32>)
+ dimensions = [1]
+ (%in1: f32, %in2: f32) {
+ %0 = arith.addf %in1, %in2: f32
+ linalg.yield %in1: f32
+ }
+ func.return %reduce : tensor<16x64xf32>
+}
+
+// CHECK-LABEL: func @reduce_not_short_form_compatible
+// CHECK-SAME: %[[ARG0:.*]]: tensor<16x32x64xf32>
+// CHECK-SAME: %[[ARG1:.*]]: tensor<16x64xf32>
+// CHECK-NOT: linalg.reduce { arith.addf } ins(%[[ARG0]] : tensor<16x32x64xf32>
+// CHECK: linalg.reduce ins(%[[ARG0]] : tensor<16x32x64xf32>) outs(%[[ARG1]] : tensor<16x64xf32>)
+// CHECK-SAME: dimensions = [1]
+// CHECK: (%[[IN1:.*]]: f32, %[[IN2:.*]]: f32) {
+// CHECK-NEXT: %[[ADD_RESULT:.*]] = arith.addf %[[IN1]], %[[IN2]] : f32
+// CHECK-NEXT: linalg.yield %[[IN1]] : f32
+// CHECK-NEXT: }
+
// -----
func.func @reduce_memref(%input: memref<16x32x64xf32>,
@@ -592,6 +620,29 @@ func.func @map_arith_with_attr(%lhs: tensor<64xf32>, %rhs: tensor<64xf32>,
// -----
+func.func @map_not_short_form_compatible(%arg0: tensor<1x32xf32>, %arg1: tensor<1x32xf32>) -> tensor<1x32xf32> {
+ %res = tensor.empty() : tensor<1x32xf32>
+ %mapped = linalg.map ins(%arg0, %arg1 : tensor<1x32xf32>, tensor<1x32xf32>) outs(%res : tensor<1x32xf32>)
+ (%in_1: f32, %in_2: f32) {
+ %1 = arith.maximumf %in_1, %in_2 : f32
+ linalg.yield %in_1 : f32
+ }
+ func.return %mapped : tensor<1x32xf32>
+}
+
+// CHECK-LABEL: func @map_not_short_form_compatible
+// CHECK-SAME: %[[ARG0:.*]]: tensor<1x32xf32>, %[[ARG1:.*]]: tensor<1x32xf32>
+// CHECK: %[[RES:.*]] = tensor.empty() : tensor<1x32xf32>
+// CHECK-NOT: linalg.map { arith.maximumf } ins(%[[ARG0]] : tensor<1x32xf32>
+// CHECK: linalg.map ins(%[[ARG0]], %[[ARG1]] : tensor<1x32xf32>, tensor<1x32xf32>)
+// CHECK-SAME: outs(%[[RES]] : tensor<1x32xf32>)
+// CHECK-NEXT: (%[[IN1:.*]]: f32, %[[IN2:.*]]: f32) {
+// CHECK-NEXT: %[[MAX_RESULT:.*]] = arith.maximumf %[[IN1]], %[[IN2]] : f32
+// CHECK-NEXT: linalg.yield %[[IN1]] : f32
+// CHECK-NEXT: }
+
+// -----
+
func.func @reduce_arith_with_attr(%input: tensor<16x32x64xf32>,
%init: tensor<16x64xf32>) -> tensor<16x64xf32> {
%reduce = linalg.reduce
>From 5189c8f6e7eea2d3134fb2f29009720d53747ef3 Mon Sep 17 00:00:00 2001
From: Boyana Norris <brnorris03 at gmail.com>
Date: Wed, 13 Aug 2025 21:08:04 -0700
Subject: [PATCH 4/4] address comments
---
.../Dialect/Linalg/IR/LinalgStructuredOps.td | 20 ++++++++++-------
mlir/test/Dialect/Linalg/roundtrip.mlir | 22 +++++++++----------
2 files changed, 22 insertions(+), 20 deletions(-)
diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
index e32d3d01bb182..f3674c3eecfe6 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
@@ -253,10 +253,12 @@ def MapOp : LinalgStructuredBase_Op<"map", [
}
```
- Shortened print form is available. Applies to simple maps with one
- non-yield operation inside the body.
+ Shortened print form is available for simple maps where the body contains exactly
+ two operations (the payload operation and a yield), the payload operation has
+ the same number of operands as block arguments with operands matching block
+ arguments in order, and the yield operand is the result of the payload operation.
- The example above will be printed as:
+ The example above will be printed using the shortened form as:
```mlir
%add = linalg.map { arith.addf }
ins(%lhs, %rhs : tensor<64xf32>, tensor<64xf32>)
@@ -340,13 +342,15 @@ def ReduceOp : LinalgStructuredBase_Op<"reduce", [
}
```
- Shortened print form is available. Applies to simple (not variadic) reduces
- with one non-yield operation inside the body. Applies only if the operation
- takes `%out` as the first argument.
+ Shortened print form is available for simple reduces where the body contains exactly
+ two operations (the payload operation and a yield), the payload operation has the
+ same number of operands as block arguments, the first block argument (init) is the
+ last operand of the payload operation with remaining operands matching remaining
+ block arguments in order, and the yield operand is the result of the payload operation.
- The example above will be printed as:
+ The example above will be printed using the shortened form as:
```mlir
- %reduce = linalg.reduce { arith.addf }
+ %reduce = linalg.reduce { arith.addf }
ins(%input:tensor<16x32x64xf32>)
outs(%init:tensor<16x64xf32>)
dimensions = [1]
diff --git a/mlir/test/Dialect/Linalg/roundtrip.mlir b/mlir/test/Dialect/Linalg/roundtrip.mlir
index a09348c69d3a3..563013d4083af 100644
--- a/mlir/test/Dialect/Linalg/roundtrip.mlir
+++ b/mlir/test/Dialect/Linalg/roundtrip.mlir
@@ -454,10 +454,10 @@ func.func @reduce_not_short_form_compatible(%input: tensor<16x32x64xf32>,
}
// CHECK-LABEL: func @reduce_not_short_form_compatible
-// CHECK-SAME: %[[ARG0:.*]]: tensor<16x32x64xf32>
-// CHECK-SAME: %[[ARG1:.*]]: tensor<16x64xf32>
-// CHECK-NOT: linalg.reduce { arith.addf } ins(%[[ARG0]] : tensor<16x32x64xf32>
-// CHECK: linalg.reduce ins(%[[ARG0]] : tensor<16x32x64xf32>) outs(%[[ARG1]] : tensor<16x64xf32>)
+// CHECK-SAME: %[[INPUT:.*]]: tensor<16x32x64xf32>
+// CHECK-SAME: %[[INIT:.*]]: tensor<16x64xf32>
+// CHECK-NOT: linalg.reduce { arith.addf } ins(%[[INPUT]] : tensor<16x32x64xf32>
+// CHECK: linalg.reduce ins(%[[INPUT]] : tensor<16x32x64xf32>) outs(%[[INIT]] : tensor<16x64xf32>)
// CHECK-SAME: dimensions = [1]
// CHECK: (%[[IN1:.*]]: f32, %[[IN2:.*]]: f32) {
// CHECK-NEXT: %[[ADD_RESULT:.*]] = arith.addf %[[IN1]], %[[IN2]] : f32
@@ -620,9 +620,8 @@ func.func @map_arith_with_attr(%lhs: tensor<64xf32>, %rhs: tensor<64xf32>,
// -----
-func.func @map_not_short_form_compatible(%arg0: tensor<1x32xf32>, %arg1: tensor<1x32xf32>) -> tensor<1x32xf32> {
- %res = tensor.empty() : tensor<1x32xf32>
- %mapped = linalg.map ins(%arg0, %arg1 : tensor<1x32xf32>, tensor<1x32xf32>) outs(%res : tensor<1x32xf32>)
+func.func @map_not_short_form_compatible(%lhs: tensor<1x32xf32>, %rhs: tensor<1x32xf32>, %init: tensor<1x32xf32>) -> tensor<1x32xf32> {
+ %mapped = linalg.map ins(%lhs, %rhs : tensor<1x32xf32>, tensor<1x32xf32>) outs(%init : tensor<1x32xf32>)
(%in_1: f32, %in_2: f32) {
%1 = arith.maximumf %in_1, %in_2 : f32
linalg.yield %in_1 : f32
@@ -631,11 +630,10 @@ func.func @map_not_short_form_compatible(%arg0: tensor<1x32xf32>, %arg1: tensor<
}
// CHECK-LABEL: func @map_not_short_form_compatible
-// CHECK-SAME: %[[ARG0:.*]]: tensor<1x32xf32>, %[[ARG1:.*]]: tensor<1x32xf32>
-// CHECK: %[[RES:.*]] = tensor.empty() : tensor<1x32xf32>
-// CHECK-NOT: linalg.map { arith.maximumf } ins(%[[ARG0]] : tensor<1x32xf32>
-// CHECK: linalg.map ins(%[[ARG0]], %[[ARG1]] : tensor<1x32xf32>, tensor<1x32xf32>)
-// CHECK-SAME: outs(%[[RES]] : tensor<1x32xf32>)
+// CHECK-SAME: %[[LHS:.*]]: tensor<1x32xf32>, %[[RHS:.*]]: tensor<1x32xf32>, %[[INIT:.*]]: tensor<1x32xf32>
+// CHECK-NOT: linalg.map { arith.maximumf } ins(%[[LHS]] : tensor<1x32xf32>
+// CHECK: linalg.map ins(%[[LHS]], %[[RHS]] : tensor<1x32xf32>, tensor<1x32xf32>)
+// CHECK-SAME: outs(%[[INIT]] : tensor<1x32xf32>)
// CHECK-NEXT: (%[[IN1:.*]]: f32, %[[IN2:.*]]: f32) {
// CHECK-NEXT: %[[MAX_RESULT:.*]] = arith.maximumf %[[IN1]], %[[IN2]] : f32
// CHECK-NEXT: linalg.yield %[[IN1]] : f32
More information about the Mlir-commits
mailing list