[Mlir-commits] [mlir] [mlir][nvgpu] Use ImplicitLocOpBuilder in nvgpu-to-nvvm pass (NFC) (PR #67993)
Guray Ozen
llvmlistbot at llvm.org
Mon Oct 2 07:47:13 PDT 2023
https://github.com/grypp created https://github.com/llvm/llvm-project/pull/67993
For the sake of better readability, this PR uses `ImplicitLocOpBuilder` instead of rewriter+loc
>From b53d0a84cd092287cfd7e038939d549d849f3539 Mon Sep 17 00:00:00 2001
From: Guray Ozen <guray.ozen at gmail.com>
Date: Mon, 2 Oct 2023 16:46:25 +0200
Subject: [PATCH] [mlir][nvgpu] Use ImplicitLocOpBuilder in nvgpu-to-nvvm pass
(NFC)
For the sake of better readability, this PR uses `ImplicitLocOpBuilder` instead of rewriter+loc
---
.../Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp | 302 +++++++++---------
1 file changed, 147 insertions(+), 155 deletions(-)
diff --git a/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp b/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp
index 4d1f6641af6dca3..d308a9e07a6080f 100644
--- a/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp
+++ b/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp
@@ -19,6 +19,8 @@
#include "mlir/Dialect/NVGPU/IR/NVGPUDialect.h"
#include "mlir/Dialect/SCF/Transforms/Patterns.h"
#include "mlir/IR/BuiltinTypes.h"
+#include "mlir/IR/ImplicitLocOpBuilder.h"
+#include "mlir/IR/Location.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/IR/TypeUtilities.h"
#include "mlir/IR/Value.h"
@@ -44,13 +46,12 @@ constexpr int exclude4LSB = 4;
/// GPU has 32 bit registers, this function truncates values when larger width
/// is not needed.
-static Value truncToI32(ConversionPatternRewriter &rewriter, Location loc,
- Value value) {
+static Value truncToI32(ImplicitLocOpBuilder &b, Value value) {
Type type = value.getType();
assert(llvm::isa<IntegerType>(type) && "expected an integer Value");
if (type.getIntOrFloatBitWidth() <= 32)
return value;
- return rewriter.create<LLVM::TruncOp>(loc, rewriter.getI32Type(), value);
+ return b.create<LLVM::TruncOp>(b.getI32Type(), value);
}
/// Returns the type for the intrinsic given the vectorResultType of the
@@ -170,22 +171,23 @@ static Value convertIntrinsicResult(Location loc, Type intrinsicResultType,
/// `nvvm.mma.sync` op expects these argments to be a given in a long list of
/// scalars of certain types. This function helps unpack the `vector` arguments
/// and cast them to the types expected by `nvvm.mma.sync`.
-static SmallVector<Value> unpackOperandVector(RewriterBase &rewriter,
- Location loc, Value operand,
+static SmallVector<Value> unpackOperandVector(ImplicitLocOpBuilder &b,
+ Value operand,
NVVM::MMATypes operandPtxType) {
SmallVector<Value> result;
- Type i32Ty = rewriter.getI32Type();
- Type f64Ty = rewriter.getF64Type();
- Type f32Ty = rewriter.getF32Type();
- Type i8Ty = rewriter.getI8Type();
- Type i4Ty = rewriter.getIntegerType(4);
+ Type i32Ty = b.getI32Type();
+ Type f64Ty = b.getF64Type();
+ Type f32Ty = b.getF32Type();
+ Type i8Ty = b.getI8Type();
+ Type i64Ty = b.getI64Type();
+ Type i4Ty = b.getIntegerType(4);
Type i8x4Ty = LLVM::getFixedVectorType(i8Ty, 4);
Type i4x8Ty = LLVM::getFixedVectorType(i4Ty, 8);
Type f32x1Ty = LLVM::getFixedVectorType(f32Ty, 1);
auto arrayTy = cast<LLVM::LLVMArrayType>(operand.getType());
for (unsigned i = 0, e = arrayTy.getNumElements(); i < e; ++i) {
- Value toUse = rewriter.create<LLVM::ExtractValueOp>(loc, operand, i);
+ Value toUse = b.create<LLVM::ExtractValueOp>(operand, i);
// For 4xi8 vectors, the intrinsic expects these to be provided as i32
// scalar types.
@@ -193,8 +195,7 @@ static SmallVector<Value> unpackOperandVector(RewriterBase &rewriter,
arrayTy.getElementType() == i4x8Ty ||
(arrayTy.getElementType() == f32x1Ty &&
operandPtxType == NVVM::MMATypes::tf32)) {
- result.push_back(
- rewriter.create<LLVM::BitcastOp>(loc, rewriter.getI32Type(), toUse));
+ result.push_back(b.create<LLVM::BitcastOp>(i32Ty, toUse));
continue;
}
@@ -207,10 +208,9 @@ static SmallVector<Value> unpackOperandVector(RewriterBase &rewriter,
innerArrayTy.getElementType() == f32Ty)) {
for (unsigned idx = 0, innerSize = innerArrayTy.getNumElements();
idx < innerSize; idx++) {
- result.push_back(rewriter.create<LLVM::ExtractElementOp>(
- loc, toUse,
- rewriter.create<LLVM::ConstantOp>(
- loc, rewriter.getI64Type(), rewriter.getI64IntegerAttr(idx))));
+ result.push_back(b.create<LLVM::ExtractElementOp>(
+ toUse,
+ b.create<LLVM::ConstantOp>(i64Ty, b.getI64IntegerAttr(idx))));
}
continue;
}
@@ -256,7 +256,7 @@ struct MmaLdMatrixOpToNVVM : public ConvertOpToLLVMPattern<nvgpu::LdMatrixOp> {
matchAndRewrite(nvgpu::LdMatrixOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
MLIRContext *ctx = getContext();
- Location loc = op->getLoc();
+ ImplicitLocOpBuilder b(op.getLoc(), rewriter);
// The result type of ldmatrix will always be a struct of 32bit integer
// registers if more than one 32bit value is returned. Otherwise, the result
@@ -283,10 +283,10 @@ struct MmaLdMatrixOpToNVVM : public ConvertOpToLLVMPattern<nvgpu::LdMatrixOp> {
auto srcMemrefType = cast<MemRefType>(op.getSrcMemref().getType());
Value srcPtr =
- getStridedElementPtr(loc, srcMemrefType, adaptor.getSrcMemref(),
+ getStridedElementPtr(b.getLoc(), srcMemrefType, adaptor.getSrcMemref(),
adaptor.getIndices(), rewriter);
- Value ldMatrixResult = rewriter.create<NVVM::LdMatrixOp>(
- loc, ldMatrixResultType, srcPtr,
+ Value ldMatrixResult = b.create<NVVM::LdMatrixOp>(
+ ldMatrixResultType, srcPtr,
/*num=*/op.getNumTiles(),
/*layout=*/op.getTranspose() ? NVVM::MMALayout::col
: NVVM::MMALayout::row);
@@ -296,15 +296,13 @@ struct MmaLdMatrixOpToNVVM : public ConvertOpToLLVMPattern<nvgpu::LdMatrixOp> {
// actual vector type (still of width 32b) and repack them into a result
// struct.
Type finalResultType = typeConverter->convertType(vectorResultType);
- Value result = rewriter.create<LLVM::UndefOp>(loc, finalResultType);
+ Value result = b.create<LLVM::UndefOp>(finalResultType);
for (int64_t i = 0, e = vectorResultType.getDimSize(0); i < e; i++) {
Value i32Register =
- num32BitRegs > 1
- ? rewriter.create<LLVM::ExtractValueOp>(loc, ldMatrixResult, i)
- : ldMatrixResult;
- Value casted =
- rewriter.create<LLVM::BitcastOp>(loc, innerVectorType, i32Register);
- result = rewriter.create<LLVM::InsertValueOp>(loc, result, casted, i);
+ num32BitRegs > 1 ? b.create<LLVM::ExtractValueOp>(ldMatrixResult, i)
+ : ldMatrixResult;
+ Value casted = b.create<LLVM::BitcastOp>(innerVectorType, i32Register);
+ result = b.create<LLVM::InsertValueOp>(result, casted, i);
}
rewriter.replaceOp(op, result);
@@ -335,7 +333,7 @@ struct MmaSyncOptoNVVM : public ConvertOpToLLVMPattern<nvgpu::MmaSyncOp> {
LogicalResult
matchAndRewrite(nvgpu::MmaSyncOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
- Location loc = op->getLoc();
+ ImplicitLocOpBuilder b(op.getLoc(), rewriter);
// Get the shapes of the MMAMatrix type being used. The shapes will
// choose which intrinsic this op will be lowered to.
VectorType aType = op.getMatrixA().getType();
@@ -368,17 +366,17 @@ struct MmaSyncOptoNVVM : public ConvertOpToLLVMPattern<nvgpu::MmaSyncOp> {
overflow = NVVM::MMAIntOverflow::satfinite;
SmallVector<Value> matA =
- unpackOperandVector(rewriter, loc, adaptor.getMatrixA(), *ptxTypeA);
+ unpackOperandVector(b, adaptor.getMatrixA(), *ptxTypeA);
SmallVector<Value> matB =
- unpackOperandVector(rewriter, loc, adaptor.getMatrixB(), *ptxTypeB);
+ unpackOperandVector(b, adaptor.getMatrixB(), *ptxTypeB);
SmallVector<Value> matC =
- unpackOperandVector(rewriter, loc, adaptor.getMatrixC(), *ptxTypeC);
+ unpackOperandVector(b, adaptor.getMatrixC(), *ptxTypeC);
Type desiredRetTy = typeConverter->convertType(op->getResultTypes()[0]);
Type intrinsicResTy = inferIntrinsicResultType(
typeConverter->convertType(op->getResultTypes()[0]));
- Value intrinsicResult = rewriter.create<NVVM::MmaOp>(
- op.getLoc(), intrinsicResTy, matA, matB, matC,
+ Value intrinsicResult = b.create<NVVM::MmaOp>(
+ intrinsicResTy, matA, matB, matC,
/*shape=*/gemmShape,
/*b1Op=*/std::nullopt,
/*intOverflow=*/overflow,
@@ -511,14 +509,14 @@ static std::string buildMmaSparseAsmString(
/// Builds an inline assembly operation corresponding to the specified MMA
/// sparse sync operation.
static FailureOr<LLVM::InlineAsmOp> emitMmaSparseSyncOpAsm(
- Location loc, NVVM::MMATypes ptxTypeA, NVVM::MMATypes ptxTypeB,
+ ImplicitLocOpBuilder &b, NVVM::MMATypes ptxTypeA, NVVM::MMATypes ptxTypeB,
NVVM::MMATypes ptxTypeC, NVVM::MMATypes ptxTypeD,
std::optional<NVVM::MMAIntOverflow> overflow, ArrayRef<Value> unpackedAData,
ArrayRef<Value> unpackedB, ArrayRef<Value> unpackedC, Value indexData,
int64_t metadataSelector, const std::array<int64_t, 3> &shape,
- Type intrinsicResultType, ConversionPatternRewriter &rewriter) {
- auto asmDialectAttr = LLVM::AsmDialectAttr::get(rewriter.getContext(),
- LLVM::AsmDialect::AD_ATT);
+ Type intrinsicResultType) {
+ auto asmDialectAttr =
+ LLVM::AsmDialectAttr::get(b.getContext(), LLVM::AsmDialect::AD_ATT);
const unsigned matASize = unpackedAData.size();
const unsigned matBSize = unpackedB.size();
@@ -536,15 +534,15 @@ static FailureOr<LLVM::InlineAsmOp> emitMmaSparseSyncOpAsm(
llvm::append_range(asmVals, args);
asmVals.push_back(indexData);
- return rewriter.create<LLVM::InlineAsmOp>(loc,
- /*resultTypes=*/intrinsicResultType,
- /*operands=*/asmVals,
- /*asm_string=*/asmStr,
- /*constraints=*/constraintStr,
- /*has_side_effects=*/true,
- /*is_align_stack=*/false,
- /*asm_dialect=*/asmDialectAttr,
- /*operand_attrs=*/ArrayAttr());
+ return b.create<LLVM::InlineAsmOp>(
+ /*resultTypes=*/intrinsicResultType,
+ /*operands=*/asmVals,
+ /*asm_string=*/asmStr,
+ /*constraints=*/constraintStr,
+ /*has_side_effects=*/true,
+ /*is_align_stack=*/false,
+ /*asm_dialect=*/asmDialectAttr,
+ /*operand_attrs=*/ArrayAttr());
}
/// Lowers `nvgpu.mma.sp.sync` to inline assembly.
@@ -555,7 +553,7 @@ struct NVGPUMmaSparseSyncLowering
LogicalResult
matchAndRewrite(nvgpu::MmaSparseSyncOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
- Location loc = op->getLoc();
+ ImplicitLocOpBuilder b(op.getLoc(), rewriter);
// Get the shapes of the MMAMatrix type being used. The shapes will
// choose which intrinsic this op will be lowered to.
VectorType aType = op.getMatrixA().getType();
@@ -586,11 +584,11 @@ struct NVGPUMmaSparseSyncLowering
overflow = NVVM::MMAIntOverflow::satfinite;
SmallVector<Value> matA =
- unpackOperandVector(rewriter, loc, adaptor.getMatrixA(), *ptxTypeA);
+ unpackOperandVector(b, adaptor.getMatrixA(), *ptxTypeA);
SmallVector<Value> matB =
- unpackOperandVector(rewriter, loc, adaptor.getMatrixB(), *ptxTypeB);
+ unpackOperandVector(b, adaptor.getMatrixB(), *ptxTypeB);
SmallVector<Value> matC =
- unpackOperandVector(rewriter, loc, adaptor.getMatrixC(), *ptxTypeC);
+ unpackOperandVector(b, adaptor.getMatrixC(), *ptxTypeC);
Type desiredRetTy = typeConverter->convertType(op->getResultTypes()[0]);
Type intrinsicResTy = inferIntrinsicResultType(
@@ -602,13 +600,13 @@ struct NVGPUMmaSparseSyncLowering
LLVM::getFixedVectorType(rewriter.getI16Type(), 2))
return op->emitOpError() << "Expected metadata type to be LLVM "
"VectorType of 2 i16 elements";
- sparseMetadata = rewriter.create<LLVM::BitcastOp>(
- loc, rewriter.getI32Type(), sparseMetadata);
+ sparseMetadata =
+ b.create<LLVM::BitcastOp>(rewriter.getI32Type(), sparseMetadata);
FailureOr<LLVM::InlineAsmOp> intrinsicResult = emitMmaSparseSyncOpAsm(
- loc, *ptxTypeA, *ptxTypeB, *ptxTypeC, *ptxTypeC, overflow, matA, matB,
+ b, *ptxTypeA, *ptxTypeB, *ptxTypeC, *ptxTypeC, overflow, matA, matB,
matC, sparseMetadata, op.getSparsitySelector(), op.getMmaShapeAsArray(),
- intrinsicResTy, rewriter);
+ intrinsicResTy);
if (failed(intrinsicResult))
return failure();
@@ -629,10 +627,12 @@ struct NVGPUAsyncCopyLowering
LogicalResult
matchAndRewrite(nvgpu::DeviceAsyncCopyOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
- Location loc = op->getLoc();
+ ImplicitLocOpBuilder b(op.getLoc(), rewriter);
+ Location loc = op.getLoc();
auto dstMemrefType = cast<MemRefType>(op.getDst().getType());
- Value dstPtr = getStridedElementPtr(loc, dstMemrefType, adaptor.getDst(),
- adaptor.getDstIndices(), rewriter);
+ Value dstPtr =
+ getStridedElementPtr(b.getLoc(), dstMemrefType, adaptor.getDst(),
+ adaptor.getDstIndices(), rewriter);
auto i8Ty = IntegerType::get(op.getContext(), 8);
FailureOr<unsigned> dstAddressSpace =
getTypeConverter()->getMemRefAddressSpace(dstMemrefType);
@@ -642,7 +642,7 @@ struct NVGPUAsyncCopyLowering
auto dstPointerType =
getTypeConverter()->getPointerType(i8Ty, *dstAddressSpace);
if (!getTypeConverter()->useOpaquePointers())
- dstPtr = rewriter.create<LLVM::BitcastOp>(loc, dstPointerType, dstPtr);
+ dstPtr = b.create<LLVM::BitcastOp>(dstPointerType, dstPtr);
auto srcMemrefType = cast<MemRefType>(op.getSrc().getType());
FailureOr<unsigned> srcAddressSpace =
@@ -656,12 +656,11 @@ struct NVGPUAsyncCopyLowering
auto srcPointerType =
getTypeConverter()->getPointerType(i8Ty, *srcAddressSpace);
if (!getTypeConverter()->useOpaquePointers())
- scrPtr = rewriter.create<LLVM::BitcastOp>(loc, srcPointerType, scrPtr);
+ scrPtr = b.create<LLVM::BitcastOp>(srcPointerType, scrPtr);
// Intrinsics takes a global pointer so we need an address space cast.
auto srcPointerGlobalType = getTypeConverter()->getPointerType(
i8Ty, NVVM::NVVMMemorySpace::kGlobalMemorySpace);
- scrPtr = rewriter.create<LLVM::AddrSpaceCastOp>(loc, srcPointerGlobalType,
- scrPtr);
+ scrPtr = b.create<LLVM::AddrSpaceCastOp>(srcPointerGlobalType, scrPtr);
int64_t dstElements = adaptor.getDstElements().getZExtValue();
int64_t sizeInBytes =
(dstMemrefType.getElementTypeBitWidth() * dstElements) / 8;
@@ -675,16 +674,14 @@ struct NVGPUAsyncCopyLowering
// memory) of CpAsyncOp is read only for SrcElements number of elements.
// The rest of the DstElements in the destination (shared memory) are
// filled with zeros.
- Value c3I32 = rewriter.create<LLVM::ConstantOp>(
- loc, rewriter.getI32Type(), rewriter.getI32IntegerAttr(3));
- Value bitwidth = rewriter.create<LLVM::ConstantOp>(
- loc, rewriter.getI32Type(),
- rewriter.getI32IntegerAttr(srcMemrefType.getElementTypeBitWidth()));
- Value srcElementsI32 =
- rewriter.create<LLVM::TruncOp>(loc, rewriter.getI32Type(), srcBytes);
- srcBytes = rewriter.create<LLVM::LShrOp>(
- loc, rewriter.create<LLVM::MulOp>(loc, bitwidth, srcElementsI32),
- c3I32);
+ Value c3I32 =
+ b.create<LLVM::ConstantOp>(b.getI32Type(), b.getI32IntegerAttr(3));
+ Value bitwidth = b.create<LLVM::ConstantOp>(
+ b.getI32Type(),
+ b.getI32IntegerAttr(srcMemrefType.getElementTypeBitWidth()));
+ Value srcElementsI32 = b.create<LLVM::TruncOp>(b.getI32Type(), srcBytes);
+ srcBytes = b.create<LLVM::LShrOp>(
+ b.create<LLVM::MulOp>(bitwidth, srcElementsI32), c3I32);
}
// Cache global (.cg) for 16 dst bytes, Cache all (.ca) for sizes other than
// 16 dst bytes.
@@ -693,15 +690,14 @@ struct NVGPUAsyncCopyLowering
? NVVM::LoadCacheModifierKind::CG
: NVVM::LoadCacheModifierKind::CA;
- rewriter.create<NVVM::CpAsyncOp>(
- loc, dstPtr, scrPtr, rewriter.getI32IntegerAttr(sizeInBytes),
+ b.create<NVVM::CpAsyncOp>(
+ dstPtr, scrPtr, rewriter.getI32IntegerAttr(sizeInBytes),
NVVM::LoadCacheModifierKindAttr::get(op->getContext(), cacheModifier),
srcBytes);
// Drop the result token.
- Value zero = rewriter.create<LLVM::ConstantOp>(
- op->getLoc(), IntegerType::get(op.getContext(), 32),
- rewriter.getI32IntegerAttr(0));
+ Value zero = b.create<LLVM::ConstantOp>(
+ IntegerType::get(op.getContext(), 32), rewriter.getI32IntegerAttr(0));
rewriter.replaceOp(op, zero);
return success();
}
@@ -790,14 +786,14 @@ struct MBarrierBasePattern : public ConvertOpToLLVMPattern<SourceOp> {
public:
using ConvertOpToLLVMPattern<SourceOp>::ConvertOpToLLVMPattern;
/// Returns the base pointer of the mbarrier object.
- Value getMbarrierPtr(Operation *op, nvgpu::MBarrierGroupType mbarType,
- Value memrefDesc, Value mbarId,
+ Value getMbarrierPtr(ImplicitLocOpBuilder &b,
+ nvgpu::MBarrierGroupType mbarType, Value memrefDesc,
+ Value mbarId,
ConversionPatternRewriter &rewriter) const {
MemRefType mbarrierMemrefType =
nvgpu::getMBarrierMemrefType(rewriter.getContext(), mbarType);
return ConvertToLLVMPattern::getStridedElementPtr(
- op->getLoc(), mbarrierMemrefType, memrefDesc, {mbarId}, rewriter);
- return memrefDesc;
+ b.getLoc(), mbarrierMemrefType, memrefDesc, {mbarId}, rewriter);
}
};
@@ -809,11 +805,12 @@ struct NVGPUMBarrierInitLowering
LogicalResult
matchAndRewrite(nvgpu::MBarrierInitOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
+ ImplicitLocOpBuilder b(op->getLoc(), rewriter);
nvgpu::MBarrierGroupType mbarrierType = op.getBarriers().getType();
rewriter.setInsertionPoint(op);
- Value barrier = getMbarrierPtr(op, mbarrierType, adaptor.getBarriers(),
+ Value barrier = getMbarrierPtr(b, mbarrierType, adaptor.getBarriers(),
adaptor.getMbarId(), rewriter);
- Value count = truncToI32(rewriter, op->getLoc(), adaptor.getCount());
+ Value count = truncToI32(b, adaptor.getCount());
if (isMbarrierShared(mbarrierType)) {
rewriter.replaceOpWithNewOp<NVVM::MBarrierInitSharedOp>(op, barrier,
count);
@@ -831,8 +828,9 @@ struct NVGPUMBarrierArriveLowering
LogicalResult
matchAndRewrite(nvgpu::MBarrierArriveOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
+ ImplicitLocOpBuilder b(op->getLoc(), rewriter);
Value barrier =
- getMbarrierPtr(op, op.getBarriers().getType(), adaptor.getBarriers(),
+ getMbarrierPtr(b, op.getBarriers().getType(), adaptor.getBarriers(),
adaptor.getMbarId(), rewriter);
Type tokenType = getTypeConverter()->convertType(
nvgpu::MBarrierTokenType::get(op->getContext()));
@@ -856,12 +854,13 @@ struct NVGPUMBarrierArriveNoCompleteLowering
LogicalResult
matchAndRewrite(nvgpu::MBarrierArriveNoCompleteOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
+ ImplicitLocOpBuilder b(op->getLoc(), rewriter);
Value barrier =
- getMbarrierPtr(op, op.getBarriers().getType(), adaptor.getBarriers(),
+ getMbarrierPtr(b, op.getBarriers().getType(), adaptor.getBarriers(),
adaptor.getMbarId(), rewriter);
Type tokenType = getTypeConverter()->convertType(
nvgpu::MBarrierTokenType::get(op->getContext()));
- Value count = truncToI32(rewriter, op->getLoc(), adaptor.getCount());
+ Value count = truncToI32(b, adaptor.getCount());
if (isMbarrierShared(op.getBarriers().getType())) {
rewriter.replaceOpWithNewOp<NVVM::MBarrierArriveNocompleteSharedOp>(
op, tokenType, barrier, count);
@@ -880,8 +879,9 @@ struct NVGPUMBarrierTestWaitLowering
LogicalResult
matchAndRewrite(nvgpu::MBarrierTestWaitOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
+ ImplicitLocOpBuilder b(op->getLoc(), rewriter);
Value barrier =
- getMbarrierPtr(op, op.getBarriers().getType(), adaptor.getBarriers(),
+ getMbarrierPtr(b, op.getBarriers().getType(), adaptor.getBarriers(),
adaptor.getMbarId(), rewriter);
Type retType = rewriter.getI1Type();
if (isMbarrierShared(op.getBarriers().getType())) {
@@ -902,10 +902,11 @@ struct NVGPUMBarrierArriveExpectTxLowering
LogicalResult
matchAndRewrite(nvgpu::MBarrierArriveExpectTxOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
+ ImplicitLocOpBuilder b(op->getLoc(), rewriter);
Value barrier =
- getMbarrierPtr(op, op.getBarriers().getType(), adaptor.getBarriers(),
+ getMbarrierPtr(b, op.getBarriers().getType(), adaptor.getBarriers(),
adaptor.getMbarId(), rewriter);
- Value txcount = truncToI32(rewriter, op->getLoc(), adaptor.getTxcount());
+ Value txcount = truncToI32(b, adaptor.getTxcount());
if (isMbarrierShared(op.getBarriers().getType())) {
rewriter.replaceOpWithNewOp<NVVM::MBarrierArriveExpectTxSharedOp>(
@@ -926,11 +927,12 @@ struct NVGPUMBarrierTryWaitParityLowering
LogicalResult
matchAndRewrite(nvgpu::MBarrierTryWaitParityOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
+ ImplicitLocOpBuilder b(op->getLoc(), rewriter);
Value barrier =
- getMbarrierPtr(op, op.getBarriers().getType(), adaptor.getBarriers(),
+ getMbarrierPtr(b, op.getBarriers().getType(), adaptor.getBarriers(),
adaptor.getMbarId(), rewriter);
- Value ticks = truncToI32(rewriter, op->getLoc(), adaptor.getTicks());
- Value phase = truncToI32(rewriter, op->getLoc(), adaptor.getPhase());
+ Value ticks = truncToI32(b, adaptor.getTicks());
+ Value phase = truncToI32(b, adaptor.getPhase());
if (isMbarrierShared(op.getBarriers().getType())) {
rewriter.replaceOpWithNewOp<NVVM::MBarrierTryWaitParitySharedOp>(
@@ -950,16 +952,17 @@ struct NVGPUTmaAsyncLoadOpLowering
LogicalResult
matchAndRewrite(nvgpu::TmaAsyncLoadOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
+ ImplicitLocOpBuilder b(op->getLoc(), rewriter);
auto srcMemrefType = cast<MemRefType>(op.getDst().getType());
Value dest = getStridedElementPtr(op->getLoc(), srcMemrefType,
adaptor.getDst(), {}, rewriter);
Value barrier =
- getMbarrierPtr(op, op.getBarriers().getType(), adaptor.getBarriers(),
+ getMbarrierPtr(b, op.getBarriers().getType(), adaptor.getBarriers(),
adaptor.getMbarId(), rewriter);
SmallVector<Value> coords = adaptor.getCoordinates();
for (auto [index, value] : llvm::enumerate(coords)) {
- coords[index] = truncToI32(rewriter, op->getLoc(), value);
+ coords[index] = truncToI32(b, value);
}
rewriter.replaceOpWithNewOp<NVVM::CpAsyncBulkTensorGlobalToSharedClusterOp>(
@@ -976,7 +979,7 @@ struct NVGPUGenerateGmmaDescriptorLowering
matchAndRewrite(nvgpu::GenerateGmmaDescriptorOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
- Location loc = op->getLoc();
+ ImplicitLocOpBuilder b(op->getLoc(), rewriter);
nvgpu::TensorMapSwizzleKind swizzleKind =
op.getTensorMap().getType().getSwizzle();
@@ -992,20 +995,18 @@ struct NVGPUGenerateGmmaDescriptorLowering
: (swizzleKind == nvgpu::TensorMapSwizzleKind::SWIZZLE_32B) ? 3
: 0;
- auto ti64 = rewriter.getIntegerType(64);
+ auto ti64 = b.getIntegerType(64);
auto makeConst = [&](uint64_t index) -> Value {
- return rewriter.create<LLVM::ConstantOp>(
- loc, ti64, rewriter.getI64IntegerAttr(index));
+ return b.create<LLVM::ConstantOp>(ti64, b.getI64IntegerAttr(index));
};
auto shiftLeft = [&](Value value, unsigned shift) -> Value {
- return rewriter.create<LLVM::ShlOp>(loc, ti64, value, makeConst(shift));
+ return b.create<LLVM::ShlOp>(ti64, value, makeConst(shift));
};
auto shiftRight = [&](Value value, unsigned shift) -> Value {
- return rewriter.create<LLVM::LShrOp>(loc, ti64, value, makeConst(shift));
+ return b.create<LLVM::LShrOp>(ti64, value, makeConst(shift));
};
auto insertBit = [&](Value desc, Value val, int startBit) {
- return rewriter.create<LLVM::OrOp>(loc, ti64, desc,
- shiftLeft(val, startBit));
+ return b.create<LLVM::OrOp>(ti64, desc, shiftLeft(val, startBit));
};
int64_t sizeN = op.getTensorMap().getType().getTensor().getDimSize(0);
@@ -1019,7 +1020,7 @@ struct NVGPUGenerateGmmaDescriptorLowering
Value baseAddr = getStridedElementPtr(
op->getLoc(), cast<MemRefType>(op.getTensor().getType()),
adaptor.getTensor(), {}, rewriter);
- Value basePtr = rewriter.create<LLVM::PtrToIntOp>(loc, ti64, baseAddr);
+ Value basePtr = b.create<LLVM::PtrToIntOp>(ti64, baseAddr);
// Just use 14 bits for base address
Value basePtr14bit = shiftRight(shiftLeft(basePtr, 46), 50);
@@ -1050,16 +1051,13 @@ struct NVGPUGenerateGmmaDescriptorLowering
}
};
-static Value makeI64Const(RewriterBase &rewriter, Operation *op,
- int32_t index) {
- return rewriter.create<LLVM::ConstantOp>(op->getLoc(),
- rewriter.getIntegerType(64),
- rewriter.getI32IntegerAttr(index));
+static Value makeI64Const(ImplicitLocOpBuilder &b, int32_t index) {
+ return b.create<LLVM::ConstantOp>(b.getIntegerType(64),
+ b.getI32IntegerAttr(index));
}
/// Returns a Value that holds data type enum that is expected by CUDA driver.
-static Value elementTypeAsLLVMConstant(RewriterBase &rewriter, Operation *op,
- Type type) {
+static Value elementTypeAsLLVMConstant(ImplicitLocOpBuilder &b, Type type) {
// Enum is from CUDA driver API
// https://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__TYPES.html
enum CUtensorMapDataTypeEnum {
@@ -1079,25 +1077,25 @@ static Value elementTypeAsLLVMConstant(RewriterBase &rewriter, Operation *op,
};
if (type.isUnsignedInteger(8))
- return makeI64Const(rewriter, op, CU_TENSOR_MAP_DATA_TYPE_UINT8);
+ return makeI64Const(b, CU_TENSOR_MAP_DATA_TYPE_UINT8);
if (type.isUnsignedInteger(16))
- return makeI64Const(rewriter, op, CU_TENSOR_MAP_DATA_TYPE_UINT16);
+ return makeI64Const(b, CU_TENSOR_MAP_DATA_TYPE_UINT16);
if (type.isUnsignedInteger(32))
- return makeI64Const(rewriter, op, CU_TENSOR_MAP_DATA_TYPE_UINT32);
+ return makeI64Const(b, CU_TENSOR_MAP_DATA_TYPE_UINT32);
if (type.isUnsignedInteger(64))
- return makeI64Const(rewriter, op, CU_TENSOR_MAP_DATA_TYPE_UINT64);
+ return makeI64Const(b, CU_TENSOR_MAP_DATA_TYPE_UINT64);
if (type.isSignlessInteger(32))
- return makeI64Const(rewriter, op, CU_TENSOR_MAP_DATA_TYPE_INT32);
+ return makeI64Const(b, CU_TENSOR_MAP_DATA_TYPE_INT32);
if (type.isSignlessInteger(64))
- return makeI64Const(rewriter, op, CU_TENSOR_MAP_DATA_TYPE_INT64);
+ return makeI64Const(b, CU_TENSOR_MAP_DATA_TYPE_INT64);
if (type.isF16())
- return makeI64Const(rewriter, op, CU_TENSOR_MAP_DATA_TYPE_FLOAT16);
+ return makeI64Const(b, CU_TENSOR_MAP_DATA_TYPE_FLOAT16);
if (type.isF32())
- return makeI64Const(rewriter, op, CU_TENSOR_MAP_DATA_TYPE_FLOAT32);
+ return makeI64Const(b, CU_TENSOR_MAP_DATA_TYPE_FLOAT32);
if (type.isF64())
- return makeI64Const(rewriter, op, CU_TENSOR_MAP_DATA_TYPE_FLOAT64);
+ return makeI64Const(b, CU_TENSOR_MAP_DATA_TYPE_FLOAT64);
if (type.isBF16())
- return makeI64Const(rewriter, op, CU_TENSOR_MAP_DATA_TYPE_BFLOAT16);
+ return makeI64Const(b, CU_TENSOR_MAP_DATA_TYPE_BFLOAT16);
llvm_unreachable("Not supported data type");
}
@@ -1109,23 +1107,22 @@ struct NVGPUTmaCreateDescriptorOpLowering
LogicalResult
matchAndRewrite(nvgpu::TmaCreateDescriptorOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
- Location loc = op->getLoc();
+ ImplicitLocOpBuilder b(op->getLoc(), rewriter);
LLVM::LLVMPointerType llvmPointerType = getTypeConverter()->getPointerType(
IntegerType::get(op->getContext(), 8));
Type llvmInt64Type = IntegerType::get(op->getContext(), 64);
- Value tensorElementType = elementTypeAsLLVMConstant(
- rewriter, op, op.getTensor().getType().getElementType());
+ Value tensorElementType =
+ elementTypeAsLLVMConstant(b, op.getTensor().getType().getElementType());
auto promotedOperands = getTypeConverter()->promoteOperands(
- loc, op->getOperands(), adaptor.getOperands(), rewriter);
+ b.getLoc(), op->getOperands(), adaptor.getOperands(), b);
- Value boxArrayPtr = rewriter.create<LLVM::AllocaOp>(
- loc, llvmPointerType, llvmInt64Type, makeI64Const(rewriter, op, 5));
+ Value boxArrayPtr = b.create<LLVM::AllocaOp>(llvmPointerType, llvmInt64Type,
+ makeI64Const(b, 5));
for (auto [index, value] : llvm::enumerate(adaptor.getBoxDimensions())) {
- Value gep = rewriter.create<LLVM::GEPOp>(
- loc, llvmPointerType, llvmPointerType, boxArrayPtr,
- makeI64Const(rewriter, op, index));
- rewriter.create<LLVM::StoreOp>(loc, value, gep);
+ Value gep = b.create<LLVM::GEPOp>(llvmPointerType, llvmPointerType,
+ boxArrayPtr, makeI64Const(b, index));
+ b.create<LLVM::StoreOp>(value, gep);
}
nvgpu::TensorMapDescriptorType desc = op.getTensorMap().getType();
@@ -1135,12 +1132,10 @@ struct NVGPUTmaCreateDescriptorOpLowering
arguments.push_back(promotedOperands[1]); // descriptor
arguments.push_back(tensorElementType); // data type
arguments.push_back(
- makeI64Const(rewriter, op, (int)desc.getInterleave())); // interleave
- arguments.push_back(
- makeI64Const(rewriter, op, (int)desc.getSwizzle())); // swizzle
- arguments.push_back(
- makeI64Const(rewriter, op, (int)desc.getL2promo())); // l2promo
- arguments.push_back(makeI64Const(rewriter, op, (int)desc.getOob())); // oob
+ makeI64Const(b, (int)desc.getInterleave())); // interleave
+ arguments.push_back(makeI64Const(b, (int)desc.getSwizzle())); // swizzle
+ arguments.push_back(makeI64Const(b, (int)desc.getL2promo())); // l2promo
+ arguments.push_back(makeI64Const(b, (int)desc.getOob())); // oob
arguments.push_back(boxArrayPtr); // box dimensions
// Set data types of the arguments
@@ -1157,7 +1152,7 @@ struct NVGPUTmaCreateDescriptorOpLowering
FunctionCallBuilder hostRegisterCallBuilder = {
"mgpuTensorMapEncodeTiledMemref", llvmPointerType, argTypes};
Value tensorMap =
- hostRegisterCallBuilder.create(loc, rewriter, arguments).getResult();
+ hostRegisterCallBuilder.create(b.getLoc(), b, arguments).getResult();
rewriter.replaceOp(op, tensorMap);
return success();
@@ -1191,11 +1186,10 @@ struct NVGPUWarpgroupMmaOpLowering
return success();
}
- Value generateNVVMWgmmaOp(MLIRContext *ctx,
- ConversionPatternRewriter &rewriter, Location loc,
- int m, int n, int k, Type resultStructType,
- Value inout, Value descriptorA,
- Value descriptorB) const {
+ Value generateNVVMWgmmaOp(ImplicitLocOpBuilder &b, int m, int n, int k,
+ Type resultStructType, Value inout,
+ Value descriptorA, Value descriptorB) const {
+ MLIRContext *ctx = b.getContext();
auto shape = NVVM::MMAShapeAttr::get(ctx, m, n, k);
auto scaleOut = NVVM::WGMMAScaleOutAttr::get(ctx, NVVM::WGMMAScaleOut::one);
auto scaleIn = NVVM::WGMMAScaleInAttr::get(ctx, NVVM::WGMMAScaleIn::one);
@@ -1205,15 +1199,16 @@ struct NVGPUWarpgroupMmaOpLowering
auto itype = NVVM::WGMMATypesAttr::get(ctx, NVVM::WGMMATypes::f16);
auto overflow =
NVVM::MMAIntOverflowAttr::get(ctx, NVVM::MMAIntOverflow::wrapped);
- Value res = rewriter.create<NVVM::WgmmaMmaAsyncOp>(
- loc, resultStructType, inout, descriptorA, descriptorB, shape, itype,
- itype, scaleOut, scaleIn, scaleIn, layoutA, layoutB, overflow);
+ Value res = b.create<NVVM::WgmmaMmaAsyncOp>(
+ resultStructType, inout, descriptorA, descriptorB, shape, itype, itype,
+ scaleOut, scaleIn, scaleIn, layoutA, layoutB, overflow);
return res;
}
LogicalResult
matchAndRewrite(nvgpu::WarpgroupMmaOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
+ ImplicitLocOpBuilder b(op->getLoc(), rewriter);
int64_t sizeM = op.getDescriptorA().getType().getTensor().getDimSize(0);
int64_t sizeN = op.getDescriptorB().getType().getTensor().getDimSize(1);
int64_t sizeK = op.getDescriptorA().getType().getTensor().getDimSize(1);
@@ -1232,13 +1227,11 @@ struct NVGPUWarpgroupMmaOpLowering
Value descriptorB = adaptor.getDescriptorB();
// Generate wgmma group
-
- auto loc = op->getLoc();
MemRefType typeTensorA = op.getDescriptorA().getType().getTensor();
MemRefType typeTensorB = op.getDescriptorB().getType().getTensor();
auto makeAdd = [&](Value lhs, Value rhs) -> Value {
- return rewriter.create<LLVM::AddOp>(loc, lhs.getType(), lhs, rhs);
+ return b.create<LLVM::AddOp>(lhs.getType(), lhs, rhs);
};
auto iterateDescA = [&](Value desc, int iterM, int iterN,
@@ -1254,7 +1247,7 @@ struct NVGPUWarpgroupMmaOpLowering
<< incrementVal << " | \t ");
if (!incrementVal)
return desc;
- return makeAdd(desc, makeI64Const(rewriter, op, incrementVal));
+ return makeAdd(desc, makeI64Const(b, incrementVal));
};
auto iterateDescB = [&](Value desc, int iterM, int iterN,
@@ -1266,10 +1259,10 @@ struct NVGPUWarpgroupMmaOpLowering
LLVM_DEBUG(DBGSE() << "Descriptor B + " << incrementVal << "\n");
if (!incrementVal)
return desc;
- return makeAdd(desc, makeI64Const(rewriter, op, incrementVal));
+ return makeAdd(desc, makeI64Const(b, incrementVal));
};
- rewriter.create<NVVM::WgmmaFenceAlignedOp>(loc);
+ b.create<NVVM::WgmmaFenceAlignedOp>();
SmallVector<Value> wgmmaResults;
for (int iterM = 0; iterM < (sizeM / wgmmaShapeM); iterM++) {
@@ -1291,14 +1284,13 @@ struct NVGPUWarpgroupMmaOpLowering
<< " B[" << (iterK * wgmmaShapeK) << ":"
<< (iterK * wgmmaShapeK + wgmmaShapeK) << "][" << 0
<< ":" << wgmmaShapeN << "])\n");
- matrixC = generateNVVMWgmmaOp(op->getContext(), rewriter, loc,
- wgmmaShapeM, wgmmaShapeN, wgmmaShapeK,
+ matrixC = generateNVVMWgmmaOp(b, wgmmaShapeM, wgmmaShapeN, wgmmaShapeK,
structType, matrixC, descA, descB);
}
wgmmaResults.push_back(matrixC);
}
- rewriter.create<NVVM::WgmmaGroupSyncAlignedOp>(loc);
- rewriter.create<NVVM::WgmmaWaitGroupSyncOp>(loc, op.getWaitGroup());
+ b.create<NVVM::WgmmaGroupSyncAlignedOp>();
+ b.create<NVVM::WgmmaWaitGroupSyncOp>(op.getWaitGroup());
ValueRange myres(wgmmaResults);
rewriter.replaceOp(op, myres);
More information about the Mlir-commits
mailing list