[Mlir-commits] [mlir] 3f1d968 - [mlir][IR] Add variadic `getParentOfType` overloads (#184071)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Tue Mar 3 08:22:53 PST 2026


Author: Matthias Springer
Date: 2026-03-03T18:22:48+02:00
New Revision: 3f1d968db946e90c7e29bfa886566957f0e374f4

URL: https://github.com/llvm/llvm-project/commit/3f1d968db946e90c7e29bfa886566957f0e374f4
DIFF: https://github.com/llvm/llvm-project/commit/3f1d968db946e90c7e29bfa886566957f0e374f4.diff

LOG: [mlir][IR] Add variadic `getParentOfType` overloads (#184071)

Add `getParentOfType` overloads that work with multiple types.

Added: 
    

Modified: 
    mlir/include/mlir/IR/Operation.h
    mlir/include/mlir/IR/Region.h
    mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp
    mlir/lib/Dialect/OpenACC/Transforms/LegalizeDataValues.cpp
    mlir/lib/Dialect/OpenACC/Utils/OpenACCUtils.cpp
    mlir/lib/Dialect/SparseTensor/Transforms/Utils/CodegenUtils.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/IR/Operation.h b/mlir/include/mlir/IR/Operation.h
index b2019574a820d..ea9fbab4acf9c 100644
--- a/mlir/include/mlir/IR/Operation.h
+++ b/mlir/include/mlir/IR/Operation.h
@@ -242,6 +242,14 @@ class alignas(8) Operation final
         return parentOp;
     return OpTy();
   }
+  template <typename... OpTy>
+  std::enable_if_t<(sizeof...(OpTy) > 1), Operation *> getParentOfType() {
+    auto *op = this;
+    while ((op = op->getParentOp()))
+      if (isa<OpTy...>(op))
+        return op;
+    return nullptr;
+  }
 
   /// Returns the closest surrounding parent operation with trait `Trait`.
   template <template <typename T> class Trait>

diff  --git a/mlir/include/mlir/IR/Region.h b/mlir/include/mlir/IR/Region.h
index 53d461df98710..13b54991832cb 100644
--- a/mlir/include/mlir/IR/Region.h
+++ b/mlir/include/mlir/IR/Region.h
@@ -210,6 +210,17 @@ class Region {
     } while ((region = region->getParentRegion()));
     return ParentT();
   }
+  template <typename... ParentT>
+  std::enable_if_t<(sizeof...(ParentT) > 1), Operation *> getParentOfType() {
+    auto *region = this;
+    do {
+      if (!region->container)
+        return nullptr;
+      if (isa<ParentT...>(region->container))
+        return region->container;
+    } while ((region = region->getParentRegion()));
+    return nullptr;
+  }
 
   /// Return the number of this region in the parent operation.
   unsigned getRegionNumber();

diff  --git a/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp b/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp
index 5ec164a892d67..800305ffb36c5 100644
--- a/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp
+++ b/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp
@@ -1216,13 +1216,7 @@ bool acc::CacheOp::isCacheReadonly() {
 // It is quite alike acc::getEnclosingComputeOp() utility,
 // but we cannot use it here.
 static bool isEnclosedIntoComputeOp(mlir::Operation *op) {
-  mlir::Operation *parentOp = op->getParentOp();
-  while (parentOp) {
-    if (mlir::isa<ACC_COMPUTE_CONSTRUCT_OPS>(parentOp))
-      return true;
-    parentOp = parentOp->getParentOp();
-  }
-  return false;
+  return op->getParentOfType<ACC_COMPUTE_CONSTRUCT_OPS>();
 }
 
 /// Helper to add an effect on an operand, referenced by its mutable range.
@@ -1476,10 +1470,6 @@ static ParseResult parseRegions(OpAsmParser &parser, OperationState &state,
   return success();
 }
 
-static bool isComputeOperation(Operation *op) {
-  return isa<ACC_COMPUTE_CONSTRUCT_AND_LOOP_OPS>(op);
-}
-
 namespace {
 /// Pattern to remove operation without region that have constant false `ifCond`
 /// and remove the condition from the operation if the `ifCond` is a true
@@ -4824,10 +4814,8 @@ void RoutineOp::addBindIDName(MLIRContext *context,
 //===----------------------------------------------------------------------===//
 
 LogicalResult acc::InitOp::verify() {
-  Operation *currOp = *this;
-  while ((currOp = currOp->getParentOp()))
-    if (isComputeOperation(currOp))
-      return emitOpError("cannot be nested in a compute operation");
+  if (getOperation()->getParentOfType<ACC_COMPUTE_CONSTRUCT_AND_LOOP_OPS>())
+    return emitOpError("cannot be nested in a compute operation");
   return success();
 }
 
@@ -4846,10 +4834,8 @@ void acc::InitOp::addDeviceType(MLIRContext *context,
 //===----------------------------------------------------------------------===//
 
 LogicalResult acc::ShutdownOp::verify() {
-  Operation *currOp = *this;
-  while ((currOp = currOp->getParentOp()))
-    if (isComputeOperation(currOp))
-      return emitOpError("cannot be nested in a compute operation");
+  if (getOperation()->getParentOfType<ACC_COMPUTE_CONSTRUCT_AND_LOOP_OPS>())
+    return emitOpError("cannot be nested in a compute operation");
   return success();
 }
 
@@ -4868,10 +4854,8 @@ void acc::ShutdownOp::addDeviceType(MLIRContext *context,
 //===----------------------------------------------------------------------===//
 
 LogicalResult acc::SetOp::verify() {
-  Operation *currOp = *this;
-  while ((currOp = currOp->getParentOp()))
-    if (isComputeOperation(currOp))
-      return emitOpError("cannot be nested in a compute operation");
+  if (getOperation()->getParentOfType<ACC_COMPUTE_CONSTRUCT_AND_LOOP_OPS>())
+    return emitOpError("cannot be nested in a compute operation");
   if (!getDeviceTypeAttr() && !getDefaultAsync() && !getDeviceNum())
     return emitOpError("at least one default_async, device_num, or device_type "
                        "operand must appear");

diff  --git a/mlir/lib/Dialect/OpenACC/Transforms/LegalizeDataValues.cpp b/mlir/lib/Dialect/OpenACC/Transforms/LegalizeDataValues.cpp
index 6d316d282278d..8b15d2c1cc7f2 100644
--- a/mlir/lib/Dialect/OpenACC/Transforms/LegalizeDataValues.cpp
+++ b/mlir/lib/Dialect/OpenACC/Transforms/LegalizeDataValues.cpp
@@ -27,14 +27,7 @@ using namespace mlir;
 namespace {
 
 static bool insideAccComputeRegion(mlir::Operation *op) {
-  mlir::Operation *parent{op->getParentOp()};
-  while (parent) {
-    if (isa<ACC_COMPUTE_CONSTRUCT_OPS>(parent)) {
-      return true;
-    }
-    parent = parent->getParentOp();
-  }
-  return false;
+  return op->getParentOfType<ACC_COMPUTE_CONSTRUCT_OPS>();
 }
 
 static void collectVars(mlir::ValueRange operands,

diff  --git a/mlir/lib/Dialect/OpenACC/Utils/OpenACCUtils.cpp b/mlir/lib/Dialect/OpenACC/Utils/OpenACCUtils.cpp
index 6e5ec0ccd1210..911f256a3d2a6 100644
--- a/mlir/lib/Dialect/OpenACC/Utils/OpenACCUtils.cpp
+++ b/mlir/lib/Dialect/OpenACC/Utils/OpenACCUtils.cpp
@@ -21,13 +21,7 @@
 #include "llvm/Support/Casting.h"
 
 mlir::Operation *mlir::acc::getEnclosingComputeOp(mlir::Region &region) {
-  mlir::Operation *parentOp = region.getParentOp();
-  while (parentOp) {
-    if (mlir::isa<ACC_COMPUTE_CONSTRUCT_OPS>(parentOp))
-      return parentOp;
-    parentOp = parentOp->getParentOp();
-  }
-  return nullptr;
+  return region.getParentOfType<ACC_COMPUTE_CONSTRUCT_OPS>();
 }
 
 template <typename OpTy>

diff  --git a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/CodegenUtils.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/CodegenUtils.cpp
index ffc898c9933c3..614c29ba68481 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/CodegenUtils.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/CodegenUtils.cpp
@@ -430,10 +430,8 @@ void mlir::sparse_tensor::sizesFromSrc(OpBuilder &builder,
 }
 
 Operation *mlir::sparse_tensor::getTop(Operation *op) {
-  for (; isa<scf::ForOp>(op->getParentOp()) ||
-         isa<scf::WhileOp>(op->getParentOp()) ||
-         isa<scf::ParallelOp>(op->getParentOp()) ||
-         isa<scf::IfOp>(op->getParentOp());
+  for (; isa<scf::ForOp, scf::WhileOp, scf::ParallelOp, scf::IfOp>(
+           op->getParentOp());
        op = op->getParentOp())
     ;
   return op;


        


More information about the Mlir-commits mailing list