[Mlir-commits] [mlir] a435e1f - [acc] Add attribute for combined constructs (#80319)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Thu Mar 7 10:06:52 PST 2024


Author: Razvan Lupusoru
Date: 2024-03-07T10:06:47-08:00
New Revision: a435e1f63bbd8c6d0ff140ccc890c25787091490

URL: https://github.com/llvm/llvm-project/commit/a435e1f63bbd8c6d0ff140ccc890c25787091490
DIFF: https://github.com/llvm/llvm-project/commit/a435e1f63bbd8c6d0ff140ccc890c25787091490.diff

LOG: [acc] Add attribute for combined constructs (#80319)

Combined constructs are decomposed into separate operations. However,
this does not adhere to `acc` dialect's goal to be able to regenerate
semantically equivalent clauses as user's intent. Thus, add an attribute
to keep track of the combined constructs.

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/OpenACC/OpenACC.h
    mlir/include/mlir/Dialect/OpenACC/OpenACCOps.td
    mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp
    mlir/test/Dialect/OpenACC/invalid.mlir
    mlir/test/Dialect/OpenACC/ops.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/OpenACC/OpenACC.h b/mlir/include/mlir/Dialect/OpenACC/OpenACC.h
index bb3b9617c24edb..0c8e0b45878206 100644
--- a/mlir/include/mlir/Dialect/OpenACC/OpenACC.h
+++ b/mlir/include/mlir/Dialect/OpenACC/OpenACC.h
@@ -133,6 +133,10 @@ static constexpr StringLiteral getRoutineInfoAttrName() {
   return StringLiteral("acc.routine_info");
 }
 
+static constexpr StringLiteral getCombinedConstructsAttrName() {
+  return CombinedConstructsTypeAttr::name;
+}
+
 struct RuntimeCounters
     : public mlir::SideEffects::Resource::Base<RuntimeCounters> {
   mlir::StringRef getName() final { return "AccRuntimeCounters"; }

diff  --git a/mlir/include/mlir/Dialect/OpenACC/OpenACCOps.td b/mlir/include/mlir/Dialect/OpenACC/OpenACCOps.td
index 6da7a742bbed8c..b5ad46361fa698 100644
--- a/mlir/include/mlir/Dialect/OpenACC/OpenACCOps.td
+++ b/mlir/include/mlir/Dialect/OpenACC/OpenACCOps.td
@@ -218,6 +218,24 @@ def GangArgTypeArrayAttr :
   let constBuilderCall = ?;
 }
 
+// Combined constructs enumerations
+def OpenACC_KernelsLoop    : I32EnumAttrCase<"KernelsLoop", 1, "kernels_loop">;
+def OpenACC_ParallelLoop   : I32EnumAttrCase<"ParallelLoop", 2, "parallel_loop">;
+def OpenACC_SerialLoop     : I32EnumAttrCase<"SerialLoop", 3, "serial_loop">;
+
+def OpenACC_CombinedConstructsType : I32EnumAttr<"CombinedConstructsType",
+    "Differentiate between combined constructs",
+    [OpenACC_KernelsLoop, OpenACC_ParallelLoop, OpenACC_SerialLoop]> {
+  let genSpecializedAttr = 0;
+  let cppNamespace = "::mlir::acc";
+}
+
+def OpenACC_CombinedConstructsAttr : EnumAttr<OpenACC_Dialect,
+                                       OpenACC_CombinedConstructsType,
+                                       "combined_constructs"> {
+  let assemblyFormat = [{ ```<` $value `>` }];
+}
+
 // Define a resource for the OpenACC runtime counters.
 def OpenACC_RuntimeCounters : Resource<"::mlir::acc::RuntimeCounters">;
 
@@ -933,7 +951,8 @@ def OpenACC_ParallelOp : OpenACC_Op<"parallel",
       Variadic<OpenACC_PointerLikeTypeInterface>:$gangFirstPrivateOperands,
       OptionalAttr<SymbolRefArrayAttr>:$firstprivatizations,
       Variadic<OpenACC_PointerLikeTypeInterface>:$dataClauseOperands,
-      OptionalAttr<DefaultValueAttr>:$defaultAttr);
+      OptionalAttr<DefaultValueAttr>:$defaultAttr,
+      UnitAttr:$combined);
 
   let regions = (region AnyRegion:$region);
 
@@ -993,6 +1012,7 @@ def OpenACC_ParallelOp : OpenACC_Op<"parallel",
   }];
 
   let assemblyFormat = [{
+    ( `combined` `(` `loop` `)` $combined^)?
     oilist(
         `dataOperands` `(` $dataClauseOperands `:` type($dataClauseOperands) `)`
       | `async` `(` custom<DeviceTypeOperands>($asyncOperands,
@@ -1068,7 +1088,8 @@ def OpenACC_SerialOp : OpenACC_Op<"serial",
       Variadic<OpenACC_PointerLikeTypeInterface>:$gangFirstPrivateOperands,
       OptionalAttr<SymbolRefArrayAttr>:$firstprivatizations,
       Variadic<OpenACC_PointerLikeTypeInterface>:$dataClauseOperands,
-      OptionalAttr<DefaultValueAttr>:$defaultAttr);
+      OptionalAttr<DefaultValueAttr>:$defaultAttr,
+      UnitAttr:$combined);
 
   let regions = (region AnyRegion:$region);
 
@@ -1109,6 +1130,7 @@ def OpenACC_SerialOp : OpenACC_Op<"serial",
   }];
 
   let assemblyFormat = [{
+    ( `combined` `(` `loop` `)` $combined^)?
     oilist(
         `dataOperands` `(` $dataClauseOperands `:` type($dataClauseOperands) `)`
       | `async` `(` custom<DeviceTypeOperands>($asyncOperands,
@@ -1182,7 +1204,8 @@ def OpenACC_KernelsOp : OpenACC_Op<"kernels",
       Optional<I1>:$selfCond,
       UnitAttr:$selfAttr,
       Variadic<OpenACC_PointerLikeTypeInterface>:$dataClauseOperands,
-      OptionalAttr<DefaultValueAttr>:$defaultAttr);
+      OptionalAttr<DefaultValueAttr>:$defaultAttr,
+      UnitAttr:$combined);
 
   let regions = (region AnyRegion:$region);
 
@@ -1242,6 +1265,7 @@ def OpenACC_KernelsOp : OpenACC_Op<"kernels",
   }];
 
   let assemblyFormat = [{
+    ( `combined` `(` `loop` `)` $combined^)?
     oilist(
         `dataOperands` `(` $dataClauseOperands `:` type($dataClauseOperands) `)`
       | `async` `(` custom<DeviceTypeOperands>($asyncOperands,
@@ -1573,7 +1597,8 @@ def OpenACC_LoopOp : OpenACC_Op<"loop",
       Variadic<OpenACC_PointerLikeTypeInterface>:$privateOperands,
       OptionalAttr<SymbolRefArrayAttr>:$privatizations,
       Variadic<AnyType>:$reductionOperands,
-      OptionalAttr<SymbolRefArrayAttr>:$reductionRecipes
+      OptionalAttr<SymbolRefArrayAttr>:$reductionRecipes,
+      OptionalAttr<OpenACC_CombinedConstructsAttr>:$combined
   );
 
   let results = (outs Variadic<AnyType>:$results);
@@ -1665,6 +1690,7 @@ def OpenACC_LoopOp : OpenACC_Op<"loop",
 
   let hasCustomAssemblyFormat = 1;
   let assemblyFormat = [{
+    custom<CombinedConstructsLoop>($combined)
     oilist(
         `gang` `` custom<GangClause>($gangOperands, type($gangOperands),
             $gangOperandsArgType, $gangOperandsDeviceType,

diff  --git a/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp b/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp
index 18187e7d4f66cd..c09a3403f9a3e3 100644
--- a/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp
+++ b/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp
@@ -1283,6 +1283,50 @@ static void printDeviceTypeOperandsWithKeywordOnly(
   p << ")";
 }
 
+static ParseResult
+parseCombinedConstructsLoop(mlir::OpAsmParser &parser,
+                            mlir::acc::CombinedConstructsTypeAttr &attr) {
+  if (succeeded(parser.parseOptionalKeyword("combined"))) {
+    if (parser.parseLParen())
+      return failure();
+    if (succeeded(parser.parseOptionalKeyword("kernels"))) {
+      attr = mlir::acc::CombinedConstructsTypeAttr::get(
+          parser.getContext(), mlir::acc::CombinedConstructsType::KernelsLoop);
+    } else if (succeeded(parser.parseOptionalKeyword("parallel"))) {
+      attr = mlir::acc::CombinedConstructsTypeAttr::get(
+          parser.getContext(), mlir::acc::CombinedConstructsType::ParallelLoop);
+    } else if (succeeded(parser.parseOptionalKeyword("serial"))) {
+      attr = mlir::acc::CombinedConstructsTypeAttr::get(
+          parser.getContext(), mlir::acc::CombinedConstructsType::SerialLoop);
+    } else {
+      parser.emitError(parser.getCurrentLocation(),
+                       "expected compute construct name");
+      return failure();
+    }
+    if (parser.parseRParen())
+      return failure();
+  }
+  return success();
+}
+
+static void
+printCombinedConstructsLoop(mlir::OpAsmPrinter &p, mlir::Operation *op,
+                            mlir::acc::CombinedConstructsTypeAttr attr) {
+  if (attr) {
+    switch (attr.getValue()) {
+    case mlir::acc::CombinedConstructsType::KernelsLoop:
+      p << "combined(kernels)";
+      break;
+    case mlir::acc::CombinedConstructsType::ParallelLoop:
+      p << "combined(parallel)";
+      break;
+    case mlir::acc::CombinedConstructsType::SerialLoop:
+      p << "combined(serial)";
+      break;
+    };
+  }
+}
+
 //===----------------------------------------------------------------------===//
 // SerialOp
 //===----------------------------------------------------------------------===//
@@ -1851,6 +1895,13 @@ LogicalResult acc::LoopOp::verify() {
           "reductions", false)))
     return failure();
 
+  if (getCombined().has_value() &&
+      (getCombined().value() != acc::CombinedConstructsType::ParallelLoop &&
+       getCombined().value() != acc::CombinedConstructsType::KernelsLoop &&
+       getCombined().value() != acc::CombinedConstructsType::SerialLoop)) {
+    return emitError("unexpected combined constructs attribute");
+  }
+
   // Check non-empty body().
   if (getRegion().empty())
     return emitError("expected non-empty body.");

diff  --git a/mlir/test/Dialect/OpenACC/invalid.mlir b/mlir/test/Dialect/OpenACC/invalid.mlir
index 70747b7e2acf4b..ec5430420524ce 100644
--- a/mlir/test/Dialect/OpenACC/invalid.mlir
+++ b/mlir/test/Dialect/OpenACC/invalid.mlir
@@ -738,3 +738,43 @@ func.func @acc_atomic_capture(%x: memref<i32>, %y: memref<i32>, %v: memref<i32>,
     acc.terminator
   }
 }
+
+// -----
+
+func.func @acc_combined() {
+  // expected-error @below {{expected 'loop'}}
+  acc.parallel combined() {
+  }
+
+  return
+}
+
+// -----
+
+func.func @acc_combined() {
+  // expected-error @below {{expected compute construct name}}
+  acc.loop combined(loop) {
+  }
+
+  return
+}
+
+// -----
+
+func.func @acc_combined() {
+  // expected-error @below {{expected 'loop'}}
+  acc.parallel combined(parallel loop) {
+  }
+
+  return
+}
+
+// -----
+
+func.func @acc_combined() {
+  // expected-error @below {{expected ')'}}
+  acc.loop combined(parallel loop) {
+  }
+
+  return
+}

diff  --git a/mlir/test/Dialect/OpenACC/ops.mlir b/mlir/test/Dialect/OpenACC/ops.mlir
index 1739b3de3e65fd..2ef2178cb2b63a 100644
--- a/mlir/test/Dialect/OpenACC/ops.mlir
+++ b/mlir/test/Dialect/OpenACC/ops.mlir
@@ -1846,9 +1846,49 @@ func.func @acc_atomic_capture(%v: memref<i32>, %x: memref<i32>, %expr: i32) {
 
 // -----
 
-%c2 = arith.constant 2 : i32
-%c1 = arith.constant 1 : i32
-acc.parallel num_gangs({%c2 : i32} [#acc.device_type<default>], {%c1 : i32, %c1 : i32, %c1 : i32} [#acc.device_type<nvidia>]) {
+// CHECK-LABEL: func.func @acc_num_gangs
+func.func @acc_num_gangs() {
+  %c2 = arith.constant 2 : i32
+  %c1 = arith.constant 1 : i32
+  acc.parallel num_gangs({%c2 : i32} [#acc.device_type<default>], {%c1 : i32, %c1 : i32, %c1 : i32} [#acc.device_type<nvidia>]) {
+  }
+
+  return
 }
 
 // CHECK: acc.parallel num_gangs({%c2{{.*}} : i32} [#acc.device_type<default>], {%c1{{.*}} : i32, %c1{{.*}} : i32, %c1{{.*}} : i32} [#acc.device_type<nvidia>])
+
+// -----
+
+// CHECK-LABEL: func.func @acc_combined
+func.func @acc_combined() {
+  acc.parallel combined(loop) {
+    acc.loop combined(parallel) {
+      acc.yield
+    }
+    acc.terminator
+  }
+
+  acc.kernels combined(loop) {
+    acc.loop combined(kernels) {
+      acc.yield
+    }
+    acc.terminator
+  }
+
+  acc.serial combined(loop) {
+    acc.loop combined(serial) {
+      acc.yield
+    }
+    acc.terminator
+  }
+
+  return
+}
+
+// CHECK: acc.parallel combined(loop)
+// CHECK: acc.loop combined(parallel)
+// CHECK: acc.kernels combined(loop)
+// CHECK: acc.loop combined(kernels)
+// CHECK: acc.serial combined(loop)
+// CHECK: acc.loop combined(serial)


        


More information about the Mlir-commits mailing list