[Mlir-commits] [mlir] fc421d7 - [MLIR] Remove all-reduce lowering from GPU to NVVM. Use in-dialect lowering instead.

Christian Sigg llvmlistbot at llvm.org
Wed Mar 11 07:18:05 PDT 2020


Author: Christian Sigg
Date: 2020-03-11T15:17:54+01:00
New Revision: fc421d7ca3ecc4864c547e587eac4dacd98c33da

URL: https://github.com/llvm/llvm-project/commit/fc421d7ca3ecc4864c547e587eac4dacd98c33da
DIFF: https://github.com/llvm/llvm-project/commit/fc421d7ca3ecc4864c547e587eac4dacd98c33da.diff

LOG: [MLIR] Remove all-reduce lowering from GPU to NVVM. Use in-dialect lowering instead.

Reviewers: herhut, mravishankar

Reviewed By: herhut

Subscribers: merge_guards_bot, jholewinski, mehdi_amini, rriddle, jpienaar, burmako, shauheen, antiagainst, nicolasvasilache, arpith-jacob, mgester, lucyrfox, aartbik, liufengdb, Joonsoo, llvm-commits

Tags: #llvm

Differential Revision: https://reviews.llvm.org/D73794

Added: 
    

Modified: 
    mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp
    mlir/test/Conversion/GPUToNVVM/gpu-to-nvvm.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp b/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp
index 7af62dfb3e4a..67c335e629fe 100644
--- a/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp
+++ b/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp
@@ -15,6 +15,7 @@
 
 #include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h"
 #include "mlir/Dialect/GPU/GPUDialect.h"
+#include "mlir/Dialect/GPU/Passes.h"
 #include "mlir/Dialect/LLVMIR/NVVMDialect.h"
 #include "mlir/IR/BlockAndValueMapping.h"
 #include "mlir/Pass/Pass.h"
@@ -28,448 +29,6 @@ using namespace mlir;
 
 namespace {
 
-/// Converts all_reduce op to LLVM/NVVM ops.
-struct GPUAllReduceOpLowering : public ConvertToLLVMPattern {
-  using AccumulatorFactory =
-      std::function<Value(Location, Value, Value, ConversionPatternRewriter &)>;
-
-  explicit GPUAllReduceOpLowering(LLVMTypeConverter &lowering_)
-      : ConvertToLLVMPattern(gpu::AllReduceOp::getOperationName(),
-                             lowering_.getDialect()->getContext(), lowering_),
-        int32Type(LLVM::LLVMType::getInt32Ty(lowering_.getDialect())) {}
-
-  PatternMatchResult
-  matchAndRewrite(Operation *op, ArrayRef<Value> operands,
-                  ConversionPatternRewriter &rewriter) const override {
-    Location loc = op->getLoc();
-    Value operand = operands.front();
-
-    // TODO(csigg): Generalize to other types of accumulation.
-    assert(op->getOperand(0).getType().isSignlessIntOrFloat());
-
-    // Create the reduction using an accumulator factory.
-    AccumulatorFactory factory =
-        getFactory(cast<gpu::AllReduceOp>(op), operand);
-    assert(factory && "failed to create accumulator factory");
-    Value result = createBlockReduce(loc, operand, factory, rewriter);
-
-    rewriter.replaceOp(op, {result});
-    return matchSuccess();
-  }
-
-private:
-  /// Returns an accumulator factory using either the op attribute or the body
-  /// region.
-  AccumulatorFactory getFactory(gpu::AllReduceOp allReduce,
-                                Value operand) const {
-    if (!allReduce.body().empty()) {
-      return getFactory(allReduce.body());
-    }
-    if (allReduce.op()) {
-      auto type = operand.getType().cast<LLVM::LLVMType>();
-      return getFactory(*allReduce.op(), type.getUnderlyingType());
-    }
-    return AccumulatorFactory();
-  }
-
-  /// Returns an accumulator factory that clones the body. The body's entry
-  /// block is expected to have 2 arguments. The gpu.yield return the
-  /// accumulated value of the same type.
-  AccumulatorFactory getFactory(Region &body) const {
-    return AccumulatorFactory([&](Location loc, Value lhs, Value rhs,
-                                  ConversionPatternRewriter &rewriter) {
-      Block *block = rewriter.getInsertionBlock();
-      Block *split = rewriter.splitBlock(block, rewriter.getInsertionPoint());
-
-      // Insert accumulator body between split block.
-      BlockAndValueMapping mapping;
-      mapping.map(body.front().getArgument(0), lhs);
-      mapping.map(body.front().getArgument(1), rhs);
-      rewriter.cloneRegionBefore(body, *split->getParent(),
-                                 split->getIterator(), mapping);
-
-      // Add branch before inserted body, into body.
-      block = block->getNextNode();
-      rewriter.create<LLVM::BrOp>(loc, ValueRange(), block);
-
-      // Replace all gpu.yield ops with branch out of body.
-      for (; block != split; block = block->getNextNode()) {
-        Operation *terminator = block->getTerminator();
-        if (!llvm::isa<gpu::YieldOp>(terminator))
-          continue;
-        rewriter.setInsertionPointToEnd(block);
-        rewriter.replaceOpWithNewOp<LLVM::BrOp>(
-            terminator, terminator->getOperand(0), split);
-      }
-
-      // Return accumulator result.
-      rewriter.setInsertionPointToStart(split);
-      return split->addArgument(lhs.getType());
-    });
-  }
-
-  /// Returns an accumulator factory that creates an op specified by opName.
-  AccumulatorFactory getFactory(StringRef opName, llvm::Type *type) const {
-    if (type->isVectorTy() || type->isArrayTy())
-      return getFactory(opName, type->getSequentialElementType());
-
-    bool isFloatingPoint = type->isFloatingPointTy();
-
-    if (opName == "add") {
-      return isFloatingPoint ? getFactory<LLVM::FAddOp>()
-                             : getFactory<LLVM::AddOp>();
-    }
-    if (opName == "mul") {
-      return isFloatingPoint ? getFactory<LLVM::FMulOp>()
-                             : getFactory<LLVM::MulOp>();
-    }
-    if (opName == "and") {
-      return getFactory<LLVM::AndOp>();
-    }
-    if (opName == "or") {
-      return getFactory<LLVM::OrOp>();
-    }
-    if (opName == "xor") {
-      return getFactory<LLVM::XOrOp>();
-    }
-    if (opName == "max") {
-      return isFloatingPoint ? getCmpFactory<LLVM::FCmpOp, LLVM::FCmpPredicate,
-                                             LLVM::FCmpPredicate::ugt>()
-                             : getCmpFactory<LLVM::ICmpOp, LLVM::ICmpPredicate,
-                                             LLVM::ICmpPredicate::ugt>();
-    }
-    if (opName == "min") {
-      return isFloatingPoint ? getCmpFactory<LLVM::FCmpOp, LLVM::FCmpPredicate,
-                                             LLVM::FCmpPredicate::ult>()
-                             : getCmpFactory<LLVM::ICmpOp, LLVM::ICmpPredicate,
-                                             LLVM::ICmpPredicate::ult>();
-    }
-
-    return AccumulatorFactory();
-  }
-
-  /// Returns an accumulator factory that creates an op of type T.
-  template <typename T>
-  AccumulatorFactory getFactory() const {
-    return [](Location loc, Value lhs, Value rhs,
-              ConversionPatternRewriter &rewriter) {
-      return rewriter.create<T>(loc, lhs.getType(), lhs, rhs);
-    };
-  }
-
-  /// Returns an accumulator for comparaison such as min, max. T is the type
-  /// of the compare op.
-  template <typename T, typename PredicateEnum, PredicateEnum predicate>
-  AccumulatorFactory getCmpFactory() const {
-    return [](Location loc, Value lhs, Value rhs,
-              ConversionPatternRewriter &rewriter) {
-      Value cmp = rewriter.create<T>(loc, predicate, lhs, rhs);
-      return rewriter.create<LLVM::SelectOp>(loc, cmp, lhs, rhs);
-    };
-  }
-
-  /// Creates an all_reduce across the block.
-  ///
-  /// First reduce the elements within a warp. The first thread of each warp
-  /// writes the intermediate result to shared memory. After synchronizing the
-  /// block, the first warp reduces the values from shared memory. The result
-  /// is broadcasted to all threads through shared memory.
-  ///
-  ///     %warp_reduce = `createWarpReduce(%operand)`
-  ///     %shared_mem_ptr = llvm.mlir.addressof @reduce_buffer
-  ///     %zero = llvm.mlir.constant(0 : i32) : !llvm.i32
-  ///     %lane_id = nvvm.read.ptx.sreg.laneid  : !llvm.i32
-  ///     %is_first_lane = llvm.icmp "eq" %lane_id, %zero : !llvm.i1
-  ///     %thread_idx = `getLinearThreadIndex()` : !llvm.i32
-  ///     llvm.cond_br %is_first_lane, ^then1, ^continue1
-  ///   ^then1:
-  ///     %warp_id = `getWarpId()`
-  ///     %store_dst = llvm.getelementptr %shared_mem_ptr[%zero, %warp_id]
-  ///     llvm.store %store_dst, %warp_reduce
-  ///     llvm.br ^continue1
-  ///   ^continue1:
-  ///     nvvm.barrier0
-  ///     %num_warps = `getNumWarps()` : !llvm.i32
-  ///     %is_valid_warp = llvm.icmp "slt" %thread_idx, %num_warps
-  ///     %result_ptr = llvm.getelementptr %shared_mem_ptr[%zero, %zero]
-  ///     llvm.cond_br %is_first_lane, ^then2, ^continue2
-  ///   ^then2:
-  ///     %load_src = llvm.getelementptr %shared_mem_ptr[%zero, %thread_idx]
-  ///     %value = llvm.load %load_src
-  ///     %result = `createWarpReduce(%value)`
-  ///     llvm.store %result_ptr, %result
-  ///     llvm.br ^continue2
-  ///   ^continue2:
-  ///     nvvm.barrier0
-  ///     %result = llvm.load %result_ptr
-  ///     return %result
-  ///
-  Value createBlockReduce(Location loc, Value operand,
-                          AccumulatorFactory &accumFactory,
-                          ConversionPatternRewriter &rewriter) const {
-    auto type = operand.getType().cast<LLVM::LLVMType>();
-
-    // Create shared memory array to store the warp reduction.
-    auto module = operand.getDefiningOp()->getParentOfType<gpu::GPUModuleOp>();
-    assert(module && "op must belong to a module");
-    Value sharedMemPtr =
-        createSharedMemoryArray(loc, module, type, kWarpSize, rewriter);
-
-    Value zero = rewriter.create<LLVM::ConstantOp>(
-        loc, int32Type, rewriter.getI32IntegerAttr(0u));
-    Value laneId = rewriter.create<NVVM::LaneIdOp>(loc, int32Type);
-    Value isFirstLane = rewriter.create<LLVM::ICmpOp>(
-        loc, LLVM::ICmpPredicate::eq, laneId, zero);
-    Value threadIdx = getLinearThreadIndex(loc, rewriter);
-    Value blockSize = getBlockSize(loc, rewriter);
-    Value activeWidth = getActiveWidth(loc, threadIdx, blockSize, rewriter);
-
-    // Reduce elements within each warp to produce the intermediate results.
-    Value warpReduce = createWarpReduce(loc, activeWidth, laneId, operand,
-                                        accumFactory, rewriter);
-
-    // Write the intermediate results to shared memory, using the first lane of
-    // each warp.
-    createPredicatedBlock(loc, rewriter, isFirstLane, [&] {
-      Value warpId = getDivideByWarpSize(threadIdx, rewriter);
-      Value storeDst = rewriter.create<LLVM::GEPOp>(
-          loc, type, sharedMemPtr, ArrayRef<Value>({zero, warpId}));
-      rewriter.create<LLVM::StoreOp>(loc, warpReduce, storeDst);
-    });
-    rewriter.create<NVVM::Barrier0Op>(loc);
-
-    Value numWarps = getNumWarps(loc, blockSize, rewriter);
-    Value isValidWarp = rewriter.create<LLVM::ICmpOp>(
-        loc, LLVM::ICmpPredicate::slt, threadIdx, numWarps);
-    Value resultPtr = rewriter.create<LLVM::GEPOp>(
-        loc, type, sharedMemPtr, ArrayRef<Value>({zero, zero}));
-
-    // Use the first numWarps threads to reduce the intermediate results from
-    // shared memory. The final result is written to shared memory again.
-    createPredicatedBlock(loc, rewriter, isValidWarp, [&] {
-      Value loadSrc = rewriter.create<LLVM::GEPOp>(
-          loc, type, sharedMemPtr, ArrayRef<Value>({zero, threadIdx}));
-      Value value = rewriter.create<LLVM::LoadOp>(loc, type, loadSrc);
-      Value result = createWarpReduce(loc, numWarps, laneId, value,
-                                      accumFactory, rewriter);
-      rewriter.create<LLVM::StoreOp>(loc, result, resultPtr);
-    });
-    rewriter.create<NVVM::Barrier0Op>(loc);
-
-    // Load and return result from shared memory.
-    Value result = rewriter.create<LLVM::LoadOp>(loc, type, resultPtr);
-    return result;
-  }
-
-  /// Creates an if-block skeleton and calls the two factories to generate the
-  /// ops in the `then` and `else` block..
-  ///
-  ///     llvm.cond_br %condition, ^then, ^continue
-  ///   ^then:
-  ///     %then_operands = `thenOpsFactory()`
-  ///     llvm.br ^continue(%then_operands)
-  ///   ^else:
-  ///     %else_operands = `elseOpsFactory()`
-  ///     llvm.br ^continue(%else_operands)
-  ///   ^continue(%block_operands):
-  ///
-  template <typename ThenOpsFactory, typename ElseOpsFactory>
-  void createIf(Location loc, ConversionPatternRewriter &rewriter,
-                Value condition, ThenOpsFactory &&thenOpsFactory,
-                ElseOpsFactory &&elseOpsFactory) const {
-    Block *currentBlock = rewriter.getInsertionBlock();
-    auto currentPoint = rewriter.getInsertionPoint();
-
-    Block *thenBlock = rewriter.splitBlock(currentBlock, currentPoint);
-    Block *elseBlock = rewriter.splitBlock(thenBlock, thenBlock->begin());
-    Block *continueBlock = rewriter.splitBlock(elseBlock, elseBlock->begin());
-
-    rewriter.setInsertionPointToEnd(currentBlock);
-    rewriter.create<LLVM::CondBrOp>(loc, condition, thenBlock, elseBlock);
-
-    auto addBranch = [&](ValueRange operands) {
-      rewriter.create<LLVM::BrOp>(loc, operands, continueBlock);
-    };
-
-    rewriter.setInsertionPointToStart(thenBlock);
-    auto thenOperands = thenOpsFactory();
-    addBranch(thenOperands);
-
-    rewriter.setInsertionPointToStart(elseBlock);
-    auto elseOperands = elseOpsFactory();
-    addBranch(elseOperands);
-
-    assert(thenOperands.size() == elseOperands.size());
-    rewriter.setInsertionPointToStart(continueBlock);
-    for (auto operand : thenOperands)
-      continueBlock->addArgument(operand.getType());
-  }
-
-  /// Shortcut for createIf with empty else block and no block operands.
-  template <typename Factory>
-  void createPredicatedBlock(Location loc, ConversionPatternRewriter &rewriter,
-                             Value condition,
-                             Factory &&predicatedOpsFactory) const {
-    createIf(
-        loc, rewriter, condition,
-        [&] {
-          predicatedOpsFactory();
-          return ArrayRef<Value>();
-        },
-        [&] { return ArrayRef<Value>(); });
-  }
-
-  /// Creates a reduction across the first activeWidth lanes of a warp.
-  /// The first lane returns the result, all others return values are undefined.
-  Value createWarpReduce(Location loc, Value activeWidth, Value laneId,
-                         Value operand, AccumulatorFactory accumFactory,
-                         ConversionPatternRewriter &rewriter) const {
-    Value warpSize = rewriter.create<LLVM::ConstantOp>(
-        loc, int32Type, rewriter.getI32IntegerAttr(kWarpSize));
-    Value isPartialWarp = rewriter.create<LLVM::ICmpOp>(
-        loc, LLVM::ICmpPredicate::slt, activeWidth, warpSize);
-    auto type = operand.getType().cast<LLVM::LLVMType>();
-
-    createIf(
-        loc, rewriter, isPartialWarp,
-        // Generate reduction over a (potentially) partial warp.
-        [&] {
-          Value value = operand;
-          Value one = rewriter.create<LLVM::ConstantOp>(
-              loc, int32Type, rewriter.getI32IntegerAttr(1));
-          // Bit mask of active lanes: `(1 << activeWidth) - 1`.
-          Value activeMask = rewriter.create<LLVM::SubOp>(
-              loc, int32Type,
-              rewriter.create<LLVM::ShlOp>(loc, int32Type, one, activeWidth),
-              one);
-          // Clamp lane: `activeWidth - 1`
-          Value maskAndClamp =
-              rewriter.create<LLVM::SubOp>(loc, int32Type, activeWidth, one);
-          auto dialect = typeConverter.getDialect();
-          auto predTy = LLVM::LLVMType::getInt1Ty(dialect);
-          auto shflTy = LLVM::LLVMType::getStructTy(dialect, {type, predTy});
-          auto returnValueAndIsValidAttr = rewriter.getUnitAttr();
-
-          // Repeatedly shuffle value from 'laneId ^ i' and accumulate if source
-          // lane is within the active range. All lanes contain the final
-          // result, but only the first lane's result is used.
-          for (int i = 1; i < kWarpSize; i <<= 1) {
-            Value offset = rewriter.create<LLVM::ConstantOp>(
-                loc, int32Type, rewriter.getI32IntegerAttr(i));
-            Value shfl = rewriter.create<NVVM::ShflBflyOp>(
-                loc, shflTy, activeMask, value, offset, maskAndClamp,
-                returnValueAndIsValidAttr);
-            Value isActiveSrcLane = rewriter.create<LLVM::ExtractValueOp>(
-                loc, predTy, shfl, rewriter.getIndexArrayAttr(1));
-            // Skip the accumulation if the shuffle op read from a lane outside
-            // of the active range.
-            createIf(
-                loc, rewriter, isActiveSrcLane,
-                [&] {
-                  Value shflValue = rewriter.create<LLVM::ExtractValueOp>(
-                      loc, type, shfl, rewriter.getIndexArrayAttr(0));
-                  return SmallVector<Value, 1>{
-                      accumFactory(loc, value, shflValue, rewriter)};
-                },
-                [&] { return llvm::makeArrayRef(value); });
-            value = rewriter.getInsertionBlock()->getArgument(0);
-          }
-          return SmallVector<Value, 1>{value};
-        },
-        // Generate a reduction over the entire warp. This is a specialization
-        // of the above reduction with unconditional accumulation.
-        [&] {
-          Value value = operand;
-          Value activeMask = rewriter.create<LLVM::ConstantOp>(
-              loc, int32Type, rewriter.getI32IntegerAttr(~0u));
-          Value maskAndClamp = rewriter.create<LLVM::ConstantOp>(
-              loc, int32Type, rewriter.getI32IntegerAttr(kWarpSize - 1));
-          for (int i = 1; i < kWarpSize; i <<= 1) {
-            Value offset = rewriter.create<LLVM::ConstantOp>(
-                loc, int32Type, rewriter.getI32IntegerAttr(i));
-            Value shflValue = rewriter.create<NVVM::ShflBflyOp>(
-                loc, type, activeMask, value, offset, maskAndClamp,
-                /*return_value_and_is_valid=*/UnitAttr());
-            value = accumFactory(loc, value, shflValue, rewriter);
-          }
-          return SmallVector<Value, 1>{value};
-        });
-    return rewriter.getInsertionBlock()->getArgument(0);
-  }
-
-  /// Creates a global array stored in shared memory.
-  Value createSharedMemoryArray(Location loc, gpu::GPUModuleOp module,
-                                LLVM::LLVMType elementType, int numElements,
-                                ConversionPatternRewriter &rewriter) const {
-    OpBuilder builder(module.body());
-
-    auto arrayType = LLVM::LLVMType::getArrayTy(elementType, numElements);
-    StringRef name = "reduce_buffer";
-    auto globalOp = builder.create<LLVM::GlobalOp>(
-        loc, arrayType.cast<LLVM::LLVMType>(),
-        /*isConstant=*/false, LLVM::Linkage::Internal, name,
-        /*value=*/Attribute(), gpu::GPUDialect::getWorkgroupAddressSpace());
-
-    return rewriter.create<LLVM::AddressOfOp>(loc, globalOp);
-  }
-
-  /// Returns the index of the thread within the block.
-  Value getLinearThreadIndex(Location loc,
-                             ConversionPatternRewriter &rewriter) const {
-    Value dimX = rewriter.create<NVVM::BlockDimXOp>(loc, int32Type);
-    Value dimY = rewriter.create<NVVM::BlockDimYOp>(loc, int32Type);
-    Value idX = rewriter.create<NVVM::ThreadIdXOp>(loc, int32Type);
-    Value idY = rewriter.create<NVVM::ThreadIdYOp>(loc, int32Type);
-    Value idZ = rewriter.create<NVVM::ThreadIdZOp>(loc, int32Type);
-    Value tmp1 = rewriter.create<LLVM::MulOp>(loc, int32Type, idZ, dimY);
-    Value tmp2 = rewriter.create<LLVM::AddOp>(loc, int32Type, tmp1, idY);
-    Value tmp3 = rewriter.create<LLVM::MulOp>(loc, int32Type, tmp2, dimX);
-    return rewriter.create<LLVM::AddOp>(loc, int32Type, tmp3, idX);
-  }
-
-  /// Returns the number of threads in the block.
-  Value getBlockSize(Location loc, ConversionPatternRewriter &rewriter) const {
-    Value dimX = rewriter.create<NVVM::BlockDimXOp>(loc, int32Type);
-    Value dimY = rewriter.create<NVVM::BlockDimYOp>(loc, int32Type);
-    Value dimZ = rewriter.create<NVVM::BlockDimZOp>(loc, int32Type);
-    Value dimXY = rewriter.create<LLVM::MulOp>(loc, int32Type, dimX, dimY);
-    return rewriter.create<LLVM::MulOp>(loc, int32Type, dimXY, dimZ);
-  }
-
-  /// Returns the number of warps in the block.
-  Value getNumWarps(Location loc, Value blockSize,
-                    ConversionPatternRewriter &rewriter) const {
-    auto warpSizeMinusOne = rewriter.create<LLVM::ConstantOp>(
-        loc, int32Type, rewriter.getI32IntegerAttr(kWarpSize - 1));
-    auto biasedBlockSize = rewriter.create<LLVM::AddOp>(
-        loc, int32Type, blockSize, warpSizeMinusOne);
-    return getDivideByWarpSize(biasedBlockSize, rewriter);
-  }
-
-  /// Returns the number of active threads in the warp, not clamped to 32.
-  Value getActiveWidth(Location loc, Value threadIdx, Value blockSize,
-                       ConversionPatternRewriter &rewriter) const {
-    Value threadIdxMask = rewriter.create<LLVM::ConstantOp>(
-        loc, int32Type, rewriter.getI32IntegerAttr(~(kWarpSize - 1)));
-    Value numThreadsWithSmallerWarpId =
-        rewriter.create<LLVM::AndOp>(loc, threadIdx, threadIdxMask);
-    return rewriter.create<LLVM::SubOp>(loc, blockSize,
-                                        numThreadsWithSmallerWarpId);
-  }
-
-  /// Returns value divided by the warp size (i.e. 32).
-  Value getDivideByWarpSize(Value value,
-                            ConversionPatternRewriter &rewriter) const {
-    auto loc = value.getLoc();
-    auto warpSize = rewriter.create<LLVM::ConstantOp>(
-        loc, int32Type, rewriter.getI32IntegerAttr(kWarpSize));
-    return rewriter.create<LLVM::SDivOp>(loc, int32Type, value, warpSize);
-  }
-
-  LLVM::LLVMType int32Type;
-
-  static constexpr int kWarpSize = 32;
-};
 
 struct GPUShuffleOpLowering : public ConvertToLLVMPattern {
   explicit GPUShuffleOpLowering(LLVMTypeConverter &lowering_)
@@ -704,6 +263,14 @@ class LowerGpuOpsToNVVMOpsPass
     });
 
     OwningRewritePatternList patterns;
+
+    // Apply in-dialect lowering first. In-dialect lowering will replace ops
+    // which need to be lowered further, which is not supported by a single
+    // conversion pass.
+    populateGpuRewritePatterns(m.getContext(), patterns);
+    applyPatternsGreedily(m, patterns);
+    patterns.clear();
+
     populateStdToLLVMConversionPatterns(converter, patterns);
     populateGpuToNVVMConversionPatterns(converter, patterns);
     LLVMConversionTarget target(getContext());
@@ -735,8 +302,8 @@ void mlir::populateGpuToNVVMConversionPatterns(
                                           NVVM::BlockIdYOp, NVVM::BlockIdZOp>,
               GPUIndexIntrinsicOpLowering<gpu::GridDimOp, NVVM::GridDimXOp,
                                           NVVM::GridDimYOp, NVVM::GridDimZOp>,
-              GPUAllReduceOpLowering, GPUShuffleOpLowering, GPUFuncOpLowering,
-              GPUReturnOpLowering>(converter);
+              GPUShuffleOpLowering, GPUFuncOpLowering, GPUReturnOpLowering>(
+          converter);
   patterns.insert<OpToFuncCallLowering<AbsFOp>>(converter, "__nv_fabsf",
                                                 "__nv_fabs");
   patterns.insert<OpToFuncCallLowering<CeilFOp>>(converter, "__nv_ceilf",

diff  --git a/mlir/test/Conversion/GPUToNVVM/gpu-to-nvvm.mlir b/mlir/test/Conversion/GPUToNVVM/gpu-to-nvvm.mlir
index 4ae9d8022bc3..f05c9af1b30f 100644
--- a/mlir/test/Conversion/GPUToNVVM/gpu-to-nvvm.mlir
+++ b/mlir/test/Conversion/GPUToNVVM/gpu-to-nvvm.mlir
@@ -1,9 +1,10 @@
-// RUN: mlir-opt %s -convert-gpu-to-nvvm -split-input-file | FileCheck %s
+// RUN: mlir-opt %s -convert-gpu-to-nvvm -split-input-file | FileCheck %s --dump-input-on-failure
 
 gpu.module @test_module {
   // CHECK-LABEL: func @gpu_index_ops()
   func @gpu_index_ops()
-      attributes { gpu.kernel } {
+      -> (index, index, index, index, index, index,
+          index, index, index, index, index, index) {
     // CHECK: = nvvm.read.ptx.sreg.tid.x : !llvm.i32
     %tIdX = "gpu.thread_id"() {dimension = "x"} : () -> (index)
     // CHECK: = nvvm.read.ptx.sreg.tid.y : !llvm.i32
@@ -32,7 +33,10 @@ gpu.module @test_module {
     // CHECK: = nvvm.read.ptx.sreg.nctaid.z : !llvm.i32
     %gDimZ = "gpu.grid_dim"() {dimension = "z"} : () -> (index)
 
-    std.return
+    std.return %tIdX, %tIdY, %tIdZ, %bDimX, %bDimY, %bDimZ,
+               %bIdX, %bIdY, %bIdZ, %gDimX, %gDimY, %gDimZ
+        : index, index, index, index, index, index,
+          index, index, index, index, index, index
   }
 }
 
@@ -40,8 +44,7 @@ gpu.module @test_module {
 
 gpu.module @test_module {
   // CHECK-LABEL: func @gpu_all_reduce_op()
-  func @gpu_all_reduce_op()
-      attributes { gpu.kernel } {
+  gpu.func @gpu_all_reduce_op() {
     %arg0 = constant 1.0 : f32
     // TODO(csigg): Check full IR expansion once lowering has settled.
     // CHECK: nvvm.shfl.sync.bfly
@@ -49,7 +52,7 @@ gpu.module @test_module {
     // CHECK: llvm.fadd
     %result = "gpu.all_reduce"(%arg0) ({}) {op = "add"} : (f32) -> (f32)
 
-    std.return
+    gpu.return
   }
 }
 
@@ -57,8 +60,7 @@ gpu.module @test_module {
 
 gpu.module @test_module {
   // CHECK-LABEL: func @gpu_all_reduce_region()
-  func @gpu_all_reduce_region()
-      attributes { gpu.kernel } {
+  gpu.func @gpu_all_reduce_region() {
     %arg0 = constant 1 : i32
     // TODO(csigg): Check full IR expansion once lowering has settled.
     // CHECK: nvvm.shfl.sync.bfly
@@ -68,7 +70,7 @@ gpu.module @test_module {
       %xor = xor %lhs, %rhs : i32
       "gpu.yield"(%xor) : (i32) -> ()
     }) : (i32) -> (i32)
-    std.return
+    gpu.return
   }
 }
 
@@ -76,8 +78,7 @@ gpu.module @test_module {
 
 gpu.module @test_module {
   // CHECK-LABEL: func @gpu_shuffle()
-  func @gpu_shuffle()
-      attributes { gpu.kernel } {
+  func @gpu_shuffle() -> (f32) {
     // CHECK: %[[#VALUE:]] = llvm.mlir.constant(1.000000e+00 : f32) : !llvm.float
     %arg0 = constant 1.0 : f32
     // CHECK: %[[#OFFSET:]] = llvm.mlir.constant(4 : i32) : !llvm.i32
@@ -93,7 +94,7 @@ gpu.module @test_module {
     // CHECK: llvm.extractvalue %[[#SHFL]][1 : index] : !llvm<"{ float, i1 }">
     %shfl, %pred = "gpu.shuffle"(%arg0, %arg1, %arg2) { mode = "xor" } : (f32, i32, i32) -> (f32, i1)
 
-    std.return
+    std.return %shfl : f32
   }
 }
 
@@ -101,8 +102,7 @@ gpu.module @test_module {
 
 gpu.module @test_module {
   // CHECK-LABEL: func @gpu_sync()
-  func @gpu_sync()
-      attributes { gpu.kernel } {
+  func @gpu_sync() {
     // CHECK: nvvm.barrier0
     gpu.barrier
     std.return
@@ -115,12 +115,12 @@ gpu.module @test_module {
   // CHECK: llvm.func @__nv_fabsf(!llvm.float) -> !llvm.float
   // CHECK: llvm.func @__nv_fabs(!llvm.double) -> !llvm.double
   // CHECK-LABEL: func @gpu_fabs
-  func @gpu_fabs(%arg_f32 : f32, %arg_f64 : f64) {
+  func @gpu_fabs(%arg_f32 : f32, %arg_f64 : f64) -> (f32, f64) {
     %result32 = std.absf %arg_f32 : f32
     // CHECK: llvm.call @__nv_fabsf(%{{.*}}) : (!llvm.float) -> !llvm.float
     %result64 = std.absf %arg_f64 : f64
     // CHECK: llvm.call @__nv_fabs(%{{.*}}) : (!llvm.double) -> !llvm.double
-    std.return
+    std.return %result32, %result64 : f32, f64
   }
 }
 
@@ -130,12 +130,12 @@ gpu.module @test_module {
   // CHECK: llvm.func @__nv_ceilf(!llvm.float) -> !llvm.float
   // CHECK: llvm.func @__nv_ceil(!llvm.double) -> !llvm.double
   // CHECK-LABEL: func @gpu_ceil
-  func @gpu_ceil(%arg_f32 : f32, %arg_f64 : f64) {
+  func @gpu_ceil(%arg_f32 : f32, %arg_f64 : f64) -> (f32, f64) {
     %result32 = std.ceilf %arg_f32 : f32
     // CHECK: llvm.call @__nv_ceilf(%{{.*}}) : (!llvm.float) -> !llvm.float
     %result64 = std.ceilf %arg_f64 : f64
     // CHECK: llvm.call @__nv_ceil(%{{.*}}) : (!llvm.double) -> !llvm.double
-    std.return
+    std.return %result32, %result64 : f32, f64
   }
 }
 
@@ -145,12 +145,12 @@ gpu.module @test_module {
   // CHECK: llvm.func @__nv_cosf(!llvm.float) -> !llvm.float
   // CHECK: llvm.func @__nv_cos(!llvm.double) -> !llvm.double
   // CHECK-LABEL: func @gpu_cos
-  func @gpu_cos(%arg_f32 : f32, %arg_f64 : f64) {
+  func @gpu_cos(%arg_f32 : f32, %arg_f64 : f64) -> (f32, f64) {
     %result32 = std.cos %arg_f32 : f32
     // CHECK: llvm.call @__nv_cosf(%{{.*}}) : (!llvm.float) -> !llvm.float
     %result64 = std.cos %arg_f64 : f64
     // CHECK: llvm.call @__nv_cos(%{{.*}}) : (!llvm.double) -> !llvm.double
-    std.return
+    std.return %result32, %result64 : f32, f64
   }
 }
 
@@ -159,14 +159,12 @@ gpu.module @test_module {
   // CHECK: llvm.func @__nv_expf(!llvm.float) -> !llvm.float
   // CHECK: llvm.func @__nv_exp(!llvm.double) -> !llvm.double
   // CHECK-LABEL: func @gpu_exp
-  func @gpu_exp(%arg_f32 : f32, %arg_f64 : f64) {
-    %exp_f32 = std.exp %arg_f32 : f32
-    // CHECK: llvm.call @__nv_expf(%{{.*}}) : (!llvm.float) -> !llvm.float
-    %result_f32 = std.exp %exp_f32 : f32
+  func @gpu_exp(%arg_f32 : f32, %arg_f64 : f64) -> (f32, f64) {
+    %result32 = std.exp %arg_f32 : f32
     // CHECK: llvm.call @__nv_expf(%{{.*}}) : (!llvm.float) -> !llvm.float
     %result64 = std.exp %arg_f64 : f64
     // CHECK: llvm.call @__nv_exp(%{{.*}}) : (!llvm.double) -> !llvm.double
-    std.return
+    std.return %result32, %result64 : f32, f64
   }
 }
 
@@ -176,12 +174,12 @@ gpu.module @test_module {
   // CHECK: llvm.func @__nv_logf(!llvm.float) -> !llvm.float
   // CHECK: llvm.func @__nv_log(!llvm.double) -> !llvm.double
   // CHECK-LABEL: func @gpu_log
-  func @gpu_log(%arg_f32 : f32, %arg_f64 : f64) {
+  func @gpu_log(%arg_f32 : f32, %arg_f64 : f64) -> (f32, f64) {
     %result32 = std.log %arg_f32 : f32
     // CHECK: llvm.call @__nv_logf(%{{.*}}) : (!llvm.float) -> !llvm.float
     %result64 = std.log %arg_f64 : f64
     // CHECK: llvm.call @__nv_log(%{{.*}}) : (!llvm.double) -> !llvm.double
-    std.return
+    std.return %result32, %result64 : f32, f64
   }
 }
 
@@ -191,12 +189,12 @@ gpu.module @test_module {
   // CHECK: llvm.func @__nv_log10f(!llvm.float) -> !llvm.float
   // CHECK: llvm.func @__nv_log10(!llvm.double) -> !llvm.double
   // CHECK-LABEL: func @gpu_log10
-  func @gpu_log10(%arg_f32 : f32, %arg_f64 : f64) {
+  func @gpu_log10(%arg_f32 : f32, %arg_f64 : f64) -> (f32, f64) {
     %result32 = std.log10 %arg_f32 : f32
     // CHECK: llvm.call @__nv_log10f(%{{.*}}) : (!llvm.float) -> !llvm.float
     %result64 = std.log10 %arg_f64 : f64
     // CHECK: llvm.call @__nv_log10(%{{.*}}) : (!llvm.double) -> !llvm.double
-    std.return
+    std.return %result32, %result64 : f32, f64
   }
 }
 
@@ -206,12 +204,12 @@ gpu.module @test_module {
   // CHECK: llvm.func @__nv_log2f(!llvm.float) -> !llvm.float
   // CHECK: llvm.func @__nv_log2(!llvm.double) -> !llvm.double
   // CHECK-LABEL: func @gpu_log2
-  func @gpu_log2(%arg_f32 : f32, %arg_f64 : f64) {
+  func @gpu_log2(%arg_f32 : f32, %arg_f64 : f64) -> (f32, f64) {
     %result32 = std.log2 %arg_f32 : f32
     // CHECK: llvm.call @__nv_log2f(%{{.*}}) : (!llvm.float) -> !llvm.float
     %result64 = std.log2 %arg_f64 : f64
     // CHECK: llvm.call @__nv_log2(%{{.*}}) : (!llvm.double) -> !llvm.double
-    std.return
+    std.return %result32, %result64 : f32, f64
   }
 }
 
@@ -221,12 +219,12 @@ gpu.module @test_module {
   // CHECK: llvm.func @__nv_tanhf(!llvm.float) -> !llvm.float
   // CHECK: llvm.func @__nv_tanh(!llvm.double) -> !llvm.double
   // CHECK-LABEL: func @gpu_tanh
-  func @gpu_tanh(%arg_f32 : f32, %arg_f64 : f64) {
+  func @gpu_tanh(%arg_f32 : f32, %arg_f64 : f64) -> (f32, f64) {
     %result32 = std.tanh %arg_f32 : f32
     // CHECK: llvm.call @__nv_tanhf(%{{.*}}) : (!llvm.float) -> !llvm.float
     %result64 = std.tanh %arg_f64 : f64
     // CHECK: llvm.call @__nv_tanh(%{{.*}}) : (!llvm.double) -> !llvm.double
-    std.return
+    std.return %result32, %result64 : f32, f64
   }
 }
 
@@ -239,14 +237,12 @@ gpu.module @test_module {
   // CHECK: llvm.func @__nv_expf(!llvm.float) -> !llvm.float
   // CHECK: llvm.func @__nv_exp(!llvm.double) -> !llvm.double
   // CHECK-LABEL: func @gpu_exp
-    func @gpu_exp(%arg_f32 : f32, %arg_f64 : f64) {
-      %exp_f32 = std.exp %arg_f32 : f32
-      // CHECK: llvm.call @__nv_expf(%{{.*}}) : (!llvm.float) -> !llvm.float
-      %result_f32 = std.exp %exp_f32 : f32
+    func @gpu_exp(%arg_f32 : f32, %arg_f64 : f64) -> (f32, f64) {
+      %result32 = std.exp %arg_f32 : f32
       // CHECK: llvm.call @__nv_expf(%{{.*}}) : (!llvm.float) -> !llvm.float
       %result64 = std.exp %arg_f64 : f64
       // CHECK: llvm.call @__nv_exp(%{{.*}}) : (!llvm.double) -> !llvm.double
-      std.return
+      std.return %result32, %result64 : f32, f64
     }
     "test.finish" () : () -> ()
   }) : () -> ()


        


More information about the Mlir-commits mailing list