[Mlir-commits] [mlir] [MLIR][Linalg] Fix crash when parsing linalg.elementwise with vector inputs (#178363) (PR #179170)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Sun Feb 1 22:59:38 PST 2026
https://github.com/IamYJLee updated https://github.com/llvm/llvm-project/pull/179170
>From bbc3187045fed4e06c41f2ce671b2eced7e326c2 Mon Sep 17 00:00:00 2001
From: LeeYoungJoon <dog3hk.dev at gmail.com>
Date: Mon, 2 Feb 2026 15:07:58 +0900
Subject: [PATCH 1/2] [MLIR][Linalg] Fix crash when parsing linalg.elementwise
with vector inputs
Fix a problem where the program could crash when linalg.elementwise received vector values instead of tensors or memrefs.
Now it handles vector inputs correctly so it no longer hits an unreachable assertion during parsing.
---
mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp | 64 +++++++++++++++++++++---
1 file changed, 56 insertions(+), 8 deletions(-)
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index eba3fa6db2126..496b7905211e6 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -4847,12 +4847,57 @@ ParseResult ElementwiseOp::parse(OpAsmParser &parser, OperationState &result) {
auto arityGroupAndKind = getArityGroupAndKind(elemwiseKindVal);
int numRegionArgs =
getArityGroupAsUInt(arityGroupAndKind.arityGroup) + 1 /*output*/;
- if (parseNamedStructuredOp(parser, result, numRegionArgs,
- ElementwiseOp::getRegionBuilder())) {
- return parser.emitError(parser.getCurrentLocation(),
- "unable to parse elemwise op");
+
+ // Parse structured op parts (ins/outs)
+ SmallVector<Type, 1> inputTypes, outputTypes;
+ SMLoc loc = parser.getCurrentLocation();
+ if (parseCommonStructuredOpParts(parser, result, inputTypes, outputTypes))
+ return failure();
+
+ // Parse optional attributes.
+ if (parser.parseOptionalAttrDict(result.attributes))
+ return failure();
+
+ // Parse result types.
+ SmallVector<Type, 1> resultTypes;
+ if (parseNamedStructuredOpResults(parser, resultTypes))
+ return failure();
+ result.addTypes(resultTypes);
+
+ // Type validation (before region build)
+ for (auto [i, type] : llvm::enumerate(inputTypes)) {
+ if (!llvm::isa<RankedTensorType, MemRefType>(type)) {
+ return parser.emitError(loc)
+ << "input operand #" << i
+ << " must be a memref or ranked tensor, but got " << type;
+ }
+ }
+ for (auto [i, type] : llvm::enumerate(outputTypes)) {
+ if (!llvm::isa<RankedTensorType, MemRefType>(type)) {
+ return parser.emitError(loc)
+ << "output operand #" << i
+ << " must be a memref or ranked tensor, but got " << type;
+ }
+ }
+
+ bool hasTensor = llvm::any_of(inputTypes, llvm::IsaPred<RankedTensorType>) ||
+ llvm::any_of(outputTypes, llvm::IsaPred<RankedTensorType>);
+ bool hasMemref = llvm::any_of(inputTypes, llvm::IsaPred<MemRefType>) ||
+ llvm::any_of(outputTypes, llvm::IsaPred<MemRefType>);
+ if (hasTensor && hasMemref) {
+ return parser.emitError(loc)
+ << "input and output operands must have the same type category "
+ "(all tensors or all memrefs)";
}
+ // Build region
+ std::unique_ptr<Region> region = std::make_unique<Region>();
+ if (parseNamedStructuredOpRegion(parser, *region, numRegionArgs, inputTypes,
+ outputTypes, result.attributes.getAttrs(),
+ ElementwiseOp::getRegionBuilder(), loc))
+ return failure();
+ result.addRegion(std::move(region));
+
// Initialize indexingMaps, if not supplied explicitly.
if (indexingMapsAttr.empty()) {
// We need to infer the numDims of the indexing maps from the output
@@ -4929,20 +4974,23 @@ void ElementwiseOp::regionBuilder(
Value result;
if (arityGroup == ElementwiseArityGroup::Unary) {
- result = helper.buildUnaryFn(kind.unaryFn, block.getArgument(0));
+ result = helper.buildUnaryFn(kind.unaryFn, block.getArgument(0), emitError);
} else if (arityGroup == ElementwiseArityGroup::Binary) {
result = helper.buildBinaryFn(kind.binaryFn, block.getArgument(0),
- block.getArgument(1));
-
+ block.getArgument(1), emitError);
} else if (arityGroup == ElementwiseArityGroup::Ternary) {
result = helper.buildTernaryFn(kind.ternaryFn, block.getArgument(0),
- block.getArgument(1), block.getArgument(2));
+ block.getArgument(1), block.getArgument(2),
+ emitError);
} else {
assert(false && "found unhandled category in elemwise");
}
+ if (!result)
+ return;
+
yields.push_back(result);
helper.yieldOutputs(yields);
}
>From 609e222e4e4444ebd940cf925d620c813e4aa127 Mon Sep 17 00:00:00 2001
From: LeeYoungJoon <dog3hk.dev at gmail.com>
Date: Mon, 2 Feb 2026 15:59:12 +0900
Subject: [PATCH 2/2] Update test cases to expect proper diagnostics for
invalid vector operands instead of an unreachable assertion.
---
.../Dialect/Linalg/elementwise/invalid.mlir | 22 +++++++++++++++----
1 file changed, 18 insertions(+), 4 deletions(-)
diff --git a/mlir/test/Dialect/Linalg/elementwise/invalid.mlir b/mlir/test/Dialect/Linalg/elementwise/invalid.mlir
index fe03519ae94b8..db6754b313eb1 100644
--- a/mlir/test/Dialect/Linalg/elementwise/invalid.mlir
+++ b/mlir/test/Dialect/Linalg/elementwise/invalid.mlir
@@ -38,8 +38,7 @@ func.func @incorrect_result_rank(%A : memref<8x16xf32>, %B: memref<8x16xf32>, %C
// -----
func.func @unary_too_many_args(%A : memref<8x16x32xf32>, %B: memref<8x16x32xf32>, %C: memref<8x16x32xf32>) {
- // expected-error at +3 {{custom op 'linalg.elementwise' [parseNamedStructuredOpRegion] ods-gen generated region expects 2 args, got 3}}
- // expected-error at +2 {{custom op 'linalg.elementwise' unable to parse elemwise op}}
+ // expected-error at +2 {{custom op 'linalg.elementwise' [parseNamedStructuredOpRegion] ods-gen generated region expects 2 args, got 3}}
linalg.elementwise kind=#linalg.elementwise_kind<exp> ins(%A, %B : memref<8x16x32xf32>, memref<8x16x32xf32>) outs(%C: memref<8x16x32xf32>)
return
}
@@ -47,8 +46,23 @@ func.func @unary_too_many_args(%A : memref<8x16x32xf32>, %B: memref<8x16x32xf32>
// -----
func.func @binary_too_few_args(%A : memref<8x16x32xf32>, %B: memref<8x16x32xf32>) {
- // expected-error at +3 {{custom op 'linalg.elementwise' [parseNamedStructuredOpRegion] ods-gen generated region expects 3 args, got 2}}
- // expected-error at +2 {{custom op 'linalg.elementwise' unable to parse elemwise op}}
+ // expected-error at +2 {{custom op 'linalg.elementwise' [parseNamedStructuredOpRegion] ods-gen generated region expects 3 args, got 2}}
linalg.elementwise kind=#linalg.elementwise_kind<add> ins(%A : memref<8x16x32xf32>) outs(%B: memref<8x16x32xf32>)
return
}
+
+// -----
+
+func.func @vector_args(%A : memref<16x8xf32>, %B: vector<16x8xf32>) {
+ // expected-error at +1 {{custom op 'linalg.elementwise' input operand #1 must be a memref or ranked tensor, but got 'vector<16x8xf32>'}}
+ linalg.elementwise kind=#linalg.elementwise_kind<add> ins(%A, %B: memref<16x8xf32>, vector<16x8xf32>) outs(%C: vector<16x8xf32>)
+ return
+}
+
+// -----
+
+func.func @input_output_mismatch(%A : memref<16x8xf32>, %B: memref<16x8xf32>) {
+ // expected-error at +1 {{custom op 'linalg.elementwise' input and output operands must have the same type category (all tensors or all memrefs)}}
+ linalg.elementwise kind=#linalg.elementwise_kind<add> ins(%A, %B: memref<16x8xf32>, memref<16x8xf32>) outs(%C: tensor<16x8xf32>)
+ return
+}
\ No newline at end of file
More information about the Mlir-commits
mailing list