[flang-commits] [flang] c067c6e - [mlir][openacc] Use new private representation in acc.parallel
Valentin Clement via flang-commits
flang-commits at lists.llvm.org
Mon May 22 09:49:37 PDT 2023
Author: Valentin Clement
Date: 2023-05-22T09:49:31-07:00
New Revision: c067c6e55dcf6cd936d4b74aed254abaa6f398fc
URL: https://github.com/llvm/llvm-project/commit/c067c6e55dcf6cd936d4b74aed254abaa6f398fc
DIFF: https://github.com/llvm/llvm-project/commit/c067c6e55dcf6cd936d4b74aed254abaa6f398fc.diff
LOG: [mlir][openacc] Use new private representation in acc.parallel
Update acc.parallel private operands list to use the new design
introduced in D150622.
Test in flang/test/Lower/OpenACC/acc-parallel.f90 and
flang/test/Lower/OpenACC/acc-parallel-loop.f90 are temporarly
disabled and will be enabled with updated lowering in the follow-up
patch.
Reviewed By: razvanlupusoru
Differential Revision: https://reviews.llvm.org/D150971
Added:
Modified:
flang/test/Lower/OpenACC/acc-parallel-loop.f90
flang/test/Lower/OpenACC/acc-parallel.f90
mlir/include/mlir/Dialect/OpenACC/OpenACCOps.td
mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp
mlir/test/Dialect/OpenACC/ops.mlir
Removed:
################################################################################
diff --git a/flang/test/Lower/OpenACC/acc-parallel-loop.f90 b/flang/test/Lower/OpenACC/acc-parallel-loop.f90
index fa7323ddf6b08..1b91b7c27a7f0 100644
--- a/flang/test/Lower/OpenACC/acc-parallel-loop.f90
+++ b/flang/test/Lower/OpenACC/acc-parallel-loop.f90
@@ -442,18 +442,19 @@ subroutine acc_parallel_loop
! CHECK: acc.yield
! CHECK-NEXT: }{{$}}
- !$acc parallel loop private(a) firstprivate(b)
- DO i = 1, n
- a(i) = b(i)
- END DO
-
-! CHECK: acc.parallel firstprivate(%[[B]] : !fir.ref<!fir.array<10xf32>>) private(%[[A]] : !fir.ref<!fir.array<10xf32>>) {
-! CHECK: acc.loop private(%[[A]] : !fir.ref<!fir.array<10xf32>>) {
-! CHECK: fir.do_loop
-! CHECK: acc.yield
-! CHECK-NEXT: }{{$}}
-! CHECK: acc.yield
-! CHECK-NEXT: }{{$}}
+! TODO: will be updated after lowering change in privatization to MLIR
+! !$acc parallel loop private(a) firstprivate(b)
+! DO i = 1, n
+! a(i) = b(i)
+! END DO
+
+! TODO: acc.parallel firstprivate(%[[B]] : !fir.ref<!fir.array<10xf32>>) private(%[[A]] : !fir.ref<!fir.array<10xf32>>) {
+! TODO: acc.loop private(%[[A]] : !fir.ref<!fir.array<10xf32>>) {
+! TODO: fir.do_loop
+! TODO: acc.yield
+! TODO-NEXT: }{{$}}
+! TODO: acc.yield
+! TODO-NEXT: }{{$}}
!$acc parallel loop seq
DO i = 1, n
diff --git a/flang/test/Lower/OpenACC/acc-parallel.f90 b/flang/test/Lower/OpenACC/acc-parallel.f90
index a880f151d7eb9..93caa367b578b 100644
--- a/flang/test/Lower/OpenACC/acc-parallel.f90
+++ b/flang/test/Lower/OpenACC/acc-parallel.f90
@@ -288,11 +288,12 @@ subroutine acc_parallel
!CHECK: acc.detach accPtr(%[[ATTACH_D]] : !fir.ptr<f32>) {dataClause = 10 : i64, name = "d"}
!CHECK: acc.detach accPtr(%[[ATTACH_E]] : !fir.ptr<f32>) {dataClause = 10 : i64, name = "e"}
- !$acc parallel private(a) firstprivate(b) private(c)
- !$acc end parallel
+! TODO: will be updated after lowering change in privatization to MLIR
+! !$acc parallel private(a) firstprivate(b) private(c)
+! !$acc end parallel
-!CHECK: acc.parallel firstprivate(%[[B]] : !fir.ref<!fir.array<10x10xf32>>) private(%[[A]], %[[C]] : !fir.ref<!fir.array<10x10xf32>>, !fir.ref<!fir.array<10x10xf32>>) {
-!CHECK: acc.yield
-!CHECK-NEXT: }{{$}}
+!TODO: acc.parallel firstprivate(%[[B]] : !fir.ref<!fir.array<10x10xf32>>) private(%[[A]], %[[C]] : !fir.ref<!fir.array<10x10xf32>>, !fir.ref<!fir.array<10x10xf32>>) {
+!TODO: acc.yield
+!TODO-NEXT: }{{$}}
end subroutine acc_parallel
diff --git a/mlir/include/mlir/Dialect/OpenACC/OpenACCOps.td b/mlir/include/mlir/Dialect/OpenACC/OpenACCOps.td
index 83f1eba0ef72c..4cc91177e5a0b 100644
--- a/mlir/include/mlir/Dialect/OpenACC/OpenACCOps.td
+++ b/mlir/include/mlir/Dialect/OpenACC/OpenACCOps.td
@@ -636,7 +636,8 @@ def OpenACC_ParallelOp : OpenACC_Op<"parallel",
UnitAttr:$selfAttr,
OptionalAttr<OpenACC_ReductionOperatorAttr>:$reductionOp,
Variadic<AnyType>:$reductionOperands,
- Variadic<AnyType>:$gangPrivateOperands,
+ Variadic<OpenACC_PointerLikeTypeInterface>:$gangPrivateOperands,
+ OptionalAttr<SymbolRefArrayAttr>:$privatizations,
Variadic<AnyType>:$gangFirstPrivateOperands,
Variadic<OpenACC_PointerLikeTypeInterface>:$dataClauseOperands,
OptionalAttr<DefaultValueAttr>:$defaultAttr);
@@ -659,7 +660,9 @@ def OpenACC_ParallelOp : OpenACC_Op<"parallel",
type($gangFirstPrivateOperands) `)`
| `num_gangs` `(` $numGangs `:` type($numGangs) `)`
| `num_workers` `(` $numWorkers `:` type($numWorkers) `)`
- | `private` `(` $gangPrivateOperands `:` type($gangPrivateOperands) `)`
+ | `private` `(` custom<PrivatizationList>(
+ $gangPrivateOperands, type($gangPrivateOperands), $privatizations)
+ `)`
| `vector_length` `(` $vectorLength `:` type($vectorLength) `)`
| `wait` `(` $waitOperands `:` type($waitOperands) `)`
| `self` `(` $selfCond `)`
diff --git a/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp b/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp
index 430582d38bd26..89f97e08d3ba4 100644
--- a/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp
+++ b/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp
@@ -436,6 +436,43 @@ LogicalResult acc::ReductionRecipeOp::verifyRegions() {
return success();
}
+//===----------------------------------------------------------------------===//
+// Custom parser and printer verifier for private clause
+//===----------------------------------------------------------------------===//
+
+static ParseResult parsePrivatizationList(
+ mlir::OpAsmParser &parser,
+ llvm::SmallVectorImpl<mlir::OpAsmParser::UnresolvedOperand> &operands,
+ llvm::SmallVectorImpl<Type> &types, mlir::ArrayAttr &privatizationSymbols) {
+ llvm::SmallVector<SymbolRefAttr> privatizationVec;
+ if (failed(parser.parseCommaSeparatedList([&]() {
+ if (parser.parseAttribute(privatizationVec.emplace_back()) ||
+ parser.parseArrow() ||
+ parser.parseOperand(operands.emplace_back()) ||
+ parser.parseColonType(types.emplace_back()))
+ return failure();
+ return success();
+ })))
+ return failure();
+ llvm::SmallVector<mlir::Attribute> privatizations(privatizationVec.begin(),
+ privatizationVec.end());
+ privatizationSymbols = ArrayAttr::get(parser.getContext(), privatizations);
+ return success();
+}
+
+static void
+printPrivatizationList(mlir::OpAsmPrinter &p, mlir::Operation *op,
+ mlir::OperandRange privateOperands,
+ mlir::TypeRange privateTypes,
+ std::optional<mlir::ArrayAttr> privatizations) {
+ for (unsigned i = 0, e = privatizations->size(); i < e; ++i) {
+ if (i != 0)
+ p << ", ";
+ p << (*privatizations)[i] << " -> " << privateOperands[i] << " : "
+ << privateOperands[i].getType();
+ }
+}
+
//===----------------------------------------------------------------------===//
// ParallelOp
//===----------------------------------------------------------------------===//
@@ -455,6 +492,45 @@ static LogicalResult checkDataOperands(Op op,
return success();
}
+static LogicalResult
+checkPrivatizationList(Operation *op,
+ std::optional<mlir::ArrayAttr> privatizations,
+ mlir::OperandRange privateOperands) {
+ if (!privateOperands.empty()) {
+ if (!privatizations || privatizations->size() != privateOperands.size())
+ return op->emitOpError() << "expected as many privatizations symbol "
+ "reference as private operands";
+ } else {
+ if (privatizations)
+ return op->emitOpError() << "unexpected privatizations symbol reference";
+ return success();
+ }
+
+ llvm::DenseSet<Value> privates;
+ for (auto args : llvm::zip(privateOperands, *privatizations)) {
+ mlir::Value privateOperand = std::get<0>(args);
+
+ if (!privates.insert(privateOperand).second)
+ return op->emitOpError() << "private operand appears more than once";
+
+ mlir::Type varType = privateOperand.getType();
+ auto symbolRef = std::get<1>(args).cast<SymbolRefAttr>();
+ auto decl =
+ SymbolTable::lookupNearestSymbolFrom<PrivateRecipeOp>(op, symbolRef);
+ if (!decl)
+ return op->emitOpError() << "expected symbol reference " << symbolRef
+ << " to point to a private declaration";
+
+ if (decl.getType() && decl.getType() != varType)
+ return op->emitOpError()
+ << "expected private (" << varType
+ << ") to be the same type as private declaration ("
+ << decl.getType() << ")";
+ }
+
+ return success();
+}
+
unsigned ParallelOp::getNumDataOperands() {
return getReductionOperands().size() + getGangPrivateOperands().size() +
getGangFirstPrivateOperands().size() + getDataClauseOperands().size();
@@ -471,6 +547,9 @@ Value ParallelOp::getDataOperand(unsigned i) {
}
LogicalResult acc::ParallelOp::verify() {
+ if (failed(checkPrivatizationList(*this, getPrivatizations(),
+ getGangPrivateOperands())))
+ return failure();
return checkDataOperands<acc::ParallelOp>(*this, getDataClauseOperands());
}
diff --git a/mlir/test/Dialect/OpenACC/ops.mlir b/mlir/test/Dialect/OpenACC/ops.mlir
index be1973f1f1963..c0498f99119fc 100644
--- a/mlir/test/Dialect/OpenACC/ops.mlir
+++ b/mlir/test/Dialect/OpenACC/ops.mlir
@@ -114,6 +114,16 @@ func.func @compute2(%A: memref<10x10xf32>, %B: memref<10x10xf32>, %C: memref<10x
// -----
+acc.private.recipe @privatization_memref_10_f32 : memref<10xf32> init {
+^bb0(%arg0: memref<10xf32>):
+ %0 = memref.alloc() : memref<10xf32>
+ acc.yield %0 : memref<10xf32>
+} destroy {
+^bb0(%arg0: memref<10xf32>):
+ memref.dealloc %arg0 : memref<10xf32>
+ acc.terminator
+}
+
func.func @compute3(%a: memref<10x10xf32>, %b: memref<10x10xf32>, %c: memref<10xf32>, %d: memref<10xf32>) -> memref<10xf32> {
%lb = arith.constant 0 : index
%st = arith.constant 1 : index
@@ -126,7 +136,7 @@ func.func @compute3(%a: memref<10x10xf32>, %b: memref<10x10xf32>, %c: memref<10x
%pc = acc.present varPtr(%c : memref<10xf32>) -> memref<10xf32>
%pd = acc.present varPtr(%d : memref<10xf32>) -> memref<10xf32>
acc.data dataOperands(%pa, %pb, %pc, %pd: memref<10x10xf32>, memref<10x10xf32>, memref<10xf32>, memref<10xf32>) {
- acc.parallel num_gangs(%numGangs: i64) num_workers(%numWorkers: i64) private(%c : memref<10xf32>) {
+ acc.parallel num_gangs(%numGangs: i64) num_workers(%numWorkers: i64) private(@privatization_memref_10_f32 -> %c : memref<10xf32>) {
acc.loop gang {
scf.for %x = %lb to %c10 step %st {
acc.loop worker {
@@ -168,7 +178,7 @@ func.func @compute3(%a: memref<10x10xf32>, %b: memref<10x10xf32>, %c: memref<10x
// CHECK-NEXT: [[NUMGANG:%.*]] = arith.constant 10 : i64
// CHECK-NEXT: [[NUMWORKERS:%.*]] = arith.constant 10 : i64
// CHECK: acc.data dataOperands(%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}} : memref<10x10xf32>, memref<10x10xf32>, memref<10xf32>, memref<10xf32>) {
-// CHECK-NEXT: acc.parallel num_gangs([[NUMGANG]] : i64) num_workers([[NUMWORKERS]] : i64) private([[ARG2]] : memref<10xf32>) {
+// CHECK-NEXT: acc.parallel num_gangs([[NUMGANG]] : i64) num_workers([[NUMWORKERS]] : i64) private(@privatization_memref_10_f32 -> [[ARG2]] : memref<10xf32>) {
// CHECK-NEXT: acc.loop gang {
// CHECK-NEXT: scf.for %{{.*}} = [[C0]] to [[C10]] step [[C1]] {
// CHECK-NEXT: acc.loop worker {
@@ -358,6 +368,26 @@ func.func @acc_loop_multiple_block() {
// -----
+acc.private.recipe @privatization_memref_10_f32 : memref<10xf32> init {
+^bb0(%arg0: memref<10xf32>):
+ %0 = memref.alloc() : memref<10xf32>
+ acc.yield %0 : memref<10xf32>
+} destroy {
+^bb0(%arg0: memref<10xf32>):
+ memref.dealloc %arg0 : memref<10xf32>
+ acc.terminator
+}
+
+acc.private.recipe @privatization_memref_10_10_f32 : memref<10x10xf32> init {
+^bb0(%arg0: memref<10x10xf32>):
+ %0 = memref.alloc() : memref<10x10xf32>
+ acc.yield %0 : memref<10x10xf32>
+} destroy {
+^bb0(%arg0: memref<10x10xf32>):
+ memref.dealloc %arg0 : memref<10x10xf32>
+ acc.terminator
+}
+
func.func @testparallelop(%a: memref<10xf32>, %b: memref<10xf32>, %c: memref<10x10xf32>) -> () {
%i64value = arith.constant 1 : i64
%i32value = arith.constant 1 : i32
@@ -394,7 +424,7 @@ func.func @testparallelop(%a: memref<10xf32>, %b: memref<10xf32>, %c: memref<10x
}
acc.parallel vector_length(%idxValue: index) {
}
- acc.parallel private(%a, %c : memref<10xf32>, memref<10x10xf32>) firstprivate(%b: memref<10xf32>) {
+ acc.parallel private(@privatization_memref_10_f32 -> %a : memref<10xf32>, @privatization_memref_10_10_f32 -> %c : memref<10x10xf32>) firstprivate(%b: memref<10xf32>) {
}
acc.parallel {
} attributes {defaultAttr = #acc<defaultvalue none>}
@@ -445,7 +475,7 @@ func.func @testparallelop(%a: memref<10xf32>, %b: memref<10xf32>, %c: memref<10x
// CHECK-NEXT: }
// CHECK: acc.parallel vector_length([[IDXVALUE]] : index) {
// CHECK-NEXT: }
-// CHECK: acc.parallel firstprivate([[ARGB]] : memref<10xf32>) private([[ARGA]], [[ARGC]] : memref<10xf32>, memref<10x10xf32>) {
+// CHECK: acc.parallel firstprivate([[ARGB]] : memref<10xf32>) private(@privatization_memref_10_f32 -> [[ARGA]] : memref<10xf32>, @privatization_memref_10_10_f32 -> [[ARGC]] : memref<10x10xf32>) {
// CHECK-NEXT: }
// CHECK: acc.parallel {
// CHECK-NEXT: } attributes {defaultAttr = #acc<defaultvalue none>}
More information about the flang-commits
mailing list