[Mlir-commits] [mlir] [mlir][nvgpu] Use ImplicitLocOpBuilder in nvgpu-to-nvvm pass (NFC) (PR #67993)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Mon Oct 2 07:48:17 PDT 2023


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir

<details>
<summary>Changes</summary>

For the sake of better readability, this PR uses  `ImplicitLocOpBuilder` instead of rewriter+loc

---

Patch is 35.58 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/67993.diff


1 Files Affected:

- (modified) mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp (+146-155) 


``````````diff
diff --git a/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp b/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp
index 4d1f6641af6dca3..c84960e0b22cc0f 100644
--- a/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp
+++ b/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp
@@ -19,6 +19,7 @@
 #include "mlir/Dialect/NVGPU/IR/NVGPUDialect.h"
 #include "mlir/Dialect/SCF/Transforms/Patterns.h"
 #include "mlir/IR/BuiltinTypes.h"
+#include "mlir/IR/ImplicitLocOpBuilder.h"
 #include "mlir/IR/PatternMatch.h"
 #include "mlir/IR/TypeUtilities.h"
 #include "mlir/IR/Value.h"
@@ -44,13 +45,12 @@ constexpr int exclude4LSB = 4;
 
 /// GPU has 32 bit registers, this function truncates values when larger width
 /// is not needed.
-static Value truncToI32(ConversionPatternRewriter &rewriter, Location loc,
-                        Value value) {
+static Value truncToI32(ImplicitLocOpBuilder &b, Value value) {
   Type type = value.getType();
   assert(llvm::isa<IntegerType>(type) && "expected an integer Value");
   if (type.getIntOrFloatBitWidth() <= 32)
     return value;
-  return rewriter.create<LLVM::TruncOp>(loc, rewriter.getI32Type(), value);
+  return b.create<LLVM::TruncOp>(b.getI32Type(), value);
 }
 
 /// Returns the type for the intrinsic given the vectorResultType of the
@@ -170,22 +170,23 @@ static Value convertIntrinsicResult(Location loc, Type intrinsicResultType,
 /// `nvvm.mma.sync` op expects these argments to be a given in a long list of
 /// scalars of certain types. This function helps unpack the `vector` arguments
 /// and cast them to the types expected by `nvvm.mma.sync`.
-static SmallVector<Value> unpackOperandVector(RewriterBase &rewriter,
-                                              Location loc, Value operand,
+static SmallVector<Value> unpackOperandVector(ImplicitLocOpBuilder &b,
+                                              Value operand,
                                               NVVM::MMATypes operandPtxType) {
   SmallVector<Value> result;
-  Type i32Ty = rewriter.getI32Type();
-  Type f64Ty = rewriter.getF64Type();
-  Type f32Ty = rewriter.getF32Type();
-  Type i8Ty = rewriter.getI8Type();
-  Type i4Ty = rewriter.getIntegerType(4);
+  Type i32Ty = b.getI32Type();
+  Type f64Ty = b.getF64Type();
+  Type f32Ty = b.getF32Type();
+  Type i8Ty = b.getI8Type();
+  Type i64Ty = b.getI64Type();
+  Type i4Ty = b.getIntegerType(4);
   Type i8x4Ty = LLVM::getFixedVectorType(i8Ty, 4);
   Type i4x8Ty = LLVM::getFixedVectorType(i4Ty, 8);
   Type f32x1Ty = LLVM::getFixedVectorType(f32Ty, 1);
   auto arrayTy = cast<LLVM::LLVMArrayType>(operand.getType());
 
   for (unsigned i = 0, e = arrayTy.getNumElements(); i < e; ++i) {
-    Value toUse = rewriter.create<LLVM::ExtractValueOp>(loc, operand, i);
+    Value toUse = b.create<LLVM::ExtractValueOp>(operand, i);
 
     // For 4xi8 vectors, the intrinsic expects these to be provided as i32
     // scalar types.
@@ -193,8 +194,7 @@ static SmallVector<Value> unpackOperandVector(RewriterBase &rewriter,
         arrayTy.getElementType() == i4x8Ty ||
         (arrayTy.getElementType() == f32x1Ty &&
          operandPtxType == NVVM::MMATypes::tf32)) {
-      result.push_back(
-          rewriter.create<LLVM::BitcastOp>(loc, rewriter.getI32Type(), toUse));
+      result.push_back(b.create<LLVM::BitcastOp>(i32Ty, toUse));
       continue;
     }
 
@@ -207,10 +207,9 @@ static SmallVector<Value> unpackOperandVector(RewriterBase &rewriter,
                          innerArrayTy.getElementType() == f32Ty)) {
       for (unsigned idx = 0, innerSize = innerArrayTy.getNumElements();
            idx < innerSize; idx++) {
-        result.push_back(rewriter.create<LLVM::ExtractElementOp>(
-            loc, toUse,
-            rewriter.create<LLVM::ConstantOp>(
-                loc, rewriter.getI64Type(), rewriter.getI64IntegerAttr(idx))));
+        result.push_back(b.create<LLVM::ExtractElementOp>(
+            toUse,
+            b.create<LLVM::ConstantOp>(i64Ty, b.getI64IntegerAttr(idx))));
       }
       continue;
     }
@@ -256,7 +255,7 @@ struct MmaLdMatrixOpToNVVM : public ConvertOpToLLVMPattern<nvgpu::LdMatrixOp> {
   matchAndRewrite(nvgpu::LdMatrixOp op, OpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
     MLIRContext *ctx = getContext();
-    Location loc = op->getLoc();
+    ImplicitLocOpBuilder b(op.getLoc(), rewriter);
 
     // The result type of ldmatrix will always be a struct of 32bit integer
     // registers if more than one 32bit value is returned. Otherwise, the result
@@ -283,10 +282,10 @@ struct MmaLdMatrixOpToNVVM : public ConvertOpToLLVMPattern<nvgpu::LdMatrixOp> {
 
     auto srcMemrefType = cast<MemRefType>(op.getSrcMemref().getType());
     Value srcPtr =
-        getStridedElementPtr(loc, srcMemrefType, adaptor.getSrcMemref(),
+        getStridedElementPtr(b.getLoc(), srcMemrefType, adaptor.getSrcMemref(),
                              adaptor.getIndices(), rewriter);
-    Value ldMatrixResult = rewriter.create<NVVM::LdMatrixOp>(
-        loc, ldMatrixResultType, srcPtr,
+    Value ldMatrixResult = b.create<NVVM::LdMatrixOp>(
+        ldMatrixResultType, srcPtr,
         /*num=*/op.getNumTiles(),
         /*layout=*/op.getTranspose() ? NVVM::MMALayout::col
                                      : NVVM::MMALayout::row);
@@ -296,15 +295,13 @@ struct MmaLdMatrixOpToNVVM : public ConvertOpToLLVMPattern<nvgpu::LdMatrixOp> {
     // actual vector type (still of width 32b) and repack them into a result
     // struct.
     Type finalResultType = typeConverter->convertType(vectorResultType);
-    Value result = rewriter.create<LLVM::UndefOp>(loc, finalResultType);
+    Value result = b.create<LLVM::UndefOp>(finalResultType);
     for (int64_t i = 0, e = vectorResultType.getDimSize(0); i < e; i++) {
       Value i32Register =
-          num32BitRegs > 1
-              ? rewriter.create<LLVM::ExtractValueOp>(loc, ldMatrixResult, i)
-              : ldMatrixResult;
-      Value casted =
-          rewriter.create<LLVM::BitcastOp>(loc, innerVectorType, i32Register);
-      result = rewriter.create<LLVM::InsertValueOp>(loc, result, casted, i);
+          num32BitRegs > 1 ? b.create<LLVM::ExtractValueOp>(ldMatrixResult, i)
+                           : ldMatrixResult;
+      Value casted = b.create<LLVM::BitcastOp>(innerVectorType, i32Register);
+      result = b.create<LLVM::InsertValueOp>(result, casted, i);
     }
 
     rewriter.replaceOp(op, result);
@@ -335,7 +332,7 @@ struct MmaSyncOptoNVVM : public ConvertOpToLLVMPattern<nvgpu::MmaSyncOp> {
   LogicalResult
   matchAndRewrite(nvgpu::MmaSyncOp op, OpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
-    Location loc = op->getLoc();
+    ImplicitLocOpBuilder b(op.getLoc(), rewriter);
     // Get the shapes of the MMAMatrix type being used. The shapes will
     // choose which intrinsic this op will be lowered to.
     VectorType aType = op.getMatrixA().getType();
@@ -368,17 +365,17 @@ struct MmaSyncOptoNVVM : public ConvertOpToLLVMPattern<nvgpu::MmaSyncOp> {
       overflow = NVVM::MMAIntOverflow::satfinite;
 
     SmallVector<Value> matA =
-        unpackOperandVector(rewriter, loc, adaptor.getMatrixA(), *ptxTypeA);
+        unpackOperandVector(b, adaptor.getMatrixA(), *ptxTypeA);
     SmallVector<Value> matB =
-        unpackOperandVector(rewriter, loc, adaptor.getMatrixB(), *ptxTypeB);
+        unpackOperandVector(b, adaptor.getMatrixB(), *ptxTypeB);
     SmallVector<Value> matC =
-        unpackOperandVector(rewriter, loc, adaptor.getMatrixC(), *ptxTypeC);
+        unpackOperandVector(b, adaptor.getMatrixC(), *ptxTypeC);
 
     Type desiredRetTy = typeConverter->convertType(op->getResultTypes()[0]);
     Type intrinsicResTy = inferIntrinsicResultType(
         typeConverter->convertType(op->getResultTypes()[0]));
-    Value intrinsicResult = rewriter.create<NVVM::MmaOp>(
-        op.getLoc(), intrinsicResTy, matA, matB, matC,
+    Value intrinsicResult = b.create<NVVM::MmaOp>(
+        intrinsicResTy, matA, matB, matC,
         /*shape=*/gemmShape,
         /*b1Op=*/std::nullopt,
         /*intOverflow=*/overflow,
@@ -511,14 +508,14 @@ static std::string buildMmaSparseAsmString(
 /// Builds an inline assembly operation corresponding to the specified MMA
 /// sparse sync operation.
 static FailureOr<LLVM::InlineAsmOp> emitMmaSparseSyncOpAsm(
-    Location loc, NVVM::MMATypes ptxTypeA, NVVM::MMATypes ptxTypeB,
+    ImplicitLocOpBuilder &b, NVVM::MMATypes ptxTypeA, NVVM::MMATypes ptxTypeB,
     NVVM::MMATypes ptxTypeC, NVVM::MMATypes ptxTypeD,
     std::optional<NVVM::MMAIntOverflow> overflow, ArrayRef<Value> unpackedAData,
     ArrayRef<Value> unpackedB, ArrayRef<Value> unpackedC, Value indexData,
     int64_t metadataSelector, const std::array<int64_t, 3> &shape,
-    Type intrinsicResultType, ConversionPatternRewriter &rewriter) {
-  auto asmDialectAttr = LLVM::AsmDialectAttr::get(rewriter.getContext(),
-                                                  LLVM::AsmDialect::AD_ATT);
+    Type intrinsicResultType) {
+  auto asmDialectAttr =
+      LLVM::AsmDialectAttr::get(b.getContext(), LLVM::AsmDialect::AD_ATT);
 
   const unsigned matASize = unpackedAData.size();
   const unsigned matBSize = unpackedB.size();
@@ -536,15 +533,15 @@ static FailureOr<LLVM::InlineAsmOp> emitMmaSparseSyncOpAsm(
     llvm::append_range(asmVals, args);
   asmVals.push_back(indexData);
 
-  return rewriter.create<LLVM::InlineAsmOp>(loc,
-                                            /*resultTypes=*/intrinsicResultType,
-                                            /*operands=*/asmVals,
-                                            /*asm_string=*/asmStr,
-                                            /*constraints=*/constraintStr,
-                                            /*has_side_effects=*/true,
-                                            /*is_align_stack=*/false,
-                                            /*asm_dialect=*/asmDialectAttr,
-                                            /*operand_attrs=*/ArrayAttr());
+  return b.create<LLVM::InlineAsmOp>(
+      /*resultTypes=*/intrinsicResultType,
+      /*operands=*/asmVals,
+      /*asm_string=*/asmStr,
+      /*constraints=*/constraintStr,
+      /*has_side_effects=*/true,
+      /*is_align_stack=*/false,
+      /*asm_dialect=*/asmDialectAttr,
+      /*operand_attrs=*/ArrayAttr());
 }
 
 /// Lowers `nvgpu.mma.sp.sync` to inline assembly.
@@ -555,7 +552,7 @@ struct NVGPUMmaSparseSyncLowering
   LogicalResult
   matchAndRewrite(nvgpu::MmaSparseSyncOp op, OpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
-    Location loc = op->getLoc();
+    ImplicitLocOpBuilder b(op.getLoc(), rewriter);
     // Get the shapes of the MMAMatrix type being used. The shapes will
     // choose which intrinsic this op will be lowered to.
     VectorType aType = op.getMatrixA().getType();
@@ -586,11 +583,11 @@ struct NVGPUMmaSparseSyncLowering
       overflow = NVVM::MMAIntOverflow::satfinite;
 
     SmallVector<Value> matA =
-        unpackOperandVector(rewriter, loc, adaptor.getMatrixA(), *ptxTypeA);
+        unpackOperandVector(b, adaptor.getMatrixA(), *ptxTypeA);
     SmallVector<Value> matB =
-        unpackOperandVector(rewriter, loc, adaptor.getMatrixB(), *ptxTypeB);
+        unpackOperandVector(b, adaptor.getMatrixB(), *ptxTypeB);
     SmallVector<Value> matC =
-        unpackOperandVector(rewriter, loc, adaptor.getMatrixC(), *ptxTypeC);
+        unpackOperandVector(b, adaptor.getMatrixC(), *ptxTypeC);
 
     Type desiredRetTy = typeConverter->convertType(op->getResultTypes()[0]);
     Type intrinsicResTy = inferIntrinsicResultType(
@@ -602,13 +599,13 @@ struct NVGPUMmaSparseSyncLowering
         LLVM::getFixedVectorType(rewriter.getI16Type(), 2))
       return op->emitOpError() << "Expected metadata type to be LLVM "
                                   "VectorType of 2 i16 elements";
-    sparseMetadata = rewriter.create<LLVM::BitcastOp>(
-        loc, rewriter.getI32Type(), sparseMetadata);
+    sparseMetadata =
+        b.create<LLVM::BitcastOp>(rewriter.getI32Type(), sparseMetadata);
 
     FailureOr<LLVM::InlineAsmOp> intrinsicResult = emitMmaSparseSyncOpAsm(
-        loc, *ptxTypeA, *ptxTypeB, *ptxTypeC, *ptxTypeC, overflow, matA, matB,
+        b, *ptxTypeA, *ptxTypeB, *ptxTypeC, *ptxTypeC, overflow, matA, matB,
         matC, sparseMetadata, op.getSparsitySelector(), op.getMmaShapeAsArray(),
-        intrinsicResTy, rewriter);
+        intrinsicResTy);
     if (failed(intrinsicResult))
       return failure();
 
@@ -629,10 +626,12 @@ struct NVGPUAsyncCopyLowering
   LogicalResult
   matchAndRewrite(nvgpu::DeviceAsyncCopyOp op, OpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
-    Location loc = op->getLoc();
+    ImplicitLocOpBuilder b(op.getLoc(), rewriter);
+    Location loc = op.getLoc();
     auto dstMemrefType = cast<MemRefType>(op.getDst().getType());
-    Value dstPtr = getStridedElementPtr(loc, dstMemrefType, adaptor.getDst(),
-                                        adaptor.getDstIndices(), rewriter);
+    Value dstPtr =
+        getStridedElementPtr(b.getLoc(), dstMemrefType, adaptor.getDst(),
+                             adaptor.getDstIndices(), rewriter);
     auto i8Ty = IntegerType::get(op.getContext(), 8);
     FailureOr<unsigned> dstAddressSpace =
         getTypeConverter()->getMemRefAddressSpace(dstMemrefType);
@@ -642,7 +641,7 @@ struct NVGPUAsyncCopyLowering
     auto dstPointerType =
         getTypeConverter()->getPointerType(i8Ty, *dstAddressSpace);
     if (!getTypeConverter()->useOpaquePointers())
-      dstPtr = rewriter.create<LLVM::BitcastOp>(loc, dstPointerType, dstPtr);
+      dstPtr = b.create<LLVM::BitcastOp>(dstPointerType, dstPtr);
 
     auto srcMemrefType = cast<MemRefType>(op.getSrc().getType());
     FailureOr<unsigned> srcAddressSpace =
@@ -656,12 +655,11 @@ struct NVGPUAsyncCopyLowering
     auto srcPointerType =
         getTypeConverter()->getPointerType(i8Ty, *srcAddressSpace);
     if (!getTypeConverter()->useOpaquePointers())
-      scrPtr = rewriter.create<LLVM::BitcastOp>(loc, srcPointerType, scrPtr);
+      scrPtr = b.create<LLVM::BitcastOp>(srcPointerType, scrPtr);
     // Intrinsics takes a global pointer so we need an address space cast.
     auto srcPointerGlobalType = getTypeConverter()->getPointerType(
         i8Ty, NVVM::NVVMMemorySpace::kGlobalMemorySpace);
-    scrPtr = rewriter.create<LLVM::AddrSpaceCastOp>(loc, srcPointerGlobalType,
-                                                    scrPtr);
+    scrPtr = b.create<LLVM::AddrSpaceCastOp>(srcPointerGlobalType, scrPtr);
     int64_t dstElements = adaptor.getDstElements().getZExtValue();
     int64_t sizeInBytes =
         (dstMemrefType.getElementTypeBitWidth() * dstElements) / 8;
@@ -675,16 +673,14 @@ struct NVGPUAsyncCopyLowering
       // memory) of CpAsyncOp is read only for SrcElements number of elements.
       // The rest of the DstElements in the destination (shared memory) are
       // filled with zeros.
-      Value c3I32 = rewriter.create<LLVM::ConstantOp>(
-          loc, rewriter.getI32Type(), rewriter.getI32IntegerAttr(3));
-      Value bitwidth = rewriter.create<LLVM::ConstantOp>(
-          loc, rewriter.getI32Type(),
-          rewriter.getI32IntegerAttr(srcMemrefType.getElementTypeBitWidth()));
-      Value srcElementsI32 =
-          rewriter.create<LLVM::TruncOp>(loc, rewriter.getI32Type(), srcBytes);
-      srcBytes = rewriter.create<LLVM::LShrOp>(
-          loc, rewriter.create<LLVM::MulOp>(loc, bitwidth, srcElementsI32),
-          c3I32);
+      Value c3I32 =
+          b.create<LLVM::ConstantOp>(b.getI32Type(), b.getI32IntegerAttr(3));
+      Value bitwidth = b.create<LLVM::ConstantOp>(
+          b.getI32Type(),
+          b.getI32IntegerAttr(srcMemrefType.getElementTypeBitWidth()));
+      Value srcElementsI32 = b.create<LLVM::TruncOp>(b.getI32Type(), srcBytes);
+      srcBytes = b.create<LLVM::LShrOp>(
+          b.create<LLVM::MulOp>(bitwidth, srcElementsI32), c3I32);
     }
     // Cache global (.cg) for 16 dst bytes, Cache all (.ca) for sizes other than
     // 16 dst bytes.
@@ -693,15 +689,14 @@ struct NVGPUAsyncCopyLowering
             ? NVVM::LoadCacheModifierKind::CG
             : NVVM::LoadCacheModifierKind::CA;
 
-    rewriter.create<NVVM::CpAsyncOp>(
-        loc, dstPtr, scrPtr, rewriter.getI32IntegerAttr(sizeInBytes),
+    b.create<NVVM::CpAsyncOp>(
+        dstPtr, scrPtr, rewriter.getI32IntegerAttr(sizeInBytes),
         NVVM::LoadCacheModifierKindAttr::get(op->getContext(), cacheModifier),
         srcBytes);
 
     // Drop the result token.
-    Value zero = rewriter.create<LLVM::ConstantOp>(
-        op->getLoc(), IntegerType::get(op.getContext(), 32),
-        rewriter.getI32IntegerAttr(0));
+    Value zero = b.create<LLVM::ConstantOp>(
+        IntegerType::get(op.getContext(), 32), rewriter.getI32IntegerAttr(0));
     rewriter.replaceOp(op, zero);
     return success();
   }
@@ -790,14 +785,14 @@ struct MBarrierBasePattern : public ConvertOpToLLVMPattern<SourceOp> {
 public:
   using ConvertOpToLLVMPattern<SourceOp>::ConvertOpToLLVMPattern;
   /// Returns the base pointer of the mbarrier object.
-  Value getMbarrierPtr(Operation *op, nvgpu::MBarrierGroupType mbarType,
-                       Value memrefDesc, Value mbarId,
+  Value getMbarrierPtr(ImplicitLocOpBuilder &b,
+                       nvgpu::MBarrierGroupType mbarType, Value memrefDesc,
+                       Value mbarId,
                        ConversionPatternRewriter &rewriter) const {
     MemRefType mbarrierMemrefType =
         nvgpu::getMBarrierMemrefType(rewriter.getContext(), mbarType);
     return ConvertToLLVMPattern::getStridedElementPtr(
-        op->getLoc(), mbarrierMemrefType, memrefDesc, {mbarId}, rewriter);
-    return memrefDesc;
+        b.getLoc(), mbarrierMemrefType, memrefDesc, {mbarId}, rewriter);
   }
 };
 
@@ -809,11 +804,12 @@ struct NVGPUMBarrierInitLowering
   LogicalResult
   matchAndRewrite(nvgpu::MBarrierInitOp op, OpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
+    ImplicitLocOpBuilder b(op->getLoc(), rewriter);
     nvgpu::MBarrierGroupType mbarrierType = op.getBarriers().getType();
     rewriter.setInsertionPoint(op);
-    Value barrier = getMbarrierPtr(op, mbarrierType, adaptor.getBarriers(),
+    Value barrier = getMbarrierPtr(b, mbarrierType, adaptor.getBarriers(),
                                    adaptor.getMbarId(), rewriter);
-    Value count = truncToI32(rewriter, op->getLoc(), adaptor.getCount());
+    Value count = truncToI32(b, adaptor.getCount());
     if (isMbarrierShared(mbarrierType)) {
       rewriter.replaceOpWithNewOp<NVVM::MBarrierInitSharedOp>(op, barrier,
                                                               count);
@@ -831,8 +827,9 @@ struct NVGPUMBarrierArriveLowering
   LogicalResult
   matchAndRewrite(nvgpu::MBarrierArriveOp op, OpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
+    ImplicitLocOpBuilder b(op->getLoc(), rewriter);
     Value barrier =
-        getMbarrierPtr(op, op.getBarriers().getType(), adaptor.getBarriers(),
+        getMbarrierPtr(b, op.getBarriers().getType(), adaptor.getBarriers(),
                        adaptor.getMbarId(), rewriter);
     Type tokenType = getTypeConverter()->convertType(
         nvgpu::MBarrierTokenType::get(op->getContext()));
@@ -856,12 +853,13 @@ struct NVGPUMBarrierArriveNoCompleteLowering
   LogicalResult
   matchAndRewrite(nvgpu::MBarrierArriveNoCompleteOp op, OpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
+    ImplicitLocOpBuilder b(op->getLoc(), rewriter);
     Value barrier =
-        getMbarrierPtr(op, op.getBarriers().getType(), adaptor.getBarriers(),
+        getMbarrierPtr(b, op.getBarriers().getType(), adaptor.getBarriers(),
                        adaptor.getMbarId(), rewriter);
     Type tokenType = getTypeConverter()->convertType(
         nvgpu::MBarrierTokenType::get(op->getContext()));
-    Value count = truncToI32(rewriter, op->getLoc(), adaptor.getCount());
+    Value count = truncToI32(b, adaptor.getCount());
     if (isMbarrierShared(op.getBarriers().getType())) {
       rewriter.replaceOpWithNewOp<NVVM::MBarrierArriveNocompleteSharedOp>(
           op, tokenType, barrier, count);
@@ -880,8 +878,9 @@ struct NVGPUM...
[truncated]

``````````

</details>


https://github.com/llvm/llvm-project/pull/67993


More information about the Mlir-commits mailing list