[Mlir-commits] [mlir] [MLIR][OpenMP] Move loop wrapper verification to the interface (NFC) (PR #110505)

Sergio Afonso llvmlistbot at llvm.org
Mon Sep 30 06:07:16 PDT 2024


https://github.com/skatrak created https://github.com/llvm/llvm-project/pull/110505

This patch moves verification code for the `LoopWrapperInterface` to the interface itself, checking it automatically for each operation that has that interface.

>From af97caf480db2ff92d48119e42749f6a294fe82d Mon Sep 17 00:00:00 2001
From: Sergio Afonso <safonsof at amd.com>
Date: Mon, 30 Sep 2024 14:03:16 +0100
Subject: [PATCH] [MLIR][OpenMP] Move loop wrapper verification to the
 interface (NFC)

This patch moves verification code for the `LoopWrapperInterface` to the
interface itself, checking it automatically for each operation that has that
interface.
---
 .../Dialect/OpenMP/OpenMPOpsInterfaces.td     |  9 +++
 mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp  | 69 +++++++++----------
 mlir/test/Dialect/OpenMP/invalid.mlir         | 46 ++++++++++---
 3 files changed, 76 insertions(+), 48 deletions(-)

diff --git a/mlir/include/mlir/Dialect/OpenMP/OpenMPOpsInterfaces.td b/mlir/include/mlir/Dialect/OpenMP/OpenMPOpsInterfaces.td
index 0078e22b1c89a6..ea1e3ebecef7b4 100644
--- a/mlir/include/mlir/Dialect/OpenMP/OpenMPOpsInterfaces.td
+++ b/mlir/include/mlir/Dialect/OpenMP/OpenMPOpsInterfaces.td
@@ -106,6 +106,15 @@ def LoopWrapperInterface : OpInterface<"LoopWrapperInterface"> {
       }]
     >
   ];
+
+  let extraClassDeclaration = [{
+    /// Interface verifier imlementation.
+    llvm::LogicalResult verifyImpl();
+  }];
+
+  let verify = [{
+    return ::llvm::cast<::mlir::omp::LoopWrapperInterface>($_op).verifyImpl();
+  }];
 }
 
 def ComposableOpInterface : OpInterface<"ComposableOpInterface"> {
diff --git a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
index 90bf5df67b03ba..59e71ecc6ec5df 100644
--- a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
+++ b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
@@ -1682,6 +1682,37 @@ LogicalResult SingleOp::verify() {
                                   getCopyprivateSyms());
 }
 
+//===----------------------------------------------------------------------===//
+// LoopWrapperInterface
+//===----------------------------------------------------------------------===//
+
+LogicalResult LoopWrapperInterface::verifyImpl() {
+  Operation *op = this->getOperation();
+  if (op->getNumRegions() != 1)
+    return emitOpError() << "loop wrapper contains multiple regions";
+
+  Region &region = op->getRegion(0);
+  if (!region.hasOneBlock())
+    return emitOpError() << "loop wrapper contains multiple blocks";
+
+  if (::llvm::range_size(region.getOps()) != 2)
+    return emitOpError()
+           << "loop wrapper does not contain exactly two nested ops";
+
+  Operation &firstOp = *region.op_begin();
+  Operation &secondOp = *(std::next(region.op_begin()));
+
+  if (!secondOp.hasTrait<OpTrait::IsTerminator>())
+    return emitOpError()
+           << "second nested op in loop wrapper is not a terminator";
+
+  if (!::llvm::isa<LoopNestOp, LoopWrapperInterface>(firstOp))
+    return emitOpError() << "first nested op in loop wrapper is not "
+                            "another loop wrapper or `omp.loop_nest`";
+
+  return success();
+}
+
 //===----------------------------------------------------------------------===//
 // WsloopOp
 //===----------------------------------------------------------------------===//
@@ -1714,32 +1745,6 @@ void printWsloop(OpAsmPrinter &p, Operation *op, Region &region,
   p.printRegion(region, /*printEntryBlockArgs=*/false);
 }
 
-static LogicalResult verifyLoopWrapperInterface(Operation *op) {
-  if (op->getNumRegions() != 1)
-    return op->emitOpError() << "loop wrapper contains multiple regions";
-
-  Region &region = op->getRegion(0);
-  if (!region.hasOneBlock())
-    return op->emitOpError() << "loop wrapper contains multiple blocks";
-
-  if (::llvm::range_size(region.getOps()) != 2)
-    return op->emitOpError()
-           << "loop wrapper does not contain exactly two nested ops";
-
-  Operation &firstOp = *region.op_begin();
-  Operation &secondOp = *(std::next(region.op_begin()));
-
-  if (!secondOp.hasTrait<OpTrait::IsTerminator>())
-    return op->emitOpError()
-           << "second nested op in loop wrapper is not a terminator";
-
-  if (!::llvm::isa<LoopNestOp, LoopWrapperInterface>(firstOp))
-    return op->emitOpError() << "first nested op in loop wrapper is not "
-                                "another loop wrapper or `omp.loop_nest`";
-
-  return success();
-}
-
 void WsloopOp::build(OpBuilder &builder, OperationState &state,
                      ArrayRef<NamedAttribute> attributes) {
   build(builder, state, /*allocate_vars=*/{}, /*allocator_vars=*/{},
@@ -1770,9 +1775,6 @@ void WsloopOp::build(OpBuilder &builder, OperationState &state,
 }
 
 LogicalResult WsloopOp::verify() {
-  if (verifyLoopWrapperInterface(*this).failed())
-    return failure();
-
   bool isCompositeChildLeaf =
       llvm::dyn_cast_if_present<LoopWrapperInterface>((*this)->getParentOp());
 
@@ -1829,9 +1831,6 @@ LogicalResult SimdOp::verify() {
   if (verifyNontemporalClause(*this, getNontemporalVars()).failed())
     return failure();
 
-  if (verifyLoopWrapperInterface(*this).failed())
-    return failure();
-
   if (getNestedWrapper())
     return emitOpError() << "must wrap an 'omp.loop_nest' directly";
 
@@ -1871,9 +1870,6 @@ LogicalResult DistributeOp::verify() {
     return emitError(
         "expected equal sizes for allocate and allocator variables");
 
-  if (verifyLoopWrapperInterface(*this).failed())
-    return failure();
-
   if (LoopWrapperInterface nested = getNestedWrapper()) {
     if (!isComposite())
       return emitError()
@@ -2079,9 +2075,6 @@ LogicalResult TaskloopOp::verify() {
         "may not appear on the same taskloop directive");
   }
 
-  if (verifyLoopWrapperInterface(*this).failed())
-    return failure();
-
   if (LoopWrapperInterface nested = getNestedWrapper()) {
     if (!isComposite())
       return emitError()
diff --git a/mlir/test/Dialect/OpenMP/invalid.mlir b/mlir/test/Dialect/OpenMP/invalid.mlir
index d8745f1015af83..35a8883e3a317e 100644
--- a/mlir/test/Dialect/OpenMP/invalid.mlir
+++ b/mlir/test/Dialect/OpenMP/invalid.mlir
@@ -355,6 +355,7 @@ func.func @omp_simd_aligned_mismatch(%arg0 : index, %arg1 : index,
     omp.loop_nest (%iv) : index = (%arg0) to (%arg1) step (%arg2) {
       omp.yield
     }
+    omp.terminator
   }) {alignments = [128],
       operandSegmentSizes = array<i32: 2, 0, 0, 0, 0, 0, 0>} : (memref<i32>, memref<i32>) -> ()
   return
@@ -370,6 +371,7 @@ func.func @omp_simd_aligned_negative(%arg0 : index, %arg1 : index,
     omp.loop_nest (%iv) : index = (%arg0) to (%arg1) step (%arg2) {
       omp.yield
     }
+    omp.terminator
   }) {alignments = [-1, 128], operandSegmentSizes = array<i32: 2, 0, 0, 0, 0, 0, 0>} : (memref<i32>, memref<i32>) -> ()
   return
 }
@@ -384,6 +386,7 @@ func.func @omp_simd_unexpected_alignment(%arg0 : index, %arg1 : index,
     omp.loop_nest (%iv) : index = (%arg0) to (%arg1) step (%arg2) {
       omp.yield
     }
+    omp.terminator
   }) {alignments = [1, 128]} : () -> ()
   return
 }
@@ -398,6 +401,7 @@ func.func @omp_simd_aligned_float(%arg0 : index, %arg1 : index,
     omp.loop_nest (%iv) : index = (%arg0) to (%arg1) step (%arg2) {
       omp.yield
     }
+    omp.terminator
   }) {alignments = [1.5, 128], operandSegmentSizes = array<i32: 2, 0, 0, 0, 0, 0, 0>} : (memref<i32>, memref<i32>) -> ()
   return
 }
@@ -412,6 +416,7 @@ func.func @omp_simd_aligned_the_same_var(%arg0 : index, %arg1 : index,
     omp.loop_nest (%iv) : index = (%arg0) to (%arg1) step (%arg2) {
       omp.yield
     }
+    omp.terminator
   }) {alignments = [1, 128], operandSegmentSizes = array<i32: 2, 0, 0, 0, 0, 0, 0>} : (memref<i32>, memref<i32>) -> ()
   return
 }
@@ -426,6 +431,7 @@ func.func @omp_simd_nontemporal_the_same_var(%arg0 : index,  %arg1 : index,
     omp.loop_nest (%iv) : index = (%arg0) to (%arg1) step (%arg2) {
       omp.yield
     }
+    omp.terminator
   }) {operandSegmentSizes = array<i32: 0, 0, 0, 0, 2, 0, 0>} : (memref<i32>, memref<i32>) -> ()
   return
 }
@@ -438,6 +444,7 @@ func.func @omp_simd_order_value(%lb : index, %ub : index, %step : index) {
     omp.loop_nest (%iv) : index = (%arg0) to (%arg1) step (%arg2) {
       omp.yield
     }
+    omp.terminator
   }
   return
 }
@@ -450,6 +457,7 @@ func.func @omp_simd_reproducible_order(%lb : index, %ub : index, %step : index)
     omp.loop_nest (%iv) : index = (%arg0) to (%arg1) step (%arg2) {
       omp.yield
     }
+    omp.terminator
   }
   return
 }
@@ -460,6 +468,7 @@ func.func @omp_simd_unconstrained_order(%lb : index, %ub : index, %step : index)
     omp.loop_nest (%iv) : index = (%arg0) to (%arg1) step (%arg2) {
       omp.yield
     }
+    omp.terminator
   }
   return
 }
@@ -470,6 +479,7 @@ func.func @omp_simd_pretty_simdlen(%lb : index, %ub : index, %step : index) -> (
     omp.loop_nest (%iv) : index = (%lb) to (%ub) step (%step) {
       omp.yield
     }
+    omp.terminator
   }
   return
 }
@@ -482,6 +492,7 @@ func.func @omp_simd_pretty_safelen(%lb : index, %ub : index, %step : index) -> (
     omp.loop_nest (%iv) : index = (%lb) to (%ub) step (%step) {
       omp.yield
     }
+    omp.terminator
   }
   return
 }
@@ -494,6 +505,7 @@ func.func @omp_simd_pretty_simdlen_safelen(%lb : index, %ub : index, %step : ind
     omp.loop_nest (%iv) : index = (%lb) to (%ub) step (%step) {
       omp.yield
     }
+    omp.terminator
   }
   return
 }
@@ -1838,6 +1850,7 @@ func.func @taskloop(%lb: i32, %ub: i32, %step: i32) {
     omp.loop_nest (%i, %j) : i32 = (%lb, %ub) to (%ub, %lb) step (%step, %step) {
       omp.yield
     }
+    omp.terminator
   }) {operandSegmentSizes = array<i32: 1, 0, 0, 0, 0, 0, 0, 0, 0, 0>} : (memref<i32>) -> ()
   return
 }
@@ -1852,6 +1865,7 @@ func.func @taskloop(%lb: i32, %ub: i32, %step: i32) {
     omp.loop_nest (%i, %j) : i32 = (%lb, %ub) to (%ub, %lb) step (%step, %step) {
       omp.yield
     }
+    omp.terminator
   }) {operandSegmentSizes = array<i32: 0, 0, 0, 0, 0, 0, 0, 0, 0, 2>, reduction_syms = [@add_f32]} : (!llvm.ptr, !llvm.ptr) -> ()
   return
 }
@@ -1865,6 +1879,7 @@ func.func @taskloop(%lb: i32, %ub: i32, %step: i32) {
     omp.loop_nest (%i, %j) : i32 = (%lb, %ub) to (%ub, %lb) step (%step, %step) {
       omp.yield
     }
+    omp.terminator
   }) {operandSegmentSizes = array<i32: 0, 0, 0, 0, 0, 0, 0, 0, 0, 1>, reduction_syms = [@add_f32, @add_f32]} : (!llvm.ptr) -> ()
   return
 }
@@ -1879,6 +1894,7 @@ func.func @taskloop(%lb: i32, %ub: i32, %step: i32) {
     omp.loop_nest (%i, %j) : i32 = (%lb, %ub) to (%ub, %lb) step (%step, %step) {
       omp.yield
     }
+    omp.terminator
   }) {in_reduction_syms = [@add_f32], operandSegmentSizes = array<i32: 0, 0, 0, 0, 0, 2, 0, 0, 0, 0>} : (!llvm.ptr, !llvm.ptr) -> ()
   return
 }
@@ -1892,6 +1908,7 @@ func.func @taskloop(%lb: i32, %ub: i32, %step: i32) {
     omp.loop_nest (%i, %j) : i32 = (%lb, %ub) to (%ub, %lb) step (%step, %step) {
       omp.yield
     }
+    omp.terminator
   }) {in_reduction_syms = [@add_f32, @add_f32], operandSegmentSizes = array<i32: 0, 0, 0, 0, 0, 1, 0, 0, 0, 0>} : (!llvm.ptr) -> ()
   return
 }
@@ -1918,6 +1935,7 @@ func.func @taskloop(%lb: i32, %ub: i32, %step: i32) {
     omp.loop_nest (%i, %j) : i32 = (%lb, %ub) to (%ub, %lb) step (%step, %step) {
       omp.yield
     }
+    omp.terminator
   }
   return
 }
@@ -1943,6 +1961,7 @@ func.func @taskloop(%lb: i32, %ub: i32, %step: i32) {
     omp.loop_nest (%i, %j) : i32 = (%lb, %ub) to (%ub, %lb) step (%step, %step) {
       omp.yield
     }
+    omp.terminator
   }
   return
 }
@@ -1956,6 +1975,7 @@ func.func @taskloop(%lb: i32, %ub: i32, %step: i32) {
     omp.loop_nest (%i, %j) : i32 = (%lb, %ub) to (%ub, %lb) step (%step, %step) {
       omp.yield
     }
+    omp.terminator
   }
   return
 }
@@ -2153,20 +2173,26 @@ func.func @omp_target_depend(%data_var: memref<i32>) {
 
 // -----
 
-func.func @omp_distribute_schedule(%chunk_size : i32) -> () {
+func.func @omp_distribute_schedule(%chunk_size : i32, %lb : i32, %ub : i32, %step : i32) -> () {
   // expected-error @below {{op chunk size set without dist_schedule_static being present}}
   "omp.distribute"(%chunk_size) <{operandSegmentSizes = array<i32: 0, 0, 1, 0>}> ({
-      "omp.terminator"() : () -> ()
-    }) : (i32) -> ()
+    omp.loop_nest (%iv) : i32 = (%lb) to (%ub) step (%step) {
+      "omp.yield"() : () -> ()
+    }
+    "omp.terminator"() : () -> ()
+  }) : (i32) -> ()
 }
 
 // -----
 
-func.func @omp_distribute_allocate(%data_var : memref<i32>) -> () {
+func.func @omp_distribute_allocate(%data_var : memref<i32>, %lb : i32, %ub : i32, %step : i32) -> () {
   // expected-error @below {{expected equal sizes for allocate and allocator variables}}
   "omp.distribute"(%data_var) <{operandSegmentSizes = array<i32: 1, 0, 0, 0>}> ({
-      "omp.terminator"() : () -> ()
-    }) : (memref<i32>) -> ()
+    omp.loop_nest (%iv) : i32 = (%lb) to (%ub) step (%step) {
+      "omp.yield"() : () -> ()
+    }
+    "omp.terminator"() : () -> ()
+  }) : (memref<i32>) -> ()
 }
 
 // -----
@@ -2174,10 +2200,10 @@ func.func @omp_distribute_allocate(%data_var : memref<i32>) -> () {
 func.func @omp_distribute_wrapper(%lb: index, %ub: index, %step: index) -> () {
   // expected-error @below {{op second nested op in loop wrapper is not a terminator}}
   omp.distribute {
-      omp.loop_nest (%iv) : index = (%lb) to (%ub) step (%step) {
-        "omp.yield"() : () -> ()
-      }
-      %0 = arith.constant 0 : i32
+    omp.loop_nest (%iv) : index = (%lb) to (%ub) step (%step) {
+      "omp.yield"() : () -> ()
+    }
+    %0 = arith.constant 0 : i32
   }
 }
 



More information about the Mlir-commits mailing list