[Mlir-commits] [mlir] [mlir][IR] Add multi-type `getParentOfType` overloads (PR #184071)
Matthias Springer
llvmlistbot at llvm.org
Sun Mar 1 23:58:43 PST 2026
https://github.com/matthias-springer created https://github.com/llvm/llvm-project/pull/184071
Add `getParentOfType` overloads that works with multiple types.
>From ff32bc28dac7b605f034344a1794a82535f3aeab Mon Sep 17 00:00:00 2001
From: Matthias Springer <me at m-sp.org>
Date: Mon, 2 Mar 2026 07:54:09 +0000
Subject: [PATCH] [mlir][IR] Add multi-type `getParentOfType` overloads
---
mlir/include/mlir/IR/Operation.h | 8 +++++
mlir/include/mlir/IR/Region.h | 9 ++++++
mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp | 30 +++++--------------
.../OpenACC/Transforms/LegalizeDataValues.cpp | 9 +-----
.../Dialect/OpenACC/Utils/OpenACCUtils.cpp | 8 +----
.../Transforms/Utils/CodegenUtils.cpp | 6 ++--
6 files changed, 28 insertions(+), 42 deletions(-)
diff --git a/mlir/include/mlir/IR/Operation.h b/mlir/include/mlir/IR/Operation.h
index b2019574a820d..ea9fbab4acf9c 100644
--- a/mlir/include/mlir/IR/Operation.h
+++ b/mlir/include/mlir/IR/Operation.h
@@ -242,6 +242,14 @@ class alignas(8) Operation final
return parentOp;
return OpTy();
}
+ template <typename... OpTy>
+ std::enable_if_t<(sizeof...(OpTy) > 1), Operation *> getParentOfType() {
+ auto *op = this;
+ while ((op = op->getParentOp()))
+ if (isa<OpTy...>(op))
+ return op;
+ return nullptr;
+ }
/// Returns the closest surrounding parent operation with trait `Trait`.
template <template <typename T> class Trait>
diff --git a/mlir/include/mlir/IR/Region.h b/mlir/include/mlir/IR/Region.h
index 53d461df98710..977c8f9999a84 100644
--- a/mlir/include/mlir/IR/Region.h
+++ b/mlir/include/mlir/IR/Region.h
@@ -210,6 +210,15 @@ class Region {
} while ((region = region->getParentRegion()));
return ParentT();
}
+ template <typename... ParentT>
+ std::enable_if_t<(sizeof...(ParentT) > 1), Operation *> getParentOfType() {
+ auto *region = this;
+ do {
+ if (region->container || isa<ParentT...>(region->container))
+ return region->container;
+ } while ((region = region->getParentRegion()));
+ return nullptr;
+ }
/// Return the number of this region in the parent operation.
unsigned getRegionNumber();
diff --git a/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp b/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp
index 5ec164a892d67..73dd4def00cab 100644
--- a/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp
+++ b/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp
@@ -1216,13 +1216,7 @@ bool acc::CacheOp::isCacheReadonly() {
// It is quite alike acc::getEnclosingComputeOp() utility,
// but we cannot use it here.
static bool isEnclosedIntoComputeOp(mlir::Operation *op) {
- mlir::Operation *parentOp = op->getParentOp();
- while (parentOp) {
- if (mlir::isa<ACC_COMPUTE_CONSTRUCT_OPS>(parentOp))
- return true;
- parentOp = parentOp->getParentOp();
- }
- return false;
+ return op->getParentOfType<ACC_COMPUTE_CONSTRUCT_OPS>();
}
/// Helper to add an effect on an operand, referenced by its mutable range.
@@ -1476,10 +1470,6 @@ static ParseResult parseRegions(OpAsmParser &parser, OperationState &state,
return success();
}
-static bool isComputeOperation(Operation *op) {
- return isa<ACC_COMPUTE_CONSTRUCT_AND_LOOP_OPS>(op);
-}
-
namespace {
/// Pattern to remove operation without region that have constant false `ifCond`
/// and remove the condition from the operation if the `ifCond` is a true
@@ -4824,10 +4814,8 @@ void RoutineOp::addBindIDName(MLIRContext *context,
//===----------------------------------------------------------------------===//
LogicalResult acc::InitOp::verify() {
- Operation *currOp = *this;
- while ((currOp = currOp->getParentOp()))
- if (isComputeOperation(currOp))
- return emitOpError("cannot be nested in a compute operation");
+ if (getOperation()->getParentOfType<ACC_COMPUTE_CONSTRUCT_OPS>())
+ return emitOpError("cannot be nested in a compute operation");
return success();
}
@@ -4846,10 +4834,8 @@ void acc::InitOp::addDeviceType(MLIRContext *context,
//===----------------------------------------------------------------------===//
LogicalResult acc::ShutdownOp::verify() {
- Operation *currOp = *this;
- while ((currOp = currOp->getParentOp()))
- if (isComputeOperation(currOp))
- return emitOpError("cannot be nested in a compute operation");
+ if (getOperation()->getParentOfType<ACC_COMPUTE_CONSTRUCT_OPS>())
+ return emitOpError("cannot be nested in a compute operation");
return success();
}
@@ -4868,10 +4854,8 @@ void acc::ShutdownOp::addDeviceType(MLIRContext *context,
//===----------------------------------------------------------------------===//
LogicalResult acc::SetOp::verify() {
- Operation *currOp = *this;
- while ((currOp = currOp->getParentOp()))
- if (isComputeOperation(currOp))
- return emitOpError("cannot be nested in a compute operation");
+ if (getOperation()->getParentOfType<ACC_COMPUTE_CONSTRUCT_OPS>())
+ return emitOpError("cannot be nested in a compute operation");
if (!getDeviceTypeAttr() && !getDefaultAsync() && !getDeviceNum())
return emitOpError("at least one default_async, device_num, or device_type "
"operand must appear");
diff --git a/mlir/lib/Dialect/OpenACC/Transforms/LegalizeDataValues.cpp b/mlir/lib/Dialect/OpenACC/Transforms/LegalizeDataValues.cpp
index 6d316d282278d..8b15d2c1cc7f2 100644
--- a/mlir/lib/Dialect/OpenACC/Transforms/LegalizeDataValues.cpp
+++ b/mlir/lib/Dialect/OpenACC/Transforms/LegalizeDataValues.cpp
@@ -27,14 +27,7 @@ using namespace mlir;
namespace {
static bool insideAccComputeRegion(mlir::Operation *op) {
- mlir::Operation *parent{op->getParentOp()};
- while (parent) {
- if (isa<ACC_COMPUTE_CONSTRUCT_OPS>(parent)) {
- return true;
- }
- parent = parent->getParentOp();
- }
- return false;
+ return op->getParentOfType<ACC_COMPUTE_CONSTRUCT_OPS>();
}
static void collectVars(mlir::ValueRange operands,
diff --git a/mlir/lib/Dialect/OpenACC/Utils/OpenACCUtils.cpp b/mlir/lib/Dialect/OpenACC/Utils/OpenACCUtils.cpp
index 6e5ec0ccd1210..911f256a3d2a6 100644
--- a/mlir/lib/Dialect/OpenACC/Utils/OpenACCUtils.cpp
+++ b/mlir/lib/Dialect/OpenACC/Utils/OpenACCUtils.cpp
@@ -21,13 +21,7 @@
#include "llvm/Support/Casting.h"
mlir::Operation *mlir::acc::getEnclosingComputeOp(mlir::Region ®ion) {
- mlir::Operation *parentOp = region.getParentOp();
- while (parentOp) {
- if (mlir::isa<ACC_COMPUTE_CONSTRUCT_OPS>(parentOp))
- return parentOp;
- parentOp = parentOp->getParentOp();
- }
- return nullptr;
+ return region.getParentOfType<ACC_COMPUTE_CONSTRUCT_OPS>();
}
template <typename OpTy>
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/CodegenUtils.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/CodegenUtils.cpp
index f57f7f7fc0946..3c6a905ebb696 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/CodegenUtils.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/CodegenUtils.cpp
@@ -418,10 +418,8 @@ void mlir::sparse_tensor::sizesFromSrc(OpBuilder &builder,
}
Operation *mlir::sparse_tensor::getTop(Operation *op) {
- for (; isa<scf::ForOp>(op->getParentOp()) ||
- isa<scf::WhileOp>(op->getParentOp()) ||
- isa<scf::ParallelOp>(op->getParentOp()) ||
- isa<scf::IfOp>(op->getParentOp());
+ for (; isa<scf::ForOp, scf::WhileOp, scf::ParallelOp, scf::IfOp>(
+ op->getParentOp());
op = op->getParentOp())
;
return op;
More information about the Mlir-commits
mailing list