[Mlir-commits] [mlir] [mlir][nvgpu] Use ImplicitLocOpBuilder in nvgpu-to-nvvm pass (NFC) (PR #67993)

Guray Ozen llvmlistbot at llvm.org
Mon Oct 2 07:47:13 PDT 2023


https://github.com/grypp created https://github.com/llvm/llvm-project/pull/67993

For the sake of better readability, this PR uses  `ImplicitLocOpBuilder` instead of rewriter+loc

>From b53d0a84cd092287cfd7e038939d549d849f3539 Mon Sep 17 00:00:00 2001
From: Guray Ozen <guray.ozen at gmail.com>
Date: Mon, 2 Oct 2023 16:46:25 +0200
Subject: [PATCH] [mlir][nvgpu] Use ImplicitLocOpBuilder in nvgpu-to-nvvm pass
 (NFC)

For the sake of better readability, this PR uses  `ImplicitLocOpBuilder` instead of rewriter+loc
---
 .../Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp    | 302 +++++++++---------
 1 file changed, 147 insertions(+), 155 deletions(-)

diff --git a/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp b/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp
index 4d1f6641af6dca3..d308a9e07a6080f 100644
--- a/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp
+++ b/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp
@@ -19,6 +19,8 @@
 #include "mlir/Dialect/NVGPU/IR/NVGPUDialect.h"
 #include "mlir/Dialect/SCF/Transforms/Patterns.h"
 #include "mlir/IR/BuiltinTypes.h"
+#include "mlir/IR/ImplicitLocOpBuilder.h"
+#include "mlir/IR/Location.h"
 #include "mlir/IR/PatternMatch.h"
 #include "mlir/IR/TypeUtilities.h"
 #include "mlir/IR/Value.h"
@@ -44,13 +46,12 @@ constexpr int exclude4LSB = 4;
 
 /// GPU has 32 bit registers, this function truncates values when larger width
 /// is not needed.
-static Value truncToI32(ConversionPatternRewriter &rewriter, Location loc,
-                        Value value) {
+static Value truncToI32(ImplicitLocOpBuilder &b, Value value) {
   Type type = value.getType();
   assert(llvm::isa<IntegerType>(type) && "expected an integer Value");
   if (type.getIntOrFloatBitWidth() <= 32)
     return value;
-  return rewriter.create<LLVM::TruncOp>(loc, rewriter.getI32Type(), value);
+  return b.create<LLVM::TruncOp>(b.getI32Type(), value);
 }
 
 /// Returns the type for the intrinsic given the vectorResultType of the
@@ -170,22 +171,23 @@ static Value convertIntrinsicResult(Location loc, Type intrinsicResultType,
 /// `nvvm.mma.sync` op expects these argments to be a given in a long list of
 /// scalars of certain types. This function helps unpack the `vector` arguments
 /// and cast them to the types expected by `nvvm.mma.sync`.
-static SmallVector<Value> unpackOperandVector(RewriterBase &rewriter,
-                                              Location loc, Value operand,
+static SmallVector<Value> unpackOperandVector(ImplicitLocOpBuilder &b,
+                                              Value operand,
                                               NVVM::MMATypes operandPtxType) {
   SmallVector<Value> result;
-  Type i32Ty = rewriter.getI32Type();
-  Type f64Ty = rewriter.getF64Type();
-  Type f32Ty = rewriter.getF32Type();
-  Type i8Ty = rewriter.getI8Type();
-  Type i4Ty = rewriter.getIntegerType(4);
+  Type i32Ty = b.getI32Type();
+  Type f64Ty = b.getF64Type();
+  Type f32Ty = b.getF32Type();
+  Type i8Ty = b.getI8Type();
+  Type i64Ty = b.getI64Type();
+  Type i4Ty = b.getIntegerType(4);
   Type i8x4Ty = LLVM::getFixedVectorType(i8Ty, 4);
   Type i4x8Ty = LLVM::getFixedVectorType(i4Ty, 8);
   Type f32x1Ty = LLVM::getFixedVectorType(f32Ty, 1);
   auto arrayTy = cast<LLVM::LLVMArrayType>(operand.getType());
 
   for (unsigned i = 0, e = arrayTy.getNumElements(); i < e; ++i) {
-    Value toUse = rewriter.create<LLVM::ExtractValueOp>(loc, operand, i);
+    Value toUse = b.create<LLVM::ExtractValueOp>(operand, i);
 
     // For 4xi8 vectors, the intrinsic expects these to be provided as i32
     // scalar types.
@@ -193,8 +195,7 @@ static SmallVector<Value> unpackOperandVector(RewriterBase &rewriter,
         arrayTy.getElementType() == i4x8Ty ||
         (arrayTy.getElementType() == f32x1Ty &&
          operandPtxType == NVVM::MMATypes::tf32)) {
-      result.push_back(
-          rewriter.create<LLVM::BitcastOp>(loc, rewriter.getI32Type(), toUse));
+      result.push_back(b.create<LLVM::BitcastOp>(i32Ty, toUse));
       continue;
     }
 
@@ -207,10 +208,9 @@ static SmallVector<Value> unpackOperandVector(RewriterBase &rewriter,
                          innerArrayTy.getElementType() == f32Ty)) {
       for (unsigned idx = 0, innerSize = innerArrayTy.getNumElements();
            idx < innerSize; idx++) {
-        result.push_back(rewriter.create<LLVM::ExtractElementOp>(
-            loc, toUse,
-            rewriter.create<LLVM::ConstantOp>(
-                loc, rewriter.getI64Type(), rewriter.getI64IntegerAttr(idx))));
+        result.push_back(b.create<LLVM::ExtractElementOp>(
+            toUse,
+            b.create<LLVM::ConstantOp>(i64Ty, b.getI64IntegerAttr(idx))));
       }
       continue;
     }
@@ -256,7 +256,7 @@ struct MmaLdMatrixOpToNVVM : public ConvertOpToLLVMPattern<nvgpu::LdMatrixOp> {
   matchAndRewrite(nvgpu::LdMatrixOp op, OpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
     MLIRContext *ctx = getContext();
-    Location loc = op->getLoc();
+    ImplicitLocOpBuilder b(op.getLoc(), rewriter);
 
     // The result type of ldmatrix will always be a struct of 32bit integer
     // registers if more than one 32bit value is returned. Otherwise, the result
@@ -283,10 +283,10 @@ struct MmaLdMatrixOpToNVVM : public ConvertOpToLLVMPattern<nvgpu::LdMatrixOp> {
 
     auto srcMemrefType = cast<MemRefType>(op.getSrcMemref().getType());
     Value srcPtr =
-        getStridedElementPtr(loc, srcMemrefType, adaptor.getSrcMemref(),
+        getStridedElementPtr(b.getLoc(), srcMemrefType, adaptor.getSrcMemref(),
                              adaptor.getIndices(), rewriter);
-    Value ldMatrixResult = rewriter.create<NVVM::LdMatrixOp>(
-        loc, ldMatrixResultType, srcPtr,
+    Value ldMatrixResult = b.create<NVVM::LdMatrixOp>(
+        ldMatrixResultType, srcPtr,
         /*num=*/op.getNumTiles(),
         /*layout=*/op.getTranspose() ? NVVM::MMALayout::col
                                      : NVVM::MMALayout::row);
@@ -296,15 +296,13 @@ struct MmaLdMatrixOpToNVVM : public ConvertOpToLLVMPattern<nvgpu::LdMatrixOp> {
     // actual vector type (still of width 32b) and repack them into a result
     // struct.
     Type finalResultType = typeConverter->convertType(vectorResultType);
-    Value result = rewriter.create<LLVM::UndefOp>(loc, finalResultType);
+    Value result = b.create<LLVM::UndefOp>(finalResultType);
     for (int64_t i = 0, e = vectorResultType.getDimSize(0); i < e; i++) {
       Value i32Register =
-          num32BitRegs > 1
-              ? rewriter.create<LLVM::ExtractValueOp>(loc, ldMatrixResult, i)
-              : ldMatrixResult;
-      Value casted =
-          rewriter.create<LLVM::BitcastOp>(loc, innerVectorType, i32Register);
-      result = rewriter.create<LLVM::InsertValueOp>(loc, result, casted, i);
+          num32BitRegs > 1 ? b.create<LLVM::ExtractValueOp>(ldMatrixResult, i)
+                           : ldMatrixResult;
+      Value casted = b.create<LLVM::BitcastOp>(innerVectorType, i32Register);
+      result = b.create<LLVM::InsertValueOp>(result, casted, i);
     }
 
     rewriter.replaceOp(op, result);
@@ -335,7 +333,7 @@ struct MmaSyncOptoNVVM : public ConvertOpToLLVMPattern<nvgpu::MmaSyncOp> {
   LogicalResult
   matchAndRewrite(nvgpu::MmaSyncOp op, OpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
-    Location loc = op->getLoc();
+    ImplicitLocOpBuilder b(op.getLoc(), rewriter);
     // Get the shapes of the MMAMatrix type being used. The shapes will
     // choose which intrinsic this op will be lowered to.
     VectorType aType = op.getMatrixA().getType();
@@ -368,17 +366,17 @@ struct MmaSyncOptoNVVM : public ConvertOpToLLVMPattern<nvgpu::MmaSyncOp> {
       overflow = NVVM::MMAIntOverflow::satfinite;
 
     SmallVector<Value> matA =
-        unpackOperandVector(rewriter, loc, adaptor.getMatrixA(), *ptxTypeA);
+        unpackOperandVector(b, adaptor.getMatrixA(), *ptxTypeA);
     SmallVector<Value> matB =
-        unpackOperandVector(rewriter, loc, adaptor.getMatrixB(), *ptxTypeB);
+        unpackOperandVector(b, adaptor.getMatrixB(), *ptxTypeB);
     SmallVector<Value> matC =
-        unpackOperandVector(rewriter, loc, adaptor.getMatrixC(), *ptxTypeC);
+        unpackOperandVector(b, adaptor.getMatrixC(), *ptxTypeC);
 
     Type desiredRetTy = typeConverter->convertType(op->getResultTypes()[0]);
     Type intrinsicResTy = inferIntrinsicResultType(
         typeConverter->convertType(op->getResultTypes()[0]));
-    Value intrinsicResult = rewriter.create<NVVM::MmaOp>(
-        op.getLoc(), intrinsicResTy, matA, matB, matC,
+    Value intrinsicResult = b.create<NVVM::MmaOp>(
+        intrinsicResTy, matA, matB, matC,
         /*shape=*/gemmShape,
         /*b1Op=*/std::nullopt,
         /*intOverflow=*/overflow,
@@ -511,14 +509,14 @@ static std::string buildMmaSparseAsmString(
 /// Builds an inline assembly operation corresponding to the specified MMA
 /// sparse sync operation.
 static FailureOr<LLVM::InlineAsmOp> emitMmaSparseSyncOpAsm(
-    Location loc, NVVM::MMATypes ptxTypeA, NVVM::MMATypes ptxTypeB,
+    ImplicitLocOpBuilder &b, NVVM::MMATypes ptxTypeA, NVVM::MMATypes ptxTypeB,
     NVVM::MMATypes ptxTypeC, NVVM::MMATypes ptxTypeD,
     std::optional<NVVM::MMAIntOverflow> overflow, ArrayRef<Value> unpackedAData,
     ArrayRef<Value> unpackedB, ArrayRef<Value> unpackedC, Value indexData,
     int64_t metadataSelector, const std::array<int64_t, 3> &shape,
-    Type intrinsicResultType, ConversionPatternRewriter &rewriter) {
-  auto asmDialectAttr = LLVM::AsmDialectAttr::get(rewriter.getContext(),
-                                                  LLVM::AsmDialect::AD_ATT);
+    Type intrinsicResultType) {
+  auto asmDialectAttr =
+      LLVM::AsmDialectAttr::get(b.getContext(), LLVM::AsmDialect::AD_ATT);
 
   const unsigned matASize = unpackedAData.size();
   const unsigned matBSize = unpackedB.size();
@@ -536,15 +534,15 @@ static FailureOr<LLVM::InlineAsmOp> emitMmaSparseSyncOpAsm(
     llvm::append_range(asmVals, args);
   asmVals.push_back(indexData);
 
-  return rewriter.create<LLVM::InlineAsmOp>(loc,
-                                            /*resultTypes=*/intrinsicResultType,
-                                            /*operands=*/asmVals,
-                                            /*asm_string=*/asmStr,
-                                            /*constraints=*/constraintStr,
-                                            /*has_side_effects=*/true,
-                                            /*is_align_stack=*/false,
-                                            /*asm_dialect=*/asmDialectAttr,
-                                            /*operand_attrs=*/ArrayAttr());
+  return b.create<LLVM::InlineAsmOp>(
+      /*resultTypes=*/intrinsicResultType,
+      /*operands=*/asmVals,
+      /*asm_string=*/asmStr,
+      /*constraints=*/constraintStr,
+      /*has_side_effects=*/true,
+      /*is_align_stack=*/false,
+      /*asm_dialect=*/asmDialectAttr,
+      /*operand_attrs=*/ArrayAttr());
 }
 
 /// Lowers `nvgpu.mma.sp.sync` to inline assembly.
@@ -555,7 +553,7 @@ struct NVGPUMmaSparseSyncLowering
   LogicalResult
   matchAndRewrite(nvgpu::MmaSparseSyncOp op, OpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
-    Location loc = op->getLoc();
+    ImplicitLocOpBuilder b(op.getLoc(), rewriter);
     // Get the shapes of the MMAMatrix type being used. The shapes will
     // choose which intrinsic this op will be lowered to.
     VectorType aType = op.getMatrixA().getType();
@@ -586,11 +584,11 @@ struct NVGPUMmaSparseSyncLowering
       overflow = NVVM::MMAIntOverflow::satfinite;
 
     SmallVector<Value> matA =
-        unpackOperandVector(rewriter, loc, adaptor.getMatrixA(), *ptxTypeA);
+        unpackOperandVector(b, adaptor.getMatrixA(), *ptxTypeA);
     SmallVector<Value> matB =
-        unpackOperandVector(rewriter, loc, adaptor.getMatrixB(), *ptxTypeB);
+        unpackOperandVector(b, adaptor.getMatrixB(), *ptxTypeB);
     SmallVector<Value> matC =
-        unpackOperandVector(rewriter, loc, adaptor.getMatrixC(), *ptxTypeC);
+        unpackOperandVector(b, adaptor.getMatrixC(), *ptxTypeC);
 
     Type desiredRetTy = typeConverter->convertType(op->getResultTypes()[0]);
     Type intrinsicResTy = inferIntrinsicResultType(
@@ -602,13 +600,13 @@ struct NVGPUMmaSparseSyncLowering
         LLVM::getFixedVectorType(rewriter.getI16Type(), 2))
       return op->emitOpError() << "Expected metadata type to be LLVM "
                                   "VectorType of 2 i16 elements";
-    sparseMetadata = rewriter.create<LLVM::BitcastOp>(
-        loc, rewriter.getI32Type(), sparseMetadata);
+    sparseMetadata =
+        b.create<LLVM::BitcastOp>(rewriter.getI32Type(), sparseMetadata);
 
     FailureOr<LLVM::InlineAsmOp> intrinsicResult = emitMmaSparseSyncOpAsm(
-        loc, *ptxTypeA, *ptxTypeB, *ptxTypeC, *ptxTypeC, overflow, matA, matB,
+        b, *ptxTypeA, *ptxTypeB, *ptxTypeC, *ptxTypeC, overflow, matA, matB,
         matC, sparseMetadata, op.getSparsitySelector(), op.getMmaShapeAsArray(),
-        intrinsicResTy, rewriter);
+        intrinsicResTy);
     if (failed(intrinsicResult))
       return failure();
 
@@ -629,10 +627,12 @@ struct NVGPUAsyncCopyLowering
   LogicalResult
   matchAndRewrite(nvgpu::DeviceAsyncCopyOp op, OpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
-    Location loc = op->getLoc();
+    ImplicitLocOpBuilder b(op.getLoc(), rewriter);
+    Location loc = op.getLoc();
     auto dstMemrefType = cast<MemRefType>(op.getDst().getType());
-    Value dstPtr = getStridedElementPtr(loc, dstMemrefType, adaptor.getDst(),
-                                        adaptor.getDstIndices(), rewriter);
+    Value dstPtr =
+        getStridedElementPtr(b.getLoc(), dstMemrefType, adaptor.getDst(),
+                             adaptor.getDstIndices(), rewriter);
     auto i8Ty = IntegerType::get(op.getContext(), 8);
     FailureOr<unsigned> dstAddressSpace =
         getTypeConverter()->getMemRefAddressSpace(dstMemrefType);
@@ -642,7 +642,7 @@ struct NVGPUAsyncCopyLowering
     auto dstPointerType =
         getTypeConverter()->getPointerType(i8Ty, *dstAddressSpace);
     if (!getTypeConverter()->useOpaquePointers())
-      dstPtr = rewriter.create<LLVM::BitcastOp>(loc, dstPointerType, dstPtr);
+      dstPtr = b.create<LLVM::BitcastOp>(dstPointerType, dstPtr);
 
     auto srcMemrefType = cast<MemRefType>(op.getSrc().getType());
     FailureOr<unsigned> srcAddressSpace =
@@ -656,12 +656,11 @@ struct NVGPUAsyncCopyLowering
     auto srcPointerType =
         getTypeConverter()->getPointerType(i8Ty, *srcAddressSpace);
     if (!getTypeConverter()->useOpaquePointers())
-      scrPtr = rewriter.create<LLVM::BitcastOp>(loc, srcPointerType, scrPtr);
+      scrPtr = b.create<LLVM::BitcastOp>(srcPointerType, scrPtr);
     // Intrinsics takes a global pointer so we need an address space cast.
     auto srcPointerGlobalType = getTypeConverter()->getPointerType(
         i8Ty, NVVM::NVVMMemorySpace::kGlobalMemorySpace);
-    scrPtr = rewriter.create<LLVM::AddrSpaceCastOp>(loc, srcPointerGlobalType,
-                                                    scrPtr);
+    scrPtr = b.create<LLVM::AddrSpaceCastOp>(srcPointerGlobalType, scrPtr);
     int64_t dstElements = adaptor.getDstElements().getZExtValue();
     int64_t sizeInBytes =
         (dstMemrefType.getElementTypeBitWidth() * dstElements) / 8;
@@ -675,16 +674,14 @@ struct NVGPUAsyncCopyLowering
       // memory) of CpAsyncOp is read only for SrcElements number of elements.
       // The rest of the DstElements in the destination (shared memory) are
       // filled with zeros.
-      Value c3I32 = rewriter.create<LLVM::ConstantOp>(
-          loc, rewriter.getI32Type(), rewriter.getI32IntegerAttr(3));
-      Value bitwidth = rewriter.create<LLVM::ConstantOp>(
-          loc, rewriter.getI32Type(),
-          rewriter.getI32IntegerAttr(srcMemrefType.getElementTypeBitWidth()));
-      Value srcElementsI32 =
-          rewriter.create<LLVM::TruncOp>(loc, rewriter.getI32Type(), srcBytes);
-      srcBytes = rewriter.create<LLVM::LShrOp>(
-          loc, rewriter.create<LLVM::MulOp>(loc, bitwidth, srcElementsI32),
-          c3I32);
+      Value c3I32 =
+          b.create<LLVM::ConstantOp>(b.getI32Type(), b.getI32IntegerAttr(3));
+      Value bitwidth = b.create<LLVM::ConstantOp>(
+          b.getI32Type(),
+          b.getI32IntegerAttr(srcMemrefType.getElementTypeBitWidth()));
+      Value srcElementsI32 = b.create<LLVM::TruncOp>(b.getI32Type(), srcBytes);
+      srcBytes = b.create<LLVM::LShrOp>(
+          b.create<LLVM::MulOp>(bitwidth, srcElementsI32), c3I32);
     }
     // Cache global (.cg) for 16 dst bytes, Cache all (.ca) for sizes other than
     // 16 dst bytes.
@@ -693,15 +690,14 @@ struct NVGPUAsyncCopyLowering
             ? NVVM::LoadCacheModifierKind::CG
             : NVVM::LoadCacheModifierKind::CA;
 
-    rewriter.create<NVVM::CpAsyncOp>(
-        loc, dstPtr, scrPtr, rewriter.getI32IntegerAttr(sizeInBytes),
+    b.create<NVVM::CpAsyncOp>(
+        dstPtr, scrPtr, rewriter.getI32IntegerAttr(sizeInBytes),
         NVVM::LoadCacheModifierKindAttr::get(op->getContext(), cacheModifier),
         srcBytes);
 
     // Drop the result token.
-    Value zero = rewriter.create<LLVM::ConstantOp>(
-        op->getLoc(), IntegerType::get(op.getContext(), 32),
-        rewriter.getI32IntegerAttr(0));
+    Value zero = b.create<LLVM::ConstantOp>(
+        IntegerType::get(op.getContext(), 32), rewriter.getI32IntegerAttr(0));
     rewriter.replaceOp(op, zero);
     return success();
   }
@@ -790,14 +786,14 @@ struct MBarrierBasePattern : public ConvertOpToLLVMPattern<SourceOp> {
 public:
   using ConvertOpToLLVMPattern<SourceOp>::ConvertOpToLLVMPattern;
   /// Returns the base pointer of the mbarrier object.
-  Value getMbarrierPtr(Operation *op, nvgpu::MBarrierGroupType mbarType,
-                       Value memrefDesc, Value mbarId,
+  Value getMbarrierPtr(ImplicitLocOpBuilder &b,
+                       nvgpu::MBarrierGroupType mbarType, Value memrefDesc,
+                       Value mbarId,
                        ConversionPatternRewriter &rewriter) const {
     MemRefType mbarrierMemrefType =
         nvgpu::getMBarrierMemrefType(rewriter.getContext(), mbarType);
     return ConvertToLLVMPattern::getStridedElementPtr(
-        op->getLoc(), mbarrierMemrefType, memrefDesc, {mbarId}, rewriter);
-    return memrefDesc;
+        b.getLoc(), mbarrierMemrefType, memrefDesc, {mbarId}, rewriter);
   }
 };
 
@@ -809,11 +805,12 @@ struct NVGPUMBarrierInitLowering
   LogicalResult
   matchAndRewrite(nvgpu::MBarrierInitOp op, OpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
+    ImplicitLocOpBuilder b(op->getLoc(), rewriter);
     nvgpu::MBarrierGroupType mbarrierType = op.getBarriers().getType();
     rewriter.setInsertionPoint(op);
-    Value barrier = getMbarrierPtr(op, mbarrierType, adaptor.getBarriers(),
+    Value barrier = getMbarrierPtr(b, mbarrierType, adaptor.getBarriers(),
                                    adaptor.getMbarId(), rewriter);
-    Value count = truncToI32(rewriter, op->getLoc(), adaptor.getCount());
+    Value count = truncToI32(b, adaptor.getCount());
     if (isMbarrierShared(mbarrierType)) {
       rewriter.replaceOpWithNewOp<NVVM::MBarrierInitSharedOp>(op, barrier,
                                                               count);
@@ -831,8 +828,9 @@ struct NVGPUMBarrierArriveLowering
   LogicalResult
   matchAndRewrite(nvgpu::MBarrierArriveOp op, OpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
+    ImplicitLocOpBuilder b(op->getLoc(), rewriter);
     Value barrier =
-        getMbarrierPtr(op, op.getBarriers().getType(), adaptor.getBarriers(),
+        getMbarrierPtr(b, op.getBarriers().getType(), adaptor.getBarriers(),
                        adaptor.getMbarId(), rewriter);
     Type tokenType = getTypeConverter()->convertType(
         nvgpu::MBarrierTokenType::get(op->getContext()));
@@ -856,12 +854,13 @@ struct NVGPUMBarrierArriveNoCompleteLowering
   LogicalResult
   matchAndRewrite(nvgpu::MBarrierArriveNoCompleteOp op, OpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
+    ImplicitLocOpBuilder b(op->getLoc(), rewriter);
     Value barrier =
-        getMbarrierPtr(op, op.getBarriers().getType(), adaptor.getBarriers(),
+        getMbarrierPtr(b, op.getBarriers().getType(), adaptor.getBarriers(),
                        adaptor.getMbarId(), rewriter);
     Type tokenType = getTypeConverter()->convertType(
         nvgpu::MBarrierTokenType::get(op->getContext()));
-    Value count = truncToI32(rewriter, op->getLoc(), adaptor.getCount());
+    Value count = truncToI32(b, adaptor.getCount());
     if (isMbarrierShared(op.getBarriers().getType())) {
       rewriter.replaceOpWithNewOp<NVVM::MBarrierArriveNocompleteSharedOp>(
           op, tokenType, barrier, count);
@@ -880,8 +879,9 @@ struct NVGPUMBarrierTestWaitLowering
   LogicalResult
   matchAndRewrite(nvgpu::MBarrierTestWaitOp op, OpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
+    ImplicitLocOpBuilder b(op->getLoc(), rewriter);
     Value barrier =
-        getMbarrierPtr(op, op.getBarriers().getType(), adaptor.getBarriers(),
+        getMbarrierPtr(b, op.getBarriers().getType(), adaptor.getBarriers(),
                        adaptor.getMbarId(), rewriter);
     Type retType = rewriter.getI1Type();
     if (isMbarrierShared(op.getBarriers().getType())) {
@@ -902,10 +902,11 @@ struct NVGPUMBarrierArriveExpectTxLowering
   LogicalResult
   matchAndRewrite(nvgpu::MBarrierArriveExpectTxOp op, OpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
+    ImplicitLocOpBuilder b(op->getLoc(), rewriter);
     Value barrier =
-        getMbarrierPtr(op, op.getBarriers().getType(), adaptor.getBarriers(),
+        getMbarrierPtr(b, op.getBarriers().getType(), adaptor.getBarriers(),
                        adaptor.getMbarId(), rewriter);
-    Value txcount = truncToI32(rewriter, op->getLoc(), adaptor.getTxcount());
+    Value txcount = truncToI32(b, adaptor.getTxcount());
 
     if (isMbarrierShared(op.getBarriers().getType())) {
       rewriter.replaceOpWithNewOp<NVVM::MBarrierArriveExpectTxSharedOp>(
@@ -926,11 +927,12 @@ struct NVGPUMBarrierTryWaitParityLowering
   LogicalResult
   matchAndRewrite(nvgpu::MBarrierTryWaitParityOp op, OpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
+    ImplicitLocOpBuilder b(op->getLoc(), rewriter);
     Value barrier =
-        getMbarrierPtr(op, op.getBarriers().getType(), adaptor.getBarriers(),
+        getMbarrierPtr(b, op.getBarriers().getType(), adaptor.getBarriers(),
                        adaptor.getMbarId(), rewriter);
-    Value ticks = truncToI32(rewriter, op->getLoc(), adaptor.getTicks());
-    Value phase = truncToI32(rewriter, op->getLoc(), adaptor.getPhase());
+    Value ticks = truncToI32(b, adaptor.getTicks());
+    Value phase = truncToI32(b, adaptor.getPhase());
 
     if (isMbarrierShared(op.getBarriers().getType())) {
       rewriter.replaceOpWithNewOp<NVVM::MBarrierTryWaitParitySharedOp>(
@@ -950,16 +952,17 @@ struct NVGPUTmaAsyncLoadOpLowering
   LogicalResult
   matchAndRewrite(nvgpu::TmaAsyncLoadOp op, OpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
+    ImplicitLocOpBuilder b(op->getLoc(), rewriter);
     auto srcMemrefType = cast<MemRefType>(op.getDst().getType());
     Value dest = getStridedElementPtr(op->getLoc(), srcMemrefType,
                                       adaptor.getDst(), {}, rewriter);
     Value barrier =
-        getMbarrierPtr(op, op.getBarriers().getType(), adaptor.getBarriers(),
+        getMbarrierPtr(b, op.getBarriers().getType(), adaptor.getBarriers(),
                        adaptor.getMbarId(), rewriter);
 
     SmallVector<Value> coords = adaptor.getCoordinates();
     for (auto [index, value] : llvm::enumerate(coords)) {
-      coords[index] = truncToI32(rewriter, op->getLoc(), value);
+      coords[index] = truncToI32(b, value);
     }
 
     rewriter.replaceOpWithNewOp<NVVM::CpAsyncBulkTensorGlobalToSharedClusterOp>(
@@ -976,7 +979,7 @@ struct NVGPUGenerateGmmaDescriptorLowering
   matchAndRewrite(nvgpu::GenerateGmmaDescriptorOp op, OpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
 
-    Location loc = op->getLoc();
+    ImplicitLocOpBuilder b(op->getLoc(), rewriter);
 
     nvgpu::TensorMapSwizzleKind swizzleKind =
         op.getTensorMap().getType().getSwizzle();
@@ -992,20 +995,18 @@ struct NVGPUGenerateGmmaDescriptorLowering
         : (swizzleKind == nvgpu::TensorMapSwizzleKind::SWIZZLE_32B) ? 3
                                                                     : 0;
 
-    auto ti64 = rewriter.getIntegerType(64);
+    auto ti64 = b.getIntegerType(64);
     auto makeConst = [&](uint64_t index) -> Value {
-      return rewriter.create<LLVM::ConstantOp>(
-          loc, ti64, rewriter.getI64IntegerAttr(index));
+      return b.create<LLVM::ConstantOp>(ti64, b.getI64IntegerAttr(index));
     };
     auto shiftLeft = [&](Value value, unsigned shift) -> Value {
-      return rewriter.create<LLVM::ShlOp>(loc, ti64, value, makeConst(shift));
+      return b.create<LLVM::ShlOp>(ti64, value, makeConst(shift));
     };
     auto shiftRight = [&](Value value, unsigned shift) -> Value {
-      return rewriter.create<LLVM::LShrOp>(loc, ti64, value, makeConst(shift));
+      return b.create<LLVM::LShrOp>(ti64, value, makeConst(shift));
     };
     auto insertBit = [&](Value desc, Value val, int startBit) {
-      return rewriter.create<LLVM::OrOp>(loc, ti64, desc,
-                                         shiftLeft(val, startBit));
+      return b.create<LLVM::OrOp>(ti64, desc, shiftLeft(val, startBit));
     };
 
     int64_t sizeN = op.getTensorMap().getType().getTensor().getDimSize(0);
@@ -1019,7 +1020,7 @@ struct NVGPUGenerateGmmaDescriptorLowering
     Value baseAddr = getStridedElementPtr(
         op->getLoc(), cast<MemRefType>(op.getTensor().getType()),
         adaptor.getTensor(), {}, rewriter);
-    Value basePtr = rewriter.create<LLVM::PtrToIntOp>(loc, ti64, baseAddr);
+    Value basePtr = b.create<LLVM::PtrToIntOp>(ti64, baseAddr);
     // Just use 14 bits for base address
     Value basePtr14bit = shiftRight(shiftLeft(basePtr, 46), 50);
 
@@ -1050,16 +1051,13 @@ struct NVGPUGenerateGmmaDescriptorLowering
   }
 };
 
-static Value makeI64Const(RewriterBase &rewriter, Operation *op,
-                          int32_t index) {
-  return rewriter.create<LLVM::ConstantOp>(op->getLoc(),
-                                           rewriter.getIntegerType(64),
-                                           rewriter.getI32IntegerAttr(index));
+static Value makeI64Const(ImplicitLocOpBuilder &b, int32_t index) {
+  return b.create<LLVM::ConstantOp>(b.getIntegerType(64),
+                                    b.getI32IntegerAttr(index));
 }
 
 /// Returns a Value that holds data type enum that is expected by CUDA driver.
-static Value elementTypeAsLLVMConstant(RewriterBase &rewriter, Operation *op,
-                                       Type type) {
+static Value elementTypeAsLLVMConstant(ImplicitLocOpBuilder &b, Type type) {
   // Enum is from CUDA driver API
   // https://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__TYPES.html
   enum CUtensorMapDataTypeEnum {
@@ -1079,25 +1077,25 @@ static Value elementTypeAsLLVMConstant(RewriterBase &rewriter, Operation *op,
   };
 
   if (type.isUnsignedInteger(8))
-    return makeI64Const(rewriter, op, CU_TENSOR_MAP_DATA_TYPE_UINT8);
+    return makeI64Const(b, CU_TENSOR_MAP_DATA_TYPE_UINT8);
   if (type.isUnsignedInteger(16))
-    return makeI64Const(rewriter, op, CU_TENSOR_MAP_DATA_TYPE_UINT16);
+    return makeI64Const(b, CU_TENSOR_MAP_DATA_TYPE_UINT16);
   if (type.isUnsignedInteger(32))
-    return makeI64Const(rewriter, op, CU_TENSOR_MAP_DATA_TYPE_UINT32);
+    return makeI64Const(b, CU_TENSOR_MAP_DATA_TYPE_UINT32);
   if (type.isUnsignedInteger(64))
-    return makeI64Const(rewriter, op, CU_TENSOR_MAP_DATA_TYPE_UINT64);
+    return makeI64Const(b, CU_TENSOR_MAP_DATA_TYPE_UINT64);
   if (type.isSignlessInteger(32))
-    return makeI64Const(rewriter, op, CU_TENSOR_MAP_DATA_TYPE_INT32);
+    return makeI64Const(b, CU_TENSOR_MAP_DATA_TYPE_INT32);
   if (type.isSignlessInteger(64))
-    return makeI64Const(rewriter, op, CU_TENSOR_MAP_DATA_TYPE_INT64);
+    return makeI64Const(b, CU_TENSOR_MAP_DATA_TYPE_INT64);
   if (type.isF16())
-    return makeI64Const(rewriter, op, CU_TENSOR_MAP_DATA_TYPE_FLOAT16);
+    return makeI64Const(b, CU_TENSOR_MAP_DATA_TYPE_FLOAT16);
   if (type.isF32())
-    return makeI64Const(rewriter, op, CU_TENSOR_MAP_DATA_TYPE_FLOAT32);
+    return makeI64Const(b, CU_TENSOR_MAP_DATA_TYPE_FLOAT32);
   if (type.isF64())
-    return makeI64Const(rewriter, op, CU_TENSOR_MAP_DATA_TYPE_FLOAT64);
+    return makeI64Const(b, CU_TENSOR_MAP_DATA_TYPE_FLOAT64);
   if (type.isBF16())
-    return makeI64Const(rewriter, op, CU_TENSOR_MAP_DATA_TYPE_BFLOAT16);
+    return makeI64Const(b, CU_TENSOR_MAP_DATA_TYPE_BFLOAT16);
 
   llvm_unreachable("Not supported data type");
 }
@@ -1109,23 +1107,22 @@ struct NVGPUTmaCreateDescriptorOpLowering
   LogicalResult
   matchAndRewrite(nvgpu::TmaCreateDescriptorOp op, OpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
-    Location loc = op->getLoc();
+    ImplicitLocOpBuilder b(op->getLoc(), rewriter);
     LLVM::LLVMPointerType llvmPointerType = getTypeConverter()->getPointerType(
         IntegerType::get(op->getContext(), 8));
     Type llvmInt64Type = IntegerType::get(op->getContext(), 64);
 
-    Value tensorElementType = elementTypeAsLLVMConstant(
-        rewriter, op, op.getTensor().getType().getElementType());
+    Value tensorElementType =
+        elementTypeAsLLVMConstant(b, op.getTensor().getType().getElementType());
     auto promotedOperands = getTypeConverter()->promoteOperands(
-        loc, op->getOperands(), adaptor.getOperands(), rewriter);
+        b.getLoc(), op->getOperands(), adaptor.getOperands(), b);
 
-    Value boxArrayPtr = rewriter.create<LLVM::AllocaOp>(
-        loc, llvmPointerType, llvmInt64Type, makeI64Const(rewriter, op, 5));
+    Value boxArrayPtr = b.create<LLVM::AllocaOp>(llvmPointerType, llvmInt64Type,
+                                                 makeI64Const(b, 5));
     for (auto [index, value] : llvm::enumerate(adaptor.getBoxDimensions())) {
-      Value gep = rewriter.create<LLVM::GEPOp>(
-          loc, llvmPointerType, llvmPointerType, boxArrayPtr,
-          makeI64Const(rewriter, op, index));
-      rewriter.create<LLVM::StoreOp>(loc, value, gep);
+      Value gep = b.create<LLVM::GEPOp>(llvmPointerType, llvmPointerType,
+                                        boxArrayPtr, makeI64Const(b, index));
+      b.create<LLVM::StoreOp>(value, gep);
     }
 
     nvgpu::TensorMapDescriptorType desc = op.getTensorMap().getType();
@@ -1135,12 +1132,10 @@ struct NVGPUTmaCreateDescriptorOpLowering
     arguments.push_back(promotedOperands[1]); // descriptor
     arguments.push_back(tensorElementType);   // data type
     arguments.push_back(
-        makeI64Const(rewriter, op, (int)desc.getInterleave())); // interleave
-    arguments.push_back(
-        makeI64Const(rewriter, op, (int)desc.getSwizzle())); // swizzle
-    arguments.push_back(
-        makeI64Const(rewriter, op, (int)desc.getL2promo())); // l2promo
-    arguments.push_back(makeI64Const(rewriter, op, (int)desc.getOob())); // oob
+        makeI64Const(b, (int)desc.getInterleave()));              // interleave
+    arguments.push_back(makeI64Const(b, (int)desc.getSwizzle())); // swizzle
+    arguments.push_back(makeI64Const(b, (int)desc.getL2promo())); // l2promo
+    arguments.push_back(makeI64Const(b, (int)desc.getOob()));     // oob
     arguments.push_back(boxArrayPtr); // box dimensions
 
     // Set data types of the arguments
@@ -1157,7 +1152,7 @@ struct NVGPUTmaCreateDescriptorOpLowering
     FunctionCallBuilder hostRegisterCallBuilder = {
         "mgpuTensorMapEncodeTiledMemref", llvmPointerType, argTypes};
     Value tensorMap =
-        hostRegisterCallBuilder.create(loc, rewriter, arguments).getResult();
+        hostRegisterCallBuilder.create(b.getLoc(), b, arguments).getResult();
 
     rewriter.replaceOp(op, tensorMap);
     return success();
@@ -1191,11 +1186,10 @@ struct NVGPUWarpgroupMmaOpLowering
     return success();
   }
 
-  Value generateNVVMWgmmaOp(MLIRContext *ctx,
-                            ConversionPatternRewriter &rewriter, Location loc,
-                            int m, int n, int k, Type resultStructType,
-                            Value inout, Value descriptorA,
-                            Value descriptorB) const {
+  Value generateNVVMWgmmaOp(ImplicitLocOpBuilder &b, int m, int n, int k,
+                            Type resultStructType, Value inout,
+                            Value descriptorA, Value descriptorB) const {
+    MLIRContext *ctx = b.getContext();
     auto shape = NVVM::MMAShapeAttr::get(ctx, m, n, k);
     auto scaleOut = NVVM::WGMMAScaleOutAttr::get(ctx, NVVM::WGMMAScaleOut::one);
     auto scaleIn = NVVM::WGMMAScaleInAttr::get(ctx, NVVM::WGMMAScaleIn::one);
@@ -1205,15 +1199,16 @@ struct NVGPUWarpgroupMmaOpLowering
     auto itype = NVVM::WGMMATypesAttr::get(ctx, NVVM::WGMMATypes::f16);
     auto overflow =
         NVVM::MMAIntOverflowAttr::get(ctx, NVVM::MMAIntOverflow::wrapped);
-    Value res = rewriter.create<NVVM::WgmmaMmaAsyncOp>(
-        loc, resultStructType, inout, descriptorA, descriptorB, shape, itype,
-        itype, scaleOut, scaleIn, scaleIn, layoutA, layoutB, overflow);
+    Value res = b.create<NVVM::WgmmaMmaAsyncOp>(
+        resultStructType, inout, descriptorA, descriptorB, shape, itype, itype,
+        scaleOut, scaleIn, scaleIn, layoutA, layoutB, overflow);
     return res;
   }
 
   LogicalResult
   matchAndRewrite(nvgpu::WarpgroupMmaOp op, OpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
+    ImplicitLocOpBuilder b(op->getLoc(), rewriter);
     int64_t sizeM = op.getDescriptorA().getType().getTensor().getDimSize(0);
     int64_t sizeN = op.getDescriptorB().getType().getTensor().getDimSize(1);
     int64_t sizeK = op.getDescriptorA().getType().getTensor().getDimSize(1);
@@ -1232,13 +1227,11 @@ struct NVGPUWarpgroupMmaOpLowering
     Value descriptorB = adaptor.getDescriptorB();
 
     //  Generate wgmma group
-
-    auto loc = op->getLoc();
     MemRefType typeTensorA = op.getDescriptorA().getType().getTensor();
     MemRefType typeTensorB = op.getDescriptorB().getType().getTensor();
 
     auto makeAdd = [&](Value lhs, Value rhs) -> Value {
-      return rewriter.create<LLVM::AddOp>(loc, lhs.getType(), lhs, rhs);
+      return b.create<LLVM::AddOp>(lhs.getType(), lhs, rhs);
     };
 
     auto iterateDescA = [&](Value desc, int iterM, int iterN,
@@ -1254,7 +1247,7 @@ struct NVGPUWarpgroupMmaOpLowering
                         << incrementVal << " | \t ");
       if (!incrementVal)
         return desc;
-      return makeAdd(desc, makeI64Const(rewriter, op, incrementVal));
+      return makeAdd(desc, makeI64Const(b, incrementVal));
     };
 
     auto iterateDescB = [&](Value desc, int iterM, int iterN,
@@ -1266,10 +1259,10 @@ struct NVGPUWarpgroupMmaOpLowering
       LLVM_DEBUG(DBGSE() << "Descriptor B + " << incrementVal << "\n");
       if (!incrementVal)
         return desc;
-      return makeAdd(desc, makeI64Const(rewriter, op, incrementVal));
+      return makeAdd(desc, makeI64Const(b, incrementVal));
     };
 
-    rewriter.create<NVVM::WgmmaFenceAlignedOp>(loc);
+    b.create<NVVM::WgmmaFenceAlignedOp>();
 
     SmallVector<Value> wgmmaResults;
     for (int iterM = 0; iterM < (sizeM / wgmmaShapeM); iterM++) {
@@ -1291,14 +1284,13 @@ struct NVGPUWarpgroupMmaOpLowering
                           << " B[" << (iterK * wgmmaShapeK) << ":"
                           << (iterK * wgmmaShapeK + wgmmaShapeK) << "][" << 0
                           << ":" << wgmmaShapeN << "])\n");
-        matrixC = generateNVVMWgmmaOp(op->getContext(), rewriter, loc,
-                                      wgmmaShapeM, wgmmaShapeN, wgmmaShapeK,
+        matrixC = generateNVVMWgmmaOp(b, wgmmaShapeM, wgmmaShapeN, wgmmaShapeK,
                                       structType, matrixC, descA, descB);
       }
       wgmmaResults.push_back(matrixC);
     }
-    rewriter.create<NVVM::WgmmaGroupSyncAlignedOp>(loc);
-    rewriter.create<NVVM::WgmmaWaitGroupSyncOp>(loc, op.getWaitGroup());
+    b.create<NVVM::WgmmaGroupSyncAlignedOp>();
+    b.create<NVVM::WgmmaWaitGroupSyncOp>(op.getWaitGroup());
 
     ValueRange myres(wgmmaResults);
     rewriter.replaceOp(op, myres);



More information about the Mlir-commits mailing list