[Mlir-commits] [mlir] b79751e - [MLIR] Add conversion from AtomicRMWOp -> GenericAtomicRMWOp.
Alexander Belyaev
llvmlistbot at llvm.org
Tue May 5 01:33:24 PDT 2020
Author: Alexander Belyaev
Date: 2020-05-05T10:32:13+02:00
New Revision: b79751e83d7e50aa897049e9831dff840926d368
URL: https://github.com/llvm/llvm-project/commit/b79751e83d7e50aa897049e9831dff840926d368
DIFF: https://github.com/llvm/llvm-project/commit/b79751e83d7e50aa897049e9831dff840926d368.diff
LOG: [MLIR] Add conversion from AtomicRMWOp -> GenericAtomicRMWOp.
Adding this pattern reduces code duplication. There is no need to have a
custom implementation for lowering to llvm.cmpxchg.
Differential Revision: https://reviews.llvm.org/D78753
Added:
mlir/include/mlir/Dialect/StandardOps/Transforms/CMakeLists.txt
mlir/include/mlir/Dialect/StandardOps/Transforms/Passes.h
mlir/include/mlir/Dialect/StandardOps/Transforms/Passes.td
mlir/lib/Dialect/StandardOps/Transforms/CMakeLists.txt
mlir/lib/Dialect/StandardOps/Transforms/ExpandAtomic.cpp
mlir/lib/Dialect/StandardOps/Transforms/PassDetail.h
mlir/test/Dialect/Standard/expand-atomic.mlir
Modified:
mlir/docs/Passes.md
mlir/include/mlir/Dialect/StandardOps/CMakeLists.txt
mlir/include/mlir/InitAllPasses.h
mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp
mlir/lib/Dialect/StandardOps/CMakeLists.txt
mlir/test/Conversion/StandardToLLVM/convert-to-llvmir.mlir
Removed:
################################################################################
diff --git a/mlir/docs/Passes.md b/mlir/docs/Passes.md
index b1f1a2eb2e17..0f48396b220e 100644
--- a/mlir/docs/Passes.md
+++ b/mlir/docs/Passes.md
@@ -39,3 +39,7 @@ This document describes the available MLIR passes and their contracts.
## `spv` Dialect Passes
[include "SPIRVPasses.md"]
+
+## `standard` Dialect Passes
+
+[include "StandardPasses.md"]
diff --git a/mlir/include/mlir/Dialect/StandardOps/CMakeLists.txt b/mlir/include/mlir/Dialect/StandardOps/CMakeLists.txt
index f33061b2d87c..9f57627c321f 100644
--- a/mlir/include/mlir/Dialect/StandardOps/CMakeLists.txt
+++ b/mlir/include/mlir/Dialect/StandardOps/CMakeLists.txt
@@ -1 +1,2 @@
add_subdirectory(IR)
+add_subdirectory(Transforms)
diff --git a/mlir/include/mlir/Dialect/StandardOps/Transforms/CMakeLists.txt b/mlir/include/mlir/Dialect/StandardOps/Transforms/CMakeLists.txt
new file mode 100644
index 000000000000..413c6523a756
--- /dev/null
+++ b/mlir/include/mlir/Dialect/StandardOps/Transforms/CMakeLists.txt
@@ -0,0 +1,5 @@
+set(LLVM_TARGET_DEFINITIONS Passes.td)
+mlir_tablegen(Passes.h.inc -gen-pass-decls)
+add_public_tablegen_target(MLIRStandardTransformsIncGen)
+
+add_mlir_doc(Passes -gen-pass-doc StandardPasses ./)
diff --git a/mlir/include/mlir/Dialect/StandardOps/Transforms/Passes.h b/mlir/include/mlir/Dialect/StandardOps/Transforms/Passes.h
new file mode 100644
index 000000000000..c0622e529564
--- /dev/null
+++ b/mlir/include/mlir/Dialect/StandardOps/Transforms/Passes.h
@@ -0,0 +1,29 @@
+
+//===- Passes.h - Pass Entrypoints ------------------------------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This header file defines prototypes that expose pass constructors in the loop
+// transformation library.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_DIALECT_STANDARD_TRANSFORMS_PASSES_H_
+#define MLIR_DIALECT_STANDARD_TRANSFORMS_PASSES_H_
+
+#include <memory>
+
+namespace mlir {
+
+class Pass;
+
+/// Creates an instance of the ExpandAtomic pass.
+std::unique_ptr<Pass> createExpandAtomicPass();
+
+} // end namespace mlir
+
+#endif // MLIR_DIALECT_STANDARD_TRANSFORMS_PASSES_H_
diff --git a/mlir/include/mlir/Dialect/StandardOps/Transforms/Passes.td b/mlir/include/mlir/Dialect/StandardOps/Transforms/Passes.td
new file mode 100644
index 000000000000..b65c03d33fc1
--- /dev/null
+++ b/mlir/include/mlir/Dialect/StandardOps/Transforms/Passes.td
@@ -0,0 +1,19 @@
+//===-- Passes.td - StandardOps pass definition file -------*- tablegen -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_DIALECT_STANDARD_TRANSFORMS_PASSES
+#define MLIR_DIALECT_STANDARD_TRANSFORMS_PASSES
+
+include "mlir/Pass/PassBase.td"
+
+def ExpandAtomic : FunctionPass<"expand-atomic"> {
+ let summary = "Expands AtomicRMWOp into GenericAtomicRMWOp.";
+ let constructor = "mlir::createExpandAtomicPass()";
+}
+
+#endif // MLIR_DIALECT_STANDARD_TRANSFORMS_PASSES
diff --git a/mlir/include/mlir/InitAllPasses.h b/mlir/include/mlir/InitAllPasses.h
index 3647a31fd950..1acfc9905384 100644
--- a/mlir/include/mlir/InitAllPasses.h
+++ b/mlir/include/mlir/InitAllPasses.h
@@ -34,6 +34,7 @@
#include "mlir/Dialect/LoopOps/Passes.h"
#include "mlir/Dialect/Quant/Passes.h"
#include "mlir/Dialect/SPIRV/Passes.h"
+#include "mlir/Dialect/StandardOps/Transforms/Passes.h"
#include "mlir/Transforms/LocationSnapshot.h"
#include "mlir/Transforms/Passes.h"
#include "mlir/Transforms/ViewOpGraph.h"
@@ -86,6 +87,10 @@ inline void registerAllPasses() {
// SPIR-V
#define GEN_PASS_REGISTRATION
#include "mlir/Dialect/SPIRV/Passes.h.inc"
+
+ // Standard
+#define GEN_PASS_REGISTRATION
+#include "mlir/Dialect/StandardOps/Transforms/Passes.h.inc"
}
} // namespace mlir
diff --git a/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp b/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp
index d6c0cde2b86a..1a01daa1188e 100644
--- a/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp
+++ b/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp
@@ -2743,113 +2743,6 @@ struct AtomicRMWOpLowering : public LoadStoreOpLowering<AtomicRMWOp> {
}
};
-/// Wrap a llvm.cmpxchg operation in a while loop so that the operation can be
-/// retried until it succeeds in atomically storing a new value into memory.
-///
-/// +---------------------------------+
-/// | <code before the AtomicRMWOp> |
-/// | <compute initial %loaded> |
-/// | br loop(%loaded) |
-/// +---------------------------------+
-/// |
-/// -------| |
-/// | v v
-/// | +--------------------------------+
-/// | | loop(%loaded): |
-/// | | <body contents> |
-/// | | %pair = cmpxchg |
-/// | | %ok = %pair[0] |
-/// | | %new = %pair[1] |
-/// | | cond_br %ok, end, loop(%new) |
-/// | +--------------------------------+
-/// | | |
-/// |----------- |
-/// v
-/// +--------------------------------+
-/// | end: |
-/// | <code after the AtomicRMWOp> |
-/// +--------------------------------+
-///
-struct AtomicCmpXchgOpLowering : public LoadStoreOpLowering<AtomicRMWOp> {
- using Base::Base;
-
- LogicalResult
- matchAndRewrite(Operation *op, ArrayRef<Value> operands,
- ConversionPatternRewriter &rewriter) const override {
- auto atomicOp = cast<AtomicRMWOp>(op);
- auto maybeKind = matchSimpleAtomicOp(atomicOp);
- if (maybeKind)
- return failure();
-
- LLVM::FCmpPredicate predicate;
- switch (atomicOp.kind()) {
- case AtomicRMWKind::maxf:
- predicate = LLVM::FCmpPredicate::ogt;
- break;
- case AtomicRMWKind::minf:
- predicate = LLVM::FCmpPredicate::olt;
- break;
- default:
- return failure();
- }
-
- OperandAdaptor<AtomicRMWOp> adaptor(operands);
- auto loc = op->getLoc();
- auto valueType = adaptor.value().getType().cast<LLVM::LLVMType>();
-
- // Split the block into initial, loop, and ending parts.
- auto *initBlock = rewriter.getInsertionBlock();
- auto initPosition = rewriter.getInsertionPoint();
- auto *loopBlock = rewriter.splitBlock(initBlock, initPosition);
- auto loopArgument = loopBlock->addArgument(valueType);
- auto loopPosition = rewriter.getInsertionPoint();
- auto *endBlock = rewriter.splitBlock(loopBlock, loopPosition);
-
- // Compute the loaded value and branch to the loop block.
- rewriter.setInsertionPointToEnd(initBlock);
- auto memRefType = atomicOp.getMemRefType();
- auto dataPtr = getDataPtr(loc, memRefType, adaptor.memref(),
- adaptor.indices(), rewriter, getModule());
- Value init = rewriter.create<LLVM::LoadOp>(loc, dataPtr);
- rewriter.create<LLVM::BrOp>(loc, init, loopBlock);
-
- // Prepare the body of the loop block.
- rewriter.setInsertionPointToStart(loopBlock);
- auto predicateI64 =
- rewriter.getI64IntegerAttr(static_cast<int64_t>(predicate));
- auto boolType = LLVM::LLVMType::getInt1Ty(&getDialect());
- auto lhs = loopArgument;
- auto rhs = adaptor.value();
- auto cmp =
- rewriter.create<LLVM::FCmpOp>(loc, boolType, predicateI64, lhs, rhs);
- auto select = rewriter.create<LLVM::SelectOp>(loc, cmp, lhs, rhs);
-
- // Prepare the epilog of the loop block.
- rewriter.setInsertionPointToEnd(loopBlock);
- // Append the cmpxchg op to the end of the loop block.
- auto successOrdering = LLVM::AtomicOrdering::acq_rel;
- auto failureOrdering = LLVM::AtomicOrdering::monotonic;
- auto pairType = LLVM::LLVMType::getStructTy(valueType, boolType);
- auto cmpxchg = rewriter.create<LLVM::AtomicCmpXchgOp>(
- loc, pairType, dataPtr, loopArgument, select, successOrdering,
- failureOrdering);
- // Extract the %new_loaded and %ok values from the pair.
- Value newLoaded = rewriter.create<LLVM::ExtractValueOp>(
- loc, valueType, cmpxchg, rewriter.getI64ArrayAttr({0}));
- Value ok = rewriter.create<LLVM::ExtractValueOp>(
- loc, boolType, cmpxchg, rewriter.getI64ArrayAttr({1}));
-
- // Conditionally branch to the end or back to the loop depending on %ok.
- rewriter.create<LLVM::CondBrOp>(loc, ok, endBlock, ArrayRef<Value>(),
- loopBlock, newLoaded);
-
- // The 'result' of the atomic_rmw op is the newly loaded value.
- rewriter.replaceOp(op, {newLoaded});
-
- return success();
- }
-};
-
/// Wrap a llvm.cmpxchg operation in a while loop so that the operation can be
/// retried until it succeeds in atomically storing a new value into memory.
///
@@ -2985,7 +2878,6 @@ void mlir::populateStdToLLVMNonMemoryConversionPatterns(
AddIOpLowering,
AllocaOpLowering,
AndOpLowering,
- AtomicCmpXchgOpLowering,
AtomicRMWOpLowering,
BranchOpLowering,
CallIndirectOpLowering,
diff --git a/mlir/lib/Dialect/StandardOps/CMakeLists.txt b/mlir/lib/Dialect/StandardOps/CMakeLists.txt
index 6d83cbb375ae..a2f496c7ab93 100644
--- a/mlir/lib/Dialect/StandardOps/CMakeLists.txt
+++ b/mlir/lib/Dialect/StandardOps/CMakeLists.txt
@@ -17,3 +17,5 @@ add_mlir_dialect_library(MLIRStandardOps
MLIRSideEffects
MLIRViewLikeInterface
)
+
+add_subdirectory(Transforms)
diff --git a/mlir/lib/Dialect/StandardOps/Transforms/CMakeLists.txt b/mlir/lib/Dialect/StandardOps/Transforms/CMakeLists.txt
new file mode 100644
index 000000000000..b79a0b569b97
--- /dev/null
+++ b/mlir/lib/Dialect/StandardOps/Transforms/CMakeLists.txt
@@ -0,0 +1,17 @@
+add_mlir_dialect_library(MLIRStandardOpsTransforms
+ ExpandAtomic.cpp
+
+ ADDITIONAL_HEADER_DIRS
+ ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/StandardOps/Transforms
+
+ DEPENDS
+ MLIRStandardTransformsIncGen
+ )
+target_link_libraries(MLIRStandardOpsTransforms
+ PUBLIC
+ MLIRIR
+ MLIRPass
+ MLIRStandardOps
+ MLIRSupport
+ LLVMSupport
+ )
diff --git a/mlir/lib/Dialect/StandardOps/Transforms/ExpandAtomic.cpp b/mlir/lib/Dialect/StandardOps/Transforms/ExpandAtomic.cpp
new file mode 100644
index 000000000000..41e0ffb60bc9
--- /dev/null
+++ b/mlir/lib/Dialect/StandardOps/Transforms/ExpandAtomic.cpp
@@ -0,0 +1,93 @@
+//===- ExpandAtomic.cpp - Code to perform loop fusion ---------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements expansion of AtomicRMWOp into GenericAtomicRMWOp.
+//
+//===----------------------------------------------------------------------===//
+
+#include "PassDetail.h"
+#include "mlir/Dialect/StandardOps/IR/Ops.h"
+#include "mlir/Dialect/StandardOps/Transforms/Passes.h"
+#include "mlir/IR/Builders.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/Transforms/DialectConversion.h"
+
+using namespace mlir;
+
+namespace {
+
+/// Converts `atomic_rmw` that cannot be lowered to a simple atomic op with
+/// AtomicRMWOpLowering pattern, e.g. with "minf" or "maxf" attributes, to
+/// `generic_atomic_rmw` with the expanded code.
+///
+/// %x = atomic_rmw "maxf" %fval, %F[%i] : (f32, memref<10xf32>) -> f32
+///
+/// will be lowered to
+///
+/// %x = std.generic_atomic_rmw %F[%i] : memref<10xf32> {
+/// ^bb0(%current: f32):
+/// %cmp = cmpf "ogt", %current, %fval : f32
+/// %new_value = select %cmp, %current, %fval : f32
+/// atomic_yield %new_value : f32
+/// }
+struct AtomicRMWOpConverter : public OpRewritePattern<AtomicRMWOp> {
+public:
+ using OpRewritePattern::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(AtomicRMWOp op,
+ PatternRewriter &rewriter) const final {
+ CmpFPredicate predicate;
+ switch (op.kind()) {
+ case AtomicRMWKind::maxf:
+ predicate = CmpFPredicate::OGT;
+ break;
+ case AtomicRMWKind::minf:
+ predicate = CmpFPredicate::OLT;
+ break;
+ default:
+ return failure();
+ }
+
+ auto loc = op.getLoc();
+ auto genericOp =
+ rewriter.create<GenericAtomicRMWOp>(loc, op.memref(), op.indices());
+ OpBuilder bodyBuilder = OpBuilder::atBlockEnd(genericOp.getBody());
+
+ Value lhs = genericOp.getCurrentValue();
+ Value rhs = op.value();
+ Value cmp = bodyBuilder.create<CmpFOp>(loc, predicate, lhs, rhs);
+ Value select = bodyBuilder.create<SelectOp>(loc, cmp, lhs, rhs);
+ bodyBuilder.create<AtomicYieldOp>(loc, select);
+
+ rewriter.replaceOp(op, genericOp.getResult());
+ return success();
+ }
+};
+
+struct ExpandAtomic : public ExpandAtomicBase<ExpandAtomic> {
+ void runOnFunction() override {
+ OwningRewritePatternList patterns;
+ patterns.insert<AtomicRMWOpConverter>(&getContext());
+
+ ConversionTarget target(getContext());
+ target.addLegalOp<GenericAtomicRMWOp>();
+ target.addDynamicallyLegalOp<AtomicRMWOp>([](AtomicRMWOp op) {
+ return op.kind() != AtomicRMWKind::maxf &&
+ op.kind() != AtomicRMWKind::minf;
+ });
+ if (failed(mlir::applyPartialConversion(getFunction(), target, patterns)))
+ signalPassFailure();
+ }
+};
+
+} // namespace
+
+std::unique_ptr<Pass> mlir::createExpandAtomicPass() {
+ return std::make_unique<ExpandAtomic>();
+}
diff --git a/mlir/lib/Dialect/StandardOps/Transforms/PassDetail.h b/mlir/lib/Dialect/StandardOps/Transforms/PassDetail.h
new file mode 100644
index 000000000000..4748bf83ab99
--- /dev/null
+++ b/mlir/lib/Dialect/StandardOps/Transforms/PassDetail.h
@@ -0,0 +1,23 @@
+//===- PassDetail.h - GPU Pass class details --------------------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef DIALECT_STANDARD_TRANSFORMS_PASSDETAIL_H_
+#define DIALECT_STANDARD_TRANSFORMS_PASSDETAIL_H_
+
+#include "mlir/Pass/Pass.h"
+
+namespace mlir {
+
+class AtomicRMWOp;
+
+#define GEN_PASS_CLASSES
+#include "mlir/Dialect/StandardOps/Transforms/Passes.h.inc"
+
+} // end namespace mlir
+
+#endif // DIALECT_STANDARD_TRANSFORMS_PASSDETAIL_H_
diff --git a/mlir/test/Conversion/StandardToLLVM/convert-to-llvmir.mlir b/mlir/test/Conversion/StandardToLLVM/convert-to-llvmir.mlir
index 1b17e46ccc1b..b7cb13e51ca2 100644
--- a/mlir/test/Conversion/StandardToLLVM/convert-to-llvmir.mlir
+++ b/mlir/test/Conversion/StandardToLLVM/convert-to-llvmir.mlir
@@ -1110,25 +1110,6 @@ func @atomic_rmw(%I : memref<10xi32>, %ival : i32, %F : memref<10xf32>, %fval :
// -----
-// CHECK-LABEL: func @cmpxchg
-func @cmpxchg(%F : memref<10xf32>, %fval : f32, %i : index) -> f32 {
- %x = atomic_rmw "maxf" %fval, %F[%i] : (f32, memref<10xf32>) -> f32
- // CHECK: %[[init:.*]] = llvm.load %{{.*}} : !llvm<"float*">
- // CHECK-NEXT: llvm.br ^bb1(%[[init]] : !llvm.float)
- // CHECK-NEXT: ^bb1(%[[loaded:.*]]: !llvm.float):
- // CHECK-NEXT: %[[cmp:.*]] = llvm.fcmp "ogt" %[[loaded]], %{{.*}} : !llvm.float
- // CHECK-NEXT: %[[max:.*]] = llvm.select %[[cmp]], %[[loaded]], %{{.*}} : !llvm.i1, !llvm.float
- // CHECK-NEXT: %[[pair:.*]] = llvm.cmpxchg %{{.*}}, %[[loaded]], %[[max]] acq_rel monotonic : !llvm.float
- // CHECK-NEXT: %[[new:.*]] = llvm.extractvalue %[[pair]][0] : !llvm<"{ float, i1 }">
- // CHECK-NEXT: %[[ok:.*]] = llvm.extractvalue %[[pair]][1] : !llvm<"{ float, i1 }">
- // CHECK-NEXT: llvm.cond_br %[[ok]], ^bb2, ^bb1(%[[new]] : !llvm.float)
- // CHECK-NEXT: ^bb2:
- return %x : f32
- // CHECK-NEXT: llvm.return %[[new]]
-}
-
-// -----
-
// CHECK-LABEL: func @generic_atomic_rmw
func @generic_atomic_rmw(%I : memref<10xf32>, %i : index) -> f32 {
%x = generic_atomic_rmw %I[%i] : memref<10xf32> {
diff --git a/mlir/test/Dialect/Standard/expand-atomic.mlir b/mlir/test/Dialect/Standard/expand-atomic.mlir
new file mode 100644
index 000000000000..b4e65945f58a
--- /dev/null
+++ b/mlir/test/Dialect/Standard/expand-atomic.mlir
@@ -0,0 +1,24 @@
+// RUN: mlir-opt %s -expand-atomic -split-input-file | FileCheck %s --dump-input-on-failure
+
+// CHECK-LABEL: func @atomic_rmw_to_generic
+// CHECK-SAME: ([[F:%.*]]: memref<10xf32>, [[f:%.*]]: f32, [[i:%.*]]: index)
+func @atomic_rmw_to_generic(%F: memref<10xf32>, %f: f32, %i: index) -> f32 {
+ %x = atomic_rmw "maxf" %f, %F[%i] : (f32, memref<10xf32>) -> f32
+ return %x : f32
+}
+// CHECK: %0 = std.generic_atomic_rmw %arg0[%arg2] : memref<10xf32> {
+// CHECK: ^bb0([[CUR_VAL:%.*]]: f32):
+// CHECK: [[CMP:%.*]] = cmpf "ogt", [[CUR_VAL]], [[f]] : f32
+// CHECK: [[SELECT:%.*]] = select [[CMP]], [[CUR_VAL]], [[f]] : f32
+// CHECK: atomic_yield [[SELECT]] : f32
+// CHECK: }
+// CHECK: return %0 : f32
+
+// -----
+
+// CHECK-LABEL: func @atomic_rmw_no_conversion
+func @atomic_rmw_no_conversion(%F: memref<10xf32>, %f: f32, %i: index) -> f32 {
+ %x = atomic_rmw "addf" %f, %F[%i] : (f32, memref<10xf32>) -> f32
+ return %x : f32
+}
+// CHECK-NOT: generic_atomic_rmw
More information about the Mlir-commits
mailing list