[Mlir-commits] [mlir] [mlir][openacc] Update verifier to catch missing device type attribute (PR #111586)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Tue Oct 8 14:03:15 PDT 2024


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir

@llvm/pr-subscribers-mlir-openacc

Author: Valentin Clement (バレンタイン クレメン) (clementval)

<details>
<summary>Changes</summary>

Operands with device_type support need the corresponding attribute but this was not catches in the verifier if it was missing. The custom parser usually constructs it but creating the op from python could lead to a segfault in the printer. This patch updates the verifier so we catch this early on.

---
Full diff: https://github.com/llvm/llvm-project/pull/111586.diff


2 Files Affected:

- (modified) mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp (+14-11) 
- (modified) mlir/test/Dialect/OpenACC/invalid.mlir (+7) 


``````````diff
diff --git a/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp b/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp
index 877bd226a03528..919a0853fb6049 100644
--- a/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp
+++ b/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp
@@ -759,20 +759,23 @@ static LogicalResult verifyDeviceTypeAndSegmentCountMatch(
     Op op, OperandRange operands, DenseI32ArrayAttr segments,
     ArrayAttr deviceTypes, llvm::StringRef keyword, int32_t maxInSegment = 0) {
   std::size_t numOperandsInSegments = 0;
-
-  if (!segments)
-    return success();
-
-  for (auto segCount : segments.asArrayRef()) {
-    if (maxInSegment != 0 && segCount > maxInSegment)
-      return op.emitOpError() << keyword << " expects a maximum of "
-                              << maxInSegment << " values per segment";
-    numOperandsInSegments += segCount;
+  std::size_t nbOfSegments = 0;
+
+  if (segments) {
+    for (auto segCount : segments.asArrayRef()) {
+      if (maxInSegment != 0 && segCount > maxInSegment)
+        return op.emitOpError() << keyword << " expects a maximum of "
+                                << maxInSegment << " values per segment";
+      numOperandsInSegments += segCount;
+      ++nbOfSegments;
+    }
   }
-  if (numOperandsInSegments != operands.size())
+
+  if ((numOperandsInSegments != operands.size()) ||
+      (!deviceTypes && !operands.empty()))
     return op.emitOpError()
            << keyword << " operand count does not match count in segments";
-  if (deviceTypes.getValue().size() != (size_t)segments.size())
+  if (deviceTypes && deviceTypes.getValue().size() != nbOfSegments)
     return op.emitOpError()
            << keyword << " segment count does not match device_type count";
   return success();
diff --git a/mlir/test/Dialect/OpenACC/invalid.mlir b/mlir/test/Dialect/OpenACC/invalid.mlir
index ec5430420524ce..96edb585ae21a2 100644
--- a/mlir/test/Dialect/OpenACC/invalid.mlir
+++ b/mlir/test/Dialect/OpenACC/invalid.mlir
@@ -507,6 +507,13 @@ acc.parallel num_gangs({%i64value: i64, %i64value : i64, %i64value : i64, %i64va
 
 // -----
 
+%0 = "arith.constant"() <{value = 1 : i64}> : () -> i64
+// expected-error at +1 {{num_gangs operand count does not match count in segments}}
+"acc.parallel"(%0) <{numGangsSegments = array<i32: 1>, operandSegmentSizes = array<i32: 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0>}> ({
+}) : (i64) -> ()
+
+// -----
+
 %i64value = arith.constant 1 : i64
 acc.parallel {
 // expected-error at +1 {{'acc.set' op cannot be nested in a compute operation}}

``````````

</details>


https://github.com/llvm/llvm-project/pull/111586


More information about the Mlir-commits mailing list