[Mlir-commits] [mlir] 3f1d968 - [mlir][IR] Add variadic `getParentOfType` overloads (#184071)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Tue Mar 3 08:22:53 PST 2026
Author: Matthias Springer
Date: 2026-03-03T18:22:48+02:00
New Revision: 3f1d968db946e90c7e29bfa886566957f0e374f4
URL: https://github.com/llvm/llvm-project/commit/3f1d968db946e90c7e29bfa886566957f0e374f4
DIFF: https://github.com/llvm/llvm-project/commit/3f1d968db946e90c7e29bfa886566957f0e374f4.diff
LOG: [mlir][IR] Add variadic `getParentOfType` overloads (#184071)
Add `getParentOfType` overloads that work with multiple types.
Added:
Modified:
mlir/include/mlir/IR/Operation.h
mlir/include/mlir/IR/Region.h
mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp
mlir/lib/Dialect/OpenACC/Transforms/LegalizeDataValues.cpp
mlir/lib/Dialect/OpenACC/Utils/OpenACCUtils.cpp
mlir/lib/Dialect/SparseTensor/Transforms/Utils/CodegenUtils.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/IR/Operation.h b/mlir/include/mlir/IR/Operation.h
index b2019574a820d..ea9fbab4acf9c 100644
--- a/mlir/include/mlir/IR/Operation.h
+++ b/mlir/include/mlir/IR/Operation.h
@@ -242,6 +242,14 @@ class alignas(8) Operation final
return parentOp;
return OpTy();
}
+ template <typename... OpTy>
+ std::enable_if_t<(sizeof...(OpTy) > 1), Operation *> getParentOfType() {
+ auto *op = this;
+ while ((op = op->getParentOp()))
+ if (isa<OpTy...>(op))
+ return op;
+ return nullptr;
+ }
/// Returns the closest surrounding parent operation with trait `Trait`.
template <template <typename T> class Trait>
diff --git a/mlir/include/mlir/IR/Region.h b/mlir/include/mlir/IR/Region.h
index 53d461df98710..13b54991832cb 100644
--- a/mlir/include/mlir/IR/Region.h
+++ b/mlir/include/mlir/IR/Region.h
@@ -210,6 +210,17 @@ class Region {
} while ((region = region->getParentRegion()));
return ParentT();
}
+ template <typename... ParentT>
+ std::enable_if_t<(sizeof...(ParentT) > 1), Operation *> getParentOfType() {
+ auto *region = this;
+ do {
+ if (!region->container)
+ return nullptr;
+ if (isa<ParentT...>(region->container))
+ return region->container;
+ } while ((region = region->getParentRegion()));
+ return nullptr;
+ }
/// Return the number of this region in the parent operation.
unsigned getRegionNumber();
diff --git a/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp b/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp
index 5ec164a892d67..800305ffb36c5 100644
--- a/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp
+++ b/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp
@@ -1216,13 +1216,7 @@ bool acc::CacheOp::isCacheReadonly() {
// It is quite alike acc::getEnclosingComputeOp() utility,
// but we cannot use it here.
static bool isEnclosedIntoComputeOp(mlir::Operation *op) {
- mlir::Operation *parentOp = op->getParentOp();
- while (parentOp) {
- if (mlir::isa<ACC_COMPUTE_CONSTRUCT_OPS>(parentOp))
- return true;
- parentOp = parentOp->getParentOp();
- }
- return false;
+ return op->getParentOfType<ACC_COMPUTE_CONSTRUCT_OPS>();
}
/// Helper to add an effect on an operand, referenced by its mutable range.
@@ -1476,10 +1470,6 @@ static ParseResult parseRegions(OpAsmParser &parser, OperationState &state,
return success();
}
-static bool isComputeOperation(Operation *op) {
- return isa<ACC_COMPUTE_CONSTRUCT_AND_LOOP_OPS>(op);
-}
-
namespace {
/// Pattern to remove operation without region that have constant false `ifCond`
/// and remove the condition from the operation if the `ifCond` is a true
@@ -4824,10 +4814,8 @@ void RoutineOp::addBindIDName(MLIRContext *context,
//===----------------------------------------------------------------------===//
LogicalResult acc::InitOp::verify() {
- Operation *currOp = *this;
- while ((currOp = currOp->getParentOp()))
- if (isComputeOperation(currOp))
- return emitOpError("cannot be nested in a compute operation");
+ if (getOperation()->getParentOfType<ACC_COMPUTE_CONSTRUCT_AND_LOOP_OPS>())
+ return emitOpError("cannot be nested in a compute operation");
return success();
}
@@ -4846,10 +4834,8 @@ void acc::InitOp::addDeviceType(MLIRContext *context,
//===----------------------------------------------------------------------===//
LogicalResult acc::ShutdownOp::verify() {
- Operation *currOp = *this;
- while ((currOp = currOp->getParentOp()))
- if (isComputeOperation(currOp))
- return emitOpError("cannot be nested in a compute operation");
+ if (getOperation()->getParentOfType<ACC_COMPUTE_CONSTRUCT_AND_LOOP_OPS>())
+ return emitOpError("cannot be nested in a compute operation");
return success();
}
@@ -4868,10 +4854,8 @@ void acc::ShutdownOp::addDeviceType(MLIRContext *context,
//===----------------------------------------------------------------------===//
LogicalResult acc::SetOp::verify() {
- Operation *currOp = *this;
- while ((currOp = currOp->getParentOp()))
- if (isComputeOperation(currOp))
- return emitOpError("cannot be nested in a compute operation");
+ if (getOperation()->getParentOfType<ACC_COMPUTE_CONSTRUCT_AND_LOOP_OPS>())
+ return emitOpError("cannot be nested in a compute operation");
if (!getDeviceTypeAttr() && !getDefaultAsync() && !getDeviceNum())
return emitOpError("at least one default_async, device_num, or device_type "
"operand must appear");
diff --git a/mlir/lib/Dialect/OpenACC/Transforms/LegalizeDataValues.cpp b/mlir/lib/Dialect/OpenACC/Transforms/LegalizeDataValues.cpp
index 6d316d282278d..8b15d2c1cc7f2 100644
--- a/mlir/lib/Dialect/OpenACC/Transforms/LegalizeDataValues.cpp
+++ b/mlir/lib/Dialect/OpenACC/Transforms/LegalizeDataValues.cpp
@@ -27,14 +27,7 @@ using namespace mlir;
namespace {
static bool insideAccComputeRegion(mlir::Operation *op) {
- mlir::Operation *parent{op->getParentOp()};
- while (parent) {
- if (isa<ACC_COMPUTE_CONSTRUCT_OPS>(parent)) {
- return true;
- }
- parent = parent->getParentOp();
- }
- return false;
+ return op->getParentOfType<ACC_COMPUTE_CONSTRUCT_OPS>();
}
static void collectVars(mlir::ValueRange operands,
diff --git a/mlir/lib/Dialect/OpenACC/Utils/OpenACCUtils.cpp b/mlir/lib/Dialect/OpenACC/Utils/OpenACCUtils.cpp
index 6e5ec0ccd1210..911f256a3d2a6 100644
--- a/mlir/lib/Dialect/OpenACC/Utils/OpenACCUtils.cpp
+++ b/mlir/lib/Dialect/OpenACC/Utils/OpenACCUtils.cpp
@@ -21,13 +21,7 @@
#include "llvm/Support/Casting.h"
mlir::Operation *mlir::acc::getEnclosingComputeOp(mlir::Region ®ion) {
- mlir::Operation *parentOp = region.getParentOp();
- while (parentOp) {
- if (mlir::isa<ACC_COMPUTE_CONSTRUCT_OPS>(parentOp))
- return parentOp;
- parentOp = parentOp->getParentOp();
- }
- return nullptr;
+ return region.getParentOfType<ACC_COMPUTE_CONSTRUCT_OPS>();
}
template <typename OpTy>
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/CodegenUtils.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/CodegenUtils.cpp
index ffc898c9933c3..614c29ba68481 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/CodegenUtils.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/CodegenUtils.cpp
@@ -430,10 +430,8 @@ void mlir::sparse_tensor::sizesFromSrc(OpBuilder &builder,
}
Operation *mlir::sparse_tensor::getTop(Operation *op) {
- for (; isa<scf::ForOp>(op->getParentOp()) ||
- isa<scf::WhileOp>(op->getParentOp()) ||
- isa<scf::ParallelOp>(op->getParentOp()) ||
- isa<scf::IfOp>(op->getParentOp());
+ for (; isa<scf::ForOp, scf::WhileOp, scf::ParallelOp, scf::IfOp>(
+ op->getParentOp());
op = op->getParentOp())
;
return op;
More information about the Mlir-commits
mailing list