[Mlir-commits] [mlir] [mlir][nvgpu] Use ImplicitLocOpBuilder in nvgpu-to-nvvm pass (NFC) (PR #67993)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Mon Oct 2 07:48:17 PDT 2023
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir
<details>
<summary>Changes</summary>
For the sake of better readability, this PR uses `ImplicitLocOpBuilder` instead of rewriter+loc
---
Patch is 35.58 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/67993.diff
1 Files Affected:
- (modified) mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp (+146-155)
``````````diff
diff --git a/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp b/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp
index 4d1f6641af6dca3..c84960e0b22cc0f 100644
--- a/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp
+++ b/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp
@@ -19,6 +19,7 @@
#include "mlir/Dialect/NVGPU/IR/NVGPUDialect.h"
#include "mlir/Dialect/SCF/Transforms/Patterns.h"
#include "mlir/IR/BuiltinTypes.h"
+#include "mlir/IR/ImplicitLocOpBuilder.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/IR/TypeUtilities.h"
#include "mlir/IR/Value.h"
@@ -44,13 +45,12 @@ constexpr int exclude4LSB = 4;
/// GPU has 32 bit registers, this function truncates values when larger width
/// is not needed.
-static Value truncToI32(ConversionPatternRewriter &rewriter, Location loc,
- Value value) {
+static Value truncToI32(ImplicitLocOpBuilder &b, Value value) {
Type type = value.getType();
assert(llvm::isa<IntegerType>(type) && "expected an integer Value");
if (type.getIntOrFloatBitWidth() <= 32)
return value;
- return rewriter.create<LLVM::TruncOp>(loc, rewriter.getI32Type(), value);
+ return b.create<LLVM::TruncOp>(b.getI32Type(), value);
}
/// Returns the type for the intrinsic given the vectorResultType of the
@@ -170,22 +170,23 @@ static Value convertIntrinsicResult(Location loc, Type intrinsicResultType,
/// `nvvm.mma.sync` op expects these argments to be a given in a long list of
/// scalars of certain types. This function helps unpack the `vector` arguments
/// and cast them to the types expected by `nvvm.mma.sync`.
-static SmallVector<Value> unpackOperandVector(RewriterBase &rewriter,
- Location loc, Value operand,
+static SmallVector<Value> unpackOperandVector(ImplicitLocOpBuilder &b,
+ Value operand,
NVVM::MMATypes operandPtxType) {
SmallVector<Value> result;
- Type i32Ty = rewriter.getI32Type();
- Type f64Ty = rewriter.getF64Type();
- Type f32Ty = rewriter.getF32Type();
- Type i8Ty = rewriter.getI8Type();
- Type i4Ty = rewriter.getIntegerType(4);
+ Type i32Ty = b.getI32Type();
+ Type f64Ty = b.getF64Type();
+ Type f32Ty = b.getF32Type();
+ Type i8Ty = b.getI8Type();
+ Type i64Ty = b.getI64Type();
+ Type i4Ty = b.getIntegerType(4);
Type i8x4Ty = LLVM::getFixedVectorType(i8Ty, 4);
Type i4x8Ty = LLVM::getFixedVectorType(i4Ty, 8);
Type f32x1Ty = LLVM::getFixedVectorType(f32Ty, 1);
auto arrayTy = cast<LLVM::LLVMArrayType>(operand.getType());
for (unsigned i = 0, e = arrayTy.getNumElements(); i < e; ++i) {
- Value toUse = rewriter.create<LLVM::ExtractValueOp>(loc, operand, i);
+ Value toUse = b.create<LLVM::ExtractValueOp>(operand, i);
// For 4xi8 vectors, the intrinsic expects these to be provided as i32
// scalar types.
@@ -193,8 +194,7 @@ static SmallVector<Value> unpackOperandVector(RewriterBase &rewriter,
arrayTy.getElementType() == i4x8Ty ||
(arrayTy.getElementType() == f32x1Ty &&
operandPtxType == NVVM::MMATypes::tf32)) {
- result.push_back(
- rewriter.create<LLVM::BitcastOp>(loc, rewriter.getI32Type(), toUse));
+ result.push_back(b.create<LLVM::BitcastOp>(i32Ty, toUse));
continue;
}
@@ -207,10 +207,9 @@ static SmallVector<Value> unpackOperandVector(RewriterBase &rewriter,
innerArrayTy.getElementType() == f32Ty)) {
for (unsigned idx = 0, innerSize = innerArrayTy.getNumElements();
idx < innerSize; idx++) {
- result.push_back(rewriter.create<LLVM::ExtractElementOp>(
- loc, toUse,
- rewriter.create<LLVM::ConstantOp>(
- loc, rewriter.getI64Type(), rewriter.getI64IntegerAttr(idx))));
+ result.push_back(b.create<LLVM::ExtractElementOp>(
+ toUse,
+ b.create<LLVM::ConstantOp>(i64Ty, b.getI64IntegerAttr(idx))));
}
continue;
}
@@ -256,7 +255,7 @@ struct MmaLdMatrixOpToNVVM : public ConvertOpToLLVMPattern<nvgpu::LdMatrixOp> {
matchAndRewrite(nvgpu::LdMatrixOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
MLIRContext *ctx = getContext();
- Location loc = op->getLoc();
+ ImplicitLocOpBuilder b(op.getLoc(), rewriter);
// The result type of ldmatrix will always be a struct of 32bit integer
// registers if more than one 32bit value is returned. Otherwise, the result
@@ -283,10 +282,10 @@ struct MmaLdMatrixOpToNVVM : public ConvertOpToLLVMPattern<nvgpu::LdMatrixOp> {
auto srcMemrefType = cast<MemRefType>(op.getSrcMemref().getType());
Value srcPtr =
- getStridedElementPtr(loc, srcMemrefType, adaptor.getSrcMemref(),
+ getStridedElementPtr(b.getLoc(), srcMemrefType, adaptor.getSrcMemref(),
adaptor.getIndices(), rewriter);
- Value ldMatrixResult = rewriter.create<NVVM::LdMatrixOp>(
- loc, ldMatrixResultType, srcPtr,
+ Value ldMatrixResult = b.create<NVVM::LdMatrixOp>(
+ ldMatrixResultType, srcPtr,
/*num=*/op.getNumTiles(),
/*layout=*/op.getTranspose() ? NVVM::MMALayout::col
: NVVM::MMALayout::row);
@@ -296,15 +295,13 @@ struct MmaLdMatrixOpToNVVM : public ConvertOpToLLVMPattern<nvgpu::LdMatrixOp> {
// actual vector type (still of width 32b) and repack them into a result
// struct.
Type finalResultType = typeConverter->convertType(vectorResultType);
- Value result = rewriter.create<LLVM::UndefOp>(loc, finalResultType);
+ Value result = b.create<LLVM::UndefOp>(finalResultType);
for (int64_t i = 0, e = vectorResultType.getDimSize(0); i < e; i++) {
Value i32Register =
- num32BitRegs > 1
- ? rewriter.create<LLVM::ExtractValueOp>(loc, ldMatrixResult, i)
- : ldMatrixResult;
- Value casted =
- rewriter.create<LLVM::BitcastOp>(loc, innerVectorType, i32Register);
- result = rewriter.create<LLVM::InsertValueOp>(loc, result, casted, i);
+ num32BitRegs > 1 ? b.create<LLVM::ExtractValueOp>(ldMatrixResult, i)
+ : ldMatrixResult;
+ Value casted = b.create<LLVM::BitcastOp>(innerVectorType, i32Register);
+ result = b.create<LLVM::InsertValueOp>(result, casted, i);
}
rewriter.replaceOp(op, result);
@@ -335,7 +332,7 @@ struct MmaSyncOptoNVVM : public ConvertOpToLLVMPattern<nvgpu::MmaSyncOp> {
LogicalResult
matchAndRewrite(nvgpu::MmaSyncOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
- Location loc = op->getLoc();
+ ImplicitLocOpBuilder b(op.getLoc(), rewriter);
// Get the shapes of the MMAMatrix type being used. The shapes will
// choose which intrinsic this op will be lowered to.
VectorType aType = op.getMatrixA().getType();
@@ -368,17 +365,17 @@ struct MmaSyncOptoNVVM : public ConvertOpToLLVMPattern<nvgpu::MmaSyncOp> {
overflow = NVVM::MMAIntOverflow::satfinite;
SmallVector<Value> matA =
- unpackOperandVector(rewriter, loc, adaptor.getMatrixA(), *ptxTypeA);
+ unpackOperandVector(b, adaptor.getMatrixA(), *ptxTypeA);
SmallVector<Value> matB =
- unpackOperandVector(rewriter, loc, adaptor.getMatrixB(), *ptxTypeB);
+ unpackOperandVector(b, adaptor.getMatrixB(), *ptxTypeB);
SmallVector<Value> matC =
- unpackOperandVector(rewriter, loc, adaptor.getMatrixC(), *ptxTypeC);
+ unpackOperandVector(b, adaptor.getMatrixC(), *ptxTypeC);
Type desiredRetTy = typeConverter->convertType(op->getResultTypes()[0]);
Type intrinsicResTy = inferIntrinsicResultType(
typeConverter->convertType(op->getResultTypes()[0]));
- Value intrinsicResult = rewriter.create<NVVM::MmaOp>(
- op.getLoc(), intrinsicResTy, matA, matB, matC,
+ Value intrinsicResult = b.create<NVVM::MmaOp>(
+ intrinsicResTy, matA, matB, matC,
/*shape=*/gemmShape,
/*b1Op=*/std::nullopt,
/*intOverflow=*/overflow,
@@ -511,14 +508,14 @@ static std::string buildMmaSparseAsmString(
/// Builds an inline assembly operation corresponding to the specified MMA
/// sparse sync operation.
static FailureOr<LLVM::InlineAsmOp> emitMmaSparseSyncOpAsm(
- Location loc, NVVM::MMATypes ptxTypeA, NVVM::MMATypes ptxTypeB,
+ ImplicitLocOpBuilder &b, NVVM::MMATypes ptxTypeA, NVVM::MMATypes ptxTypeB,
NVVM::MMATypes ptxTypeC, NVVM::MMATypes ptxTypeD,
std::optional<NVVM::MMAIntOverflow> overflow, ArrayRef<Value> unpackedAData,
ArrayRef<Value> unpackedB, ArrayRef<Value> unpackedC, Value indexData,
int64_t metadataSelector, const std::array<int64_t, 3> &shape,
- Type intrinsicResultType, ConversionPatternRewriter &rewriter) {
- auto asmDialectAttr = LLVM::AsmDialectAttr::get(rewriter.getContext(),
- LLVM::AsmDialect::AD_ATT);
+ Type intrinsicResultType) {
+ auto asmDialectAttr =
+ LLVM::AsmDialectAttr::get(b.getContext(), LLVM::AsmDialect::AD_ATT);
const unsigned matASize = unpackedAData.size();
const unsigned matBSize = unpackedB.size();
@@ -536,15 +533,15 @@ static FailureOr<LLVM::InlineAsmOp> emitMmaSparseSyncOpAsm(
llvm::append_range(asmVals, args);
asmVals.push_back(indexData);
- return rewriter.create<LLVM::InlineAsmOp>(loc,
- /*resultTypes=*/intrinsicResultType,
- /*operands=*/asmVals,
- /*asm_string=*/asmStr,
- /*constraints=*/constraintStr,
- /*has_side_effects=*/true,
- /*is_align_stack=*/false,
- /*asm_dialect=*/asmDialectAttr,
- /*operand_attrs=*/ArrayAttr());
+ return b.create<LLVM::InlineAsmOp>(
+ /*resultTypes=*/intrinsicResultType,
+ /*operands=*/asmVals,
+ /*asm_string=*/asmStr,
+ /*constraints=*/constraintStr,
+ /*has_side_effects=*/true,
+ /*is_align_stack=*/false,
+ /*asm_dialect=*/asmDialectAttr,
+ /*operand_attrs=*/ArrayAttr());
}
/// Lowers `nvgpu.mma.sp.sync` to inline assembly.
@@ -555,7 +552,7 @@ struct NVGPUMmaSparseSyncLowering
LogicalResult
matchAndRewrite(nvgpu::MmaSparseSyncOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
- Location loc = op->getLoc();
+ ImplicitLocOpBuilder b(op.getLoc(), rewriter);
// Get the shapes of the MMAMatrix type being used. The shapes will
// choose which intrinsic this op will be lowered to.
VectorType aType = op.getMatrixA().getType();
@@ -586,11 +583,11 @@ struct NVGPUMmaSparseSyncLowering
overflow = NVVM::MMAIntOverflow::satfinite;
SmallVector<Value> matA =
- unpackOperandVector(rewriter, loc, adaptor.getMatrixA(), *ptxTypeA);
+ unpackOperandVector(b, adaptor.getMatrixA(), *ptxTypeA);
SmallVector<Value> matB =
- unpackOperandVector(rewriter, loc, adaptor.getMatrixB(), *ptxTypeB);
+ unpackOperandVector(b, adaptor.getMatrixB(), *ptxTypeB);
SmallVector<Value> matC =
- unpackOperandVector(rewriter, loc, adaptor.getMatrixC(), *ptxTypeC);
+ unpackOperandVector(b, adaptor.getMatrixC(), *ptxTypeC);
Type desiredRetTy = typeConverter->convertType(op->getResultTypes()[0]);
Type intrinsicResTy = inferIntrinsicResultType(
@@ -602,13 +599,13 @@ struct NVGPUMmaSparseSyncLowering
LLVM::getFixedVectorType(rewriter.getI16Type(), 2))
return op->emitOpError() << "Expected metadata type to be LLVM "
"VectorType of 2 i16 elements";
- sparseMetadata = rewriter.create<LLVM::BitcastOp>(
- loc, rewriter.getI32Type(), sparseMetadata);
+ sparseMetadata =
+ b.create<LLVM::BitcastOp>(rewriter.getI32Type(), sparseMetadata);
FailureOr<LLVM::InlineAsmOp> intrinsicResult = emitMmaSparseSyncOpAsm(
- loc, *ptxTypeA, *ptxTypeB, *ptxTypeC, *ptxTypeC, overflow, matA, matB,
+ b, *ptxTypeA, *ptxTypeB, *ptxTypeC, *ptxTypeC, overflow, matA, matB,
matC, sparseMetadata, op.getSparsitySelector(), op.getMmaShapeAsArray(),
- intrinsicResTy, rewriter);
+ intrinsicResTy);
if (failed(intrinsicResult))
return failure();
@@ -629,10 +626,12 @@ struct NVGPUAsyncCopyLowering
LogicalResult
matchAndRewrite(nvgpu::DeviceAsyncCopyOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
- Location loc = op->getLoc();
+ ImplicitLocOpBuilder b(op.getLoc(), rewriter);
+ Location loc = op.getLoc();
auto dstMemrefType = cast<MemRefType>(op.getDst().getType());
- Value dstPtr = getStridedElementPtr(loc, dstMemrefType, adaptor.getDst(),
- adaptor.getDstIndices(), rewriter);
+ Value dstPtr =
+ getStridedElementPtr(b.getLoc(), dstMemrefType, adaptor.getDst(),
+ adaptor.getDstIndices(), rewriter);
auto i8Ty = IntegerType::get(op.getContext(), 8);
FailureOr<unsigned> dstAddressSpace =
getTypeConverter()->getMemRefAddressSpace(dstMemrefType);
@@ -642,7 +641,7 @@ struct NVGPUAsyncCopyLowering
auto dstPointerType =
getTypeConverter()->getPointerType(i8Ty, *dstAddressSpace);
if (!getTypeConverter()->useOpaquePointers())
- dstPtr = rewriter.create<LLVM::BitcastOp>(loc, dstPointerType, dstPtr);
+ dstPtr = b.create<LLVM::BitcastOp>(dstPointerType, dstPtr);
auto srcMemrefType = cast<MemRefType>(op.getSrc().getType());
FailureOr<unsigned> srcAddressSpace =
@@ -656,12 +655,11 @@ struct NVGPUAsyncCopyLowering
auto srcPointerType =
getTypeConverter()->getPointerType(i8Ty, *srcAddressSpace);
if (!getTypeConverter()->useOpaquePointers())
- scrPtr = rewriter.create<LLVM::BitcastOp>(loc, srcPointerType, scrPtr);
+ scrPtr = b.create<LLVM::BitcastOp>(srcPointerType, scrPtr);
// Intrinsics takes a global pointer so we need an address space cast.
auto srcPointerGlobalType = getTypeConverter()->getPointerType(
i8Ty, NVVM::NVVMMemorySpace::kGlobalMemorySpace);
- scrPtr = rewriter.create<LLVM::AddrSpaceCastOp>(loc, srcPointerGlobalType,
- scrPtr);
+ scrPtr = b.create<LLVM::AddrSpaceCastOp>(srcPointerGlobalType, scrPtr);
int64_t dstElements = adaptor.getDstElements().getZExtValue();
int64_t sizeInBytes =
(dstMemrefType.getElementTypeBitWidth() * dstElements) / 8;
@@ -675,16 +673,14 @@ struct NVGPUAsyncCopyLowering
// memory) of CpAsyncOp is read only for SrcElements number of elements.
// The rest of the DstElements in the destination (shared memory) are
// filled with zeros.
- Value c3I32 = rewriter.create<LLVM::ConstantOp>(
- loc, rewriter.getI32Type(), rewriter.getI32IntegerAttr(3));
- Value bitwidth = rewriter.create<LLVM::ConstantOp>(
- loc, rewriter.getI32Type(),
- rewriter.getI32IntegerAttr(srcMemrefType.getElementTypeBitWidth()));
- Value srcElementsI32 =
- rewriter.create<LLVM::TruncOp>(loc, rewriter.getI32Type(), srcBytes);
- srcBytes = rewriter.create<LLVM::LShrOp>(
- loc, rewriter.create<LLVM::MulOp>(loc, bitwidth, srcElementsI32),
- c3I32);
+ Value c3I32 =
+ b.create<LLVM::ConstantOp>(b.getI32Type(), b.getI32IntegerAttr(3));
+ Value bitwidth = b.create<LLVM::ConstantOp>(
+ b.getI32Type(),
+ b.getI32IntegerAttr(srcMemrefType.getElementTypeBitWidth()));
+ Value srcElementsI32 = b.create<LLVM::TruncOp>(b.getI32Type(), srcBytes);
+ srcBytes = b.create<LLVM::LShrOp>(
+ b.create<LLVM::MulOp>(bitwidth, srcElementsI32), c3I32);
}
// Cache global (.cg) for 16 dst bytes, Cache all (.ca) for sizes other than
// 16 dst bytes.
@@ -693,15 +689,14 @@ struct NVGPUAsyncCopyLowering
? NVVM::LoadCacheModifierKind::CG
: NVVM::LoadCacheModifierKind::CA;
- rewriter.create<NVVM::CpAsyncOp>(
- loc, dstPtr, scrPtr, rewriter.getI32IntegerAttr(sizeInBytes),
+ b.create<NVVM::CpAsyncOp>(
+ dstPtr, scrPtr, rewriter.getI32IntegerAttr(sizeInBytes),
NVVM::LoadCacheModifierKindAttr::get(op->getContext(), cacheModifier),
srcBytes);
// Drop the result token.
- Value zero = rewriter.create<LLVM::ConstantOp>(
- op->getLoc(), IntegerType::get(op.getContext(), 32),
- rewriter.getI32IntegerAttr(0));
+ Value zero = b.create<LLVM::ConstantOp>(
+ IntegerType::get(op.getContext(), 32), rewriter.getI32IntegerAttr(0));
rewriter.replaceOp(op, zero);
return success();
}
@@ -790,14 +785,14 @@ struct MBarrierBasePattern : public ConvertOpToLLVMPattern<SourceOp> {
public:
using ConvertOpToLLVMPattern<SourceOp>::ConvertOpToLLVMPattern;
/// Returns the base pointer of the mbarrier object.
- Value getMbarrierPtr(Operation *op, nvgpu::MBarrierGroupType mbarType,
- Value memrefDesc, Value mbarId,
+ Value getMbarrierPtr(ImplicitLocOpBuilder &b,
+ nvgpu::MBarrierGroupType mbarType, Value memrefDesc,
+ Value mbarId,
ConversionPatternRewriter &rewriter) const {
MemRefType mbarrierMemrefType =
nvgpu::getMBarrierMemrefType(rewriter.getContext(), mbarType);
return ConvertToLLVMPattern::getStridedElementPtr(
- op->getLoc(), mbarrierMemrefType, memrefDesc, {mbarId}, rewriter);
- return memrefDesc;
+ b.getLoc(), mbarrierMemrefType, memrefDesc, {mbarId}, rewriter);
}
};
@@ -809,11 +804,12 @@ struct NVGPUMBarrierInitLowering
LogicalResult
matchAndRewrite(nvgpu::MBarrierInitOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
+ ImplicitLocOpBuilder b(op->getLoc(), rewriter);
nvgpu::MBarrierGroupType mbarrierType = op.getBarriers().getType();
rewriter.setInsertionPoint(op);
- Value barrier = getMbarrierPtr(op, mbarrierType, adaptor.getBarriers(),
+ Value barrier = getMbarrierPtr(b, mbarrierType, adaptor.getBarriers(),
adaptor.getMbarId(), rewriter);
- Value count = truncToI32(rewriter, op->getLoc(), adaptor.getCount());
+ Value count = truncToI32(b, adaptor.getCount());
if (isMbarrierShared(mbarrierType)) {
rewriter.replaceOpWithNewOp<NVVM::MBarrierInitSharedOp>(op, barrier,
count);
@@ -831,8 +827,9 @@ struct NVGPUMBarrierArriveLowering
LogicalResult
matchAndRewrite(nvgpu::MBarrierArriveOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
+ ImplicitLocOpBuilder b(op->getLoc(), rewriter);
Value barrier =
- getMbarrierPtr(op, op.getBarriers().getType(), adaptor.getBarriers(),
+ getMbarrierPtr(b, op.getBarriers().getType(), adaptor.getBarriers(),
adaptor.getMbarId(), rewriter);
Type tokenType = getTypeConverter()->convertType(
nvgpu::MBarrierTokenType::get(op->getContext()));
@@ -856,12 +853,13 @@ struct NVGPUMBarrierArriveNoCompleteLowering
LogicalResult
matchAndRewrite(nvgpu::MBarrierArriveNoCompleteOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
+ ImplicitLocOpBuilder b(op->getLoc(), rewriter);
Value barrier =
- getMbarrierPtr(op, op.getBarriers().getType(), adaptor.getBarriers(),
+ getMbarrierPtr(b, op.getBarriers().getType(), adaptor.getBarriers(),
adaptor.getMbarId(), rewriter);
Type tokenType = getTypeConverter()->convertType(
nvgpu::MBarrierTokenType::get(op->getContext()));
- Value count = truncToI32(rewriter, op->getLoc(), adaptor.getCount());
+ Value count = truncToI32(b, adaptor.getCount());
if (isMbarrierShared(op.getBarriers().getType())) {
rewriter.replaceOpWithNewOp<NVVM::MBarrierArriveNocompleteSharedOp>(
op, tokenType, barrier, count);
@@ -880,8 +878,9 @@ struct NVGPUM...
[truncated]
``````````
</details>
https://github.com/llvm/llvm-project/pull/67993
More information about the Mlir-commits
mailing list