[Mlir-commits] [mlir] [mlir][NFC] update `Conversion` create APIs (4/n) (#149687) (PR #149879)
Maksim Levental
llvmlistbot at llvm.org
Mon Jul 21 13:01:23 PDT 2025
https://github.com/makslevental updated https://github.com/llvm/llvm-project/pull/149879
>From 24edc8dd17e99d4e6fdffab87261db80e1ef8c77 Mon Sep 17 00:00:00 2001
From: max <maksim.levental at gmail.com>
Date: Mon, 21 Jul 2025 15:26:00 -0400
Subject: [PATCH] [mlir][NFC] update `Conversion` create APIs (4/n) (#149687)
See https://github.com/llvm/llvm-project/pull/147168 for more info.
---
.../AMDGPUToROCDL/AMDGPUToROCDL.cpp | 278 +++++++++---------
.../AffineToStandard/AffineToStandard.cpp | 44 +--
.../ArithToAMDGPU/ArithToAMDGPU.cpp | 213 +++++++-------
.../ArithToArmSME/ArithToArmSME.cpp | 8 +-
.../Conversion/ArithToEmitC/ArithToEmitC.cpp | 90 +++---
.../Conversion/ArithToLLVM/ArithToLLVM.cpp | 38 +--
.../Conversion/ArithToSPIRV/ArithToSPIRV.cpp | 103 +++----
.../ArmNeon2dToIntr/ArmNeon2dToIntr.cpp | 8 +-
.../Conversion/ArmSMEToLLVM/ArmSMEToLLVM.cpp | 213 +++++++-------
.../Conversion/ArmSMEToSCF/ArmSMEToSCF.cpp | 112 +++----
.../Conversion/AsyncToLLVM/AsyncToLLVM.cpp | 106 +++----
.../BufferizationToMemRef.cpp | 27 +-
12 files changed, 631 insertions(+), 609 deletions(-)
diff --git a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
index ef35ee208f002..fe3dc91328879 100644
--- a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
+++ b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
@@ -50,20 +50,20 @@ static Value convertUnsignedToI32(ConversionPatternRewriter &rewriter,
if (i32 == valTy)
return val;
return valTy.getWidth() > 32
- ? Value(rewriter.create<LLVM::TruncOp>(loc, i32, val))
- : Value(rewriter.create<LLVM::ZExtOp>(loc, i32, val));
+ ? Value(LLVM::TruncOp::create(rewriter, loc, i32, val))
+ : Value(LLVM::ZExtOp::create(rewriter, loc, i32, val));
}
static Value createI32Constant(ConversionPatternRewriter &rewriter,
Location loc, int32_t value) {
Type i32 = rewriter.getI32Type();
- return rewriter.create<LLVM::ConstantOp>(loc, i32, value);
+ return LLVM::ConstantOp::create(rewriter, loc, i32, value);
}
static Value createI1Constant(ConversionPatternRewriter &rewriter, Location loc,
bool value) {
Type llvmI1 = rewriter.getI1Type();
- return rewriter.create<LLVM::ConstantOp>(loc, llvmI1, value);
+ return LLVM::ConstantOp::create(rewriter, loc, llvmI1, value);
}
/// Returns the linear index used to access an element in the memref.
@@ -78,11 +78,11 @@ static Value getLinearIndexI32(ConversionPatternRewriter &rewriter,
ShapedType::isDynamic(stride)
? convertUnsignedToI32(rewriter, loc,
memRefDescriptor.stride(rewriter, loc, i))
- : rewriter.create<LLVM::ConstantOp>(loc, i32, stride);
- increment = rewriter.create<LLVM::MulOp>(loc, increment, strideValue);
+ : LLVM::ConstantOp::create(rewriter, loc, i32, stride);
+ increment = LLVM::MulOp::create(rewriter, loc, increment, strideValue);
}
- index =
- index ? rewriter.create<LLVM::AddOp>(loc, index, increment) : increment;
+ index = index ? LLVM::AddOp::create(rewriter, loc, index, increment)
+ : increment;
}
return index ? index : createI32Constant(rewriter, loc, 0);
}
@@ -110,14 +110,14 @@ static Value getNumRecords(ConversionPatternRewriter &rewriter, Location loc,
for (uint32_t i = 0, e = memrefType.getRank(); i < e; ++i) {
Value size = memrefDescriptor.size(rewriter, loc, i);
Value stride = memrefDescriptor.stride(rewriter, loc, i);
- Value maxThisDim = rewriter.create<LLVM::MulOp>(loc, size, stride);
+ Value maxThisDim = LLVM::MulOp::create(rewriter, loc, size, stride);
maxIndex = maxIndex
- ? rewriter.create<LLVM::UMaxOp>(loc, maxIndex, maxThisDim)
+ ? LLVM::UMaxOp::create(rewriter, loc, maxIndex, maxThisDim)
: maxThisDim;
}
Value maxIndexI32 = convertUnsignedToI32(rewriter, loc, maxIndex);
Value byteWidthConst = createI32Constant(rewriter, loc, elementByteWidth);
- return rewriter.create<LLVM::MulOp>(loc, maxIndexI32, byteWidthConst);
+ return LLVM::MulOp::create(rewriter, loc, maxIndexI32, byteWidthConst);
}
static Value makeBufferRsrc(ConversionPatternRewriter &rewriter, Location loc,
@@ -132,14 +132,14 @@ static Value makeBufferRsrc(ConversionPatternRewriter &rewriter, Location loc,
Value stride;
if (chipset.majorVersion == 9 && chipset >= kGfx942 && cacheSwizzleStride) {
Value cacheStrideZext =
- rewriter.create<LLVM::ZExtOp>(loc, i16, cacheSwizzleStride);
- Value swizzleBit = rewriter.create<LLVM::ConstantOp>(
- loc, i16, rewriter.getI16IntegerAttr(1 << 14));
- stride = rewriter.create<LLVM::OrOp>(loc, cacheStrideZext, swizzleBit,
- /*isDisjoint=*/true);
+ LLVM::ZExtOp::create(rewriter, loc, i16, cacheSwizzleStride);
+ Value swizzleBit = LLVM::ConstantOp::create(
+ rewriter, loc, i16, rewriter.getI16IntegerAttr(1 << 14));
+ stride = LLVM::OrOp::create(rewriter, loc, cacheStrideZext, swizzleBit,
+ /*isDisjoint=*/true);
} else {
- stride = rewriter.create<LLVM::ConstantOp>(loc, i16,
- rewriter.getI16IntegerAttr(0));
+ stride = LLVM::ConstantOp::create(rewriter, loc, i16,
+ rewriter.getI16IntegerAttr(0));
}
// Get the number of elements.
// Flag word:
@@ -209,20 +209,21 @@ struct FatRawBufferCastLowering
: descriptor.alignedPtr(rewriter, loc);
Value offset = adaptor.getResetOffset()
- ? rewriter.create<LLVM::ConstantOp>(
- loc, getIndexType(), rewriter.getIndexAttr(0))
+ ? LLVM::ConstantOp::create(rewriter, loc, getIndexType(),
+ rewriter.getIndexAttr(0))
: descriptor.offset(rewriter, loc);
bool hasSizes = memrefType.getRank() > 0;
// No need to unpack() and pack() all the individual sizes and strides,
// so we'll just extract the arrays.
- Value sizes = hasSizes ? rewriter.create<LLVM::ExtractValueOp>(
- loc, descriptor, kSizePosInMemRefDescriptor)
- : Value{};
- Value strides = hasSizes
- ? rewriter.create<LLVM::ExtractValueOp>(
- loc, descriptor, kStridePosInMemRefDescriptor)
- : Value{};
+ Value sizes = hasSizes
+ ? LLVM::ExtractValueOp::create(rewriter, loc, descriptor,
+ kSizePosInMemRefDescriptor)
+ : Value{};
+ Value strides =
+ hasSizes ? LLVM::ExtractValueOp::create(rewriter, loc, descriptor,
+ kStridePosInMemRefDescriptor)
+ : Value{};
Value fatPtr = makeBufferRsrc(
rewriter, loc, basePointer, numRecords, adaptor.getBoundsCheck(),
@@ -231,17 +232,17 @@ struct FatRawBufferCastLowering
Value result = MemRefDescriptor::poison(
rewriter, loc,
getTypeConverter()->convertType(op.getResult().getType()));
- result = rewriter.create<LLVM::InsertValueOp>(
- loc, result, fatPtr, kAllocatedPtrPosInMemRefDescriptor);
- result = rewriter.create<LLVM::InsertValueOp>(
- loc, result, fatPtr, kAlignedPtrPosInMemRefDescriptor);
- result = rewriter.create<LLVM::InsertValueOp>(loc, result, offset,
- kOffsetPosInMemRefDescriptor);
+ result = LLVM::InsertValueOp::create(rewriter, loc, result, fatPtr,
+ kAllocatedPtrPosInMemRefDescriptor);
+ result = LLVM::InsertValueOp::create(rewriter, loc, result, fatPtr,
+ kAlignedPtrPosInMemRefDescriptor);
+ result = LLVM::InsertValueOp::create(rewriter, loc, result, offset,
+ kOffsetPosInMemRefDescriptor);
if (hasSizes) {
- result = rewriter.create<LLVM::InsertValueOp>(loc, result, sizes,
- kSizePosInMemRefDescriptor);
- result = rewriter.create<LLVM::InsertValueOp>(
- loc, result, strides, kStridePosInMemRefDescriptor);
+ result = LLVM::InsertValueOp::create(rewriter, loc, result, sizes,
+ kSizePosInMemRefDescriptor);
+ result = LLVM::InsertValueOp::create(rewriter, loc, result, strides,
+ kStridePosInMemRefDescriptor);
}
rewriter.replaceOp(op, result);
return success();
@@ -342,8 +343,8 @@ struct RawBufferOpLowering : public ConvertOpToLLVMPattern<GpuOp> {
SmallVector<Value, 6> args;
if (storeData) {
if (llvmBufferValType != llvmWantedDataType) {
- Value castForStore =
- rewriter.create<LLVM::BitcastOp>(loc, llvmBufferValType, storeData);
+ Value castForStore = LLVM::BitcastOp::create(
+ rewriter, loc, llvmBufferValType, storeData);
args.push_back(castForStore);
} else {
args.push_back(storeData);
@@ -352,8 +353,8 @@ struct RawBufferOpLowering : public ConvertOpToLLVMPattern<GpuOp> {
if (atomicCmpData) {
if (llvmBufferValType != llvmWantedDataType) {
- Value castForCmp = rewriter.create<LLVM::BitcastOp>(
- loc, llvmBufferValType, atomicCmpData);
+ Value castForCmp = LLVM::BitcastOp::create(
+ rewriter, loc, llvmBufferValType, atomicCmpData);
args.push_back(castForCmp);
} else {
args.push_back(atomicCmpData);
@@ -382,18 +383,18 @@ struct RawBufferOpLowering : public ConvertOpToLLVMPattern<GpuOp> {
if (std::optional<int32_t> indexOffset = adaptor.getIndexOffset();
indexOffset && *indexOffset > 0) {
Value extraOffsetConst = createI32Constant(rewriter, loc, *indexOffset);
- voffset =
- voffset ? rewriter.create<LLVM::AddOp>(loc, voffset, extraOffsetConst)
- : extraOffsetConst;
+ voffset = voffset ? LLVM::AddOp::create(rewriter, loc, voffset,
+ extraOffsetConst)
+ : extraOffsetConst;
}
- voffset = rewriter.create<LLVM::MulOp>(loc, voffset, byteWidthConst);
+ voffset = LLVM::MulOp::create(rewriter, loc, voffset, byteWidthConst);
args.push_back(voffset);
// SGPR offset.
Value sgprOffset = adaptor.getSgprOffset();
if (!sgprOffset)
sgprOffset = createI32Constant(rewriter, loc, 0);
- sgprOffset = rewriter.create<LLVM::MulOp>(loc, sgprOffset, byteWidthConst);
+ sgprOffset = LLVM::MulOp::create(rewriter, loc, sgprOffset, byteWidthConst);
args.push_back(sgprOffset);
// bit 0: GLC = 0 (atomics drop value, less coherency)
@@ -403,13 +404,13 @@ struct RawBufferOpLowering : public ConvertOpToLLVMPattern<GpuOp> {
llvm::SmallVector<Type, 1> resultTypes(gpuOp->getNumResults(),
llvmBufferValType);
- Operation *lowered = rewriter.create<Intrinsic>(loc, resultTypes, args,
- ArrayRef<NamedAttribute>());
+ Operation *lowered = Intrinsic::create(rewriter, loc, resultTypes, args,
+ ArrayRef<NamedAttribute>());
if (lowered->getNumResults() == 1) {
Value replacement = lowered->getResult(0);
if (llvmBufferValType != llvmWantedDataType) {
- replacement = rewriter.create<LLVM::BitcastOp>(loc, llvmWantedDataType,
- replacement);
+ replacement = LLVM::BitcastOp::create(rewriter, loc, llvmWantedDataType,
+ replacement);
}
rewriter.replaceOp(gpuOp, replacement);
} else {
@@ -465,12 +466,12 @@ struct LDSBarrierOpLowering : public ConvertOpToLLVMPattern<LDSBarrierOp> {
<< chipset.majorVersion;
Location loc = op->getLoc();
- rewriter.create<ROCDL::SWaitcntOp>(loc, ldsOnlyBits);
+ ROCDL::SWaitcntOp::create(rewriter, loc, ldsOnlyBits);
rewriter.replaceOpWithNewOp<ROCDL::SBarrierOp>(op);
} else {
Location loc = op->getLoc();
- rewriter.create<ROCDL::WaitDscntOp>(loc, 0);
- rewriter.create<ROCDL::BarrierSignalOp>(loc, -1);
+ ROCDL::WaitDscntOp::create(rewriter, loc, 0);
+ ROCDL::BarrierSignalOp::create(rewriter, loc, -1);
rewriter.replaceOpWithNewOp<ROCDL::BarrierWaitOp>(op, -1);
}
@@ -516,19 +517,21 @@ static Value convertMFMAVectorOperand(ConversionPatternRewriter &rewriter,
Type inputType = input.getType();
if (auto vectorType = dyn_cast<VectorType>(inputType)) {
if (vectorType.getElementType().isBF16() && !allowBf16)
- return rewriter.create<LLVM::BitcastOp>(
- loc, vectorType.clone(rewriter.getI16Type()), input);
+ return LLVM::BitcastOp::create(
+ rewriter, loc, vectorType.clone(rewriter.getI16Type()), input);
if (vectorType.getElementType().isInteger(8) &&
vectorType.getNumElements() <= 8)
- return rewriter.create<LLVM::BitcastOp>(
- loc, rewriter.getIntegerType(vectorType.getNumElements() * 8), input);
+ return LLVM::BitcastOp::create(
+ rewriter, loc,
+ rewriter.getIntegerType(vectorType.getNumElements() * 8), input);
if (isa<IntegerType>(vectorType.getElementType()) &&
vectorType.getElementTypeBitWidth() <= 8) {
int64_t numWords = llvm::divideCeil(
vectorType.getNumElements() * vectorType.getElementTypeBitWidth(),
32);
- return rewriter.create<LLVM::BitcastOp>(
- loc, VectorType::get(numWords, rewriter.getI32Type()), input);
+ return LLVM::BitcastOp::create(
+ rewriter, loc, VectorType::get(numWords, rewriter.getI32Type()),
+ input);
}
}
return input;
@@ -549,8 +552,8 @@ static Value castMFMAScaleOperand(ConversionPatternRewriter &rewriter,
Type inputType = input.getType();
Type outputType = rewriter.getI32Type();
if (auto intType = dyn_cast<IntegerType>(inputType))
- return rewriter.create<LLVM::ZExtOp>(loc, outputType, input);
- return rewriter.create<LLVM::BitcastOp>(loc, outputType, input);
+ return LLVM::ZExtOp::create(rewriter, loc, outputType, input);
+ return LLVM::BitcastOp::create(rewriter, loc, outputType, input);
}
/// Push an input operand. If it is a float type, nothing to do. If it is
@@ -576,8 +579,8 @@ static void wmmaPushInputOperand(ConversionPatternRewriter &rewriter,
Type elemType = vectorType.getElementType();
if (elemType.isBF16())
- llvmInput = rewriter.create<LLVM::BitcastOp>(
- loc, vectorType.clone(rewriter.getI16Type()), llvmInput);
+ llvmInput = LLVM::BitcastOp::create(
+ rewriter, loc, vectorType.clone(rewriter.getI16Type()), llvmInput);
if (elemType.getIntOrFloatBitWidth() > 8) {
operands.push_back(llvmInput);
return;
@@ -613,7 +616,7 @@ static void wmmaPushInputOperand(ConversionPatternRewriter &rewriter,
// (256 / 64) * 4 = 16 bits of input (on gfx12+) but take i32 arguments.
// Add in the zeros here.
if (numBits < 32)
- castInput = rewriter.create<LLVM::ZExtOp>(loc, i32, castInput);
+ castInput = LLVM::ZExtOp::create(rewriter, loc, i32, castInput);
operands.push_back(castInput);
}
@@ -633,8 +636,8 @@ static void wmmaPushOutputOperand(ConversionPatternRewriter &rewriter,
auto vectorType = dyn_cast<VectorType>(inputType);
Type elemType = vectorType.getElementType();
if (elemType.isBF16())
- output = rewriter.create<LLVM::BitcastOp>(
- loc, vectorType.clone(rewriter.getI16Type()), output);
+ output = LLVM::BitcastOp::create(
+ rewriter, loc, vectorType.clone(rewriter.getI16Type()), output);
operands.push_back(output);
if (elemType.isF16() || elemType.isBF16() || elemType.isInteger(16)) {
operands.push_back(createI1Constant(rewriter, loc, subwordOffset));
@@ -992,7 +995,7 @@ struct MFMAOpLowering : public ConvertOpToLLVMPattern<MFMAOp> {
};
Value lowered = rewriter.create(loweredOp)->getResult(0);
if (outType != intrinsicOutType)
- lowered = rewriter.create<LLVM::BitcastOp>(loc, outType, lowered);
+ lowered = LLVM::BitcastOp::create(rewriter, loc, outType, lowered);
rewriter.replaceOp(op, lowered);
return success();
}
@@ -1092,8 +1095,8 @@ struct WMMAOpLowering : public ConvertOpToLLVMPattern<WMMAOp> {
Operation *maybeCastBack = lowered;
if (rawOutType != outType)
- maybeCastBack =
- rewriter.create<LLVM::BitcastOp>(loc, outType, lowered->getResult(0));
+ maybeCastBack = LLVM::BitcastOp::create(rewriter, loc, outType,
+ lowered->getResult(0));
rewriter.replaceOp(op, maybeCastBack->getResults());
return success();
@@ -1143,22 +1146,22 @@ struct TransposeLoadOpLowering
switch (elementTypeSize) {
case 4: {
assert(numElements == 16);
- auto rocdlOp =
- rewriter.create<ROCDL::ds_read_tr4_b64>(loc, rocdlResultType, srcPtr);
+ auto rocdlOp = ROCDL::ds_read_tr4_b64::create(rewriter, loc,
+ rocdlResultType, srcPtr);
rewriter.replaceOpWithNewOp<LLVM::BitcastOp>(op, llvmResultType, rocdlOp);
break;
}
case 6: {
assert(numElements == 16);
- auto rocdlOp =
- rewriter.create<ROCDL::ds_read_tr6_b96>(loc, rocdlResultType, srcPtr);
+ auto rocdlOp = ROCDL::ds_read_tr6_b96::create(rewriter, loc,
+ rocdlResultType, srcPtr);
rewriter.replaceOpWithNewOp<LLVM::BitcastOp>(op, llvmResultType, rocdlOp);
break;
}
case 8: {
assert(numElements == 8);
- auto rocdlOp =
- rewriter.create<ROCDL::ds_read_tr8_b64>(loc, rocdlResultType, srcPtr);
+ auto rocdlOp = ROCDL::ds_read_tr8_b64::create(rewriter, loc,
+ rocdlResultType, srcPtr);
rewriter.replaceOpWithNewOp<LLVM::BitcastOp>(op, llvmResultType, rocdlOp);
break;
}
@@ -1316,21 +1319,21 @@ LogicalResult ExtPackedFp8OpLowering::matchAndRewrite(
Type sourceElemType = getElementTypeOrSelf(op.getSource());
// Extend to a v4i8
if (!sourceVecType || sourceVecType.getNumElements() < 4) {
- Value longVec = rewriter.create<LLVM::UndefOp>(loc, v4i8);
+ Value longVec = LLVM::UndefOp::create(rewriter, loc, v4i8);
if (!sourceVecType) {
- longVec = rewriter.create<LLVM::InsertElementOp>(
- loc, longVec, source, createI32Constant(rewriter, loc, 0));
+ longVec = LLVM::InsertElementOp::create(
+ rewriter, loc, longVec, source, createI32Constant(rewriter, loc, 0));
} else {
for (int32_t i = 0, e = sourceVecType.getNumElements(); i < e; ++i) {
Value idx = createI32Constant(rewriter, loc, i);
- Value elem = rewriter.create<LLVM::ExtractElementOp>(loc, source, idx);
+ Value elem = LLVM::ExtractElementOp::create(rewriter, loc, source, idx);
longVec =
- rewriter.create<LLVM::InsertElementOp>(loc, longVec, elem, idx);
+ LLVM::InsertElementOp::create(rewriter, loc, longVec, elem, idx);
}
}
source = longVec;
}
- Value i32Source = rewriter.create<LLVM::BitcastOp>(loc, i32, source);
+ Value i32Source = LLVM::BitcastOp::create(rewriter, loc, i32, source);
if (resultVecType) {
if (typeIsExpectedBf8ForChipset(chipset, sourceElemType)) {
rewriter.replaceOpWithNewOp<ROCDL::CvtPkF32Bf8Op>(op, f32, i32Source,
@@ -1382,21 +1385,21 @@ LogicalResult ScaledExtPackedOpLowering::matchAndRewrite(
// Extend to a packedVectorType
if (sourceVecType.getNumElements() < packedVecType.getNumElements()) {
- Value longVec = rewriter.create<LLVM::ZeroOp>(loc, packedVecType);
+ Value longVec = LLVM::ZeroOp::create(rewriter, loc, packedVecType);
if (!sourceVecType) {
- longVec = rewriter.create<LLVM::InsertElementOp>(
- loc, longVec, source, createI32Constant(rewriter, loc, 0));
+ longVec = LLVM::InsertElementOp::create(
+ rewriter, loc, longVec, source, createI32Constant(rewriter, loc, 0));
} else {
for (int32_t i = 0, e = sourceVecType.getNumElements(); i < e; ++i) {
Value idx = createI32Constant(rewriter, loc, i);
- Value elem = rewriter.create<LLVM::ExtractElementOp>(loc, source, idx);
+ Value elem = LLVM::ExtractElementOp::create(rewriter, loc, source, idx);
longVec =
- rewriter.create<LLVM::InsertElementOp>(loc, longVec, elem, idx);
+ LLVM::InsertElementOp::create(rewriter, loc, longVec, elem, idx);
}
}
source = longVec;
}
- Value i32Source = rewriter.create<LLVM::BitcastOp>(loc, i32, source);
+ Value i32Source = LLVM::BitcastOp::create(rewriter, loc, i32, source);
if (isa<Float8E5M2Type>(sourceElemType) && destElemType.isF32())
rewriter.replaceOpWithNewOp<ROCDL::CvtScaleF32PkF32Bf8Op>(
@@ -1454,54 +1457,57 @@ LogicalResult PackedScaledTruncOpLowering::matchAndRewrite(
Value scale = adaptor.getScale();
Value existing = adaptor.getExisting();
if (existing)
- existing = rewriter.create<LLVM::BitcastOp>(loc, intResultType, existing);
+ existing = LLVM::BitcastOp::create(rewriter, loc, intResultType, existing);
else
- existing = rewriter.create<LLVM::ZeroOp>(loc, intResultType);
+ existing = LLVM::ZeroOp::create(rewriter, loc, intResultType);
if (sourceVecType.getNumElements() < 2) {
Value c0 = createI32Constant(rewriter, loc, 0);
- Value elem0 = rewriter.create<LLVM::ExtractElementOp>(loc, source, c0);
+ Value elem0 = LLVM::ExtractElementOp::create(rewriter, loc, source, c0);
VectorType v2 = VectorType::get(2, sourceElemType);
- source = rewriter.create<LLVM::ZeroOp>(loc, v2);
- source = rewriter.create<LLVM::InsertElementOp>(loc, source, elem0, c0);
+ source = LLVM::ZeroOp::create(rewriter, loc, v2);
+ source = LLVM::InsertElementOp::create(rewriter, loc, source, elem0, c0);
}
Value sourceA, sourceB;
if (sourceElemType.isF32()) {
Value c0 = createI32Constant(rewriter, loc, 0);
Value c1 = createI32Constant(rewriter, loc, 1);
- sourceA = rewriter.create<LLVM::ExtractElementOp>(loc, source, c0);
- sourceB = rewriter.create<LLVM::ExtractElementOp>(loc, source, c1);
+ sourceA = LLVM::ExtractElementOp::create(rewriter, loc, source, c0);
+ sourceB = LLVM::ExtractElementOp::create(rewriter, loc, source, c1);
}
Value result;
if (sourceElemType.isF32() && isa<Float8E5M2Type>(resultElemType))
- result = rewriter.create<ROCDL::CvtScaleF32PkBf8F32Op>(
- loc, intResultType, existing, sourceA, sourceB, scale, op.getIndex());
+ result = ROCDL::CvtScaleF32PkBf8F32Op::create(rewriter, loc, intResultType,
+ existing, sourceA, sourceB,
+ scale, op.getIndex());
else if (sourceElemType.isF16() && isa<Float8E5M2Type>(resultElemType))
- result = rewriter.create<ROCDL::CvtScaleF32PkBf8F16Op>(
- loc, intResultType, existing, source, scale, op.getIndex());
+ result = ROCDL::CvtScaleF32PkBf8F16Op::create(
+ rewriter, loc, intResultType, existing, source, scale, op.getIndex());
else if (sourceElemType.isBF16() && isa<Float8E5M2Type>(resultElemType))
- result = rewriter.create<ROCDL::CvtScaleF32PkBf8Bf16Op>(
- loc, intResultType, existing, source, scale, op.getIndex());
+ result = ROCDL::CvtScaleF32PkBf8Bf16Op::create(
+ rewriter, loc, intResultType, existing, source, scale, op.getIndex());
else if (sourceElemType.isF32() && isa<Float8E4M3FNType>(resultElemType))
- result = rewriter.create<ROCDL::CvtScaleF32PkFp8F32Op>(
- loc, intResultType, existing, sourceA, sourceB, scale, op.getIndex());
+ result = ROCDL::CvtScaleF32PkFp8F32Op::create(rewriter, loc, intResultType,
+ existing, sourceA, sourceB,
+ scale, op.getIndex());
else if (sourceElemType.isF16() && isa<Float8E4M3FNType>(resultElemType))
- result = rewriter.create<ROCDL::CvtScaleF32PkFp8F16Op>(
- loc, intResultType, existing, source, scale, op.getIndex());
+ result = ROCDL::CvtScaleF32PkFp8F16Op::create(
+ rewriter, loc, intResultType, existing, source, scale, op.getIndex());
else if (sourceElemType.isBF16() && isa<Float8E4M3FNType>(resultElemType))
- result = rewriter.create<ROCDL::CvtScaleF32PkFp8Bf16Op>(
- loc, intResultType, existing, source, scale, op.getIndex());
+ result = ROCDL::CvtScaleF32PkFp8Bf16Op::create(
+ rewriter, loc, intResultType, existing, source, scale, op.getIndex());
else if (sourceElemType.isF32() && isa<Float4E2M1FNType>(resultElemType))
- result = rewriter.create<ROCDL::CvtScaleF32PkFp4F32Op>(
- loc, intResultType, existing, sourceA, sourceB, scale, op.getIndex());
+ result = ROCDL::CvtScaleF32PkFp4F32Op::create(rewriter, loc, intResultType,
+ existing, sourceA, sourceB,
+ scale, op.getIndex());
else if (sourceElemType.isF16() && isa<Float4E2M1FNType>(resultElemType))
- result = rewriter.create<ROCDL::CvtScaleF32PkFp4F16Op>(
- loc, intResultType, existing, source, scale, op.getIndex());
+ result = ROCDL::CvtScaleF32PkFp4F16Op::create(
+ rewriter, loc, intResultType, existing, source, scale, op.getIndex());
else if (sourceElemType.isBF16() && isa<Float4E2M1FNType>(resultElemType))
- result = rewriter.create<ROCDL::CvtScaleF32PkFp4Bf16Op>(
- loc, intResultType, existing, source, scale, op.getIndex());
+ result = ROCDL::CvtScaleF32PkFp4Bf16Op::create(
+ rewriter, loc, intResultType, existing, source, scale, op.getIndex());
else
return failure();
@@ -1526,20 +1532,20 @@ LogicalResult PackedTrunc2xFp8OpLowering::matchAndRewrite(
Value sourceA = adaptor.getSourceA();
Value sourceB = adaptor.getSourceB();
if (!sourceB)
- sourceB = rewriter.create<LLVM::UndefOp>(loc, sourceA.getType());
+ sourceB = LLVM::UndefOp::create(rewriter, loc, sourceA.getType());
Value existing = adaptor.getExisting();
if (existing)
- existing = rewriter.create<LLVM::BitcastOp>(loc, i32, existing);
+ existing = LLVM::BitcastOp::create(rewriter, loc, i32, existing);
else
- existing = rewriter.create<LLVM::UndefOp>(loc, i32);
+ existing = LLVM::UndefOp::create(rewriter, loc, i32);
Value result;
if (typeIsExpectedBf8ForChipset(chipset, resultElemType))
- result = rewriter.create<ROCDL::CvtPkBf8F32Op>(loc, i32, sourceA, sourceB,
- existing, op.getWordIndex());
+ result = ROCDL::CvtPkBf8F32Op::create(rewriter, loc, i32, sourceA, sourceB,
+ existing, op.getWordIndex());
else if (typeIsExpectedFp8ForChipset(chipset, resultElemType))
- result = rewriter.create<ROCDL::CvtPkFp8F32Op>(loc, i32, sourceA, sourceB,
- existing, op.getWordIndex());
+ result = ROCDL::CvtPkFp8F32Op::create(rewriter, loc, i32, sourceA, sourceB,
+ existing, op.getWordIndex());
result = rewriter.replaceOpWithNewOp<LLVM::BitcastOp>(
op, getTypeConverter()->convertType(resultType), result);
@@ -1563,17 +1569,17 @@ LogicalResult PackedStochRoundFp8OpLowering::matchAndRewrite(
Value stoch = adaptor.getStochiasticParam();
Value existing = adaptor.getExisting();
if (existing)
- existing = rewriter.create<LLVM::BitcastOp>(loc, i32, existing);
+ existing = LLVM::BitcastOp::create(rewriter, loc, i32, existing);
else
- existing = rewriter.create<LLVM::UndefOp>(loc, i32);
+ existing = LLVM::UndefOp::create(rewriter, loc, i32);
Value result;
if (typeIsExpectedBf8ForChipset(chipset, resultElemType))
- result = rewriter.create<ROCDL::CvtSrBf8F32Op>(
- loc, i32, source, stoch, existing, op.getStoreIndex());
+ result = ROCDL::CvtSrBf8F32Op::create(rewriter, loc, i32, source, stoch,
+ existing, op.getStoreIndex());
else if (typeIsExpectedFp8ForChipset(chipset, resultElemType))
- result = rewriter.create<ROCDL::CvtSrFp8F32Op>(
- loc, i32, source, stoch, existing, op.getStoreIndex());
+ result = ROCDL::CvtSrFp8F32Op::create(rewriter, loc, i32, source, stoch,
+ existing, op.getStoreIndex());
result = rewriter.replaceOpWithNewOp<LLVM::BitcastOp>(
op, getTypeConverter()->convertType(resultType), result);
@@ -1617,14 +1623,15 @@ struct AMDGPUDPPLowering : public ConvertOpToLLVMPattern<DPPOp> {
if (operandType.getIntOrFloatBitWidth() <= 16) {
if (llvm::isa<FloatType>(operandType)) {
operand =
- rewriter.create<LLVM::BitcastOp>(loc, llvmSrcIntType, operand);
+ LLVM::BitcastOp::create(rewriter, loc, llvmSrcIntType, operand);
}
auto llvmVecType = typeConverter->convertType(mlir::VectorType::get(
32 / operandType.getIntOrFloatBitWidth(), llvmSrcIntType));
- Value undefVec = rewriter.create<LLVM::UndefOp>(loc, llvmVecType);
- operand = rewriter.create<LLVM::InsertElementOp>(
- loc, undefVec, operand, createI32Constant(rewriter, loc, 0));
- operand = rewriter.create<LLVM::BitcastOp>(loc, llvmType, operand);
+ Value undefVec = LLVM::UndefOp::create(rewriter, loc, llvmVecType);
+ operand =
+ LLVM::InsertElementOp::create(rewriter, loc, undefVec, operand,
+ createI32Constant(rewriter, loc, 0));
+ operand = LLVM::BitcastOp::create(rewriter, loc, llvmType, operand);
}
return operand;
};
@@ -1711,14 +1718,15 @@ struct AMDGPUDPPLowering : public ConvertOpToLLVMPattern<DPPOp> {
bool boundCtrl = DppOp->getAttrOfType<BoolAttr>("bound_ctrl").getValue();
// create a ROCDL_DPPMovOp instruction with the appropriate attributes
- auto dppMovOp = rewriter.create<ROCDL::DPPUpdateOp>(
- loc, llvmType, old, src, DppCtrl, rowMask, bankMask, boundCtrl);
+ auto dppMovOp =
+ ROCDL::DPPUpdateOp::create(rewriter, loc, llvmType, old, src, DppCtrl,
+ rowMask, bankMask, boundCtrl);
Value result = dppMovOp.getRes();
if (srcType.getIntOrFloatBitWidth() < 32) {
- result = rewriter.create<LLVM::TruncOp>(loc, llvmSrcIntType, result);
+ result = LLVM::TruncOp::create(rewriter, loc, llvmSrcIntType, result);
if (!llvm::isa<IntegerType>(srcType)) {
- result = rewriter.create<LLVM::BitcastOp>(loc, srcType, result);
+ result = LLVM::BitcastOp::create(rewriter, loc, srcType, result);
}
}
@@ -1752,7 +1760,7 @@ struct AMDGPUSwizzleBitModeLowering
SmallVector<Value> swizzled;
for (Value v : decomposed) {
Value res =
- rewriter.create<ROCDL::DsSwizzleOp>(loc, v.getType(), v, maskValue);
+ ROCDL::DsSwizzleOp::create(rewriter, loc, v.getType(), v, maskValue);
swizzled.emplace_back(res);
}
diff --git a/mlir/lib/Conversion/AffineToStandard/AffineToStandard.cpp b/mlir/lib/Conversion/AffineToStandard/AffineToStandard.cpp
index 3b143ca1ef9ce..3b148f9021666 100644
--- a/mlir/lib/Conversion/AffineToStandard/AffineToStandard.cpp
+++ b/mlir/lib/Conversion/AffineToStandard/AffineToStandard.cpp
@@ -50,9 +50,9 @@ static Value buildMinMaxReductionSeq(Location loc,
Value value = *valueIt++;
for (; valueIt != values.end(); ++valueIt) {
if (predicate == arith::CmpIPredicate::sgt)
- value = builder.create<arith::MaxSIOp>(loc, value, *valueIt);
+ value = arith::MaxSIOp::create(builder, loc, value, *valueIt);
else
- value = builder.create<arith::MinSIOp>(loc, value, *valueIt);
+ value = arith::MinSIOp::create(builder, loc, value, *valueIt);
}
return value;
@@ -154,9 +154,9 @@ class AffineForLowering : public OpRewritePattern<AffineForOp> {
Value lowerBound = lowerAffineLowerBound(op, rewriter);
Value upperBound = lowerAffineUpperBound(op, rewriter);
Value step =
- rewriter.create<arith::ConstantIndexOp>(loc, op.getStepAsInt());
- auto scfForOp = rewriter.create<scf::ForOp>(loc, lowerBound, upperBound,
- step, op.getInits());
+ arith::ConstantIndexOp::create(rewriter, loc, op.getStepAsInt());
+ auto scfForOp = scf::ForOp::create(rewriter, loc, lowerBound, upperBound,
+ step, op.getInits());
rewriter.eraseBlock(scfForOp.getBody());
rewriter.inlineRegionBefore(op.getRegion(), scfForOp.getRegion(),
scfForOp.getRegion().end());
@@ -197,7 +197,7 @@ class AffineParallelLowering : public OpRewritePattern<AffineParallelOp> {
}
steps.reserve(op.getSteps().size());
for (int64_t step : op.getSteps())
- steps.push_back(rewriter.create<arith::ConstantIndexOp>(loc, step));
+ steps.push_back(arith::ConstantIndexOp::create(rewriter, loc, step));
// Get the terminator op.
auto affineParOpTerminator =
@@ -205,9 +205,9 @@ class AffineParallelLowering : public OpRewritePattern<AffineParallelOp> {
scf::ParallelOp parOp;
if (op.getResults().empty()) {
// Case with no reduction operations/return values.
- parOp = rewriter.create<scf::ParallelOp>(loc, lowerBoundTuple,
- upperBoundTuple, steps,
- /*bodyBuilderFn=*/nullptr);
+ parOp = scf::ParallelOp::create(rewriter, loc, lowerBoundTuple,
+ upperBoundTuple, steps,
+ /*bodyBuilderFn=*/nullptr);
rewriter.eraseBlock(parOp.getBody());
rewriter.inlineRegionBefore(op.getRegion(), parOp.getRegion(),
parOp.getRegion().end());
@@ -233,9 +233,9 @@ class AffineParallelLowering : public OpRewritePattern<AffineParallelOp> {
identityVals.push_back(
arith::getIdentityValue(reductionOpValue, resultType, rewriter, loc));
}
- parOp = rewriter.create<scf::ParallelOp>(
- loc, lowerBoundTuple, upperBoundTuple, steps, identityVals,
- /*bodyBuilderFn=*/nullptr);
+ parOp = scf::ParallelOp::create(rewriter, loc, lowerBoundTuple,
+ upperBoundTuple, steps, identityVals,
+ /*bodyBuilderFn=*/nullptr);
// Copy the body of the affine.parallel op.
rewriter.eraseBlock(parOp.getBody());
@@ -261,7 +261,7 @@ class AffineParallelLowering : public OpRewritePattern<AffineParallelOp> {
Value reductionResult = arith::getReductionOp(
reductionOpValue, rewriter, loc, reductionBody.getArgument(0),
reductionBody.getArgument(1));
- rewriter.create<scf::ReduceReturnOp>(loc, reductionResult);
+ scf::ReduceReturnOp::create(rewriter, loc, reductionResult);
}
rewriter.replaceOp(op, parOp.getResults());
return success();
@@ -278,7 +278,7 @@ class AffineIfLowering : public OpRewritePattern<AffineIfOp> {
// Now we just have to handle the condition logic.
auto integerSet = op.getIntegerSet();
- Value zeroConstant = rewriter.create<arith::ConstantIndexOp>(loc, 0);
+ Value zeroConstant = arith::ConstantIndexOp::create(rewriter, loc, 0);
SmallVector<Value, 8> operands(op.getOperands());
auto operandsRef = llvm::ArrayRef(operands);
@@ -298,18 +298,18 @@ class AffineIfLowering : public OpRewritePattern<AffineIfOp> {
auto pred =
isEquality ? arith::CmpIPredicate::eq : arith::CmpIPredicate::sge;
Value cmpVal =
- rewriter.create<arith::CmpIOp>(loc, pred, affResult, zeroConstant);
- cond = cond
- ? rewriter.create<arith::AndIOp>(loc, cond, cmpVal).getResult()
- : cmpVal;
+ arith::CmpIOp::create(rewriter, loc, pred, affResult, zeroConstant);
+ cond =
+ cond ? arith::AndIOp::create(rewriter, loc, cond, cmpVal).getResult()
+ : cmpVal;
}
cond = cond ? cond
- : rewriter.create<arith::ConstantIntOp>(loc, /*value=*/1,
- /*width=*/1);
+ : arith::ConstantIntOp::create(rewriter, loc, /*value=*/1,
+ /*width=*/1);
bool hasElseRegion = !op.getElseRegion().empty();
- auto ifOp = rewriter.create<scf::IfOp>(loc, op.getResultTypes(), cond,
- hasElseRegion);
+ auto ifOp = scf::IfOp::create(rewriter, loc, op.getResultTypes(), cond,
+ hasElseRegion);
rewriter.inlineRegionBefore(op.getThenRegion(),
&ifOp.getThenRegion().back());
rewriter.eraseBlock(&ifOp.getThenRegion().back());
diff --git a/mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp b/mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp
index 156c679c5039e..73a17b09721b2 100644
--- a/mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp
+++ b/mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp
@@ -115,9 +115,9 @@ static Value castF32To(Type desType, Value f32, Location loc,
if (elementType.isF32())
return f32;
if (elementType.getIntOrFloatBitWidth() < 32)
- return rewriter.create<arith::TruncFOp>(loc, desType, f32);
+ return arith::TruncFOp::create(rewriter, loc, desType, f32);
if (elementType.getIntOrFloatBitWidth() > 32)
- return rewriter.create<arith::ExtFOp>(loc, desType, f32);
+ return arith::ExtFOp::create(rewriter, loc, desType, f32);
llvm_unreachable("The only 32-bit float type is f32");
}
@@ -139,27 +139,27 @@ ExtFOnFloat8RewritePattern::matchAndRewrite(arith::ExtFOp op,
Type outElemType = getElementTypeOrSelf(op.getOut().getType());
VectorType extResType = VectorType::get(2, rewriter.getF32Type());
if (!inVecType) {
- Value asFloat = rewriter.create<amdgpu::ExtPackedFp8Op>(
- loc, rewriter.getF32Type(), in, 0);
+ Value asFloat = amdgpu::ExtPackedFp8Op::create(
+ rewriter, loc, rewriter.getF32Type(), in, 0);
Value result = castF32To(outElemType, asFloat, loc, rewriter);
rewriter.replaceOp(op, result);
return success();
}
int64_t numElements = inVecType.getNumElements();
- Value zero = rewriter.create<arith::ConstantOp>(
- loc, outElemType, rewriter.getFloatAttr(outElemType, 0.0));
+ Value zero = arith::ConstantOp::create(
+ rewriter, loc, outElemType, rewriter.getFloatAttr(outElemType, 0.0));
VectorType outType = cast<VectorType>(op.getOut().getType());
if (inVecType.getShape().empty()) {
Value zerodSplat =
rewriter.createOrFold<vector::BroadcastOp>(loc, outType, zero);
Value scalarIn =
- rewriter.create<vector::ExtractOp>(loc, in, ArrayRef<int64_t>{});
+ vector::ExtractOp::create(rewriter, loc, in, ArrayRef<int64_t>{});
Value scalarExt =
- rewriter.create<arith::ExtFOp>(loc, outElemType, scalarIn);
- Value result = rewriter.create<vector::InsertOp>(loc, scalarExt, zerodSplat,
- ArrayRef<int64_t>{});
+ arith::ExtFOp::create(rewriter, loc, outElemType, scalarIn);
+ Value result = vector::InsertOp::create(rewriter, loc, scalarExt,
+ zerodSplat, ArrayRef<int64_t>{});
rewriter.replaceOp(op, result);
return success();
}
@@ -171,32 +171,32 @@ ExtFOnFloat8RewritePattern::matchAndRewrite(arith::ExtFOp op,
if (inVecType.getRank() > 1) {
inVecType = VectorType::get(SmallVector<int64_t>{numElements},
inVecType.getElementType());
- in = rewriter.create<vector::ShapeCastOp>(loc, inVecType, in);
+ in = vector::ShapeCastOp::create(rewriter, loc, inVecType, in);
}
for (int64_t i = 0; i < numElements; i += 4) {
int64_t elemsThisOp = std::min(numElements, i + 4) - i;
- Value inSlice = rewriter.create<vector::ExtractStridedSliceOp>(
- loc, in, i, elemsThisOp, 1);
+ Value inSlice = vector::ExtractStridedSliceOp::create(rewriter, loc, in, i,
+ elemsThisOp, 1);
for (int64_t j = 0; j < elemsThisOp; j += 2) {
if (i + j + 1 < numElements) { // Convert two 8-bit elements
- Value asFloats = rewriter.create<amdgpu::ExtPackedFp8Op>(
- loc, extResType, inSlice, j / 2);
+ Value asFloats = amdgpu::ExtPackedFp8Op::create(
+ rewriter, loc, extResType, inSlice, j / 2);
Type desType = VectorType::get(2, outElemType);
Value asType = castF32To(desType, asFloats, loc, rewriter);
- result = rewriter.create<vector::InsertStridedSliceOp>(
- loc, asType, result, i + j, 1);
+ result = vector::InsertStridedSliceOp::create(rewriter, loc, asType,
+ result, i + j, 1);
} else { // Convert a 8-bit element
- Value asFloat = rewriter.create<amdgpu::ExtPackedFp8Op>(
- loc, rewriter.getF32Type(), inSlice, j / 2 * 2);
+ Value asFloat = amdgpu::ExtPackedFp8Op::create(
+ rewriter, loc, rewriter.getF32Type(), inSlice, j / 2 * 2);
Value asType = castF32To(outElemType, asFloat, loc, rewriter);
- result = rewriter.create<vector::InsertOp>(loc, asType, result, i + j);
+ result = vector::InsertOp::create(rewriter, loc, asType, result, i + j);
}
}
}
if (inVecType.getRank() != outType.getRank()) {
- result = rewriter.create<vector::ShapeCastOp>(loc, outType, result);
+ result = vector::ShapeCastOp::create(rewriter, loc, outType, result);
}
rewriter.replaceOp(op, result);
@@ -208,9 +208,9 @@ static Value castToF32(Value value, Location loc, PatternRewriter &rewriter) {
if (type.isF32())
return value;
if (type.getIntOrFloatBitWidth() < 32)
- return rewriter.create<arith::ExtFOp>(loc, rewriter.getF32Type(), value);
+ return arith::ExtFOp::create(rewriter, loc, rewriter.getF32Type(), value);
if (type.getIntOrFloatBitWidth() > 32)
- return rewriter.create<arith::TruncFOp>(loc, rewriter.getF32Type(), value);
+ return arith::TruncFOp::create(rewriter, loc, rewriter.getF32Type(), value);
llvm_unreachable("The only 32-bit float type is f32");
}
@@ -250,13 +250,15 @@ static Value clampInput(PatternRewriter &rewriter, Location loc,
loc, arith::CmpFPredicate::OEQ, source, negInf);
Value isNan = rewriter.createOrFold<arith::CmpFOp>(
loc, arith::CmpFPredicate::UNO, source, source);
- Value isNonFinite = rewriter.create<arith::OrIOp>(
- loc, rewriter.create<arith::OrIOp>(loc, isInf, isNegInf), isNan);
+ Value isNonFinite = arith::OrIOp::create(
+ rewriter, loc, arith::OrIOp::create(rewriter, loc, isInf, isNegInf),
+ isNan);
- Value clampedBelow = rewriter.create<arith::MaximumFOp>(loc, source, minCst);
- Value clamped = rewriter.create<arith::MinimumFOp>(loc, clampedBelow, maxCst);
+ Value clampedBelow = arith::MaximumFOp::create(rewriter, loc, source, minCst);
+ Value clamped =
+ arith::MinimumFOp::create(rewriter, loc, clampedBelow, maxCst);
Value res =
- rewriter.create<arith::SelectOp>(loc, isNonFinite, source, clamped);
+ arith::SelectOp::create(rewriter, loc, isNonFinite, source, clamped);
return res;
}
@@ -290,25 +292,25 @@ TruncFToFloat8RewritePattern::matchAndRewrite(arith::TruncFOp op,
VectorType truncResType = VectorType::get(4, outElemType);
if (!inVectorTy) {
Value asFloat = castToF32(in, loc, rewriter);
- Value asF8s = rewriter.create<amdgpu::PackedTrunc2xFp8Op>(
- loc, truncResType, asFloat, /*sourceB=*/nullptr, 0,
+ Value asF8s = amdgpu::PackedTrunc2xFp8Op::create(
+ rewriter, loc, truncResType, asFloat, /*sourceB=*/nullptr, 0,
/*existing=*/nullptr);
- Value result = rewriter.create<vector::ExtractOp>(loc, asF8s, 0);
+ Value result = vector::ExtractOp::create(rewriter, loc, asF8s, 0);
rewriter.replaceOp(op, result);
return success();
}
int64_t numElements = outVecType.getNumElements();
- Value zero = rewriter.create<arith::ConstantOp>(
- loc, outElemType, rewriter.getFloatAttr(outElemType, 0.0));
+ Value zero = arith::ConstantOp::create(
+ rewriter, loc, outElemType, rewriter.getFloatAttr(outElemType, 0.0));
if (outVecType.getShape().empty()) {
Value scalarIn =
- rewriter.create<vector::ExtractOp>(loc, in, ArrayRef<int64_t>{});
+ vector::ExtractOp::create(rewriter, loc, in, ArrayRef<int64_t>{});
// Recurse to send the 0-D vector case to the 1-D vector case
Value scalarTrunc =
- rewriter.create<arith::TruncFOp>(loc, outElemType, scalarIn);
- Value result = rewriter.create<vector::InsertOp>(loc, scalarTrunc, zero,
- ArrayRef<int64_t>{});
+ arith::TruncFOp::create(rewriter, loc, outElemType, scalarIn);
+ Value result = vector::InsertOp::create(rewriter, loc, scalarTrunc, zero,
+ ArrayRef<int64_t>{});
rewriter.replaceOp(op, result);
return success();
}
@@ -320,32 +322,32 @@ TruncFToFloat8RewritePattern::matchAndRewrite(arith::TruncFOp op,
if (inVectorTy.getRank() > 1) {
inVectorTy = VectorType::get(SmallVector<int64_t>{numElements},
inVectorTy.getElementType());
- in = rewriter.create<vector::ShapeCastOp>(loc, inVectorTy, in);
+ in = vector::ShapeCastOp::create(rewriter, loc, inVectorTy, in);
}
for (int64_t i = 0; i < numElements; i += 4) {
int64_t elemsThisOp = std::min(numElements, i + 4) - i;
Value thisResult = nullptr;
for (int64_t j = 0; j < elemsThisOp; j += 2) {
- Value elemA = rewriter.create<vector::ExtractOp>(loc, in, i + j);
+ Value elemA = vector::ExtractOp::create(rewriter, loc, in, i + j);
Value asFloatA = castToF32(elemA, loc, rewriter);
Value asFloatB = nullptr;
if (j + 1 < elemsThisOp) {
- Value elemB = rewriter.create<vector::ExtractOp>(loc, in, i + j + 1);
+ Value elemB = vector::ExtractOp::create(rewriter, loc, in, i + j + 1);
asFloatB = castToF32(elemB, loc, rewriter);
}
- thisResult = rewriter.create<amdgpu::PackedTrunc2xFp8Op>(
- loc, truncResType, asFloatA, asFloatB, j / 2, thisResult);
+ thisResult = amdgpu::PackedTrunc2xFp8Op::create(
+ rewriter, loc, truncResType, asFloatA, asFloatB, j / 2, thisResult);
}
if (elemsThisOp < 4)
- thisResult = rewriter.create<vector::ExtractStridedSliceOp>(
- loc, thisResult, 0, elemsThisOp, 1);
- result = rewriter.create<vector::InsertStridedSliceOp>(loc, thisResult,
- result, i, 1);
+ thisResult = vector::ExtractStridedSliceOp::create(
+ rewriter, loc, thisResult, 0, elemsThisOp, 1);
+ result = vector::InsertStridedSliceOp::create(rewriter, loc, thisResult,
+ result, i, 1);
}
if (inVectorTy.getRank() != outVecType.getRank()) {
- result = rewriter.create<vector::ShapeCastOp>(loc, outVecType, result);
+ result = vector::ShapeCastOp::create(rewriter, loc, outVecType, result);
}
rewriter.replaceOp(op, result);
@@ -373,10 +375,10 @@ LogicalResult TruncfToFloat16RewritePattern::matchAndRewrite(
// Handle the case where input type is not a vector type
if (!inVectorTy) {
- auto sourceB = rewriter.create<LLVM::PoisonOp>(loc, rewriter.getF32Type());
+ auto sourceB = LLVM::PoisonOp::create(rewriter, loc, rewriter.getF32Type());
Value asF16s =
- rewriter.create<ROCDL::CvtPkRtz>(loc, truncResType, in, sourceB);
- Value result = rewriter.create<vector::ExtractOp>(loc, asF16s, 0);
+ ROCDL::CvtPkRtz::create(rewriter, loc, truncResType, in, sourceB);
+ Value result = vector::ExtractOp::create(rewriter, loc, asF16s, 0);
rewriter.replaceOp(op, result);
return success();
}
@@ -389,7 +391,7 @@ LogicalResult TruncfToFloat16RewritePattern::matchAndRewrite(
if (inVectorTy.getRank() > 1) {
inVectorTy = VectorType::get(SmallVector<int64_t>{numElements},
inVectorTy.getElementType());
- in = rewriter.create<vector::ShapeCastOp>(loc, inVectorTy, in);
+ in = vector::ShapeCastOp::create(rewriter, loc, inVectorTy, in);
}
// Handle the vector case. We also handle the (uncommon) case where the vector
@@ -397,25 +399,25 @@ LogicalResult TruncfToFloat16RewritePattern::matchAndRewrite(
for (int64_t i = 0; i < numElements; i += 2) {
int64_t elemsThisOp = std::min(numElements, i + 2) - i;
Value thisResult = nullptr;
- Value elemA = rewriter.create<vector::ExtractOp>(loc, in, i);
- Value elemB = rewriter.create<LLVM::PoisonOp>(loc, rewriter.getF32Type());
+ Value elemA = vector::ExtractOp::create(rewriter, loc, in, i);
+ Value elemB = LLVM::PoisonOp::create(rewriter, loc, rewriter.getF32Type());
if (elemsThisOp == 2) {
- elemB = rewriter.create<vector::ExtractOp>(loc, in, i + 1);
+ elemB = vector::ExtractOp::create(rewriter, loc, in, i + 1);
}
thisResult =
- rewriter.create<ROCDL::CvtPkRtz>(loc, truncResType, elemA, elemB);
+ ROCDL::CvtPkRtz::create(rewriter, loc, truncResType, elemA, elemB);
// Place back the truncated result into the possibly larger vector. If we
// are operating on a size 2 vector, these operations should be folded away
- thisResult = rewriter.create<vector::ExtractStridedSliceOp>(
- loc, thisResult, 0, elemsThisOp, 1);
- result = rewriter.create<vector::InsertStridedSliceOp>(loc, thisResult,
- result, i, 1);
+ thisResult = vector::ExtractStridedSliceOp::create(
+ rewriter, loc, thisResult, 0, elemsThisOp, 1);
+ result = vector::InsertStridedSliceOp::create(rewriter, loc, thisResult,
+ result, i, 1);
}
if (inVectorTy.getRank() != outVecType.getRank()) {
- result = rewriter.create<vector::ShapeCastOp>(loc, outVecType, result);
+ result = vector::ShapeCastOp::create(rewriter, loc, outVecType, result);
}
rewriter.replaceOp(op, result);
@@ -472,18 +474,18 @@ ScalingExtFRewritePattern::matchAndRewrite(arith::ScalingExtFOp op,
Type scaleF32Type =
scaleVecType ? VectorType::get(scaleVecType.getShape(), f32) : f32;
if (scaleType.getIntOrFloatBitWidth() < 32)
- scale = rewriter.create<arith::ExtFOp>(loc, scaleF32Type, scale);
+ scale = arith::ExtFOp::create(rewriter, loc, scaleF32Type, scale);
else if (scaleType.getIntOrFloatBitWidth() > 32)
- scale = rewriter.create<arith::TruncFOp>(loc, scaleF32Type, scale);
+ scale = arith::TruncFOp::create(rewriter, loc, scaleF32Type, scale);
VectorType extScaleResultType = VectorType::get(opWidth, outType);
if (!outVecType) {
- Value inCast = rewriter.create<vector::BroadcastOp>(
- loc, VectorType::get(1, inType), in);
+ Value inCast = vector::BroadcastOp::create(rewriter, loc,
+ VectorType::get(1, inType), in);
// TODO: replace this with non-packed ScaledExtOp
- Value scaleExt = rewriter.create<amdgpu::ScaledExtPackedOp>(
- loc, extScaleResultType, inCast, scale, 0);
+ Value scaleExt = amdgpu::ScaledExtPackedOp::create(
+ rewriter, loc, extScaleResultType, inCast, scale, 0);
scaleExt = rewriter.replaceOpWithNewOp<vector::ExtractOp>(op, scaleExt, 0);
return success();
}
@@ -508,20 +510,20 @@ ScalingExtFRewritePattern::matchAndRewrite(arith::ScalingExtFOp op,
int64_t blockSize = computeProduct(ratio);
- Value zero = rewriter.create<arith::ConstantOp>(
- loc, outType, rewriter.getFloatAttr(outType, 0.0));
+ Value zero = arith::ConstantOp::create(rewriter, loc, outType,
+ rewriter.getFloatAttr(outType, 0.0));
Value result =
rewriter.createOrFold<vector::BroadcastOp>(loc, outVecType, zero);
for (SmallVector<int64_t> offsets : StaticTileOffsetRange(inShape, ratio)) {
SmallVector<int64_t> strides(offsets.size(), 1);
- Value block = rewriter.create<vector::ExtractStridedSliceOp>(
- loc, in, offsets, ratio, strides);
+ Value block = vector::ExtractStridedSliceOp::create(
+ rewriter, loc, in, offsets, ratio, strides);
VectorType block1DType = VectorType::get(blockSize, inType);
Value block1D =
- rewriter.create<vector::ShapeCastOp>(loc, block1DType, block);
+ vector::ShapeCastOp::create(rewriter, loc, block1DType, block);
Value uniformScale =
- rewriter.create<vector::ExtractOp>(loc, scale, offsets);
+ vector::ExtractOp::create(rewriter, loc, scale, offsets);
VectorType blockResultType = VectorType::get(blockSize, outType);
Value blockResult =
@@ -530,23 +532,23 @@ ScalingExtFRewritePattern::matchAndRewrite(arith::ScalingExtFOp op,
for (int64_t i = 0, sliceWidth = std::min(opWidth, blockSize - i);
i < blockSize;
i += sliceWidth, sliceWidth = std::min(opWidth, blockSize - i)) {
- Value slice = rewriter.create<vector::ExtractStridedSliceOp>(
- loc, block1D, i, sliceWidth, 1);
+ Value slice = vector::ExtractStridedSliceOp::create(
+ rewriter, loc, block1D, i, sliceWidth, 1);
// TODO: replace this with non-packed ScaledExtOp for sliceWidth == 1
- Value scaleExt = rewriter.create<amdgpu::ScaledExtPackedOp>(
- loc, extScaleResultType, slice, uniformScale, 0);
+ Value scaleExt = amdgpu::ScaledExtPackedOp::create(
+ rewriter, loc, extScaleResultType, slice, uniformScale, 0);
if (sliceWidth != opWidth)
- scaleExt = rewriter.create<vector::ExtractStridedSliceOp>(
- loc, scaleExt, 0, sliceWidth, 1);
- blockResult = rewriter.create<vector::InsertStridedSliceOp>(
- loc, scaleExt, blockResult, i, 1);
+ scaleExt = vector::ExtractStridedSliceOp::create(
+ rewriter, loc, scaleExt, 0, sliceWidth, 1);
+ blockResult = vector::InsertStridedSliceOp::create(
+ rewriter, loc, scaleExt, blockResult, i, 1);
}
VectorType resultType = VectorType::get(ratio, outType);
Value cast =
- rewriter.create<vector::ShapeCastOp>(loc, resultType, blockResult);
- result = rewriter.create<vector::InsertStridedSliceOp>(loc, cast, result,
- offsets, strides);
+ vector::ShapeCastOp::create(rewriter, loc, resultType, blockResult);
+ result = vector::InsertStridedSliceOp::create(rewriter, loc, cast, result,
+ offsets, strides);
}
rewriter.replaceOp(op, result);
@@ -578,21 +580,22 @@ ScalingTruncFRewritePattern::matchAndRewrite(arith::ScalingTruncFOp op,
Type scaleF32Type =
scaleVecType ? VectorType::get(scaleVecType.getShape(), f32) : f32;
if (scaleType.getIntOrFloatBitWidth() < 32)
- scale = rewriter.create<arith::ExtFOp>(loc, scaleF32Type, scale);
+ scale = arith::ExtFOp::create(rewriter, loc, scaleF32Type, scale);
else if (scaleType.getIntOrFloatBitWidth() > 32)
- scale = rewriter.create<arith::TruncFOp>(loc, scaleF32Type, scale);
+ scale = arith::TruncFOp::create(rewriter, loc, scaleF32Type, scale);
- Value zero = rewriter.create<arith::ConstantOp>(
- loc, outType, rewriter.getFloatAttr(outType, 0.0));
+ Value zero = arith::ConstantOp::create(rewriter, loc, outType,
+ rewriter.getFloatAttr(outType, 0.0));
unsigned numPackedElem = 32 / outType.getIntOrFloatBitWidth();
VectorType truncScaleResultType = VectorType::get(numPackedElem, outType);
if (!outVecType) {
Type inVecType = VectorType::get(1, inType);
- Value inCast = rewriter.create<vector::BroadcastOp>(loc, inVecType, in);
+ Value inCast = vector::BroadcastOp::create(rewriter, loc, inVecType, in);
// TODO: replace this with non-packed ScaledTruncOp
- Value scaleTrunc = rewriter.create<amdgpu::PackedScaledTruncOp>(
- loc, truncScaleResultType, inCast, scale, 0, /*existing=*/nullptr);
+ Value scaleTrunc = amdgpu::PackedScaledTruncOp::create(
+ rewriter, loc, truncScaleResultType, inCast, scale, 0,
+ /*existing=*/nullptr);
scaleTrunc =
rewriter.replaceOpWithNewOp<vector::ExtractOp>(op, scaleTrunc, 0);
return success();
@@ -623,13 +626,13 @@ ScalingTruncFRewritePattern::matchAndRewrite(arith::ScalingTruncFOp op,
for (SmallVector<int64_t> offsets : StaticTileOffsetRange(inShape, ratio)) {
SmallVector<int64_t> strides(offsets.size(), 1);
- Value block = rewriter.create<vector::ExtractStridedSliceOp>(
- loc, in, offsets, ratio, strides);
+ Value block = vector::ExtractStridedSliceOp::create(
+ rewriter, loc, in, offsets, ratio, strides);
VectorType block1DType = VectorType::get(blockSize, inType);
Value block1D =
- rewriter.create<vector::ShapeCastOp>(loc, block1DType, block);
+ vector::ShapeCastOp::create(rewriter, loc, block1DType, block);
Value uniformScale =
- rewriter.create<vector::ExtractOp>(loc, scale, offsets);
+ vector::ExtractOp::create(rewriter, loc, scale, offsets);
VectorType blockResultType = VectorType::get(blockSize, outType);
Value blockResult =
@@ -638,26 +641,26 @@ ScalingTruncFRewritePattern::matchAndRewrite(arith::ScalingTruncFOp op,
for (int64_t i = 0, sliceWidth = std::min(opWidth, blockSize - i);
i < blockSize;
i += sliceWidth, sliceWidth = std::min(opWidth, blockSize - i)) {
- Value slice = rewriter.create<vector::ExtractStridedSliceOp>(
- loc, block1D, i, sliceWidth, 1);
+ Value slice = vector::ExtractStridedSliceOp::create(
+ rewriter, loc, block1D, i, sliceWidth, 1);
// TODO: replace this with non-packed ScaledTruncOp for sliceWidth == 1
- Value scaleTrunc = rewriter.create<amdgpu::PackedScaledTruncOp>(
- loc, truncScaleResultType, slice, uniformScale, 0,
+ Value scaleTrunc = amdgpu::PackedScaledTruncOp::create(
+ rewriter, loc, truncScaleResultType, slice, uniformScale, 0,
/*existing=*/nullptr);
int64_t packedWidth =
cast<VectorType>(scaleTrunc.getType()).getNumElements();
if (packedWidth != opWidth)
- scaleTrunc = rewriter.create<vector::ExtractStridedSliceOp>(
- loc, scaleTrunc, 0, sliceWidth, 1);
- blockResult = rewriter.create<vector::InsertStridedSliceOp>(
- loc, scaleTrunc, blockResult, i, 1);
+ scaleTrunc = vector::ExtractStridedSliceOp::create(
+ rewriter, loc, scaleTrunc, 0, sliceWidth, 1);
+ blockResult = vector::InsertStridedSliceOp::create(
+ rewriter, loc, scaleTrunc, blockResult, i, 1);
}
VectorType resultType = VectorType::get(ratio, outType);
Value cast =
- rewriter.create<vector::ShapeCastOp>(loc, resultType, blockResult);
- result = rewriter.create<vector::InsertStridedSliceOp>(loc, cast, result,
- offsets, strides);
+ vector::ShapeCastOp::create(rewriter, loc, resultType, blockResult);
+ result = vector::InsertStridedSliceOp::create(rewriter, loc, cast, result,
+ offsets, strides);
}
rewriter.replaceOp(op, result);
diff --git a/mlir/lib/Conversion/ArithToArmSME/ArithToArmSME.cpp b/mlir/lib/Conversion/ArithToArmSME/ArithToArmSME.cpp
index cbe0b3fda3410..ba489436a1a4d 100644
--- a/mlir/lib/Conversion/ArithToArmSME/ArithToArmSME.cpp
+++ b/mlir/lib/Conversion/ArithToArmSME/ArithToArmSME.cpp
@@ -74,15 +74,15 @@ struct ConstantOpToArmSMELowering : public OpRewritePattern<arith::ConstantOp> {
VectorType tileSliceType = VectorType::Builder(tileType).dropDim(0);
auto denseAttr1D = DenseElementsAttr::get(
tileSliceType, denseAttr.getSplatValue<Attribute>());
- auto constantOp1D = rewriter.create<arith::ConstantOp>(loc, denseAttr1D);
+ auto constantOp1D = arith::ConstantOp::create(rewriter, loc, denseAttr1D);
- auto initTile = rewriter.create<arm_sme::GetTileOp>(loc, tileType);
+ auto initTile = arm_sme::GetTileOp::create(rewriter, loc, tileType);
auto makeLoopBody = [&](OpBuilder &b, Location loc, Value tileSliceIndex,
Value currentTile) {
// Create 'arm_sme.insert_tile_slice' to write vector to tile
// slice.
- auto nextTile = b.create<arm_sme::InsertTileSliceOp>(
- loc, tileType, constantOp1D, currentTile, tileSliceIndex);
+ auto nextTile = arm_sme::InsertTileSliceOp::create(
+ b, loc, tileType, constantOp1D, currentTile, tileSliceIndex);
return nextTile.getResult();
};
auto forOp = mlir::arm_sme::createLoopOverTileSlices(
diff --git a/mlir/lib/Conversion/ArithToEmitC/ArithToEmitC.cpp b/mlir/lib/Conversion/ArithToEmitC/ArithToEmitC.cpp
index a5c08a6378021..59b3fe2e4eaed 100644
--- a/mlir/lib/Conversion/ArithToEmitC/ArithToEmitC.cpp
+++ b/mlir/lib/Conversion/ArithToEmitC/ArithToEmitC.cpp
@@ -110,9 +110,9 @@ class CmpFOpConversion : public OpConversionPattern<arith::CmpFOp> {
emitc::CmpPredicate predicate;
switch (op.getPredicate()) {
case arith::CmpFPredicate::AlwaysFalse: {
- auto constant = rewriter.create<emitc::ConstantOp>(
- op.getLoc(), rewriter.getI1Type(),
- rewriter.getBoolAttr(/*value=*/false));
+ auto constant =
+ emitc::ConstantOp::create(rewriter, op.getLoc(), rewriter.getI1Type(),
+ rewriter.getBoolAttr(/*value=*/false));
rewriter.replaceOp(op, constant);
return success();
}
@@ -179,9 +179,9 @@ class CmpFOpConversion : public OpConversionPattern<arith::CmpFOp> {
return success();
}
case arith::CmpFPredicate::AlwaysTrue: {
- auto constant = rewriter.create<emitc::ConstantOp>(
- op.getLoc(), rewriter.getI1Type(),
- rewriter.getBoolAttr(/*value=*/true));
+ auto constant =
+ emitc::ConstantOp::create(rewriter, op.getLoc(), rewriter.getI1Type(),
+ rewriter.getBoolAttr(/*value=*/true));
rewriter.replaceOp(op, constant);
return success();
}
@@ -189,8 +189,8 @@ class CmpFOpConversion : public OpConversionPattern<arith::CmpFOp> {
// Compare the values naively
auto cmpResult =
- rewriter.create<emitc::CmpOp>(op.getLoc(), op.getType(), predicate,
- adaptor.getLhs(), adaptor.getRhs());
+ emitc::CmpOp::create(rewriter, op.getLoc(), op.getType(), predicate,
+ adaptor.getLhs(), adaptor.getRhs());
// Adjust the results for unordered/ordered semantics
if (unordered) {
@@ -213,16 +213,16 @@ class CmpFOpConversion : public OpConversionPattern<arith::CmpFOp> {
Value isNaN(ConversionPatternRewriter &rewriter, Location loc,
Value operand) const {
// A value is NaN exactly when it compares unequal to itself.
- return rewriter.create<emitc::CmpOp>(
- loc, rewriter.getI1Type(), emitc::CmpPredicate::ne, operand, operand);
+ return emitc::CmpOp::create(rewriter, loc, rewriter.getI1Type(),
+ emitc::CmpPredicate::ne, operand, operand);
}
/// Return a value that is true if \p operand is not NaN.
Value isNotNaN(ConversionPatternRewriter &rewriter, Location loc,
Value operand) const {
// A value is not NaN exactly when it compares equal to itself.
- return rewriter.create<emitc::CmpOp>(
- loc, rewriter.getI1Type(), emitc::CmpPredicate::eq, operand, operand);
+ return emitc::CmpOp::create(rewriter, loc, rewriter.getI1Type(),
+ emitc::CmpPredicate::eq, operand, operand);
}
/// Return a value that is true if the operands \p first and \p second are
@@ -231,8 +231,8 @@ class CmpFOpConversion : public OpConversionPattern<arith::CmpFOp> {
Location loc, Value first, Value second) const {
auto firstIsNaN = isNaN(rewriter, loc, first);
auto secondIsNaN = isNaN(rewriter, loc, second);
- return rewriter.create<emitc::LogicalOrOp>(loc, rewriter.getI1Type(),
- firstIsNaN, secondIsNaN);
+ return emitc::LogicalOrOp::create(rewriter, loc, rewriter.getI1Type(),
+ firstIsNaN, secondIsNaN);
}
/// Return a value that is true if the operands \p first and \p second are
@@ -241,8 +241,8 @@ class CmpFOpConversion : public OpConversionPattern<arith::CmpFOp> {
Value first, Value second) const {
auto firstIsNotNaN = isNotNaN(rewriter, loc, first);
auto secondIsNotNaN = isNotNaN(rewriter, loc, second);
- return rewriter.create<emitc::LogicalAndOp>(loc, rewriter.getI1Type(),
- firstIsNotNaN, secondIsNotNaN);
+ return emitc::LogicalAndOp::create(rewriter, loc, rewriter.getI1Type(),
+ firstIsNotNaN, secondIsNotNaN);
}
};
@@ -378,10 +378,10 @@ class CastConversion : public OpConversionPattern<ArithOp> {
Type attrType = (emitc::isPointerWideType(operandType))
? rewriter.getIndexType()
: operandType;
- auto constOne = rewriter.create<emitc::ConstantOp>(
- op.getLoc(), operandType, rewriter.getOneAttr(attrType));
- auto oneAndOperand = rewriter.create<emitc::BitwiseAndOp>(
- op.getLoc(), operandType, adaptor.getIn(), constOne);
+ auto constOne = emitc::ConstantOp::create(
+ rewriter, op.getLoc(), operandType, rewriter.getOneAttr(attrType));
+ auto oneAndOperand = emitc::BitwiseAndOp::create(
+ rewriter, op.getLoc(), operandType, adaptor.getIn(), constOne);
rewriter.replaceOpWithNewOp<emitc::CastOp>(op, opReturnType,
oneAndOperand);
return success();
@@ -466,9 +466,8 @@ class BinaryUIOpConversion final : public OpConversionPattern<ArithOp> {
Value lhsAdapted = adaptValueType(uiBinOp.getLhs(), rewriter, unsignedType);
Value rhsAdapted = adaptValueType(uiBinOp.getRhs(), rewriter, unsignedType);
- auto newDivOp =
- rewriter.create<EmitCOp>(uiBinOp.getLoc(), unsignedType,
- ArrayRef<Value>{lhsAdapted, rhsAdapted});
+ auto newDivOp = EmitCOp::create(rewriter, uiBinOp.getLoc(), unsignedType,
+ ArrayRef<Value>{lhsAdapted, rhsAdapted});
Value resultAdapted = adaptValueType(newDivOp, rewriter, newRetTy);
rewriter.replaceOp(uiBinOp, resultAdapted);
return success();
@@ -588,38 +587,40 @@ class ShiftOpConversion : public OpConversionPattern<ArithOp> {
// Add a runtime check for overflow
Value width;
if (emitc::isPointerWideType(type)) {
- Value eight = rewriter.create<emitc::ConstantOp>(
- op.getLoc(), rhsType, rewriter.getIndexAttr(8));
- emitc::CallOpaqueOp sizeOfCall = rewriter.create<emitc::CallOpaqueOp>(
- op.getLoc(), rhsType, "sizeof", ArrayRef<Value>{eight});
- width = rewriter.create<emitc::MulOp>(op.getLoc(), rhsType, eight,
- sizeOfCall.getResult(0));
+ Value eight = emitc::ConstantOp::create(rewriter, op.getLoc(), rhsType,
+ rewriter.getIndexAttr(8));
+ emitc::CallOpaqueOp sizeOfCall = emitc::CallOpaqueOp::create(
+ rewriter, op.getLoc(), rhsType, "sizeof", ArrayRef<Value>{eight});
+ width = emitc::MulOp::create(rewriter, op.getLoc(), rhsType, eight,
+ sizeOfCall.getResult(0));
} else {
- width = rewriter.create<emitc::ConstantOp>(
- op.getLoc(), rhsType,
+ width = emitc::ConstantOp::create(
+ rewriter, op.getLoc(), rhsType,
rewriter.getIntegerAttr(rhsType, type.getIntOrFloatBitWidth()));
}
- Value excessCheck = rewriter.create<emitc::CmpOp>(
- op.getLoc(), rewriter.getI1Type(), emitc::CmpPredicate::lt, rhs, width);
+ Value excessCheck =
+ emitc::CmpOp::create(rewriter, op.getLoc(), rewriter.getI1Type(),
+ emitc::CmpPredicate::lt, rhs, width);
// Any concrete value is a valid refinement of poison.
- Value poison = rewriter.create<emitc::ConstantOp>(
- op.getLoc(), arithmeticType,
+ Value poison = emitc::ConstantOp::create(
+ rewriter, op.getLoc(), arithmeticType,
(isa<IntegerType>(arithmeticType)
? rewriter.getIntegerAttr(arithmeticType, 0)
: rewriter.getIndexAttr(0)));
- emitc::ExpressionOp ternary = rewriter.create<emitc::ExpressionOp>(
- op.getLoc(), arithmeticType, /*do_not_inline=*/false);
+ emitc::ExpressionOp ternary = emitc::ExpressionOp::create(
+ rewriter, op.getLoc(), arithmeticType, /*do_not_inline=*/false);
Block &bodyBlock = ternary.getBodyRegion().emplaceBlock();
auto currentPoint = rewriter.getInsertionPoint();
rewriter.setInsertionPointToStart(&bodyBlock);
Value arithmeticResult =
- rewriter.create<EmitCOp>(op.getLoc(), arithmeticType, lhs, rhs);
- Value resultOrPoison = rewriter.create<emitc::ConditionalOp>(
- op.getLoc(), arithmeticType, excessCheck, arithmeticResult, poison);
- rewriter.create<emitc::YieldOp>(op.getLoc(), resultOrPoison);
+ EmitCOp::create(rewriter, op.getLoc(), arithmeticType, lhs, rhs);
+ Value resultOrPoison =
+ emitc::ConditionalOp::create(rewriter, op.getLoc(), arithmeticType,
+ excessCheck, arithmeticResult, poison);
+ emitc::YieldOp::create(rewriter, op.getLoc(), resultOrPoison);
rewriter.setInsertionPoint(op->getBlock(), currentPoint);
Value result = adaptValueType(ternary, rewriter, type);
@@ -700,11 +701,12 @@ class FtoICastOpConversion : public OpConversionPattern<CastOp> {
/*isSigned=*/false);
}
- Value result = rewriter.create<emitc::CastOp>(
- castOp.getLoc(), actualResultType, adaptor.getOperands());
+ Value result = emitc::CastOp::create(
+ rewriter, castOp.getLoc(), actualResultType, adaptor.getOperands());
if (isa<arith::FPToUIOp>(castOp)) {
- result = rewriter.create<emitc::CastOp>(castOp.getLoc(), dstType, result);
+ result =
+ emitc::CastOp::create(rewriter, castOp.getLoc(), dstType, result);
}
rewriter.replaceOp(castOp, result);
diff --git a/mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp b/mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp
index f7bf581adc9e3..18e857c81af8d 100644
--- a/mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp
+++ b/mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp
@@ -293,11 +293,11 @@ LogicalResult IndexCastOpLowering<OpTy, ExtCastTy>::matchAndRewrite(
[&](Type llvm1DVectorTy, ValueRange operands) -> Value {
typename OpTy::Adaptor adaptor(operands);
if (targetBits < sourceBits) {
- return rewriter.create<LLVM::TruncOp>(op.getLoc(), llvm1DVectorTy,
- adaptor.getIn());
+ return LLVM::TruncOp::create(rewriter, op.getLoc(), llvm1DVectorTy,
+ adaptor.getIn());
}
- return rewriter.create<ExtCastTy>(op.getLoc(), llvm1DVectorTy,
- adaptor.getIn());
+ return ExtCastTy::create(rewriter, op.getLoc(), llvm1DVectorTy,
+ adaptor.getIn());
},
rewriter);
}
@@ -324,12 +324,12 @@ LogicalResult AddUIExtendedOpLowering::matchAndRewrite(
Type newOverflowType = typeConverter->convertType(overflowResultType);
Type structType =
LLVM::LLVMStructType::getLiteral(ctx, {sumResultType, newOverflowType});
- Value addOverflow = rewriter.create<LLVM::UAddWithOverflowOp>(
- loc, structType, adaptor.getLhs(), adaptor.getRhs());
+ Value addOverflow = LLVM::UAddWithOverflowOp::create(
+ rewriter, loc, structType, adaptor.getLhs(), adaptor.getRhs());
Value sumExtracted =
- rewriter.create<LLVM::ExtractValueOp>(loc, addOverflow, 0);
+ LLVM::ExtractValueOp::create(rewriter, loc, addOverflow, 0);
Value overflowExtracted =
- rewriter.create<LLVM::ExtractValueOp>(loc, addOverflow, 1);
+ LLVM::ExtractValueOp::create(rewriter, loc, addOverflow, 1);
rewriter.replaceOp(op, {sumExtracted, overflowExtracted});
return success();
}
@@ -381,15 +381,15 @@ LogicalResult MulIExtendedOpLowering<ArithMulOp, IsSigned>::matchAndRewrite(
"LLVM dialect should support all signless integer types");
using LLVMExtOp = std::conditional_t<IsSigned, LLVM::SExtOp, LLVM::ZExtOp>;
- Value lhsExt = rewriter.create<LLVMExtOp>(loc, wideType, adaptor.getLhs());
- Value rhsExt = rewriter.create<LLVMExtOp>(loc, wideType, adaptor.getRhs());
- Value mulExt = rewriter.create<LLVM::MulOp>(loc, wideType, lhsExt, rhsExt);
+ Value lhsExt = LLVMExtOp::create(rewriter, loc, wideType, adaptor.getLhs());
+ Value rhsExt = LLVMExtOp::create(rewriter, loc, wideType, adaptor.getRhs());
+ Value mulExt = LLVM::MulOp::create(rewriter, loc, wideType, lhsExt, rhsExt);
// Split the 2*N-bit wide result into two N-bit values.
- Value low = rewriter.create<LLVM::TruncOp>(loc, resultType, mulExt);
- Value shiftVal = rewriter.create<LLVM::ConstantOp>(loc, shiftValAttr);
- Value highExt = rewriter.create<LLVM::LShrOp>(loc, mulExt, shiftVal);
- Value high = rewriter.create<LLVM::TruncOp>(loc, resultType, highExt);
+ Value low = LLVM::TruncOp::create(rewriter, loc, resultType, mulExt);
+ Value shiftVal = LLVM::ConstantOp::create(rewriter, loc, shiftValAttr);
+ Value highExt = LLVM::LShrOp::create(rewriter, loc, mulExt, shiftVal);
+ Value high = LLVM::TruncOp::create(rewriter, loc, resultType, highExt);
rewriter.replaceOp(op, {low, high});
return success();
@@ -435,8 +435,8 @@ CmpIOpLowering::matchAndRewrite(arith::CmpIOp op, OpAdaptor adaptor,
op.getOperation(), adaptor.getOperands(), *getTypeConverter(),
[&](Type llvm1DVectorTy, ValueRange operands) {
OpAdaptor adaptor(operands);
- return rewriter.create<LLVM::ICmpOp>(
- op.getLoc(), llvm1DVectorTy,
+ return LLVM::ICmpOp::create(
+ rewriter, op.getLoc(), llvm1DVectorTy,
convertCmpPredicate<LLVM::ICmpPredicate>(op.getPredicate()),
adaptor.getLhs(), adaptor.getRhs());
},
@@ -471,8 +471,8 @@ CmpFOpLowering::matchAndRewrite(arith::CmpFOp op, OpAdaptor adaptor,
op.getOperation(), adaptor.getOperands(), *getTypeConverter(),
[&](Type llvm1DVectorTy, ValueRange operands) {
OpAdaptor adaptor(operands);
- return rewriter.create<LLVM::FCmpOp>(
- op.getLoc(), llvm1DVectorTy,
+ return LLVM::FCmpOp::create(
+ rewriter, op.getLoc(), llvm1DVectorTy,
convertCmpPredicate<LLVM::FCmpPredicate>(op.getPredicate()),
adaptor.getLhs(), adaptor.getRhs(), fmf);
},
diff --git a/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp b/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp
index 434d7df853a5e..d43e6816641cb 100644
--- a/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp
+++ b/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp
@@ -117,12 +117,12 @@ static Value getScalarOrVectorConstInt(Type type, uint64_t value,
if (auto vectorType = dyn_cast<VectorType>(type)) {
Attribute element = IntegerAttr::get(vectorType.getElementType(), value);
auto attr = SplatElementsAttr::get(vectorType, element);
- return builder.create<spirv::ConstantOp>(loc, vectorType, attr);
+ return spirv::ConstantOp::create(builder, loc, vectorType, attr);
}
if (auto intType = dyn_cast<IntegerType>(type))
- return builder.create<spirv::ConstantOp>(
- loc, type, builder.getIntegerAttr(type, value));
+ return spirv::ConstantOp::create(builder, loc, type,
+ builder.getIntegerAttr(type, value));
return nullptr;
}
@@ -418,18 +418,19 @@ static Value emulateSignedRemainder(Location loc, Value lhs, Value rhs,
Type type = lhs.getType();
// Calculate the remainder with spirv.UMod.
- Value lhsAbs = builder.create<SignedAbsOp>(loc, type, lhs);
- Value rhsAbs = builder.create<SignedAbsOp>(loc, type, rhs);
- Value abs = builder.create<spirv::UModOp>(loc, lhsAbs, rhsAbs);
+ Value lhsAbs = SignedAbsOp::create(builder, loc, type, lhs);
+ Value rhsAbs = SignedAbsOp::create(builder, loc, type, rhs);
+ Value abs = spirv::UModOp::create(builder, loc, lhsAbs, rhsAbs);
// Fix the sign.
Value isPositive;
if (lhs == signOperand)
- isPositive = builder.create<spirv::IEqualOp>(loc, lhs, lhsAbs);
+ isPositive = spirv::IEqualOp::create(builder, loc, lhs, lhsAbs);
else
- isPositive = builder.create<spirv::IEqualOp>(loc, rhs, rhsAbs);
- Value absNegate = builder.create<spirv::SNegateOp>(loc, type, abs);
- return builder.create<spirv::SelectOp>(loc, type, isPositive, abs, absNegate);
+ isPositive = spirv::IEqualOp::create(builder, loc, rhs, rhsAbs);
+ Value absNegate = spirv::SNegateOp::create(builder, loc, type, abs);
+ return spirv::SelectOp::create(builder, loc, type, isPositive, abs,
+ absNegate);
}
/// Converts arith.remsi to GLSL SPIR-V ops.
@@ -601,13 +602,13 @@ struct ExtSII1Pattern final : public OpConversionPattern<arith::ExtSIOp> {
Value allOnes;
if (auto intTy = dyn_cast<IntegerType>(dstType)) {
unsigned componentBitwidth = intTy.getWidth();
- allOnes = rewriter.create<spirv::ConstantOp>(
- loc, intTy,
+ allOnes = spirv::ConstantOp::create(
+ rewriter, loc, intTy,
rewriter.getIntegerAttr(intTy, APInt::getAllOnes(componentBitwidth)));
} else if (auto vectorTy = dyn_cast<VectorType>(dstType)) {
unsigned componentBitwidth = vectorTy.getElementTypeBitWidth();
- allOnes = rewriter.create<spirv::ConstantOp>(
- loc, vectorTy,
+ allOnes = spirv::ConstantOp::create(
+ rewriter, loc, vectorTy,
SplatElementsAttr::get(vectorTy,
APInt::getAllOnes(componentBitwidth)));
} else {
@@ -653,8 +654,8 @@ struct ExtSIPattern final : public OpConversionPattern<arith::ExtSIOp> {
// First shift left to sequeeze out all leading bits beyond the original
// bitwidth. Here we need to use the original source and result type's
// bitwidth.
- auto shiftLOp = rewriter.create<spirv::ShiftLeftLogicalOp>(
- op.getLoc(), dstType, adaptor.getIn(), shiftSize);
+ auto shiftLOp = spirv::ShiftLeftLogicalOp::create(
+ rewriter, op.getLoc(), dstType, adaptor.getIn(), shiftSize);
// Then we perform arithmetic right shift to make sure we have the right
// sign bits for negative values.
@@ -757,9 +758,9 @@ struct TruncII1Pattern final : public OpConversionPattern<arith::TruncIOp> {
auto srcType = adaptor.getOperands().front().getType();
// Check if (x & 1) == 1.
Value mask = spirv::ConstantOp::getOne(srcType, loc, rewriter);
- Value maskedSrc = rewriter.create<spirv::BitwiseAndOp>(
- loc, srcType, adaptor.getOperands()[0], mask);
- Value isOne = rewriter.create<spirv::IEqualOp>(loc, maskedSrc, mask);
+ Value maskedSrc = spirv::BitwiseAndOp::create(
+ rewriter, loc, srcType, adaptor.getOperands()[0], mask);
+ Value isOne = spirv::IEqualOp::create(rewriter, loc, maskedSrc, mask);
Value zero = spirv::ConstantOp::getZero(dstType, loc, rewriter);
Value one = spirv::ConstantOp::getOne(dstType, loc, rewriter);
@@ -914,9 +915,9 @@ class CmpIOpBooleanPattern final : public OpConversionPattern<arith::CmpIOp> {
if (auto vectorType = dyn_cast<VectorType>(dstType))
type = VectorType::get(vectorType.getShape(), type);
Value extLhs =
- rewriter.create<arith::ExtUIOp>(op.getLoc(), type, adaptor.getLhs());
+ arith::ExtUIOp::create(rewriter, op.getLoc(), type, adaptor.getLhs());
Value extRhs =
- rewriter.create<arith::ExtUIOp>(op.getLoc(), type, adaptor.getRhs());
+ arith::ExtUIOp::create(rewriter, op.getLoc(), type, adaptor.getRhs());
rewriter.replaceOpWithNewOp<arith::CmpIOp>(op, op.getPredicate(), extLhs,
extRhs);
@@ -1067,12 +1068,12 @@ class CmpFOpNanNonePattern final : public OpConversionPattern<arith::CmpFOp> {
replace = spirv::ConstantOp::getZero(op.getType(), loc, rewriter);
}
} else {
- Value lhsIsNan = rewriter.create<spirv::IsNanOp>(loc, adaptor.getLhs());
- Value rhsIsNan = rewriter.create<spirv::IsNanOp>(loc, adaptor.getRhs());
+ Value lhsIsNan = spirv::IsNanOp::create(rewriter, loc, adaptor.getLhs());
+ Value rhsIsNan = spirv::IsNanOp::create(rewriter, loc, adaptor.getRhs());
- replace = rewriter.create<spirv::LogicalOrOp>(loc, lhsIsNan, rhsIsNan);
+ replace = spirv::LogicalOrOp::create(rewriter, loc, lhsIsNan, rhsIsNan);
if (op.getPredicate() == arith::CmpFPredicate::ORD)
- replace = rewriter.create<spirv::LogicalNotOp>(loc, replace);
+ replace = spirv::LogicalNotOp::create(rewriter, loc, replace);
}
rewriter.replaceOp(op, replace);
@@ -1094,17 +1095,17 @@ class AddUIExtendedOpPattern final
ConversionPatternRewriter &rewriter) const override {
Type dstElemTy = adaptor.getLhs().getType();
Location loc = op->getLoc();
- Value result = rewriter.create<spirv::IAddCarryOp>(loc, adaptor.getLhs(),
- adaptor.getRhs());
+ Value result = spirv::IAddCarryOp::create(rewriter, loc, adaptor.getLhs(),
+ adaptor.getRhs());
- Value sumResult = rewriter.create<spirv::CompositeExtractOp>(
- loc, result, llvm::ArrayRef(0));
- Value carryValue = rewriter.create<spirv::CompositeExtractOp>(
- loc, result, llvm::ArrayRef(1));
+ Value sumResult = spirv::CompositeExtractOp::create(rewriter, loc, result,
+ llvm::ArrayRef(0));
+ Value carryValue = spirv::CompositeExtractOp::create(rewriter, loc, result,
+ llvm::ArrayRef(1));
// Convert the carry value to boolean.
Value one = spirv::ConstantOp::getOne(dstElemTy, loc, rewriter);
- Value carryResult = rewriter.create<spirv::IEqualOp>(loc, carryValue, one);
+ Value carryResult = spirv::IEqualOp::create(rewriter, loc, carryValue, one);
rewriter.replaceOp(op, {sumResult, carryResult});
return success();
@@ -1125,12 +1126,12 @@ class MulIExtendedOpPattern final : public OpConversionPattern<ArithMulOp> {
ConversionPatternRewriter &rewriter) const override {
Location loc = op->getLoc();
Value result =
- rewriter.create<SPIRVMulOp>(loc, adaptor.getLhs(), adaptor.getRhs());
+ SPIRVMulOp::create(rewriter, loc, adaptor.getLhs(), adaptor.getRhs());
- Value low = rewriter.create<spirv::CompositeExtractOp>(loc, result,
- llvm::ArrayRef(0));
- Value high = rewriter.create<spirv::CompositeExtractOp>(loc, result,
- llvm::ArrayRef(1));
+ Value low = spirv::CompositeExtractOp::create(rewriter, loc, result,
+ llvm::ArrayRef(0));
+ Value high = spirv::CompositeExtractOp::create(rewriter, loc, result,
+ llvm::ArrayRef(1));
rewriter.replaceOp(op, {low, high});
return success();
@@ -1183,20 +1184,20 @@ class MinimumMaximumFOpPattern final : public OpConversionPattern<Op> {
Location loc = op.getLoc();
Value spirvOp =
- rewriter.create<SPIRVOp>(loc, dstType, adaptor.getOperands());
+ SPIRVOp::create(rewriter, loc, dstType, adaptor.getOperands());
if (bitEnumContainsAll(op.getFastmath(), arith::FastMathFlags::nnan)) {
rewriter.replaceOp(op, spirvOp);
return success();
}
- Value lhsIsNan = rewriter.create<spirv::IsNanOp>(loc, adaptor.getLhs());
- Value rhsIsNan = rewriter.create<spirv::IsNanOp>(loc, adaptor.getRhs());
+ Value lhsIsNan = spirv::IsNanOp::create(rewriter, loc, adaptor.getLhs());
+ Value rhsIsNan = spirv::IsNanOp::create(rewriter, loc, adaptor.getRhs());
- Value select1 = rewriter.create<spirv::SelectOp>(loc, dstType, lhsIsNan,
- adaptor.getLhs(), spirvOp);
- Value select2 = rewriter.create<spirv::SelectOp>(loc, dstType, rhsIsNan,
- adaptor.getRhs(), select1);
+ Value select1 = spirv::SelectOp::create(rewriter, loc, dstType, lhsIsNan,
+ adaptor.getLhs(), spirvOp);
+ Value select2 = spirv::SelectOp::create(rewriter, loc, dstType, rhsIsNan,
+ adaptor.getRhs(), select1);
rewriter.replaceOp(op, select2);
return success();
@@ -1237,7 +1238,7 @@ class MinNumMaxNumFOpPattern final : public OpConversionPattern<Op> {
Location loc = op.getLoc();
Value spirvOp =
- rewriter.create<SPIRVOp>(loc, dstType, adaptor.getOperands());
+ SPIRVOp::create(rewriter, loc, dstType, adaptor.getOperands());
if (!shouldInsertNanGuards<SPIRVOp>() ||
bitEnumContainsAll(op.getFastmath(), arith::FastMathFlags::nnan)) {
@@ -1245,13 +1246,13 @@ class MinNumMaxNumFOpPattern final : public OpConversionPattern<Op> {
return success();
}
- Value lhsIsNan = rewriter.create<spirv::IsNanOp>(loc, adaptor.getLhs());
- Value rhsIsNan = rewriter.create<spirv::IsNanOp>(loc, adaptor.getRhs());
+ Value lhsIsNan = spirv::IsNanOp::create(rewriter, loc, adaptor.getLhs());
+ Value rhsIsNan = spirv::IsNanOp::create(rewriter, loc, adaptor.getRhs());
- Value select1 = rewriter.create<spirv::SelectOp>(loc, dstType, lhsIsNan,
- adaptor.getRhs(), spirvOp);
- Value select2 = rewriter.create<spirv::SelectOp>(loc, dstType, rhsIsNan,
- adaptor.getLhs(), select1);
+ Value select1 = spirv::SelectOp::create(rewriter, loc, dstType, lhsIsNan,
+ adaptor.getRhs(), spirvOp);
+ Value select2 = spirv::SelectOp::create(rewriter, loc, dstType, rhsIsNan,
+ adaptor.getLhs(), select1);
rewriter.replaceOp(op, select2);
return success();
diff --git a/mlir/lib/Conversion/ArmNeon2dToIntr/ArmNeon2dToIntr.cpp b/mlir/lib/Conversion/ArmNeon2dToIntr/ArmNeon2dToIntr.cpp
index 9c6de938a7108..1510b0b16b07d 100644
--- a/mlir/lib/Conversion/ArmNeon2dToIntr/ArmNeon2dToIntr.cpp
+++ b/mlir/lib/Conversion/ArmNeon2dToIntr/ArmNeon2dToIntr.cpp
@@ -41,11 +41,11 @@ class Sdot2dLoweringPattern : public OpRewritePattern<Sdot2dOp> {
Value c2d = op.getC();
Location loc = op.getLoc();
Value b1d =
- rewriter.create<vector::ShapeCastOp>(loc, flattenedVectorType, b2d);
+ vector::ShapeCastOp::create(rewriter, loc, flattenedVectorType, b2d);
Value c1d =
- rewriter.create<vector::ShapeCastOp>(loc, flattenedVectorType, c2d);
- Value newOp = rewriter.create<SdotOp>(loc, op.getRes().getType(), op.getA(),
- b1d, c1d);
+ vector::ShapeCastOp::create(rewriter, loc, flattenedVectorType, c2d);
+ Value newOp = SdotOp::create(rewriter, loc, op.getRes().getType(),
+ op.getA(), b1d, c1d);
rewriter.replaceOp(op, {newOp});
return success();
}
diff --git a/mlir/lib/Conversion/ArmSMEToLLVM/ArmSMEToLLVM.cpp b/mlir/lib/Conversion/ArmSMEToLLVM/ArmSMEToLLVM.cpp
index 21ea444e31821..9bc3fa3473398 100644
--- a/mlir/lib/Conversion/ArmSMEToLLVM/ArmSMEToLLVM.cpp
+++ b/mlir/lib/Conversion/ArmSMEToLLVM/ArmSMEToLLVM.cpp
@@ -45,38 +45,38 @@ static Operation *createLoadTileSliceIntrinsic(
if (layout == arm_sme::TileSliceLayout::Horizontal) {
switch (type) {
case arm_sme::ArmSMETileType::ZAB:
- return rewriter.create<arm_sme::aarch64_sme_ld1b_horiz>(
- loc, maskOp, ptr, tileId, tileSliceI32);
+ return arm_sme::aarch64_sme_ld1b_horiz::create(rewriter, loc, maskOp, ptr,
+ tileId, tileSliceI32);
case arm_sme::ArmSMETileType::ZAH:
- return rewriter.create<arm_sme::aarch64_sme_ld1h_horiz>(
- loc, maskOp, ptr, tileId, tileSliceI32);
+ return arm_sme::aarch64_sme_ld1h_horiz::create(rewriter, loc, maskOp, ptr,
+ tileId, tileSliceI32);
case arm_sme::ArmSMETileType::ZAS:
- return rewriter.create<arm_sme::aarch64_sme_ld1w_horiz>(
- loc, maskOp, ptr, tileId, tileSliceI32);
+ return arm_sme::aarch64_sme_ld1w_horiz::create(rewriter, loc, maskOp, ptr,
+ tileId, tileSliceI32);
case arm_sme::ArmSMETileType::ZAD:
- return rewriter.create<arm_sme::aarch64_sme_ld1d_horiz>(
- loc, maskOp, ptr, tileId, tileSliceI32);
+ return arm_sme::aarch64_sme_ld1d_horiz::create(rewriter, loc, maskOp, ptr,
+ tileId, tileSliceI32);
case arm_sme::ArmSMETileType::ZAQ:
- return rewriter.create<arm_sme::aarch64_sme_ld1q_horiz>(
- loc, maskOp, ptr, tileId, tileSliceI32);
+ return arm_sme::aarch64_sme_ld1q_horiz::create(rewriter, loc, maskOp, ptr,
+ tileId, tileSliceI32);
}
} else {
switch (type) {
case arm_sme::ArmSMETileType::ZAB:
- return rewriter.create<arm_sme::aarch64_sme_ld1b_vert>(
- loc, maskOp, ptr, tileId, tileSliceI32);
+ return arm_sme::aarch64_sme_ld1b_vert::create(rewriter, loc, maskOp, ptr,
+ tileId, tileSliceI32);
case arm_sme::ArmSMETileType::ZAH:
- return rewriter.create<arm_sme::aarch64_sme_ld1h_vert>(
- loc, maskOp, ptr, tileId, tileSliceI32);
+ return arm_sme::aarch64_sme_ld1h_vert::create(rewriter, loc, maskOp, ptr,
+ tileId, tileSliceI32);
case arm_sme::ArmSMETileType::ZAS:
- return rewriter.create<arm_sme::aarch64_sme_ld1w_vert>(
- loc, maskOp, ptr, tileId, tileSliceI32);
+ return arm_sme::aarch64_sme_ld1w_vert::create(rewriter, loc, maskOp, ptr,
+ tileId, tileSliceI32);
case arm_sme::ArmSMETileType::ZAD:
- return rewriter.create<arm_sme::aarch64_sme_ld1d_vert>(
- loc, maskOp, ptr, tileId, tileSliceI32);
+ return arm_sme::aarch64_sme_ld1d_vert::create(rewriter, loc, maskOp, ptr,
+ tileId, tileSliceI32);
case arm_sme::ArmSMETileType::ZAQ:
- return rewriter.create<arm_sme::aarch64_sme_ld1q_vert>(
- loc, maskOp, ptr, tileId, tileSliceI32);
+ return arm_sme::aarch64_sme_ld1q_vert::create(rewriter, loc, maskOp, ptr,
+ tileId, tileSliceI32);
break;
}
}
@@ -91,38 +91,38 @@ static Operation *createStoreTileSliceIntrinsic(
if (layout == arm_sme::TileSliceLayout::Horizontal) {
switch (type) {
case arm_sme::ArmSMETileType::ZAB:
- return rewriter.create<arm_sme::aarch64_sme_st1b_horiz>(
- loc, maskOp, ptr, tileId, tileSliceI32);
+ return arm_sme::aarch64_sme_st1b_horiz::create(rewriter, loc, maskOp, ptr,
+ tileId, tileSliceI32);
case arm_sme::ArmSMETileType::ZAH:
- return rewriter.create<arm_sme::aarch64_sme_st1h_horiz>(
- loc, maskOp, ptr, tileId, tileSliceI32);
+ return arm_sme::aarch64_sme_st1h_horiz::create(rewriter, loc, maskOp, ptr,
+ tileId, tileSliceI32);
case arm_sme::ArmSMETileType::ZAS:
- return rewriter.create<arm_sme::aarch64_sme_st1w_horiz>(
- loc, maskOp, ptr, tileId, tileSliceI32);
+ return arm_sme::aarch64_sme_st1w_horiz::create(rewriter, loc, maskOp, ptr,
+ tileId, tileSliceI32);
case arm_sme::ArmSMETileType::ZAD:
- return rewriter.create<arm_sme::aarch64_sme_st1d_horiz>(
- loc, maskOp, ptr, tileId, tileSliceI32);
+ return arm_sme::aarch64_sme_st1d_horiz::create(rewriter, loc, maskOp, ptr,
+ tileId, tileSliceI32);
case arm_sme::ArmSMETileType::ZAQ:
- return rewriter.create<arm_sme::aarch64_sme_st1q_horiz>(
- loc, maskOp, ptr, tileId, tileSliceI32);
+ return arm_sme::aarch64_sme_st1q_horiz::create(rewriter, loc, maskOp, ptr,
+ tileId, tileSliceI32);
}
} else {
switch (type) {
case arm_sme::ArmSMETileType::ZAB:
- return rewriter.create<arm_sme::aarch64_sme_st1b_vert>(
- loc, maskOp, ptr, tileId, tileSliceI32);
+ return arm_sme::aarch64_sme_st1b_vert::create(rewriter, loc, maskOp, ptr,
+ tileId, tileSliceI32);
case arm_sme::ArmSMETileType::ZAH:
- return rewriter.create<arm_sme::aarch64_sme_st1h_vert>(
- loc, maskOp, ptr, tileId, tileSliceI32);
+ return arm_sme::aarch64_sme_st1h_vert::create(rewriter, loc, maskOp, ptr,
+ tileId, tileSliceI32);
case arm_sme::ArmSMETileType::ZAS:
- return rewriter.create<arm_sme::aarch64_sme_st1w_vert>(
- loc, maskOp, ptr, tileId, tileSliceI32);
+ return arm_sme::aarch64_sme_st1w_vert::create(rewriter, loc, maskOp, ptr,
+ tileId, tileSliceI32);
case arm_sme::ArmSMETileType::ZAD:
- return rewriter.create<arm_sme::aarch64_sme_st1d_vert>(
- loc, maskOp, ptr, tileId, tileSliceI32);
+ return arm_sme::aarch64_sme_st1d_vert::create(rewriter, loc, maskOp, ptr,
+ tileId, tileSliceI32);
case arm_sme::ArmSMETileType::ZAQ:
- return rewriter.create<arm_sme::aarch64_sme_st1q_vert>(
- loc, maskOp, ptr, tileId, tileSliceI32);
+ return arm_sme::aarch64_sme_st1q_vert::create(rewriter, loc, maskOp, ptr,
+ tileId, tileSliceI32);
}
}
llvm_unreachable("unknown type in createStoreTileSliceIntrinsic");
@@ -146,16 +146,16 @@ createAllocaForTile(RewriterBase &rewriter, Location loc,
// Move to the first operation in the function.
rewriter.setInsertionPointToStart(&func.getBlocks().front());
// Create an alloca matching the tile size of the `tileOp`.
- auto vscale = rewriter.create<vector::VectorScaleOp>(loc);
+ auto vscale = vector::VectorScaleOp::create(rewriter, loc);
auto tileElementType = tileOp.getTileType().getElementType();
auto memrefType = MemRefType::get(
{ShapedType::kDynamic, ShapedType::kDynamic}, tileElementType);
unsigned minElements = arm_sme::getSMETileSliceMinNumElts(tileElementType);
auto minElementsOp =
- rewriter.create<arith::ConstantIndexOp>(loc, minElements);
- auto vectorLen = rewriter.create<arith::MulIOp>(loc, vscale, minElementsOp);
- auto alloca = rewriter.create<memref::AllocaOp>(
- loc, memrefType, ValueRange{vectorLen, vectorLen});
+ arith::ConstantIndexOp::create(rewriter, loc, minElements);
+ auto vectorLen = arith::MulIOp::create(rewriter, loc, vscale, minElementsOp);
+ auto alloca = memref::AllocaOp::create(rewriter, loc, memrefType,
+ ValueRange{vectorLen, vectorLen});
return alloca;
}
@@ -293,10 +293,10 @@ struct ConvertArmSMESpillsAndFillsToLLVM : public ConvertToLLVMPattern {
Value tileMemory, Value sliceIndex) const {
auto llvmType = getTypeConverter()->convertType(tileMemory.getType());
auto descriptor =
- rewriter.create<UnrealizedConversionCastOp>(loc, llvmType, tileMemory);
- auto zero = rewriter.create<arith::ConstantIntOp>(loc, 0, /*width=*/64);
- auto sliceIndexI64 = rewriter.create<arith::IndexCastOp>(
- loc, rewriter.getI64Type(), sliceIndex);
+ UnrealizedConversionCastOp::create(rewriter, loc, llvmType, tileMemory);
+ auto zero = arith::ConstantIntOp::create(rewriter, loc, 0, /*width=*/64);
+ auto sliceIndexI64 = arith::IndexCastOp::create(
+ rewriter, loc, rewriter.getI64Type(), sliceIndex);
return getStridedElementPtr(
static_cast<ConversionPatternRewriter &>(rewriter), loc,
llvm::cast<MemRefType>(tileMemory.getType()), descriptor.getResult(0),
@@ -309,28 +309,29 @@ struct ConvertArmSMESpillsAndFillsToLLVM : public ConvertToLLVMPattern {
arm_sme::ArmSMETileType tileType, VectorType sliceType,
IntegerAttr tileId, Value sliceIndex) const {
// Cast the slice index to an i32.
- auto sliceIndexI32 = rewriter.create<arith::IndexCastOp>(
- loc, rewriter.getI32Type(), sliceIndex);
+ auto sliceIndexI32 = arith::IndexCastOp::create(
+ rewriter, loc, rewriter.getI32Type(), sliceIndex);
// Create an all-true predicate for the slice.
auto predicateType = sliceType.clone(rewriter.getI1Type());
- auto allTruePredicate = rewriter.create<arith::ConstantOp>(
- loc, DenseElementsAttr::get(predicateType, true));
+ auto allTruePredicate = arith::ConstantOp::create(
+ rewriter, loc, DenseElementsAttr::get(predicateType, true));
// Create padding vector (never used due to all-true predicate).
- auto padVector = rewriter.create<LLVM::PoisonOp>(loc, sliceType);
+ auto padVector = LLVM::PoisonOp::create(rewriter, loc, sliceType);
// Get a pointer to the current slice.
auto slicePtr =
getInMemoryTileSlicePtr(rewriter, loc, tileAlloca, sliceIndex);
// Read the value of the current slice from ZA.
- auto currentTileSlice = rewriter.create<arm_sme::aarch64_sme_read_horiz>(
- loc, sliceType, padVector, allTruePredicate, tileId, sliceIndexI32);
+ auto currentTileSlice = arm_sme::aarch64_sme_read_horiz::create(
+ rewriter, loc, sliceType, padVector, allTruePredicate, tileId,
+ sliceIndexI32);
// Load the new tile slice back from memory into ZA.
createLoadTileSliceIntrinsic(
rewriter, loc, tileType, arm_sme::TileSliceLayout::Horizontal,
allTruePredicate, slicePtr, tileId, sliceIndexI32);
// Store the current tile slice to memory.
- auto zero = rewriter.create<arith::ConstantIndexOp>(loc, 0);
- rewriter.create<vector::StoreOp>(loc, currentTileSlice, tileAlloca,
- ValueRange{sliceIndex, zero});
+ auto zero = arith::ConstantIndexOp::create(rewriter, loc, 0);
+ vector::StoreOp::create(rewriter, loc, currentTileSlice, tileAlloca,
+ ValueRange{sliceIndex, zero});
}
/// Emits a full in-place swap of the contents of a tile in ZA and a
@@ -341,12 +342,14 @@ struct ConvertArmSMESpillsAndFillsToLLVM : public ConvertToLLVMPattern {
RewriterBase::InsertionGuard guard(rewriter);
// Create an scf.for over all tile slices.
auto minNumElts =
- rewriter.create<arith::ConstantIndexOp>(loc, sliceType.getDimSize(0));
- auto lowerBound = rewriter.create<arith::ConstantIndexOp>(loc, 0);
- auto upperBound = rewriter.create<arith::MulIOp>(
- loc, minNumElts, rewriter.create<vector::VectorScaleOp>(loc));
- auto step = rewriter.create<arith::ConstantIndexOp>(loc, 1);
- auto forOp = rewriter.create<scf::ForOp>(loc, lowerBound, upperBound, step);
+ arith::ConstantIndexOp::create(rewriter, loc, sliceType.getDimSize(0));
+ auto lowerBound = arith::ConstantIndexOp::create(rewriter, loc, 0);
+ auto upperBound =
+ arith::MulIOp::create(rewriter, loc, minNumElts,
+ vector::VectorScaleOp::create(rewriter, loc));
+ auto step = arith::ConstantIndexOp::create(rewriter, loc, 1);
+ auto forOp =
+ scf::ForOp::create(rewriter, loc, lowerBound, upperBound, step);
// Emit a swap for each tile slice.
rewriter.setInsertionPointToStart(forOp.getBody());
auto sliceIndex = forOp.getInductionVar();
@@ -479,8 +482,8 @@ struct ZeroOpConversion : public ConvertArmSMEOpToLLVMPattern<arm_sme::ZeroOp> {
//
// This holds for all tile sizes.
int32_t zeroMask = baseMaskForSize << int32_t(tileId.getInt());
- rewriter.create<arm_sme::aarch64_sme_zero>(
- loc, rewriter.getI32IntegerAttr(zeroMask));
+ arm_sme::aarch64_sme_zero::create(rewriter, loc,
+ rewriter.getI32IntegerAttr(zeroMask));
// Create a placeholder op to preserve dataflow.
// Note: Place the `get_tile` op at the start of the block. This ensures
@@ -513,8 +516,8 @@ struct LoadTileSliceConversion
auto tileSlice = loadTileSliceOp.getTileSliceIndex();
// Cast tile slice to i32 for intrinsic.
- auto tileSliceI32 = rewriter.create<arith::IndexCastUIOp>(
- loc, rewriter.getI32Type(), tileSlice);
+ auto tileSliceI32 = arith::IndexCastUIOp::create(
+ rewriter, loc, rewriter.getI32Type(), tileSlice);
// Create all active predicate mask.
auto maskOp = loadTileSliceOp.getMask();
@@ -559,8 +562,8 @@ struct StoreTileSliceConversion
auto tileSlice = storeTileSliceOp.getTileSliceIndex();
// Cast tile slice to i32 for intrinsic.
- auto tileSliceI32 = rewriter.create<arith::IndexCastUIOp>(
- loc, rewriter.getI32Type(), tileSlice);
+ auto tileSliceI32 = arith::IndexCastUIOp::create(
+ rewriter, loc, rewriter.getI32Type(), tileSlice);
auto maskOp = storeTileSliceOp.getMask();
@@ -595,28 +598,28 @@ struct InsertTileSliceConversion
auto tileSlice = insertTileSliceOp.getTileSliceIndex();
// Cast tile slice from index to i32 for intrinsic.
- auto tileSliceI32 = rewriter.create<arith::IndexCastUIOp>(
- loc, rewriter.getI32Type(), tileSlice);
+ auto tileSliceI32 = arith::IndexCastUIOp::create(
+ rewriter, loc, rewriter.getI32Type(), tileSlice);
// Create all active predicate mask.
- auto one = rewriter.create<arith::ConstantOp>(
- loc, rewriter.getI1Type(),
+ auto one = arith::ConstantOp::create(
+ rewriter, loc, rewriter.getI1Type(),
rewriter.getIntegerAttr(rewriter.getI1Type(), 1));
auto predTy = VectorType::get(tileType.getShape()[0], rewriter.getI1Type(),
/*scalableDims=*/{true});
- auto allActiveMask = rewriter.create<vector::SplatOp>(loc, predTy, one);
+ auto allActiveMask = vector::SplatOp::create(rewriter, loc, predTy, one);
// Create 'arm_sme.intr.write.(horiz|vert)' to write vector to tile slice.
switch (insertTileSliceOp.getLayout()) {
case arm_sme::TileSliceLayout::Horizontal:
- rewriter.create<arm_sme::aarch64_sme_write_horiz>(
- loc, tileId, tileSliceI32, allActiveMask,
- insertTileSliceOp.getVector());
+ arm_sme::aarch64_sme_write_horiz::create(rewriter, loc, tileId,
+ tileSliceI32, allActiveMask,
+ insertTileSliceOp.getVector());
break;
case arm_sme::TileSliceLayout::Vertical:
- rewriter.create<arm_sme::aarch64_sme_write_vert>(
- loc, tileId, tileSliceI32, allActiveMask,
- insertTileSliceOp.getVector());
+ arm_sme::aarch64_sme_write_vert::create(rewriter, loc, tileId,
+ tileSliceI32, allActiveMask,
+ insertTileSliceOp.getVector());
break;
}
@@ -646,16 +649,16 @@ struct ExtractTileSliceConversion
// Create an 'all true' predicate for the tile slice.
auto predicateType = sliceType.cloneWith({}, rewriter.getI1Type());
- auto allTruePredicate = rewriter.create<arith::ConstantOp>(
- loc, DenseElementsAttr::get(predicateType, true));
+ auto allTruePredicate = arith::ConstantOp::create(
+ rewriter, loc, DenseElementsAttr::get(predicateType, true));
// Zero destination/fallback for tile slice extraction.
- auto zeroVector = rewriter.create<arith::ConstantOp>(
- loc, sliceType, rewriter.getZeroAttr(sliceType));
+ auto zeroVector = arith::ConstantOp::create(
+ rewriter, loc, sliceType, rewriter.getZeroAttr(sliceType));
// Cast tile slice from index to i32 for intrinsic.
- auto sliceIndexI32 = rewriter.create<arith::IndexCastOp>(
- loc, rewriter.getI32Type(), sliceIndex);
+ auto sliceIndexI32 = arith::IndexCastOp::create(
+ rewriter, loc, rewriter.getI32Type(), sliceIndex);
// Create 'arm_sme.intr.read.(horiz|vert)' to extract the tile slice.
switch (extractTileSlice.getLayout()) {
@@ -743,7 +746,7 @@ struct OuterProductOpConversion
Value acc = outerProductOp.getAcc();
if (!acc) {
// Initalize accumulator with zero.
- auto zero = rewriter.create<arm_sme::ZeroOp>(loc, resultVectorType);
+ auto zero = arm_sme::ZeroOp::create(rewriter, loc, resultVectorType);
zero.setTileId(tileId);
acc = zero;
}
@@ -754,16 +757,16 @@ struct OuterProductOpConversion
if (!lhsMask || !rhsMask) {
auto predTy =
outerProductOp.getLhsType().cloneWith({}, rewriter.getI1Type());
- Value allActiveMask = rewriter.create<arith::ConstantOp>(
- loc, DenseElementsAttr::get(predTy, true));
+ Value allActiveMask = arith::ConstantOp::create(
+ rewriter, loc, DenseElementsAttr::get(predTy, true));
lhsMask = allActiveMask;
rhsMask = allActiveMask;
}
// Create 'arm_sme.intr.mopa' outer product intrinsic.
- rewriter.create<arm_sme::aarch64_sme_mopa>(loc, tileId, lhsMask, rhsMask,
- outerProductOp.getLhs(),
- outerProductOp.getRhs());
+ arm_sme::aarch64_sme_mopa::create(rewriter, loc, tileId, lhsMask, rhsMask,
+ outerProductOp.getLhs(),
+ outerProductOp.getRhs());
// The outerproduct intrinsics have no result, replace
// 'arm_sme.outerproduct' with the input tile to preserve dataflow.
@@ -792,7 +795,7 @@ struct OuterProductWideningOpConversion
Value acc = op.getAcc();
if (!acc) {
// Initalize accumulator with zero.
- auto zero = rewriter.create<arm_sme::ZeroOp>(loc, op.getResultType());
+ auto zero = arm_sme::ZeroOp::create(rewriter, loc, op.getResultType());
zero.setTileId(tileId);
acc = zero;
}
@@ -801,14 +804,14 @@ struct OuterProductWideningOpConversion
Value rhsMask = op.getRhsMask();
if (!lhsMask || !rhsMask) {
auto predTy = op.getLhsType().cloneWith({}, rewriter.getI1Type());
- Value allActiveMask = rewriter.create<arith::ConstantOp>(
- loc, DenseElementsAttr::get(predTy, true));
+ Value allActiveMask = arith::ConstantOp::create(
+ rewriter, loc, DenseElementsAttr::get(predTy, true));
lhsMask = allActiveMask;
rhsMask = allActiveMask;
}
- rewriter.create<OuterProductWideningIntrOp>(
- loc, tileId, lhsMask, rhsMask, adaptor.getLhs(), adaptor.getRhs());
+ OuterProductWideningIntrOp::create(rewriter, loc, tileId, lhsMask, rhsMask,
+ adaptor.getLhs(), adaptor.getRhs());
// The outerproduct intrinsics have no result, replace
// 'arm_sme.outerproduct' with the input tile to preserve dataflow.
@@ -843,13 +846,13 @@ struct StreamingVLOpConversion
auto *intrOp = [&]() -> Operation * {
switch (streamingVlOp.getTypeSize()) {
case arm_sme::TypeSize::Byte:
- return rewriter.create<arm_sme::aarch64_sme_cntsb>(loc, i64Type);
+ return arm_sme::aarch64_sme_cntsb::create(rewriter, loc, i64Type);
case arm_sme::TypeSize::Half:
- return rewriter.create<arm_sme::aarch64_sme_cntsh>(loc, i64Type);
+ return arm_sme::aarch64_sme_cntsh::create(rewriter, loc, i64Type);
case arm_sme::TypeSize::Word:
- return rewriter.create<arm_sme::aarch64_sme_cntsw>(loc, i64Type);
+ return arm_sme::aarch64_sme_cntsw::create(rewriter, loc, i64Type);
case arm_sme::TypeSize::Double:
- return rewriter.create<arm_sme::aarch64_sme_cntsd>(loc, i64Type);
+ return arm_sme::aarch64_sme_cntsd::create(rewriter, loc, i64Type);
}
llvm_unreachable("unknown type size in StreamingVLOpConversion");
}();
@@ -872,8 +875,8 @@ static void mergeConsecutiveTileZerosInBlock(Block *block) {
if (zeroOpsToMerge.size() <= 1)
return;
IRRewriter rewriter(zeroOpsToMerge.front());
- rewriter.create<arm_sme::aarch64_sme_zero>(
- zeroOpsToMerge.front().getLoc(),
+ arm_sme::aarch64_sme_zero::create(
+ rewriter, zeroOpsToMerge.front().getLoc(),
rewriter.getI32IntegerAttr(mergedZeroMask));
for (auto zeroOp : zeroOpsToMerge)
rewriter.eraseOp(zeroOp);
diff --git a/mlir/lib/Conversion/ArmSMEToSCF/ArmSMEToSCF.cpp b/mlir/lib/Conversion/ArmSMEToSCF/ArmSMEToSCF.cpp
index 458628c29c6ac..9a37b30c14813 100644
--- a/mlir/lib/Conversion/ArmSMEToSCF/ArmSMEToSCF.cpp
+++ b/mlir/lib/Conversion/ArmSMEToSCF/ArmSMEToSCF.cpp
@@ -39,7 +39,7 @@ SmallVector<Value, 2> getMemrefIndices(ValueRange indices, unsigned rank,
auto tileSliceOffset = tileSliceIndex;
auto baseIndexPlusTileSliceOffset =
- rewriter.create<arith::AddIOp>(loc, indices[0], tileSliceOffset);
+ arith::AddIOp::create(rewriter, loc, indices[0], tileSliceOffset);
outIndices.push_back(baseIndexPlusTileSliceOffset);
outIndices.push_back(indices[1]);
@@ -59,10 +59,11 @@ FailureOr<scf::ForOp> createLoadStoreForOverTileSlices(
if (memrefIndices.size() != 2)
return rewriter.notifyMatchFailure(loc, "invalid number of indices");
- auto minTileSlices = rewriter.create<arith::ConstantIndexOp>(
- loc, arm_sme::getSMETileSliceMinNumElts(tileType.getElementType()));
+ auto minTileSlices = arith::ConstantIndexOp::create(
+ rewriter, loc,
+ arm_sme::getSMETileSliceMinNumElts(tileType.getElementType()));
auto vscale =
- rewriter.create<vector::VectorScaleOp>(loc, rewriter.getIndexType());
+ vector::VectorScaleOp::create(rewriter, loc, rewriter.getIndexType());
auto predicateType =
VectorType::get(tileType.getDimSize(1), rewriter.getI1Type(), true);
@@ -70,7 +71,7 @@ FailureOr<scf::ForOp> createLoadStoreForOverTileSlices(
// elements in a vector of SVL bits for a given element type (SVL_B,
// SVL_H, ..., SVL_Q).
auto numTileSlices =
- rewriter.create<arith::MulIOp>(loc, minTileSlices, vscale);
+ arith::MulIOp::create(rewriter, loc, minTileSlices, vscale);
Value predicate;
Value upperBound;
@@ -82,30 +83,30 @@ FailureOr<scf::ForOp> createLoadStoreForOverTileSlices(
// The upper bound of the loop must be clamped at `numTileSlices` as
// `vector.create_mask` allows operands to be greater than the size of a
// dimension.
- auto numRowI64 = rewriter.create<arith::IndexCastOp>(
- loc, rewriter.getI64Type(), maskDim0);
- auto numTileSlicesI64 = rewriter.create<arith::IndexCastOp>(
- loc, rewriter.getI64Type(), numTileSlices);
+ auto numRowI64 = arith::IndexCastOp::create(
+ rewriter, loc, rewriter.getI64Type(), maskDim0);
+ auto numTileSlicesI64 = arith::IndexCastOp::create(
+ rewriter, loc, rewriter.getI64Type(), numTileSlices);
auto upperBoundI64 =
- rewriter.create<arith::MinSIOp>(loc, numRowI64, numTileSlicesI64);
- upperBound = rewriter.create<arith::IndexCastOp>(
- loc, rewriter.getIndexType(), upperBoundI64);
+ arith::MinSIOp::create(rewriter, loc, numRowI64, numTileSlicesI64);
+ upperBound = arith::IndexCastOp::create(
+ rewriter, loc, rewriter.getIndexType(), upperBoundI64);
predicate =
- rewriter.create<vector::CreateMaskOp>(loc, predicateType, maskDim1);
+ vector::CreateMaskOp::create(rewriter, loc, predicateType, maskDim1);
} else {
upperBound = numTileSlices;
// No mask. Create an 'all true' predicate for the tile slice.
- predicate = rewriter.create<arith::ConstantOp>(
- loc, DenseElementsAttr::get(predicateType, true));
+ predicate = arith::ConstantOp::create(
+ rewriter, loc, DenseElementsAttr::get(predicateType, true));
}
bool hasCarriedArgs = bool(initTile);
- auto lowerBound = rewriter.create<arith::ConstantIndexOp>(loc, 0);
- auto step = rewriter.create<arith::ConstantIndexOp>(loc, 1);
- auto forOp = rewriter.create<scf::ForOp>(loc, lowerBound, upperBound, step,
- hasCarriedArgs ? ValueRange{initTile}
- : ValueRange{});
+ auto lowerBound = arith::ConstantIndexOp::create(rewriter, loc, 0);
+ auto step = arith::ConstantIndexOp::create(rewriter, loc, 1);
+ auto forOp =
+ scf::ForOp::create(rewriter, loc, lowerBound, upperBound, step,
+ hasCarriedArgs ? ValueRange{initTile} : ValueRange{});
rewriter.setInsertionPointToStart(forOp.getBody());
Value tileSliceIndex = forOp.getInductionVar();
@@ -118,7 +119,7 @@ FailureOr<scf::ForOp> createLoadStoreForOverTileSlices(
assert(bool(nextTile) == hasCarriedArgs);
if (nextTile)
- rewriter.create<scf::YieldOp>(loc, nextTile);
+ scf::YieldOp::create(rewriter, loc, nextTile);
return forOp;
}
@@ -194,9 +195,9 @@ struct TileLoadOpConversion : public OpRewritePattern<arm_sme::TileLoadOp> {
// Initialize tile with zero to satisfy padding. Inactive cols will be
// zeroed anyway since the loads use zeroing predication. For inactive
// rows however, no load will occur so these need to be zeroed.
- initTile = rewriter.create<arm_sme::ZeroOp>(loc, tileType);
+ initTile = arm_sme::ZeroOp::create(rewriter, loc, tileType);
} else {
- initTile = rewriter.create<arm_sme::GetTileOp>(loc, tileType);
+ initTile = arm_sme::GetTileOp::create(rewriter, loc, tileType);
}
// Create a loop to load the active tile slices from memory.
@@ -207,9 +208,10 @@ struct TileLoadOpConversion : public OpRewritePattern<arm_sme::TileLoadOp> {
Value currentTile) -> Value {
// Create 'arm_sme.load_tile_slice' to load tile slice from memory
// into tile.
- return rewriter.create<arm_sme::LoadTileSliceOp>(
- loc, tileType, tileLoadOp.getBase(), predicate, currentTile,
- memrefIndices, tileSliceIndex, tileLoadOp.getLayout());
+ return arm_sme::LoadTileSliceOp::create(
+ rewriter, loc, tileType, tileLoadOp.getBase(), predicate,
+ currentTile, memrefIndices, tileSliceIndex,
+ tileLoadOp.getLayout());
});
if (failed(forOp))
@@ -283,22 +285,22 @@ struct TileLoadOpWithMaskAndPadNonZeroConversion
auto numRows = createMaskOp.getOperands()[0];
auto numCols = createMaskOp.getOperands()[1];
- auto numColsI32 = rewriter.create<arith::IndexCastUIOp>(
- loc, rewriter.getI32Type(), numCols);
+ auto numColsI32 = arith::IndexCastUIOp::create(
+ rewriter, loc, rewriter.getI32Type(), numCols);
- auto initTile = rewriter.create<arm_sme::GetTileOp>(loc, tileType);
+ auto initTile = arm_sme::GetTileOp::create(rewriter, loc, tileType);
// Create a loop that loads each ZA tile slice from memory.
- auto step = rewriter.create<arith::ConstantIndexOp>(loc, 1);
- auto minTileSlices = rewriter.create<arith::ConstantIndexOp>(
- loc, arm_sme::getSMETileSliceMinNumElts(tileElementType));
+ auto step = arith::ConstantIndexOp::create(rewriter, loc, 1);
+ auto minTileSlices = arith::ConstantIndexOp::create(
+ rewriter, loc, arm_sme::getSMETileSliceMinNumElts(tileElementType));
auto vscale =
- rewriter.create<vector::VectorScaleOp>(loc, rewriter.getIndexType());
- auto lowerBound = rewriter.create<arith::ConstantIndexOp>(loc, 0);
+ vector::VectorScaleOp::create(rewriter, loc, rewriter.getIndexType());
+ auto lowerBound = arith::ConstantIndexOp::create(rewriter, loc, 0);
auto numTileSlices =
- rewriter.create<arith::MulIOp>(loc, minTileSlices, vscale);
- auto forOp = rewriter.create<scf::ForOp>(loc, lowerBound, numTileSlices,
- step, ValueRange{initTile});
+ arith::MulIOp::create(rewriter, loc, minTileSlices, vscale);
+ auto forOp = scf::ForOp::create(rewriter, loc, lowerBound, numTileSlices,
+ step, ValueRange{initTile});
rewriter.setInsertionPointToStart(forOp.getBody());
@@ -306,17 +308,18 @@ struct TileLoadOpWithMaskAndPadNonZeroConversion
auto currentTile = forOp.getRegionIterArg(0);
// Combine masks.
- auto rowIsActive = rewriter.create<arith::CmpIOp>(
- loc, arith::CmpIPredicate::ult, tileSliceIndex, numRows);
- auto rowIsActiveI32 = rewriter.create<arith::ExtSIOp>(
- loc, rewriter.getI32Type(), rowIsActive);
- auto mask = rewriter.create<arith::AndIOp>(loc, rowIsActiveI32, numColsI32);
- auto maskIndex =
- rewriter.create<arith::IndexCastOp>(loc, rewriter.getIndexType(), mask);
+ auto rowIsActive = arith::CmpIOp::create(
+ rewriter, loc, arith::CmpIPredicate::ult, tileSliceIndex, numRows);
+ auto rowIsActiveI32 = arith::ExtSIOp::create(
+ rewriter, loc, rewriter.getI32Type(), rowIsActive);
+ auto mask =
+ arith::AndIOp::create(rewriter, loc, rowIsActiveI32, numColsI32);
+ auto maskIndex = arith::IndexCastOp::create(rewriter, loc,
+ rewriter.getIndexType(), mask);
auto predicateType =
VectorType::get(tileType.getDimSize(1), rewriter.getI1Type(), true);
- auto maskOp1D = rewriter.create<vector::CreateMaskOp>(
- loc, predicateType, maskIndex.getResult());
+ auto maskOp1D = vector::CreateMaskOp::create(rewriter, loc, predicateType,
+ maskIndex.getResult());
auto memrefIndices = getMemrefIndices(
tileLoadOp.getIndices(), tileLoadOp.getMemRefType().getRank(),
@@ -324,17 +327,18 @@ struct TileLoadOpWithMaskAndPadNonZeroConversion
// Splat pad into 1-D vector matching type of tile slice.
VectorType tileSliceType = VectorType::Builder(tileType).dropDim(0);
- auto pad1DOp = rewriter.create<vector::SplatOp>(loc, tileSliceType, padOp);
+ auto pad1DOp = vector::SplatOp::create(rewriter, loc, tileSliceType, padOp);
- auto loadSlice = rewriter.create<vector::MaskedLoadOp>(
- loc, tileSliceType, tileLoadOp.getBase(), memrefIndices, maskOp1D,
- /*passthru=*/pad1DOp);
+ auto loadSlice = vector::MaskedLoadOp::create(rewriter, loc, tileSliceType,
+ tileLoadOp.getBase(),
+ memrefIndices, maskOp1D,
+ /*passthru=*/pad1DOp);
// Create 'arm_sme.insert_tile_slice' to insert slice into tile.
- auto insertSlice = rewriter.create<arm_sme::InsertTileSliceOp>(
- loc, tileType, loadSlice->getResult(0), currentTile, tileSliceIndex,
- tileLoadOp.getLayout());
- rewriter.create<scf::YieldOp>(loc, insertSlice.getResult());
+ auto insertSlice = arm_sme::InsertTileSliceOp::create(
+ rewriter, loc, tileType, loadSlice->getResult(0), currentTile,
+ tileSliceIndex, tileLoadOp.getLayout());
+ scf::YieldOp::create(rewriter, loc, insertSlice.getResult());
rewriter.setInsertionPointAfter(forOp);
diff --git a/mlir/lib/Conversion/AsyncToLLVM/AsyncToLLVM.cpp b/mlir/lib/Conversion/AsyncToLLVM/AsyncToLLVM.cpp
index 94f7caa315cf7..79e1683b4e2cf 100644
--- a/mlir/lib/Conversion/AsyncToLLVM/AsyncToLLVM.cpp
+++ b/mlir/lib/Conversion/AsyncToLLVM/AsyncToLLVM.cpp
@@ -203,7 +203,7 @@ static void addAsyncRuntimeApiDeclarations(ModuleOp module) {
auto addFuncDecl = [&](StringRef name, FunctionType type) {
if (module.lookupSymbol(name))
return;
- builder.create<func::FuncOp>(name, type).setPrivate();
+ func::FuncOp::create(builder, name, type).setPrivate();
};
MLIRContext *ctx = module.getContext();
@@ -254,15 +254,15 @@ static void addResumeFunction(ModuleOp module) {
auto voidTy = LLVM::LLVMVoidType::get(ctx);
Type ptrType = AsyncAPI::opaquePointerType(ctx);
- auto resumeOp = moduleBuilder.create<LLVM::LLVMFuncOp>(
- kResume, LLVM::LLVMFunctionType::get(voidTy, {ptrType}));
+ auto resumeOp = LLVM::LLVMFuncOp::create(
+ moduleBuilder, kResume, LLVM::LLVMFunctionType::get(voidTy, {ptrType}));
resumeOp.setPrivate();
auto *block = resumeOp.addEntryBlock(moduleBuilder);
auto blockBuilder = ImplicitLocOpBuilder::atBlockEnd(loc, block);
- blockBuilder.create<LLVM::CoroResumeOp>(resumeOp.getArgument(0));
- blockBuilder.create<LLVM::ReturnOp>(ValueRange());
+ LLVM::CoroResumeOp::create(blockBuilder, resumeOp.getArgument(0));
+ LLVM::ReturnOp::create(blockBuilder, ValueRange());
}
//===----------------------------------------------------------------------===//
@@ -282,7 +282,8 @@ class AsyncRuntimeTypeConverter : public TypeConverter {
// in patterns for other dialects.
auto addUnrealizedCast = [](OpBuilder &builder, Type type,
ValueRange inputs, Location loc) -> Value {
- auto cast = builder.create<UnrealizedConversionCastOp>(loc, type, inputs);
+ auto cast =
+ UnrealizedConversionCastOp::create(builder, loc, type, inputs);
return cast.getResult(0);
};
@@ -343,8 +344,8 @@ class CoroIdOpConversion : public AsyncOpConversionPattern<CoroIdOp> {
// Constants for initializing coroutine frame.
auto constZero =
- rewriter.create<LLVM::ConstantOp>(loc, rewriter.getI32Type(), 0);
- auto nullPtr = rewriter.create<LLVM::ZeroOp>(loc, ptrType);
+ LLVM::ConstantOp::create(rewriter, loc, rewriter.getI32Type(), 0);
+ auto nullPtr = LLVM::ZeroOp::create(rewriter, loc, ptrType);
// Get coroutine id: @llvm.coro.id.
rewriter.replaceOpWithNewOp<LLVM::CoroIdOp>(
@@ -372,33 +373,33 @@ class CoroBeginOpConversion : public AsyncOpConversionPattern<CoroBeginOp> {
// Get coroutine frame size: @llvm.coro.size.i64.
Value coroSize =
- rewriter.create<LLVM::CoroSizeOp>(loc, rewriter.getI64Type());
+ LLVM::CoroSizeOp::create(rewriter, loc, rewriter.getI64Type());
// Get coroutine frame alignment: @llvm.coro.align.i64.
Value coroAlign =
- rewriter.create<LLVM::CoroAlignOp>(loc, rewriter.getI64Type());
+ LLVM::CoroAlignOp::create(rewriter, loc, rewriter.getI64Type());
// Round up the size to be multiple of the alignment. Since aligned_alloc
// requires the size parameter be an integral multiple of the alignment
// parameter.
auto makeConstant = [&](uint64_t c) {
- return rewriter.create<LLVM::ConstantOp>(op->getLoc(),
- rewriter.getI64Type(), c);
+ return LLVM::ConstantOp::create(rewriter, op->getLoc(),
+ rewriter.getI64Type(), c);
};
- coroSize = rewriter.create<LLVM::AddOp>(op->getLoc(), coroSize, coroAlign);
+ coroSize = LLVM::AddOp::create(rewriter, op->getLoc(), coroSize, coroAlign);
coroSize =
- rewriter.create<LLVM::SubOp>(op->getLoc(), coroSize, makeConstant(1));
+ LLVM::SubOp::create(rewriter, op->getLoc(), coroSize, makeConstant(1));
Value negCoroAlign =
- rewriter.create<LLVM::SubOp>(op->getLoc(), makeConstant(0), coroAlign);
+ LLVM::SubOp::create(rewriter, op->getLoc(), makeConstant(0), coroAlign);
coroSize =
- rewriter.create<LLVM::AndOp>(op->getLoc(), coroSize, negCoroAlign);
+ LLVM::AndOp::create(rewriter, op->getLoc(), coroSize, negCoroAlign);
// Allocate memory for the coroutine frame.
auto allocFuncOp = LLVM::lookupOrCreateAlignedAllocFn(
rewriter, op->getParentOfType<ModuleOp>(), rewriter.getI64Type());
if (failed(allocFuncOp))
return failure();
- auto coroAlloc = rewriter.create<LLVM::CallOp>(
- loc, allocFuncOp.value(), ValueRange{coroAlign, coroSize});
+ auto coroAlloc = LLVM::CallOp::create(rewriter, loc, allocFuncOp.value(),
+ ValueRange{coroAlign, coroSize});
// Begin a coroutine: @llvm.coro.begin.
auto coroId = CoroBeginOpAdaptor(adaptor.getOperands()).getId();
@@ -427,7 +428,7 @@ class CoroFreeOpConversion : public AsyncOpConversionPattern<CoroFreeOp> {
// Get a pointer to the coroutine frame memory: @llvm.coro.free.
auto coroMem =
- rewriter.create<LLVM::CoroFreeOp>(loc, ptrType, adaptor.getOperands());
+ LLVM::CoroFreeOp::create(rewriter, loc, ptrType, adaptor.getOperands());
// Free the memory.
auto freeFuncOp =
@@ -455,15 +456,15 @@ class CoroEndOpConversion : public OpConversionPattern<CoroEndOp> {
matchAndRewrite(CoroEndOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
// We are not in the block that is part of the unwind sequence.
- auto constFalse = rewriter.create<LLVM::ConstantOp>(
- op->getLoc(), rewriter.getI1Type(), rewriter.getBoolAttr(false));
- auto noneToken = rewriter.create<LLVM::NoneTokenOp>(op->getLoc());
+ auto constFalse =
+ LLVM::ConstantOp::create(rewriter, op->getLoc(), rewriter.getI1Type(),
+ rewriter.getBoolAttr(false));
+ auto noneToken = LLVM::NoneTokenOp::create(rewriter, op->getLoc());
// Mark the end of a coroutine: @llvm.coro.end.
auto coroHdl = adaptor.getHandle();
- rewriter.create<LLVM::CoroEndOp>(
- op->getLoc(), rewriter.getI1Type(),
- ValueRange({coroHdl, constFalse, noneToken}));
+ LLVM::CoroEndOp::create(rewriter, op->getLoc(), rewriter.getI1Type(),
+ ValueRange({coroHdl, constFalse, noneToken}));
rewriter.eraseOp(op);
return success();
@@ -534,13 +535,13 @@ class CoroSuspendOpConversion : public OpConversionPattern<CoroSuspendOp> {
auto loc = op->getLoc();
// This is not a final suspension point.
- auto constFalse = rewriter.create<LLVM::ConstantOp>(
- loc, rewriter.getI1Type(), rewriter.getBoolAttr(false));
+ auto constFalse = LLVM::ConstantOp::create(
+ rewriter, loc, rewriter.getI1Type(), rewriter.getBoolAttr(false));
// Suspend a coroutine: @llvm.coro.suspend
auto coroState = adaptor.getState();
- auto coroSuspend = rewriter.create<LLVM::CoroSuspendOp>(
- loc, i8, ValueRange({coroState, constFalse}));
+ auto coroSuspend = LLVM::CoroSuspendOp::create(
+ rewriter, loc, i8, ValueRange({coroState, constFalse}));
// Cast return code to i32.
@@ -551,7 +552,7 @@ class CoroSuspendOpConversion : public OpConversionPattern<CoroSuspendOp> {
llvm::SmallVector<Block *, 2> caseDest = {op.getResumeDest(),
op.getCleanupDest()};
rewriter.replaceOpWithNewOp<LLVM::SwitchOp>(
- op, rewriter.create<LLVM::SExtOp>(loc, i32, coroSuspend.getResult()),
+ op, LLVM::SExtOp::create(rewriter, loc, i32, coroSuspend.getResult()),
/*defaultDestination=*/op.getSuspendDest(),
/*defaultOperands=*/ValueRange(),
/*caseValues=*/caseValues,
@@ -602,11 +603,11 @@ class RuntimeCreateOpLowering : public ConvertOpToLLVMPattern<RuntimeCreateOp> {
// %Size = getelementptr %T* null, int 1
// %SizeI = ptrtoint %T* %Size to i64
- auto nullPtr = rewriter.create<LLVM::ZeroOp>(loc, storagePtrType);
+ auto nullPtr = LLVM::ZeroOp::create(rewriter, loc, storagePtrType);
auto gep =
- rewriter.create<LLVM::GEPOp>(loc, storagePtrType, storedType,
- nullPtr, ArrayRef<LLVM::GEPArg>{1});
- return rewriter.create<LLVM::PtrToIntOp>(loc, i64, gep);
+ LLVM::GEPOp::create(rewriter, loc, storagePtrType, storedType,
+ nullPtr, ArrayRef<LLVM::GEPArg>{1});
+ return LLVM::PtrToIntOp::create(rewriter, loc, i64, gep);
};
rewriter.replaceOpWithNewOp<func::CallOp>(op, kCreateValue, resultType,
@@ -739,8 +740,8 @@ class RuntimeAwaitOpLowering : public OpConversionPattern<RuntimeAwaitOp> {
.Case<ValueType>([](Type) { return kAwaitValue; })
.Case<GroupType>([](Type) { return kAwaitGroup; });
- rewriter.create<func::CallOp>(op->getLoc(), apiFuncName, TypeRange(),
- adaptor.getOperands());
+ func::CallOp::create(rewriter, op->getLoc(), apiFuncName, TypeRange(),
+ adaptor.getOperands());
rewriter.eraseOp(op);
return success();
@@ -772,13 +773,12 @@ class RuntimeAwaitAndResumeOpLowering
// A pointer to coroutine resume intrinsic wrapper.
addResumeFunction(op->getParentOfType<ModuleOp>());
- auto resumePtr = rewriter.create<LLVM::AddressOfOp>(
- op->getLoc(), AsyncAPI::opaquePointerType(rewriter.getContext()),
- kResume);
+ auto resumePtr = LLVM::AddressOfOp::create(
+ rewriter, op->getLoc(),
+ AsyncAPI::opaquePointerType(rewriter.getContext()), kResume);
- rewriter.create<func::CallOp>(
- op->getLoc(), apiFuncName, TypeRange(),
- ValueRange({operand, handle, resumePtr.getRes()}));
+ func::CallOp::create(rewriter, op->getLoc(), apiFuncName, TypeRange(),
+ ValueRange({operand, handle, resumePtr.getRes()}));
rewriter.eraseOp(op);
return success();
@@ -801,9 +801,9 @@ class RuntimeResumeOpLowering
ConversionPatternRewriter &rewriter) const override {
// A pointer to coroutine resume intrinsic wrapper.
addResumeFunction(op->getParentOfType<ModuleOp>());
- auto resumePtr = rewriter.create<LLVM::AddressOfOp>(
- op->getLoc(), AsyncAPI::opaquePointerType(rewriter.getContext()),
- kResume);
+ auto resumePtr = LLVM::AddressOfOp::create(
+ rewriter, op->getLoc(),
+ AsyncAPI::opaquePointerType(rewriter.getContext()), kResume);
// Call async runtime API to execute a coroutine in the managed thread.
auto coroHdl = adaptor.getHandle();
@@ -832,8 +832,8 @@ class RuntimeStoreOpLowering : public ConvertOpToLLVMPattern<RuntimeStoreOp> {
// Get a pointer to the async value storage from the runtime.
auto ptrType = AsyncAPI::opaquePointerType(rewriter.getContext());
auto storage = adaptor.getStorage();
- auto storagePtr = rewriter.create<func::CallOp>(
- loc, kGetValueStorage, TypeRange(ptrType), storage);
+ auto storagePtr = func::CallOp::create(rewriter, loc, kGetValueStorage,
+ TypeRange(ptrType), storage);
// Cast from i8* to the LLVM pointer type.
auto valueType = op.getValue().getType();
@@ -845,7 +845,7 @@ class RuntimeStoreOpLowering : public ConvertOpToLLVMPattern<RuntimeStoreOp> {
Value castedStoragePtr = storagePtr.getResult(0);
// Store the yielded value into the async value storage.
auto value = adaptor.getValue();
- rewriter.create<LLVM::StoreOp>(loc, value, castedStoragePtr);
+ LLVM::StoreOp::create(rewriter, loc, value, castedStoragePtr);
// Erase the original runtime store operation.
rewriter.eraseOp(op);
@@ -872,8 +872,8 @@ class RuntimeLoadOpLowering : public ConvertOpToLLVMPattern<RuntimeLoadOp> {
// Get a pointer to the async value storage from the runtime.
auto ptrType = AsyncAPI::opaquePointerType(rewriter.getContext());
auto storage = adaptor.getStorage();
- auto storagePtr = rewriter.create<func::CallOp>(
- loc, kGetValueStorage, TypeRange(ptrType), storage);
+ auto storagePtr = func::CallOp::create(rewriter, loc, kGetValueStorage,
+ TypeRange(ptrType), storage);
// Cast from i8* to the LLVM pointer type.
auto valueType = op.getResult().getType();
@@ -960,9 +960,9 @@ class RefCountingOpLowering : public OpConversionPattern<RefCountingOp> {
LogicalResult
matchAndRewrite(RefCountingOp op, typename RefCountingOp::Adaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
- auto count = rewriter.create<arith::ConstantOp>(
- op->getLoc(), rewriter.getI64Type(),
- rewriter.getI64IntegerAttr(op.getCount()));
+ auto count =
+ arith::ConstantOp::create(rewriter, op->getLoc(), rewriter.getI64Type(),
+ rewriter.getI64IntegerAttr(op.getCount()));
auto operand = adaptor.getOperand();
rewriter.replaceOpWithNewOp<func::CallOp>(op, TypeRange(), apiFunctionName,
diff --git a/mlir/lib/Conversion/BufferizationToMemRef/BufferizationToMemRef.cpp b/mlir/lib/Conversion/BufferizationToMemRef/BufferizationToMemRef.cpp
index b9991f36cdaaf..30a7170cf5c6a 100644
--- a/mlir/lib/Conversion/BufferizationToMemRef/BufferizationToMemRef.cpp
+++ b/mlir/lib/Conversion/BufferizationToMemRef/BufferizationToMemRef.cpp
@@ -47,26 +47,26 @@ struct CloneOpConversion : public OpConversionPattern<bufferization::CloneOp> {
if (auto unrankedType = dyn_cast<UnrankedMemRefType>(type)) {
// Constants
- Value zero = rewriter.create<arith::ConstantIndexOp>(loc, 0);
- Value one = rewriter.create<arith::ConstantIndexOp>(loc, 1);
+ Value zero = arith::ConstantIndexOp::create(rewriter, loc, 0);
+ Value one = arith::ConstantIndexOp::create(rewriter, loc, 1);
// Dynamically evaluate the size and shape of the unranked memref
- Value rank = rewriter.create<memref::RankOp>(loc, op.getInput());
+ Value rank = memref::RankOp::create(rewriter, loc, op.getInput());
MemRefType allocType =
MemRefType::get({ShapedType::kDynamic}, rewriter.getIndexType());
- Value shape = rewriter.create<memref::AllocaOp>(loc, allocType, rank);
+ Value shape = memref::AllocaOp::create(rewriter, loc, allocType, rank);
// Create a loop to query dimension sizes, store them as a shape, and
// compute the total size of the memref
auto loopBody = [&](OpBuilder &builder, Location loc, Value i,
ValueRange args) {
auto acc = args.front();
- auto dim = rewriter.create<memref::DimOp>(loc, op.getInput(), i);
+ auto dim = memref::DimOp::create(rewriter, loc, op.getInput(), i);
- rewriter.create<memref::StoreOp>(loc, dim, shape, i);
- acc = rewriter.create<arith::MulIOp>(loc, acc, dim);
+ memref::StoreOp::create(rewriter, loc, dim, shape, i);
+ acc = arith::MulIOp::create(rewriter, loc, acc, dim);
- rewriter.create<scf::YieldOp>(loc, acc);
+ scf::YieldOp::create(rewriter, loc, acc);
};
auto size = rewriter
.create<scf::ForOp>(loc, zero, rank, one, ValueRange(one),
@@ -78,9 +78,9 @@ struct CloneOpConversion : public OpConversionPattern<bufferization::CloneOp> {
// Allocate new memref with 1D dynamic shape, then reshape into the
// shape of the original unranked memref
- alloc = rewriter.create<memref::AllocOp>(loc, memrefType, size);
+ alloc = memref::AllocOp::create(rewriter, loc, memrefType, size);
alloc =
- rewriter.create<memref::ReshapeOp>(loc, unrankedType, alloc, shape);
+ memref::ReshapeOp::create(rewriter, loc, unrankedType, alloc, shape);
} else {
MemRefType memrefType = cast<MemRefType>(type);
MemRefLayoutAttrInterface layout;
@@ -103,14 +103,15 @@ struct CloneOpConversion : public OpConversionPattern<bufferization::CloneOp> {
}
// Allocate a memref with identity layout.
- alloc = rewriter.create<memref::AllocOp>(loc, allocType, dynamicOperands);
+ alloc =
+ memref::AllocOp::create(rewriter, loc, allocType, dynamicOperands);
// Cast the allocation to the specified type if needed.
if (memrefType != allocType)
alloc =
- rewriter.create<memref::CastOp>(op->getLoc(), memrefType, alloc);
+ memref::CastOp::create(rewriter, op->getLoc(), memrefType, alloc);
}
- rewriter.create<memref::CopyOp>(loc, op.getInput(), alloc);
+ memref::CopyOp::create(rewriter, loc, op.getInput(), alloc);
rewriter.replaceOp(op, alloc);
return success();
}
More information about the Mlir-commits
mailing list