[Mlir-commits] [mlir] [mlir][IR] Add multi-type `getParentOfType` overloads (PR #184071)

Matthias Springer llvmlistbot at llvm.org
Mon Mar 2 00:30:02 PST 2026


https://github.com/matthias-springer updated https://github.com/llvm/llvm-project/pull/184071

>From d7d26e5ca76382d63e1f4a92df9811fd92772436 Mon Sep 17 00:00:00 2001
From: Matthias Springer <me at m-sp.org>
Date: Mon, 2 Mar 2026 07:54:09 +0000
Subject: [PATCH] [mlir][IR] Add multi-type `getParentOfType` overloads

---
 mlir/include/mlir/IR/Operation.h              |  8 +++++
 mlir/include/mlir/IR/Region.h                 | 11 +++++++
 mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp       | 30 +++++--------------
 .../OpenACC/Transforms/LegalizeDataValues.cpp |  9 +-----
 .../Dialect/OpenACC/Utils/OpenACCUtils.cpp    |  8 +----
 .../Transforms/Utils/CodegenUtils.cpp         |  6 ++--
 6 files changed, 30 insertions(+), 42 deletions(-)

diff --git a/mlir/include/mlir/IR/Operation.h b/mlir/include/mlir/IR/Operation.h
index b2019574a820d..ea9fbab4acf9c 100644
--- a/mlir/include/mlir/IR/Operation.h
+++ b/mlir/include/mlir/IR/Operation.h
@@ -242,6 +242,14 @@ class alignas(8) Operation final
         return parentOp;
     return OpTy();
   }
+  template <typename... OpTy>
+  std::enable_if_t<(sizeof...(OpTy) > 1), Operation *> getParentOfType() {
+    auto *op = this;
+    while ((op = op->getParentOp()))
+      if (isa<OpTy...>(op))
+        return op;
+    return nullptr;
+  }
 
   /// Returns the closest surrounding parent operation with trait `Trait`.
   template <template <typename T> class Trait>
diff --git a/mlir/include/mlir/IR/Region.h b/mlir/include/mlir/IR/Region.h
index 53d461df98710..13b54991832cb 100644
--- a/mlir/include/mlir/IR/Region.h
+++ b/mlir/include/mlir/IR/Region.h
@@ -210,6 +210,17 @@ class Region {
     } while ((region = region->getParentRegion()));
     return ParentT();
   }
+  template <typename... ParentT>
+  std::enable_if_t<(sizeof...(ParentT) > 1), Operation *> getParentOfType() {
+    auto *region = this;
+    do {
+      if (!region->container)
+        return nullptr;
+      if (isa<ParentT...>(region->container))
+        return region->container;
+    } while ((region = region->getParentRegion()));
+    return nullptr;
+  }
 
   /// Return the number of this region in the parent operation.
   unsigned getRegionNumber();
diff --git a/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp b/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp
index 5ec164a892d67..800305ffb36c5 100644
--- a/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp
+++ b/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp
@@ -1216,13 +1216,7 @@ bool acc::CacheOp::isCacheReadonly() {
 // It is quite alike acc::getEnclosingComputeOp() utility,
 // but we cannot use it here.
 static bool isEnclosedIntoComputeOp(mlir::Operation *op) {
-  mlir::Operation *parentOp = op->getParentOp();
-  while (parentOp) {
-    if (mlir::isa<ACC_COMPUTE_CONSTRUCT_OPS>(parentOp))
-      return true;
-    parentOp = parentOp->getParentOp();
-  }
-  return false;
+  return op->getParentOfType<ACC_COMPUTE_CONSTRUCT_OPS>();
 }
 
 /// Helper to add an effect on an operand, referenced by its mutable range.
@@ -1476,10 +1470,6 @@ static ParseResult parseRegions(OpAsmParser &parser, OperationState &state,
   return success();
 }
 
-static bool isComputeOperation(Operation *op) {
-  return isa<ACC_COMPUTE_CONSTRUCT_AND_LOOP_OPS>(op);
-}
-
 namespace {
 /// Pattern to remove operation without region that have constant false `ifCond`
 /// and remove the condition from the operation if the `ifCond` is a true
@@ -4824,10 +4814,8 @@ void RoutineOp::addBindIDName(MLIRContext *context,
 //===----------------------------------------------------------------------===//
 
 LogicalResult acc::InitOp::verify() {
-  Operation *currOp = *this;
-  while ((currOp = currOp->getParentOp()))
-    if (isComputeOperation(currOp))
-      return emitOpError("cannot be nested in a compute operation");
+  if (getOperation()->getParentOfType<ACC_COMPUTE_CONSTRUCT_AND_LOOP_OPS>())
+    return emitOpError("cannot be nested in a compute operation");
   return success();
 }
 
@@ -4846,10 +4834,8 @@ void acc::InitOp::addDeviceType(MLIRContext *context,
 //===----------------------------------------------------------------------===//
 
 LogicalResult acc::ShutdownOp::verify() {
-  Operation *currOp = *this;
-  while ((currOp = currOp->getParentOp()))
-    if (isComputeOperation(currOp))
-      return emitOpError("cannot be nested in a compute operation");
+  if (getOperation()->getParentOfType<ACC_COMPUTE_CONSTRUCT_AND_LOOP_OPS>())
+    return emitOpError("cannot be nested in a compute operation");
   return success();
 }
 
@@ -4868,10 +4854,8 @@ void acc::ShutdownOp::addDeviceType(MLIRContext *context,
 //===----------------------------------------------------------------------===//
 
 LogicalResult acc::SetOp::verify() {
-  Operation *currOp = *this;
-  while ((currOp = currOp->getParentOp()))
-    if (isComputeOperation(currOp))
-      return emitOpError("cannot be nested in a compute operation");
+  if (getOperation()->getParentOfType<ACC_COMPUTE_CONSTRUCT_AND_LOOP_OPS>())
+    return emitOpError("cannot be nested in a compute operation");
   if (!getDeviceTypeAttr() && !getDefaultAsync() && !getDeviceNum())
     return emitOpError("at least one default_async, device_num, or device_type "
                        "operand must appear");
diff --git a/mlir/lib/Dialect/OpenACC/Transforms/LegalizeDataValues.cpp b/mlir/lib/Dialect/OpenACC/Transforms/LegalizeDataValues.cpp
index 6d316d282278d..8b15d2c1cc7f2 100644
--- a/mlir/lib/Dialect/OpenACC/Transforms/LegalizeDataValues.cpp
+++ b/mlir/lib/Dialect/OpenACC/Transforms/LegalizeDataValues.cpp
@@ -27,14 +27,7 @@ using namespace mlir;
 namespace {
 
 static bool insideAccComputeRegion(mlir::Operation *op) {
-  mlir::Operation *parent{op->getParentOp()};
-  while (parent) {
-    if (isa<ACC_COMPUTE_CONSTRUCT_OPS>(parent)) {
-      return true;
-    }
-    parent = parent->getParentOp();
-  }
-  return false;
+  return op->getParentOfType<ACC_COMPUTE_CONSTRUCT_OPS>();
 }
 
 static void collectVars(mlir::ValueRange operands,
diff --git a/mlir/lib/Dialect/OpenACC/Utils/OpenACCUtils.cpp b/mlir/lib/Dialect/OpenACC/Utils/OpenACCUtils.cpp
index 6e5ec0ccd1210..911f256a3d2a6 100644
--- a/mlir/lib/Dialect/OpenACC/Utils/OpenACCUtils.cpp
+++ b/mlir/lib/Dialect/OpenACC/Utils/OpenACCUtils.cpp
@@ -21,13 +21,7 @@
 #include "llvm/Support/Casting.h"
 
 mlir::Operation *mlir::acc::getEnclosingComputeOp(mlir::Region &region) {
-  mlir::Operation *parentOp = region.getParentOp();
-  while (parentOp) {
-    if (mlir::isa<ACC_COMPUTE_CONSTRUCT_OPS>(parentOp))
-      return parentOp;
-    parentOp = parentOp->getParentOp();
-  }
-  return nullptr;
+  return region.getParentOfType<ACC_COMPUTE_CONSTRUCT_OPS>();
 }
 
 template <typename OpTy>
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/CodegenUtils.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/CodegenUtils.cpp
index f57f7f7fc0946..3c6a905ebb696 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/CodegenUtils.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/CodegenUtils.cpp
@@ -418,10 +418,8 @@ void mlir::sparse_tensor::sizesFromSrc(OpBuilder &builder,
 }
 
 Operation *mlir::sparse_tensor::getTop(Operation *op) {
-  for (; isa<scf::ForOp>(op->getParentOp()) ||
-         isa<scf::WhileOp>(op->getParentOp()) ||
-         isa<scf::ParallelOp>(op->getParentOp()) ||
-         isa<scf::IfOp>(op->getParentOp());
+  for (; isa<scf::ForOp, scf::WhileOp, scf::ParallelOp, scf::IfOp>(
+           op->getParentOp());
        op = op->getParentOp())
     ;
   return op;



More information about the Mlir-commits mailing list