[Mlir-commits] [mlir] 7c4e45e - [mlir][SCFToOpenMP] Add pass option to emit LLVM opaque pointers
Markus Böck
llvmlistbot at llvm.org
Mon Feb 13 02:50:02 PST 2023
Author: Markus Böck
Date: 2023-02-13T11:49:37+01:00
New Revision: 7c4e45ec7d6daf936c9a01f1ef2d107e9977f681
URL: https://github.com/llvm/llvm-project/commit/7c4e45ec7d6daf936c9a01f1ef2d107e9977f681
DIFF: https://github.com/llvm/llvm-project/commit/7c4e45ec7d6daf936c9a01f1ef2d107e9977f681.diff
LOG: [mlir][SCFToOpenMP] Add pass option to emit LLVM opaque pointers
Part of https://discourse.llvm.org/t/rfc-switching-the-llvm-dialect-and-dialect-lowerings-to-opaque-pointers/68179
There were luckily only very few changes that had to be made. To allow users to also specify the pass option from C++ code I have also migrated the pass to use autogenerated constructors to autogenerate a pass option struct.
Differential Revision: https://reviews.llvm.org/D143855
Added:
mlir/test/Conversion/SCFToOpenMP/typed-pointers.mlir
Modified:
mlir/include/mlir/Conversion/Passes.td
mlir/include/mlir/Conversion/SCFToOpenMP/SCFToOpenMP.h
mlir/lib/Conversion/SCFToOpenMP/SCFToOpenMP.cpp
mlir/test/Conversion/SCFToOpenMP/reductions.mlir
mlir/test/Conversion/SCFToOpenMP/scf-to-openmp.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Conversion/Passes.td b/mlir/include/mlir/Conversion/Passes.td
index 2924cc88044df..0533373b25e85 100644
--- a/mlir/include/mlir/Conversion/Passes.td
+++ b/mlir/include/mlir/Conversion/Passes.td
@@ -746,10 +746,16 @@ def SCFToControlFlow : Pass<"convert-scf-to-cf"> {
// SCFToOpenMP
//===----------------------------------------------------------------------===//
-def ConvertSCFToOpenMP : Pass<"convert-scf-to-openmp", "ModuleOp"> {
+def ConvertSCFToOpenMPPass : Pass<"convert-scf-to-openmp", "ModuleOp"> {
let summary = "Convert SCF parallel loop to OpenMP parallel + workshare "
"constructs.";
- let constructor = "mlir::createConvertSCFToOpenMPPass()";
+
+ let options = [
+ Option<"useOpaquePointers", "use-opaque-pointers", "bool",
+ /*default=*/"false", "Generate LLVM IR using opaque pointers "
+ "instead of typed pointers">
+ ];
+
let dependentDialects = ["omp::OpenMPDialect", "LLVM::LLVMDialect",
"memref::MemRefDialect"];
}
diff --git a/mlir/include/mlir/Conversion/SCFToOpenMP/SCFToOpenMP.h b/mlir/include/mlir/Conversion/SCFToOpenMP/SCFToOpenMP.h
index 7dd5315eedb4a..dfff8e66b066a 100644
--- a/mlir/include/mlir/Conversion/SCFToOpenMP/SCFToOpenMP.h
+++ b/mlir/include/mlir/Conversion/SCFToOpenMP/SCFToOpenMP.h
@@ -12,15 +12,11 @@
#include <memory>
namespace mlir {
-class ModuleOp;
-template <typename T>
-class OperationPass;
+class Pass;
-#define GEN_PASS_DECL_CONVERTSCFTOOPENMP
+#define GEN_PASS_DECL_CONVERTSCFTOOPENMPPASS
#include "mlir/Conversion/Passes.h.inc"
-std::unique_ptr<OperationPass<ModuleOp>> createConvertSCFToOpenMPPass();
-
} // namespace mlir
#endif // MLIR_CONVERSION_SCFTOOPENMP_SCFTOOPENMP_H
diff --git a/mlir/lib/Conversion/SCFToOpenMP/SCFToOpenMP.cpp b/mlir/lib/Conversion/SCFToOpenMP/SCFToOpenMP.cpp
index a4acf0431b72f..78e63a595e800 100644
--- a/mlir/lib/Conversion/SCFToOpenMP/SCFToOpenMP.cpp
+++ b/mlir/lib/Conversion/SCFToOpenMP/SCFToOpenMP.cpp
@@ -26,7 +26,7 @@
#include "mlir/Transforms/DialectConversion.h"
namespace mlir {
-#define GEN_PASS_DEF_CONVERTSCFTOOPENMP
+#define GEN_PASS_DEF_CONVERTSCFTOOPENMPPASS
#include "mlir/Conversion/Passes.h.inc"
} // namespace mlir
@@ -212,22 +212,32 @@ static omp::ReductionDeclareOp createDecl(PatternRewriter &builder,
return decl;
}
+/// Returns an LLVM pointer type with the given element type, or an opaque
+/// pointer if 'useOpaquePointers' is true.
+static LLVM::LLVMPointerType getPointerType(Type elementType,
+ bool useOpaquePointers) {
+ if (useOpaquePointers)
+ return LLVM::LLVMPointerType::get(elementType.getContext());
+ return LLVM::LLVMPointerType::get(elementType);
+}
+
/// Adds an atomic reduction combiner to the given OpenMP reduction declaration
/// using llvm.atomicrmw of the given kind.
static omp::ReductionDeclareOp addAtomicRMW(OpBuilder &builder,
LLVM::AtomicBinOp atomicKind,
omp::ReductionDeclareOp decl,
- scf::ReduceOp reduce) {
+ scf::ReduceOp reduce,
+ bool useOpaquePointers) {
OpBuilder::InsertionGuard guard(builder);
Type type = reduce.getOperand().getType();
- Type ptrType = LLVM::LLVMPointerType::get(type);
+ Type ptrType = getPointerType(type, useOpaquePointers);
Location reduceOperandLoc = reduce.getOperand().getLoc();
builder.createBlock(&decl.getAtomicReductionRegion(),
decl.getAtomicReductionRegion().end(), {ptrType, ptrType},
{reduceOperandLoc, reduceOperandLoc});
Block *atomicBlock = &decl.getAtomicReductionRegion().back();
builder.setInsertionPointToEnd(atomicBlock);
- Value loaded = builder.create<LLVM::LoadOp>(reduce.getLoc(),
+ Value loaded = builder.create<LLVM::LoadOp>(reduce.getLoc(), decl.getType(),
atomicBlock->getArgument(1));
builder.create<LLVM::AtomicRMWOp>(reduce.getLoc(), atomicKind,
atomicBlock->getArgument(0), loaded,
@@ -241,7 +251,8 @@ static omp::ReductionDeclareOp addAtomicRMW(OpBuilder &builder,
/// the neutral value, necessary for the OpenMP declaration. If the reduction
/// cannot be recognized, returns null.
static omp::ReductionDeclareOp declareReduction(PatternRewriter &builder,
- scf::ReduceOp reduce) {
+ scf::ReduceOp reduce,
+ bool useOpaquePointers) {
Operation *container = SymbolTable::getNearestSymbolTable(reduce);
SymbolTable symbolTable(container);
@@ -262,29 +273,34 @@ static omp::ReductionDeclareOp declareReduction(PatternRewriter &builder,
if (matchSimpleReduction<arith::AddFOp, LLVM::FAddOp>(reduction)) {
omp::ReductionDeclareOp decl = createDecl(builder, symbolTable, reduce,
builder.getFloatAttr(type, 0.0));
- return addAtomicRMW(builder, LLVM::AtomicBinOp::fadd, decl, reduce);
+ return addAtomicRMW(builder, LLVM::AtomicBinOp::fadd, decl, reduce,
+ useOpaquePointers);
}
if (matchSimpleReduction<arith::AddIOp, LLVM::AddOp>(reduction)) {
omp::ReductionDeclareOp decl = createDecl(builder, symbolTable, reduce,
builder.getIntegerAttr(type, 0));
- return addAtomicRMW(builder, LLVM::AtomicBinOp::add, decl, reduce);
+ return addAtomicRMW(builder, LLVM::AtomicBinOp::add, decl, reduce,
+ useOpaquePointers);
}
if (matchSimpleReduction<arith::OrIOp, LLVM::OrOp>(reduction)) {
omp::ReductionDeclareOp decl = createDecl(builder, symbolTable, reduce,
builder.getIntegerAttr(type, 0));
- return addAtomicRMW(builder, LLVM::AtomicBinOp::_or, decl, reduce);
+ return addAtomicRMW(builder, LLVM::AtomicBinOp::_or, decl, reduce,
+ useOpaquePointers);
}
if (matchSimpleReduction<arith::XOrIOp, LLVM::XOrOp>(reduction)) {
omp::ReductionDeclareOp decl = createDecl(builder, symbolTable, reduce,
builder.getIntegerAttr(type, 0));
- return addAtomicRMW(builder, LLVM::AtomicBinOp::_xor, decl, reduce);
+ return addAtomicRMW(builder, LLVM::AtomicBinOp::_xor, decl, reduce,
+ useOpaquePointers);
}
if (matchSimpleReduction<arith::AndIOp, LLVM::AndOp>(reduction)) {
omp::ReductionDeclareOp decl = createDecl(
builder, symbolTable, reduce,
builder.getIntegerAttr(
type, llvm::APInt::getAllOnesValue(type.getIntOrFloatBitWidth())));
- return addAtomicRMW(builder, LLVM::AtomicBinOp::_and, decl, reduce);
+ return addAtomicRMW(builder, LLVM::AtomicBinOp::_and, decl, reduce,
+ useOpaquePointers);
}
// Match simple binary reductions that cannot be expressed with atomicrmw.
@@ -316,7 +332,7 @@ static omp::ReductionDeclareOp declareReduction(PatternRewriter &builder,
builder, symbolTable, reduce, minMaxValueForSignedInt(type, !isMin));
return addAtomicRMW(builder,
isMin ? LLVM::AtomicBinOp::min : LLVM::AtomicBinOp::max,
- decl, reduce);
+ decl, reduce, useOpaquePointers);
}
if (matchSelectReduction<arith::CmpIOp, arith::SelectOp>(
reduction, {arith::CmpIPredicate::ult, arith::CmpIPredicate::ule},
@@ -328,7 +344,7 @@ static omp::ReductionDeclareOp declareReduction(PatternRewriter &builder,
builder, symbolTable, reduce, minMaxValueForUnsignedInt(type, !isMin));
return addAtomicRMW(
builder, isMin ? LLVM::AtomicBinOp::umin : LLVM::AtomicBinOp::umax,
- decl, reduce);
+ decl, reduce, useOpaquePointers);
}
return nullptr;
@@ -337,7 +353,12 @@ static omp::ReductionDeclareOp declareReduction(PatternRewriter &builder,
namespace {
struct ParallelOpLowering : public OpRewritePattern<scf::ParallelOp> {
- using OpRewritePattern<scf::ParallelOp>::OpRewritePattern;
+
+ bool useOpaquePointers;
+
+ ParallelOpLowering(MLIRContext *context, bool useOpaquePointers)
+ : OpRewritePattern<scf::ParallelOp>(context),
+ useOpaquePointers(useOpaquePointers) {}
LogicalResult matchAndRewrite(scf::ParallelOp parallelOp,
PatternRewriter &rewriter) const override {
@@ -346,7 +367,8 @@ struct ParallelOpLowering : public OpRewritePattern<scf::ParallelOp> {
// declaration and use it instead of redeclaring.
SmallVector<Attribute> reductionDeclSymbols;
for (auto reduce : parallelOp.getOps<scf::ReduceOp>()) {
- omp::ReductionDeclareOp decl = declareReduction(rewriter, reduce);
+ omp::ReductionDeclareOp decl =
+ declareReduction(rewriter, reduce, useOpaquePointers);
if (!decl)
return failure();
reductionDeclSymbols.push_back(
@@ -366,7 +388,8 @@ struct ParallelOpLowering : public OpRewritePattern<scf::ParallelOp> {
"cannot create a reduction variable if the type is not an LLVM "
"pointer element");
Value storage = rewriter.create<LLVM::AllocaOp>(
- loc, LLVM::LLVMPointerType::get(init.getType()), one, 0);
+ loc, getPointerType(init.getType(), useOpaquePointers),
+ init.getType(), one, 0);
rewriter.create<LLVM::StoreOp>(loc, init, storage);
reductionVariables.push_back(storage);
}
@@ -426,8 +449,9 @@ struct ParallelOpLowering : public OpRewritePattern<scf::ParallelOp> {
// Load loop results.
SmallVector<Value> results;
results.reserve(reductionVariables.size());
- for (Value variable : reductionVariables) {
- Value res = rewriter.create<LLVM::LoadOp>(loc, variable);
+ for (auto [variable, type] :
+ llvm::zip(reductionVariables, parallelOp.getResultTypes())) {
+ Value res = rewriter.create<LLVM::LoadOp>(loc, type, variable);
results.push_back(res);
}
rewriter.replaceOp(parallelOp, results);
@@ -437,29 +461,29 @@ struct ParallelOpLowering : public OpRewritePattern<scf::ParallelOp> {
};
/// Applies the conversion patterns in the given function.
-static LogicalResult applyPatterns(ModuleOp module) {
+static LogicalResult applyPatterns(ModuleOp module, bool useOpaquePointers) {
ConversionTarget target(*module.getContext());
target.addIllegalOp<scf::ReduceOp, scf::ReduceReturnOp, scf::ParallelOp>();
target.addLegalDialect<omp::OpenMPDialect, LLVM::LLVMDialect,
memref::MemRefDialect>();
RewritePatternSet patterns(module.getContext());
- patterns.add<ParallelOpLowering>(module.getContext());
+ patterns.add<ParallelOpLowering>(module.getContext(), useOpaquePointers);
FrozenRewritePatternSet frozen(std::move(patterns));
return applyPartialConversion(module, target, frozen);
}
/// A pass converting SCF operations to OpenMP operations.
-struct SCFToOpenMPPass : public impl::ConvertSCFToOpenMPBase<SCFToOpenMPPass> {
+struct SCFToOpenMPPass
+ : public impl::ConvertSCFToOpenMPPassBase<SCFToOpenMPPass> {
+
+ using Base::Base;
+
/// Pass entry point.
void runOnOperation() override {
- if (failed(applyPatterns(getOperation())))
+ if (failed(applyPatterns(getOperation(), useOpaquePointers)))
signalPassFailure();
}
};
} // namespace
-
-std::unique_ptr<OperationPass<ModuleOp>> mlir::createConvertSCFToOpenMPPass() {
- return std::make_unique<SCFToOpenMPPass>();
-}
diff --git a/mlir/test/Conversion/SCFToOpenMP/reductions.mlir b/mlir/test/Conversion/SCFToOpenMP/reductions.mlir
index d71f7578804d9..4cf5f1b0f753c 100644
--- a/mlir/test/Conversion/SCFToOpenMP/reductions.mlir
+++ b/mlir/test/Conversion/SCFToOpenMP/reductions.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt -convert-scf-to-openmp -split-input-file %s | FileCheck %s
+// RUN: mlir-opt -convert-scf-to-openmp='use-opaque-pointers=1' -split-input-file %s | FileCheck %s
// CHECK: omp.reduction.declare @[[$REDF:.*]] : f32
@@ -12,8 +12,8 @@
// CHECK: omp.yield(%[[RES]] : f32)
// CHECK: atomic
-// CHECK: ^{{.*}}(%[[ARG0:.*]]: !llvm.ptr<f32>, %[[ARG1:.*]]: !llvm.ptr<f32>):
-// CHECK: %[[RHS:.*]] = llvm.load %[[ARG1]]
+// CHECK: ^{{.*}}(%[[ARG0:.*]]: !llvm.ptr, %[[ARG1:.*]]: !llvm.ptr):
+// CHECK: %[[RHS:.*]] = llvm.load %[[ARG1]] : !llvm.ptr -> f32
// CHECK: llvm.atomicrmw fadd %[[ARG0]], %[[RHS]] monotonic
// CHECK-LABEL: @reduction1
@@ -143,8 +143,8 @@ func.func @reduction3(%arg0 : index, %arg1 : index, %arg2 : index,
// CHECK: omp.yield(%[[RES]] : i64)
// CHECK: atomic
-// CHECK: ^{{.*}}(%[[ARG0:.*]]: !llvm.ptr<i64>, %[[ARG1:.*]]: !llvm.ptr<i64>):
-// CHECK: %[[RHS:.*]] = llvm.load %[[ARG1]]
+// CHECK: ^{{.*}}(%[[ARG0:.*]]: !llvm.ptr, %[[ARG1:.*]]: !llvm.ptr):
+// CHECK: %[[RHS:.*]] = llvm.load %[[ARG1]] : !llvm.ptr -> i64
// CHECK: llvm.atomicrmw max %[[ARG0]], %[[RHS]] monotonic
// CHECK-LABEL: @reduction4
@@ -187,8 +187,8 @@ func.func @reduction4(%arg0 : index, %arg1 : index, %arg2 : index,
// CHECK: omp.yield
}
// CHECK: omp.terminator
- // CHECK: %[[RES1:.*]] = llvm.load %[[BUF1]]
- // CHECK: %[[RES2:.*]] = llvm.load %[[BUF2]]
+ // CHECK: %[[RES1:.*]] = llvm.load %[[BUF1]] : !llvm.ptr -> f32
+ // CHECK: %[[RES2:.*]] = llvm.load %[[BUF2]] : !llvm.ptr -> i64
// CHECK: return %[[RES1]], %[[RES2]]
return %res#0, %res#1 : f32, i64
}
diff --git a/mlir/test/Conversion/SCFToOpenMP/scf-to-openmp.mlir b/mlir/test/Conversion/SCFToOpenMP/scf-to-openmp.mlir
index e0fdcae1b896b..508052d483ec1 100644
--- a/mlir/test/Conversion/SCFToOpenMP/scf-to-openmp.mlir
+++ b/mlir/test/Conversion/SCFToOpenMP/scf-to-openmp.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt -convert-scf-to-openmp %s | FileCheck %s
+// RUN: mlir-opt -convert-scf-to-openmp='use-opaque-pointers=1' %s | FileCheck %s
// CHECK-LABEL: @parallel
func.func @parallel(%arg0: index, %arg1: index, %arg2: index,
diff --git a/mlir/test/Conversion/SCFToOpenMP/typed-pointers.mlir b/mlir/test/Conversion/SCFToOpenMP/typed-pointers.mlir
new file mode 100644
index 0000000000000..fb90c5d7d10fd
--- /dev/null
+++ b/mlir/test/Conversion/SCFToOpenMP/typed-pointers.mlir
@@ -0,0 +1,78 @@
+// RUN: mlir-opt -convert-scf-to-openmp='use-opaque-pointers=0' -split-input-file %s | FileCheck %s
+
+// CHECK: omp.reduction.declare @[[$REDF1:.*]] : f32
+
+// CHECK: init
+// CHECK: %[[INIT:.*]] = llvm.mlir.constant(-3.4
+// CHECK: omp.yield(%[[INIT]] : f32)
+
+// CHECK: combiner
+// CHECK: ^{{.*}}(%[[ARG0:.*]]: f32, %[[ARG1:.*]]: f32)
+// CHECK: %[[CMP:.*]] = arith.cmpf oge, %[[ARG0]], %[[ARG1]]
+// CHECK: %[[RES:.*]] = arith.select %[[CMP]], %[[ARG0]], %[[ARG1]]
+// CHECK: omp.yield(%[[RES]] : f32)
+
+// CHECK-NOT: atomic
+
+// CHECK: omp.reduction.declare @[[$REDF2:.*]] : i64
+
+// CHECK: init
+// CHECK: %[[INIT:.*]] = llvm.mlir.constant
+// CHECK: omp.yield(%[[INIT]] : i64)
+
+// CHECK: combiner
+// CHECK: ^{{.*}}(%[[ARG0:.*]]: i64, %[[ARG1:.*]]: i64)
+// CHECK: %[[CMP:.*]] = arith.cmpi slt, %[[ARG0]], %[[ARG1]]
+// CHECK: %[[RES:.*]] = arith.select %[[CMP]], %[[ARG1]], %[[ARG0]]
+// CHECK: omp.yield(%[[RES]] : i64)
+
+// CHECK: atomic
+// CHECK: ^{{.*}}(%[[ARG0:.*]]: !llvm.ptr<i64>, %[[ARG1:.*]]: !llvm.ptr<i64>):
+// CHECK: %[[RHS:.*]] = llvm.load %[[ARG1]]
+// CHECK: llvm.atomicrmw max %[[ARG0]], %[[RHS]] monotonic
+
+// CHECK-LABEL: @reduction4
+func.func @reduction4(%arg0 : index, %arg1 : index, %arg2 : index,
+ %arg3 : index, %arg4 : index) -> (f32, i64) {
+ %step = arith.constant 1 : index
+ // CHECK: %[[ZERO:.*]] = arith.constant 0.0
+ %zero = arith.constant 0.0 : f32
+ // CHECK: %[[IONE:.*]] = arith.constant 1
+ %ione = arith.constant 1 : i64
+ // CHECK: %[[BUF1:.*]] = llvm.alloca %{{.*}} x f32
+ // CHECK: llvm.store %[[ZERO]], %[[BUF1]]
+ // CHECK: %[[BUF2:.*]] = llvm.alloca %{{.*}} x i64
+ // CHECK: llvm.store %[[IONE]], %[[BUF2]]
+
+ // CHECK: omp.parallel
+ // CHECK: omp.wsloop
+ // CHECK-SAME: reduction(@[[$REDF1]] -> %[[BUF1]]
+ // CHECK-SAME: @[[$REDF2]] -> %[[BUF2]]
+ // CHECK: memref.alloca_scope
+ %res:2 = scf.parallel (%i0, %i1) = (%arg0, %arg1) to (%arg2, %arg3)
+ step (%arg4, %step) init (%zero, %ione) -> (f32, i64) {
+ %one = arith.constant 1.0 : f32
+ // CHECK: omp.reduction %{{.*}}, %[[BUF1]]
+ scf.reduce(%one) : f32 {
+ ^bb0(%lhs : f32, %rhs: f32):
+ %cmp = arith.cmpf oge, %lhs, %rhs : f32
+ %res = arith.select %cmp, %lhs, %rhs : f32
+ scf.reduce.return %res : f32
+ }
+ // CHECK: arith.fptosi
+ %1 = arith.fptosi %one : f32 to i64
+ // CHECK: omp.reduction %{{.*}}, %[[BUF2]]
+ scf.reduce(%1) : i64 {
+ ^bb1(%lhs: i64, %rhs: i64):
+ %cmp = arith.cmpi slt, %lhs, %rhs : i64
+ %res = arith.select %cmp, %rhs, %lhs : i64
+ scf.reduce.return %res : i64
+ }
+ // CHECK: omp.yield
+ }
+ // CHECK: omp.terminator
+ // CHECK: %[[RES1:.*]] = llvm.load %[[BUF1]] : !llvm.ptr<f32>
+ // CHECK: %[[RES2:.*]] = llvm.load %[[BUF2]] : !llvm.ptr<i64>
+ // CHECK: return %[[RES1]], %[[RES2]]
+ return %res#0, %res#1 : f32, i64
+}
More information about the Mlir-commits
mailing list