[Mlir-commits] [mlir] b79751e - [MLIR] Add conversion from AtomicRMWOp -> GenericAtomicRMWOp.

Alexander Belyaev llvmlistbot at llvm.org
Tue May 5 01:33:24 PDT 2020


Author: Alexander Belyaev
Date: 2020-05-05T10:32:13+02:00
New Revision: b79751e83d7e50aa897049e9831dff840926d368

URL: https://github.com/llvm/llvm-project/commit/b79751e83d7e50aa897049e9831dff840926d368
DIFF: https://github.com/llvm/llvm-project/commit/b79751e83d7e50aa897049e9831dff840926d368.diff

LOG: [MLIR] Add conversion from AtomicRMWOp -> GenericAtomicRMWOp.

Adding this pattern reduces code duplication. There is no need to have a
custom implementation for lowering to llvm.cmpxchg.

Differential Revision: https://reviews.llvm.org/D78753

Added: 
    mlir/include/mlir/Dialect/StandardOps/Transforms/CMakeLists.txt
    mlir/include/mlir/Dialect/StandardOps/Transforms/Passes.h
    mlir/include/mlir/Dialect/StandardOps/Transforms/Passes.td
    mlir/lib/Dialect/StandardOps/Transforms/CMakeLists.txt
    mlir/lib/Dialect/StandardOps/Transforms/ExpandAtomic.cpp
    mlir/lib/Dialect/StandardOps/Transforms/PassDetail.h
    mlir/test/Dialect/Standard/expand-atomic.mlir

Modified: 
    mlir/docs/Passes.md
    mlir/include/mlir/Dialect/StandardOps/CMakeLists.txt
    mlir/include/mlir/InitAllPasses.h
    mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp
    mlir/lib/Dialect/StandardOps/CMakeLists.txt
    mlir/test/Conversion/StandardToLLVM/convert-to-llvmir.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/docs/Passes.md b/mlir/docs/Passes.md
index b1f1a2eb2e17..0f48396b220e 100644
--- a/mlir/docs/Passes.md
+++ b/mlir/docs/Passes.md
@@ -39,3 +39,7 @@ This document describes the available MLIR passes and their contracts.
 ## `spv` Dialect Passes
 
 [include "SPIRVPasses.md"]
+
+## `standard` Dialect Passes
+
+[include "StandardPasses.md"]

diff  --git a/mlir/include/mlir/Dialect/StandardOps/CMakeLists.txt b/mlir/include/mlir/Dialect/StandardOps/CMakeLists.txt
index f33061b2d87c..9f57627c321f 100644
--- a/mlir/include/mlir/Dialect/StandardOps/CMakeLists.txt
+++ b/mlir/include/mlir/Dialect/StandardOps/CMakeLists.txt
@@ -1 +1,2 @@
 add_subdirectory(IR)
+add_subdirectory(Transforms)

diff  --git a/mlir/include/mlir/Dialect/StandardOps/Transforms/CMakeLists.txt b/mlir/include/mlir/Dialect/StandardOps/Transforms/CMakeLists.txt
new file mode 100644
index 000000000000..413c6523a756
--- /dev/null
+++ b/mlir/include/mlir/Dialect/StandardOps/Transforms/CMakeLists.txt
@@ -0,0 +1,5 @@
+set(LLVM_TARGET_DEFINITIONS Passes.td)
+mlir_tablegen(Passes.h.inc -gen-pass-decls)
+add_public_tablegen_target(MLIRStandardTransformsIncGen)
+
+add_mlir_doc(Passes -gen-pass-doc StandardPasses ./)

diff  --git a/mlir/include/mlir/Dialect/StandardOps/Transforms/Passes.h b/mlir/include/mlir/Dialect/StandardOps/Transforms/Passes.h
new file mode 100644
index 000000000000..c0622e529564
--- /dev/null
+++ b/mlir/include/mlir/Dialect/StandardOps/Transforms/Passes.h
@@ -0,0 +1,29 @@
+
+//===- Passes.h - Pass Entrypoints ------------------------------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This header file defines prototypes that expose pass constructors in the loop
+// transformation library.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_DIALECT_STANDARD_TRANSFORMS_PASSES_H_
+#define MLIR_DIALECT_STANDARD_TRANSFORMS_PASSES_H_
+
+#include <memory>
+
+namespace mlir {
+
+class Pass;
+
+/// Creates an instance of the ExpandAtomic pass.
+std::unique_ptr<Pass> createExpandAtomicPass();
+
+} // end namespace mlir
+
+#endif // MLIR_DIALECT_STANDARD_TRANSFORMS_PASSES_H_

diff  --git a/mlir/include/mlir/Dialect/StandardOps/Transforms/Passes.td b/mlir/include/mlir/Dialect/StandardOps/Transforms/Passes.td
new file mode 100644
index 000000000000..b65c03d33fc1
--- /dev/null
+++ b/mlir/include/mlir/Dialect/StandardOps/Transforms/Passes.td
@@ -0,0 +1,19 @@
+//===-- Passes.td - StandardOps pass definition file -------*- tablegen -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_DIALECT_STANDARD_TRANSFORMS_PASSES
+#define MLIR_DIALECT_STANDARD_TRANSFORMS_PASSES
+
+include "mlir/Pass/PassBase.td"
+
+def ExpandAtomic : FunctionPass<"expand-atomic"> {
+  let summary = "Expands AtomicRMWOp into GenericAtomicRMWOp.";
+  let constructor = "mlir::createExpandAtomicPass()";
+}
+
+#endif // MLIR_DIALECT_STANDARD_TRANSFORMS_PASSES

diff  --git a/mlir/include/mlir/InitAllPasses.h b/mlir/include/mlir/InitAllPasses.h
index 3647a31fd950..1acfc9905384 100644
--- a/mlir/include/mlir/InitAllPasses.h
+++ b/mlir/include/mlir/InitAllPasses.h
@@ -34,6 +34,7 @@
 #include "mlir/Dialect/LoopOps/Passes.h"
 #include "mlir/Dialect/Quant/Passes.h"
 #include "mlir/Dialect/SPIRV/Passes.h"
+#include "mlir/Dialect/StandardOps/Transforms/Passes.h"
 #include "mlir/Transforms/LocationSnapshot.h"
 #include "mlir/Transforms/Passes.h"
 #include "mlir/Transforms/ViewOpGraph.h"
@@ -86,6 +87,10 @@ inline void registerAllPasses() {
   // SPIR-V
 #define GEN_PASS_REGISTRATION
 #include "mlir/Dialect/SPIRV/Passes.h.inc"
+
+  // Standard
+#define GEN_PASS_REGISTRATION
+#include "mlir/Dialect/StandardOps/Transforms/Passes.h.inc"
 }
 
 } // namespace mlir

diff  --git a/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp b/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp
index d6c0cde2b86a..1a01daa1188e 100644
--- a/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp
+++ b/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp
@@ -2743,113 +2743,6 @@ struct AtomicRMWOpLowering : public LoadStoreOpLowering<AtomicRMWOp> {
   }
 };
 
-/// Wrap a llvm.cmpxchg operation in a while loop so that the operation can be
-/// retried until it succeeds in atomically storing a new value into memory.
-///
-///      +---------------------------------+
-///      |   <code before the AtomicRMWOp> |
-///      |   <compute initial %loaded>     |
-///      |   br loop(%loaded)              |
-///      +---------------------------------+
-///             |
-///  -------|   |
-///  |      v   v
-///  |   +--------------------------------+
-///  |   | loop(%loaded):                 |
-///  |   |   <body contents>              |
-///  |   |   %pair = cmpxchg              |
-///  |   |   %ok = %pair[0]               |
-///  |   |   %new = %pair[1]              |
-///  |   |   cond_br %ok, end, loop(%new) |
-///  |   +--------------------------------+
-///  |          |        |
-///  |-----------        |
-///                      v
-///      +--------------------------------+
-///      | end:                           |
-///      |   <code after the AtomicRMWOp> |
-///      +--------------------------------+
-///
-struct AtomicCmpXchgOpLowering : public LoadStoreOpLowering<AtomicRMWOp> {
-  using Base::Base;
-
-  LogicalResult
-  matchAndRewrite(Operation *op, ArrayRef<Value> operands,
-                  ConversionPatternRewriter &rewriter) const override {
-    auto atomicOp = cast<AtomicRMWOp>(op);
-    auto maybeKind = matchSimpleAtomicOp(atomicOp);
-    if (maybeKind)
-      return failure();
-
-    LLVM::FCmpPredicate predicate;
-    switch (atomicOp.kind()) {
-    case AtomicRMWKind::maxf:
-      predicate = LLVM::FCmpPredicate::ogt;
-      break;
-    case AtomicRMWKind::minf:
-      predicate = LLVM::FCmpPredicate::olt;
-      break;
-    default:
-      return failure();
-    }
-
-    OperandAdaptor<AtomicRMWOp> adaptor(operands);
-    auto loc = op->getLoc();
-    auto valueType = adaptor.value().getType().cast<LLVM::LLVMType>();
-
-    // Split the block into initial, loop, and ending parts.
-    auto *initBlock = rewriter.getInsertionBlock();
-    auto initPosition = rewriter.getInsertionPoint();
-    auto *loopBlock = rewriter.splitBlock(initBlock, initPosition);
-    auto loopArgument = loopBlock->addArgument(valueType);
-    auto loopPosition = rewriter.getInsertionPoint();
-    auto *endBlock = rewriter.splitBlock(loopBlock, loopPosition);
-
-    // Compute the loaded value and branch to the loop block.
-    rewriter.setInsertionPointToEnd(initBlock);
-    auto memRefType = atomicOp.getMemRefType();
-    auto dataPtr = getDataPtr(loc, memRefType, adaptor.memref(),
-                              adaptor.indices(), rewriter, getModule());
-    Value init = rewriter.create<LLVM::LoadOp>(loc, dataPtr);
-    rewriter.create<LLVM::BrOp>(loc, init, loopBlock);
-
-    // Prepare the body of the loop block.
-    rewriter.setInsertionPointToStart(loopBlock);
-    auto predicateI64 =
-        rewriter.getI64IntegerAttr(static_cast<int64_t>(predicate));
-    auto boolType = LLVM::LLVMType::getInt1Ty(&getDialect());
-    auto lhs = loopArgument;
-    auto rhs = adaptor.value();
-    auto cmp =
-        rewriter.create<LLVM::FCmpOp>(loc, boolType, predicateI64, lhs, rhs);
-    auto select = rewriter.create<LLVM::SelectOp>(loc, cmp, lhs, rhs);
-
-    // Prepare the epilog of the loop block.
-    rewriter.setInsertionPointToEnd(loopBlock);
-    // Append the cmpxchg op to the end of the loop block.
-    auto successOrdering = LLVM::AtomicOrdering::acq_rel;
-    auto failureOrdering = LLVM::AtomicOrdering::monotonic;
-    auto pairType = LLVM::LLVMType::getStructTy(valueType, boolType);
-    auto cmpxchg = rewriter.create<LLVM::AtomicCmpXchgOp>(
-        loc, pairType, dataPtr, loopArgument, select, successOrdering,
-        failureOrdering);
-    // Extract the %new_loaded and %ok values from the pair.
-    Value newLoaded = rewriter.create<LLVM::ExtractValueOp>(
-        loc, valueType, cmpxchg, rewriter.getI64ArrayAttr({0}));
-    Value ok = rewriter.create<LLVM::ExtractValueOp>(
-        loc, boolType, cmpxchg, rewriter.getI64ArrayAttr({1}));
-
-    // Conditionally branch to the end or back to the loop depending on %ok.
-    rewriter.create<LLVM::CondBrOp>(loc, ok, endBlock, ArrayRef<Value>(),
-                                    loopBlock, newLoaded);
-
-    // The 'result' of the atomic_rmw op is the newly loaded value.
-    rewriter.replaceOp(op, {newLoaded});
-
-    return success();
-  }
-};
-
 /// Wrap a llvm.cmpxchg operation in a while loop so that the operation can be
 /// retried until it succeeds in atomically storing a new value into memory.
 ///
@@ -2985,7 +2878,6 @@ void mlir::populateStdToLLVMNonMemoryConversionPatterns(
       AddIOpLowering,
       AllocaOpLowering,
       AndOpLowering,
-      AtomicCmpXchgOpLowering,
       AtomicRMWOpLowering,
       BranchOpLowering,
       CallIndirectOpLowering,

diff  --git a/mlir/lib/Dialect/StandardOps/CMakeLists.txt b/mlir/lib/Dialect/StandardOps/CMakeLists.txt
index 6d83cbb375ae..a2f496c7ab93 100644
--- a/mlir/lib/Dialect/StandardOps/CMakeLists.txt
+++ b/mlir/lib/Dialect/StandardOps/CMakeLists.txt
@@ -17,3 +17,5 @@ add_mlir_dialect_library(MLIRStandardOps
   MLIRSideEffects
   MLIRViewLikeInterface
   )
+
+add_subdirectory(Transforms)

diff  --git a/mlir/lib/Dialect/StandardOps/Transforms/CMakeLists.txt b/mlir/lib/Dialect/StandardOps/Transforms/CMakeLists.txt
new file mode 100644
index 000000000000..b79a0b569b97
--- /dev/null
+++ b/mlir/lib/Dialect/StandardOps/Transforms/CMakeLists.txt
@@ -0,0 +1,17 @@
+add_mlir_dialect_library(MLIRStandardOpsTransforms
+  ExpandAtomic.cpp
+
+  ADDITIONAL_HEADER_DIRS
+  ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/StandardOps/Transforms
+
+  DEPENDS
+  MLIRStandardTransformsIncGen
+  )
+target_link_libraries(MLIRStandardOpsTransforms
+  PUBLIC
+  MLIRIR
+  MLIRPass
+  MLIRStandardOps
+  MLIRSupport
+  LLVMSupport
+  )

diff  --git a/mlir/lib/Dialect/StandardOps/Transforms/ExpandAtomic.cpp b/mlir/lib/Dialect/StandardOps/Transforms/ExpandAtomic.cpp
new file mode 100644
index 000000000000..41e0ffb60bc9
--- /dev/null
+++ b/mlir/lib/Dialect/StandardOps/Transforms/ExpandAtomic.cpp
@@ -0,0 +1,93 @@
+//===- ExpandAtomic.cpp - Code to perform loop fusion ---------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements expansion of AtomicRMWOp into GenericAtomicRMWOp.
+//
+//===----------------------------------------------------------------------===//
+
+#include "PassDetail.h"
+#include "mlir/Dialect/StandardOps/IR/Ops.h"
+#include "mlir/Dialect/StandardOps/Transforms/Passes.h"
+#include "mlir/IR/Builders.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/Transforms/DialectConversion.h"
+
+using namespace mlir;
+
+namespace {
+
+/// Converts `atomic_rmw` that cannot be lowered to a simple atomic op with
+/// AtomicRMWOpLowering pattern, e.g. with "minf" or "maxf" attributes, to
+/// `generic_atomic_rmw` with the expanded code.
+///
+/// %x = atomic_rmw "maxf" %fval, %F[%i] : (f32, memref<10xf32>) -> f32
+///
+/// will be lowered to
+///
+/// %x = std.generic_atomic_rmw %F[%i] : memref<10xf32> {
+/// ^bb0(%current: f32):
+///   %cmp = cmpf "ogt", %current, %fval : f32
+///   %new_value = select %cmp, %current, %fval : f32
+///   atomic_yield %new_value : f32
+/// }
+struct AtomicRMWOpConverter : public OpRewritePattern<AtomicRMWOp> {
+public:
+  using OpRewritePattern::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(AtomicRMWOp op,
+                                PatternRewriter &rewriter) const final {
+    CmpFPredicate predicate;
+    switch (op.kind()) {
+    case AtomicRMWKind::maxf:
+      predicate = CmpFPredicate::OGT;
+      break;
+    case AtomicRMWKind::minf:
+      predicate = CmpFPredicate::OLT;
+      break;
+    default:
+      return failure();
+    }
+
+    auto loc = op.getLoc();
+    auto genericOp =
+        rewriter.create<GenericAtomicRMWOp>(loc, op.memref(), op.indices());
+    OpBuilder bodyBuilder = OpBuilder::atBlockEnd(genericOp.getBody());
+
+    Value lhs = genericOp.getCurrentValue();
+    Value rhs = op.value();
+    Value cmp = bodyBuilder.create<CmpFOp>(loc, predicate, lhs, rhs);
+    Value select = bodyBuilder.create<SelectOp>(loc, cmp, lhs, rhs);
+    bodyBuilder.create<AtomicYieldOp>(loc, select);
+
+    rewriter.replaceOp(op, genericOp.getResult());
+    return success();
+  }
+};
+
+struct ExpandAtomic : public ExpandAtomicBase<ExpandAtomic> {
+  void runOnFunction() override {
+    OwningRewritePatternList patterns;
+    patterns.insert<AtomicRMWOpConverter>(&getContext());
+
+    ConversionTarget target(getContext());
+    target.addLegalOp<GenericAtomicRMWOp>();
+    target.addDynamicallyLegalOp<AtomicRMWOp>([](AtomicRMWOp op) {
+      return op.kind() != AtomicRMWKind::maxf &&
+             op.kind() != AtomicRMWKind::minf;
+    });
+    if (failed(mlir::applyPartialConversion(getFunction(), target, patterns)))
+      signalPassFailure();
+  }
+};
+
+} // namespace
+
+std::unique_ptr<Pass> mlir::createExpandAtomicPass() {
+  return std::make_unique<ExpandAtomic>();
+}

diff  --git a/mlir/lib/Dialect/StandardOps/Transforms/PassDetail.h b/mlir/lib/Dialect/StandardOps/Transforms/PassDetail.h
new file mode 100644
index 000000000000..4748bf83ab99
--- /dev/null
+++ b/mlir/lib/Dialect/StandardOps/Transforms/PassDetail.h
@@ -0,0 +1,23 @@
+//===- PassDetail.h - GPU Pass class details --------------------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef DIALECT_STANDARD_TRANSFORMS_PASSDETAIL_H_
+#define DIALECT_STANDARD_TRANSFORMS_PASSDETAIL_H_
+
+#include "mlir/Pass/Pass.h"
+
+namespace mlir {
+
+class AtomicRMWOp;
+
+#define GEN_PASS_CLASSES
+#include "mlir/Dialect/StandardOps/Transforms/Passes.h.inc"
+
+} // end namespace mlir
+
+#endif // DIALECT_STANDARD_TRANSFORMS_PASSDETAIL_H_

diff  --git a/mlir/test/Conversion/StandardToLLVM/convert-to-llvmir.mlir b/mlir/test/Conversion/StandardToLLVM/convert-to-llvmir.mlir
index 1b17e46ccc1b..b7cb13e51ca2 100644
--- a/mlir/test/Conversion/StandardToLLVM/convert-to-llvmir.mlir
+++ b/mlir/test/Conversion/StandardToLLVM/convert-to-llvmir.mlir
@@ -1110,25 +1110,6 @@ func @atomic_rmw(%I : memref<10xi32>, %ival : i32, %F : memref<10xf32>, %fval :
 
 // -----
 
-// CHECK-LABEL: func @cmpxchg
-func @cmpxchg(%F : memref<10xf32>, %fval : f32, %i : index) -> f32 {
-  %x = atomic_rmw "maxf" %fval, %F[%i] : (f32, memref<10xf32>) -> f32
-  // CHECK: %[[init:.*]] = llvm.load %{{.*}} : !llvm<"float*">
-  // CHECK-NEXT: llvm.br ^bb1(%[[init]] : !llvm.float)
-  // CHECK-NEXT: ^bb1(%[[loaded:.*]]: !llvm.float):
-  // CHECK-NEXT: %[[cmp:.*]] = llvm.fcmp "ogt" %[[loaded]], %{{.*}} : !llvm.float
-  // CHECK-NEXT: %[[max:.*]] = llvm.select %[[cmp]], %[[loaded]], %{{.*}} : !llvm.i1, !llvm.float
-  // CHECK-NEXT: %[[pair:.*]] = llvm.cmpxchg %{{.*}}, %[[loaded]], %[[max]] acq_rel monotonic : !llvm.float
-  // CHECK-NEXT: %[[new:.*]] = llvm.extractvalue %[[pair]][0] : !llvm<"{ float, i1 }">
-  // CHECK-NEXT: %[[ok:.*]] = llvm.extractvalue %[[pair]][1] : !llvm<"{ float, i1 }">
-  // CHECK-NEXT: llvm.cond_br %[[ok]], ^bb2, ^bb1(%[[new]] : !llvm.float)
-  // CHECK-NEXT: ^bb2:
-  return %x : f32
-  // CHECK-NEXT: llvm.return %[[new]]
-}
-
-// -----
-
 // CHECK-LABEL: func @generic_atomic_rmw
 func @generic_atomic_rmw(%I : memref<10xf32>, %i : index) -> f32 {
   %x = generic_atomic_rmw %I[%i] : memref<10xf32> {

diff  --git a/mlir/test/Dialect/Standard/expand-atomic.mlir b/mlir/test/Dialect/Standard/expand-atomic.mlir
new file mode 100644
index 000000000000..b4e65945f58a
--- /dev/null
+++ b/mlir/test/Dialect/Standard/expand-atomic.mlir
@@ -0,0 +1,24 @@
+// RUN: mlir-opt %s -expand-atomic -split-input-file | FileCheck %s --dump-input-on-failure
+
+// CHECK-LABEL: func @atomic_rmw_to_generic
+// CHECK-SAME: ([[F:%.*]]: memref<10xf32>, [[f:%.*]]: f32, [[i:%.*]]: index)
+func @atomic_rmw_to_generic(%F: memref<10xf32>, %f: f32, %i: index) -> f32 {
+  %x = atomic_rmw "maxf" %f, %F[%i] : (f32, memref<10xf32>) -> f32
+  return %x : f32
+}
+// CHECK: %0 = std.generic_atomic_rmw %arg0[%arg2] : memref<10xf32> {
+// CHECK: ^bb0([[CUR_VAL:%.*]]: f32):
+// CHECK:   [[CMP:%.*]] = cmpf "ogt", [[CUR_VAL]], [[f]] : f32
+// CHECK:   [[SELECT:%.*]] = select [[CMP]], [[CUR_VAL]], [[f]] : f32
+// CHECK:   atomic_yield [[SELECT]] : f32
+// CHECK: }
+// CHECK: return %0 : f32
+
+// -----
+
+// CHECK-LABEL: func @atomic_rmw_no_conversion
+func @atomic_rmw_no_conversion(%F: memref<10xf32>, %f: f32, %i: index) -> f32 {
+  %x = atomic_rmw "addf" %f, %F[%i] : (f32, memref<10xf32>) -> f32
+  return %x : f32
+}
+// CHECK-NOT: generic_atomic_rmw


        


More information about the Mlir-commits mailing list