[Mlir-commits] [mlir] [mlir][NFC] update `Conversion` create APIs (4/n) (#149687) (PR #149879)
Maksim Levental
llvmlistbot at llvm.org
Mon Jul 21 12:57:49 PDT 2025
https://github.com/makslevental updated https://github.com/llvm/llvm-project/pull/149879
>From 75ef3dabc70e37441cf8bdea0c317fefc93e2279 Mon Sep 17 00:00:00 2001
From: max <maksim.levental at gmail.com>
Date: Mon, 21 Jul 2025 15:26:00 -0400
Subject: [PATCH] [mlir][NFC] update `Conversion` create APIs (4/n) (#149687)
See https://github.com/llvm/llvm-project/pull/147168 for more info.
---
.../AMDGPUToROCDL/AMDGPUToROCDL.cpp | 278 ++---
.../AffineToStandard/AffineToStandard.cpp | 44 +-
.../ArithToAMDGPU/ArithToAMDGPU.cpp | 213 ++--
.../ArithToArmSME/ArithToArmSME.cpp | 8 +-
.../Conversion/ArithToEmitC/ArithToEmitC.cpp | 90 +-
.../Conversion/ArithToLLVM/ArithToLLVM.cpp | 38 +-
.../Conversion/ArithToSPIRV/ArithToSPIRV.cpp | 103 +-
.../ArmNeon2dToIntr/ArmNeon2dToIntr.cpp | 8 +-
.../Conversion/ArmSMEToLLVM/ArmSMEToLLVM.cpp | 213 ++--
.../Conversion/ArmSMEToSCF/ArmSMEToSCF.cpp | 112 +-
.../Conversion/AsyncToLLVM/AsyncToLLVM.cpp | 106 +-
.../BufferizationToMemRef.cpp | 27 +-
.../ComplexCommon/DivisionConverter.cpp | 602 ++++-----
.../ComplexToLLVM/ComplexToLLVM.cpp | 36 +-
.../ComplexToLibm/ComplexToLibm.cpp | 4 +-
.../ComplexToROCDLLibraryCalls.cpp | 4 +-
.../ComplexToStandard/ComplexToStandard.cpp | 706 +++++------
.../ControlFlowToLLVM/ControlFlowToLLVM.cpp | 10 +-
.../ControlFlowToSCF/ControlFlowToSCF.cpp | 38 +-
.../Conversion/FuncToEmitC/FuncToEmitC.cpp | 4 +-
mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp | 73 +-
.../Conversion/GPUCommon/GPUOpsLowering.cpp | 133 +-
.../GPUCommon/GPUToLLVMConversion.cpp | 133 +-
.../GPUCommon/IndexIntrinsicsOpLowering.h | 16 +-
.../GPUCommon/OpToFuncCallLowering.h | 24 +-
.../Conversion/GPUToLLVMSPV/GPUToLLVMSPV.cpp | 28 +-
.../GPUToNVVM/LowerGpuOpsToNVVMOps.cpp | 63 +-
.../Conversion/GPUToNVVM/WmmaOpsToNvvm.cpp | 58 +-
.../GPUToROCDL/LowerGpuOpsToROCDLOps.cpp | 76 +-
mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp | 92 +-
.../Conversion/GPUToSPIRV/GPUToSPIRVPass.cpp | 7 +-
.../Conversion/GPUToSPIRV/WmmaOpsToSPIRV.cpp | 8 +-
.../Conversion/IndexToLLVM/IndexToLLVM.cpp | 72 +-
.../Conversion/IndexToSPIRV/IndexToSPIRV.cpp | 85 +-
.../Conversion/LLVMCommon/MemRefBuilder.cpp | 118 +-
mlir/lib/Conversion/LLVMCommon/Pattern.cpp | 92 +-
.../Conversion/LLVMCommon/PrintCallHelper.cpp | 12 +-
.../Conversion/LLVMCommon/StructBuilder.cpp | 4 +-
.../Conversion/LLVMCommon/TypeConverter.cpp | 16 +-
.../Conversion/LLVMCommon/VectorPattern.cpp | 8 +-
.../LinalgToStandard/LinalgToStandard.cpp | 8 +-
mlir/lib/Conversion/MPIToLLVM/MPIToLLVM.cpp | 108 +-
.../Conversion/MathToFuncs/MathToFuncs.cpp | 265 ++--
mlir/lib/Conversion/MathToLLVM/MathToLLVM.cpp | 80 +-
mlir/lib/Conversion/MathToLibm/MathToLibm.cpp | 21 +-
.../Conversion/MathToSPIRV/MathToSPIRV.cpp | 112 +-
.../MemRefToEmitC/MemRefToEmitC.cpp | 14 +-
.../Conversion/MemRefToLLVM/MemRefToLLVM.cpp | 220 ++--
.../MemRefToSPIRV/MemRefToSPIRV.cpp | 37 +-
mlir/lib/Conversion/MeshToMPI/MeshToMPI.cpp | 303 ++---
.../Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp | 238 ++--
.../Conversion/OpenACCToSCF/OpenACCToSCF.cpp | 4 +-
.../Conversion/OpenMPToLLVM/OpenMPToLLVM.cpp | 4 +-
.../PDLToPDLInterp/PDLToPDLInterp.cpp | 217 ++--
.../SCFToControlFlow/SCFToControlFlow.cpp | 58 +-
mlir/lib/Conversion/SCFToEmitC/SCFToEmitC.cpp | 20 +-
mlir/lib/Conversion/SCFToGPU/SCFToGPU.cpp | 72 +-
.../Conversion/SCFToOpenMP/SCFToOpenMP.cpp | 63 +-
mlir/lib/Conversion/SCFToSPIRV/SCFToSPIRV.cpp | 65 +-
.../ConvertLaunchFuncToLLVMCalls.cpp | 14 +-
.../Conversion/SPIRVToLLVM/SPIRVToLLVM.cpp | 190 +--
.../ConvertShapeConstraints.cpp | 2 +-
.../ShapeToStandard/ShapeToStandard.cpp | 249 ++--
.../TensorToSPIRV/TensorToSPIRV.cpp | 10 +-
.../Conversion/TosaToArith/TosaToArith.cpp | 123 +-
.../Conversion/TosaToLinalg/TosaToLinalg.cpp | 1073 +++++++++--------
.../TosaToLinalg/TosaToLinalgNamed.cpp | 299 ++---
.../TosaToMLProgram/TosaToMLProgram.cpp | 12 +-
mlir/lib/Conversion/TosaToSCF/TosaToSCF.cpp | 38 +-
.../Conversion/TosaToTensor/TosaToTensor.cpp | 35 +-
.../VectorToArmSME/VectorToArmSME.cpp | 91 +-
.../Conversion/VectorToGPU/VectorToGPU.cpp | 118 +-
.../VectorToLLVM/ConvertVectorToLLVM.cpp | 306 ++---
.../Conversion/VectorToSCF/VectorToSCF.cpp | 235 ++--
.../VectorToSPIRV/VectorToSPIRV.cpp | 91 +-
.../VectorToXeGPU/VectorToXeGPU.cpp | 43 +-
mlir/lib/Conversion/XeVMToLLVM/XeVMToLLVM.cpp | 44 +-
77 files changed, 4412 insertions(+), 4232 deletions(-)
diff --git a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
index ef35ee208f002..fe3dc91328879 100644
--- a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
+++ b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
@@ -50,20 +50,20 @@ static Value convertUnsignedToI32(ConversionPatternRewriter &rewriter,
if (i32 == valTy)
return val;
return valTy.getWidth() > 32
- ? Value(rewriter.create<LLVM::TruncOp>(loc, i32, val))
- : Value(rewriter.create<LLVM::ZExtOp>(loc, i32, val));
+ ? Value(LLVM::TruncOp::create(rewriter, loc, i32, val))
+ : Value(LLVM::ZExtOp::create(rewriter, loc, i32, val));
}
static Value createI32Constant(ConversionPatternRewriter &rewriter,
Location loc, int32_t value) {
Type i32 = rewriter.getI32Type();
- return rewriter.create<LLVM::ConstantOp>(loc, i32, value);
+ return LLVM::ConstantOp::create(rewriter, loc, i32, value);
}
static Value createI1Constant(ConversionPatternRewriter &rewriter, Location loc,
bool value) {
Type llvmI1 = rewriter.getI1Type();
- return rewriter.create<LLVM::ConstantOp>(loc, llvmI1, value);
+ return LLVM::ConstantOp::create(rewriter, loc, llvmI1, value);
}
/// Returns the linear index used to access an element in the memref.
@@ -78,11 +78,11 @@ static Value getLinearIndexI32(ConversionPatternRewriter &rewriter,
ShapedType::isDynamic(stride)
? convertUnsignedToI32(rewriter, loc,
memRefDescriptor.stride(rewriter, loc, i))
- : rewriter.create<LLVM::ConstantOp>(loc, i32, stride);
- increment = rewriter.create<LLVM::MulOp>(loc, increment, strideValue);
+ : LLVM::ConstantOp::create(rewriter, loc, i32, stride);
+ increment = LLVM::MulOp::create(rewriter, loc, increment, strideValue);
}
- index =
- index ? rewriter.create<LLVM::AddOp>(loc, index, increment) : increment;
+ index = index ? LLVM::AddOp::create(rewriter, loc, index, increment)
+ : increment;
}
return index ? index : createI32Constant(rewriter, loc, 0);
}
@@ -110,14 +110,14 @@ static Value getNumRecords(ConversionPatternRewriter &rewriter, Location loc,
for (uint32_t i = 0, e = memrefType.getRank(); i < e; ++i) {
Value size = memrefDescriptor.size(rewriter, loc, i);
Value stride = memrefDescriptor.stride(rewriter, loc, i);
- Value maxThisDim = rewriter.create<LLVM::MulOp>(loc, size, stride);
+ Value maxThisDim = LLVM::MulOp::create(rewriter, loc, size, stride);
maxIndex = maxIndex
- ? rewriter.create<LLVM::UMaxOp>(loc, maxIndex, maxThisDim)
+ ? LLVM::UMaxOp::create(rewriter, loc, maxIndex, maxThisDim)
: maxThisDim;
}
Value maxIndexI32 = convertUnsignedToI32(rewriter, loc, maxIndex);
Value byteWidthConst = createI32Constant(rewriter, loc, elementByteWidth);
- return rewriter.create<LLVM::MulOp>(loc, maxIndexI32, byteWidthConst);
+ return LLVM::MulOp::create(rewriter, loc, maxIndexI32, byteWidthConst);
}
static Value makeBufferRsrc(ConversionPatternRewriter &rewriter, Location loc,
@@ -132,14 +132,14 @@ static Value makeBufferRsrc(ConversionPatternRewriter &rewriter, Location loc,
Value stride;
if (chipset.majorVersion == 9 && chipset >= kGfx942 && cacheSwizzleStride) {
Value cacheStrideZext =
- rewriter.create<LLVM::ZExtOp>(loc, i16, cacheSwizzleStride);
- Value swizzleBit = rewriter.create<LLVM::ConstantOp>(
- loc, i16, rewriter.getI16IntegerAttr(1 << 14));
- stride = rewriter.create<LLVM::OrOp>(loc, cacheStrideZext, swizzleBit,
- /*isDisjoint=*/true);
+ LLVM::ZExtOp::create(rewriter, loc, i16, cacheSwizzleStride);
+ Value swizzleBit = LLVM::ConstantOp::create(
+ rewriter, loc, i16, rewriter.getI16IntegerAttr(1 << 14));
+ stride = LLVM::OrOp::create(rewriter, loc, cacheStrideZext, swizzleBit,
+ /*isDisjoint=*/true);
} else {
- stride = rewriter.create<LLVM::ConstantOp>(loc, i16,
- rewriter.getI16IntegerAttr(0));
+ stride = LLVM::ConstantOp::create(rewriter, loc, i16,
+ rewriter.getI16IntegerAttr(0));
}
// Get the number of elements.
// Flag word:
@@ -209,20 +209,21 @@ struct FatRawBufferCastLowering
: descriptor.alignedPtr(rewriter, loc);
Value offset = adaptor.getResetOffset()
- ? rewriter.create<LLVM::ConstantOp>(
- loc, getIndexType(), rewriter.getIndexAttr(0))
+ ? LLVM::ConstantOp::create(rewriter, loc, getIndexType(),
+ rewriter.getIndexAttr(0))
: descriptor.offset(rewriter, loc);
bool hasSizes = memrefType.getRank() > 0;
// No need to unpack() and pack() all the individual sizes and strides,
// so we'll just extract the arrays.
- Value sizes = hasSizes ? rewriter.create<LLVM::ExtractValueOp>(
- loc, descriptor, kSizePosInMemRefDescriptor)
- : Value{};
- Value strides = hasSizes
- ? rewriter.create<LLVM::ExtractValueOp>(
- loc, descriptor, kStridePosInMemRefDescriptor)
- : Value{};
+ Value sizes = hasSizes
+ ? LLVM::ExtractValueOp::create(rewriter, loc, descriptor,
+ kSizePosInMemRefDescriptor)
+ : Value{};
+ Value strides =
+ hasSizes ? LLVM::ExtractValueOp::create(rewriter, loc, descriptor,
+ kStridePosInMemRefDescriptor)
+ : Value{};
Value fatPtr = makeBufferRsrc(
rewriter, loc, basePointer, numRecords, adaptor.getBoundsCheck(),
@@ -231,17 +232,17 @@ struct FatRawBufferCastLowering
Value result = MemRefDescriptor::poison(
rewriter, loc,
getTypeConverter()->convertType(op.getResult().getType()));
- result = rewriter.create<LLVM::InsertValueOp>(
- loc, result, fatPtr, kAllocatedPtrPosInMemRefDescriptor);
- result = rewriter.create<LLVM::InsertValueOp>(
- loc, result, fatPtr, kAlignedPtrPosInMemRefDescriptor);
- result = rewriter.create<LLVM::InsertValueOp>(loc, result, offset,
- kOffsetPosInMemRefDescriptor);
+ result = LLVM::InsertValueOp::create(rewriter, loc, result, fatPtr,
+ kAllocatedPtrPosInMemRefDescriptor);
+ result = LLVM::InsertValueOp::create(rewriter, loc, result, fatPtr,
+ kAlignedPtrPosInMemRefDescriptor);
+ result = LLVM::InsertValueOp::create(rewriter, loc, result, offset,
+ kOffsetPosInMemRefDescriptor);
if (hasSizes) {
- result = rewriter.create<LLVM::InsertValueOp>(loc, result, sizes,
- kSizePosInMemRefDescriptor);
- result = rewriter.create<LLVM::InsertValueOp>(
- loc, result, strides, kStridePosInMemRefDescriptor);
+ result = LLVM::InsertValueOp::create(rewriter, loc, result, sizes,
+ kSizePosInMemRefDescriptor);
+ result = LLVM::InsertValueOp::create(rewriter, loc, result, strides,
+ kStridePosInMemRefDescriptor);
}
rewriter.replaceOp(op, result);
return success();
@@ -342,8 +343,8 @@ struct RawBufferOpLowering : public ConvertOpToLLVMPattern<GpuOp> {
SmallVector<Value, 6> args;
if (storeData) {
if (llvmBufferValType != llvmWantedDataType) {
- Value castForStore =
- rewriter.create<LLVM::BitcastOp>(loc, llvmBufferValType, storeData);
+ Value castForStore = LLVM::BitcastOp::create(
+ rewriter, loc, llvmBufferValType, storeData);
args.push_back(castForStore);
} else {
args.push_back(storeData);
@@ -352,8 +353,8 @@ struct RawBufferOpLowering : public ConvertOpToLLVMPattern<GpuOp> {
if (atomicCmpData) {
if (llvmBufferValType != llvmWantedDataType) {
- Value castForCmp = rewriter.create<LLVM::BitcastOp>(
- loc, llvmBufferValType, atomicCmpData);
+ Value castForCmp = LLVM::BitcastOp::create(
+ rewriter, loc, llvmBufferValType, atomicCmpData);
args.push_back(castForCmp);
} else {
args.push_back(atomicCmpData);
@@ -382,18 +383,18 @@ struct RawBufferOpLowering : public ConvertOpToLLVMPattern<GpuOp> {
if (std::optional<int32_t> indexOffset = adaptor.getIndexOffset();
indexOffset && *indexOffset > 0) {
Value extraOffsetConst = createI32Constant(rewriter, loc, *indexOffset);
- voffset =
- voffset ? rewriter.create<LLVM::AddOp>(loc, voffset, extraOffsetConst)
- : extraOffsetConst;
+ voffset = voffset ? LLVM::AddOp::create(rewriter, loc, voffset,
+ extraOffsetConst)
+ : extraOffsetConst;
}
- voffset = rewriter.create<LLVM::MulOp>(loc, voffset, byteWidthConst);
+ voffset = LLVM::MulOp::create(rewriter, loc, voffset, byteWidthConst);
args.push_back(voffset);
// SGPR offset.
Value sgprOffset = adaptor.getSgprOffset();
if (!sgprOffset)
sgprOffset = createI32Constant(rewriter, loc, 0);
- sgprOffset = rewriter.create<LLVM::MulOp>(loc, sgprOffset, byteWidthConst);
+ sgprOffset = LLVM::MulOp::create(rewriter, loc, sgprOffset, byteWidthConst);
args.push_back(sgprOffset);
// bit 0: GLC = 0 (atomics drop value, less coherency)
@@ -403,13 +404,13 @@ struct RawBufferOpLowering : public ConvertOpToLLVMPattern<GpuOp> {
llvm::SmallVector<Type, 1> resultTypes(gpuOp->getNumResults(),
llvmBufferValType);
- Operation *lowered = rewriter.create<Intrinsic>(loc, resultTypes, args,
- ArrayRef<NamedAttribute>());
+ Operation *lowered = Intrinsic::create(rewriter, loc, resultTypes, args,
+ ArrayRef<NamedAttribute>());
if (lowered->getNumResults() == 1) {
Value replacement = lowered->getResult(0);
if (llvmBufferValType != llvmWantedDataType) {
- replacement = rewriter.create<LLVM::BitcastOp>(loc, llvmWantedDataType,
- replacement);
+ replacement = LLVM::BitcastOp::create(rewriter, loc, llvmWantedDataType,
+ replacement);
}
rewriter.replaceOp(gpuOp, replacement);
} else {
@@ -465,12 +466,12 @@ struct LDSBarrierOpLowering : public ConvertOpToLLVMPattern<LDSBarrierOp> {
<< chipset.majorVersion;
Location loc = op->getLoc();
- rewriter.create<ROCDL::SWaitcntOp>(loc, ldsOnlyBits);
+ ROCDL::SWaitcntOp::create(rewriter, loc, ldsOnlyBits);
rewriter.replaceOpWithNewOp<ROCDL::SBarrierOp>(op);
} else {
Location loc = op->getLoc();
- rewriter.create<ROCDL::WaitDscntOp>(loc, 0);
- rewriter.create<ROCDL::BarrierSignalOp>(loc, -1);
+ ROCDL::WaitDscntOp::create(rewriter, loc, 0);
+ ROCDL::BarrierSignalOp::create(rewriter, loc, -1);
rewriter.replaceOpWithNewOp<ROCDL::BarrierWaitOp>(op, -1);
}
@@ -516,19 +517,21 @@ static Value convertMFMAVectorOperand(ConversionPatternRewriter &rewriter,
Type inputType = input.getType();
if (auto vectorType = dyn_cast<VectorType>(inputType)) {
if (vectorType.getElementType().isBF16() && !allowBf16)
- return rewriter.create<LLVM::BitcastOp>(
- loc, vectorType.clone(rewriter.getI16Type()), input);
+ return LLVM::BitcastOp::create(
+ rewriter, loc, vectorType.clone(rewriter.getI16Type()), input);
if (vectorType.getElementType().isInteger(8) &&
vectorType.getNumElements() <= 8)
- return rewriter.create<LLVM::BitcastOp>(
- loc, rewriter.getIntegerType(vectorType.getNumElements() * 8), input);
+ return LLVM::BitcastOp::create(
+ rewriter, loc,
+ rewriter.getIntegerType(vectorType.getNumElements() * 8), input);
if (isa<IntegerType>(vectorType.getElementType()) &&
vectorType.getElementTypeBitWidth() <= 8) {
int64_t numWords = llvm::divideCeil(
vectorType.getNumElements() * vectorType.getElementTypeBitWidth(),
32);
- return rewriter.create<LLVM::BitcastOp>(
- loc, VectorType::get(numWords, rewriter.getI32Type()), input);
+ return LLVM::BitcastOp::create(
+ rewriter, loc, VectorType::get(numWords, rewriter.getI32Type()),
+ input);
}
}
return input;
@@ -549,8 +552,8 @@ static Value castMFMAScaleOperand(ConversionPatternRewriter &rewriter,
Type inputType = input.getType();
Type outputType = rewriter.getI32Type();
if (auto intType = dyn_cast<IntegerType>(inputType))
- return rewriter.create<LLVM::ZExtOp>(loc, outputType, input);
- return rewriter.create<LLVM::BitcastOp>(loc, outputType, input);
+ return LLVM::ZExtOp::create(rewriter, loc, outputType, input);
+ return LLVM::BitcastOp::create(rewriter, loc, outputType, input);
}
/// Push an input operand. If it is a float type, nothing to do. If it is
@@ -576,8 +579,8 @@ static void wmmaPushInputOperand(ConversionPatternRewriter &rewriter,
Type elemType = vectorType.getElementType();
if (elemType.isBF16())
- llvmInput = rewriter.create<LLVM::BitcastOp>(
- loc, vectorType.clone(rewriter.getI16Type()), llvmInput);
+ llvmInput = LLVM::BitcastOp::create(
+ rewriter, loc, vectorType.clone(rewriter.getI16Type()), llvmInput);
if (elemType.getIntOrFloatBitWidth() > 8) {
operands.push_back(llvmInput);
return;
@@ -613,7 +616,7 @@ static void wmmaPushInputOperand(ConversionPatternRewriter &rewriter,
// (256 / 64) * 4 = 16 bits of input (on gfx12+) but take i32 arguments.
// Add in the zeros here.
if (numBits < 32)
- castInput = rewriter.create<LLVM::ZExtOp>(loc, i32, castInput);
+ castInput = LLVM::ZExtOp::create(rewriter, loc, i32, castInput);
operands.push_back(castInput);
}
@@ -633,8 +636,8 @@ static void wmmaPushOutputOperand(ConversionPatternRewriter &rewriter,
auto vectorType = dyn_cast<VectorType>(inputType);
Type elemType = vectorType.getElementType();
if (elemType.isBF16())
- output = rewriter.create<LLVM::BitcastOp>(
- loc, vectorType.clone(rewriter.getI16Type()), output);
+ output = LLVM::BitcastOp::create(
+ rewriter, loc, vectorType.clone(rewriter.getI16Type()), output);
operands.push_back(output);
if (elemType.isF16() || elemType.isBF16() || elemType.isInteger(16)) {
operands.push_back(createI1Constant(rewriter, loc, subwordOffset));
@@ -992,7 +995,7 @@ struct MFMAOpLowering : public ConvertOpToLLVMPattern<MFMAOp> {
};
Value lowered = rewriter.create(loweredOp)->getResult(0);
if (outType != intrinsicOutType)
- lowered = rewriter.create<LLVM::BitcastOp>(loc, outType, lowered);
+ lowered = LLVM::BitcastOp::create(rewriter, loc, outType, lowered);
rewriter.replaceOp(op, lowered);
return success();
}
@@ -1092,8 +1095,8 @@ struct WMMAOpLowering : public ConvertOpToLLVMPattern<WMMAOp> {
Operation *maybeCastBack = lowered;
if (rawOutType != outType)
- maybeCastBack =
- rewriter.create<LLVM::BitcastOp>(loc, outType, lowered->getResult(0));
+ maybeCastBack = LLVM::BitcastOp::create(rewriter, loc, outType,
+ lowered->getResult(0));
rewriter.replaceOp(op, maybeCastBack->getResults());
return success();
@@ -1143,22 +1146,22 @@ struct TransposeLoadOpLowering
switch (elementTypeSize) {
case 4: {
assert(numElements == 16);
- auto rocdlOp =
- rewriter.create<ROCDL::ds_read_tr4_b64>(loc, rocdlResultType, srcPtr);
+ auto rocdlOp = ROCDL::ds_read_tr4_b64::create(rewriter, loc,
+ rocdlResultType, srcPtr);
rewriter.replaceOpWithNewOp<LLVM::BitcastOp>(op, llvmResultType, rocdlOp);
break;
}
case 6: {
assert(numElements == 16);
- auto rocdlOp =
- rewriter.create<ROCDL::ds_read_tr6_b96>(loc, rocdlResultType, srcPtr);
+ auto rocdlOp = ROCDL::ds_read_tr6_b96::create(rewriter, loc,
+ rocdlResultType, srcPtr);
rewriter.replaceOpWithNewOp<LLVM::BitcastOp>(op, llvmResultType, rocdlOp);
break;
}
case 8: {
assert(numElements == 8);
- auto rocdlOp =
- rewriter.create<ROCDL::ds_read_tr8_b64>(loc, rocdlResultType, srcPtr);
+ auto rocdlOp = ROCDL::ds_read_tr8_b64::create(rewriter, loc,
+ rocdlResultType, srcPtr);
rewriter.replaceOpWithNewOp<LLVM::BitcastOp>(op, llvmResultType, rocdlOp);
break;
}
@@ -1316,21 +1319,21 @@ LogicalResult ExtPackedFp8OpLowering::matchAndRewrite(
Type sourceElemType = getElementTypeOrSelf(op.getSource());
// Extend to a v4i8
if (!sourceVecType || sourceVecType.getNumElements() < 4) {
- Value longVec = rewriter.create<LLVM::UndefOp>(loc, v4i8);
+ Value longVec = LLVM::UndefOp::create(rewriter, loc, v4i8);
if (!sourceVecType) {
- longVec = rewriter.create<LLVM::InsertElementOp>(
- loc, longVec, source, createI32Constant(rewriter, loc, 0));
+ longVec = LLVM::InsertElementOp::create(
+ rewriter, loc, longVec, source, createI32Constant(rewriter, loc, 0));
} else {
for (int32_t i = 0, e = sourceVecType.getNumElements(); i < e; ++i) {
Value idx = createI32Constant(rewriter, loc, i);
- Value elem = rewriter.create<LLVM::ExtractElementOp>(loc, source, idx);
+ Value elem = LLVM::ExtractElementOp::create(rewriter, loc, source, idx);
longVec =
- rewriter.create<LLVM::InsertElementOp>(loc, longVec, elem, idx);
+ LLVM::InsertElementOp::create(rewriter, loc, longVec, elem, idx);
}
}
source = longVec;
}
- Value i32Source = rewriter.create<LLVM::BitcastOp>(loc, i32, source);
+ Value i32Source = LLVM::BitcastOp::create(rewriter, loc, i32, source);
if (resultVecType) {
if (typeIsExpectedBf8ForChipset(chipset, sourceElemType)) {
rewriter.replaceOpWithNewOp<ROCDL::CvtPkF32Bf8Op>(op, f32, i32Source,
@@ -1382,21 +1385,21 @@ LogicalResult ScaledExtPackedOpLowering::matchAndRewrite(
// Extend to a packedVectorType
if (sourceVecType.getNumElements() < packedVecType.getNumElements()) {
- Value longVec = rewriter.create<LLVM::ZeroOp>(loc, packedVecType);
+ Value longVec = LLVM::ZeroOp::create(rewriter, loc, packedVecType);
if (!sourceVecType) {
- longVec = rewriter.create<LLVM::InsertElementOp>(
- loc, longVec, source, createI32Constant(rewriter, loc, 0));
+ longVec = LLVM::InsertElementOp::create(
+ rewriter, loc, longVec, source, createI32Constant(rewriter, loc, 0));
} else {
for (int32_t i = 0, e = sourceVecType.getNumElements(); i < e; ++i) {
Value idx = createI32Constant(rewriter, loc, i);
- Value elem = rewriter.create<LLVM::ExtractElementOp>(loc, source, idx);
+ Value elem = LLVM::ExtractElementOp::create(rewriter, loc, source, idx);
longVec =
- rewriter.create<LLVM::InsertElementOp>(loc, longVec, elem, idx);
+ LLVM::InsertElementOp::create(rewriter, loc, longVec, elem, idx);
}
}
source = longVec;
}
- Value i32Source = rewriter.create<LLVM::BitcastOp>(loc, i32, source);
+ Value i32Source = LLVM::BitcastOp::create(rewriter, loc, i32, source);
if (isa<Float8E5M2Type>(sourceElemType) && destElemType.isF32())
rewriter.replaceOpWithNewOp<ROCDL::CvtScaleF32PkF32Bf8Op>(
@@ -1454,54 +1457,57 @@ LogicalResult PackedScaledTruncOpLowering::matchAndRewrite(
Value scale = adaptor.getScale();
Value existing = adaptor.getExisting();
if (existing)
- existing = rewriter.create<LLVM::BitcastOp>(loc, intResultType, existing);
+ existing = LLVM::BitcastOp::create(rewriter, loc, intResultType, existing);
else
- existing = rewriter.create<LLVM::ZeroOp>(loc, intResultType);
+ existing = LLVM::ZeroOp::create(rewriter, loc, intResultType);
if (sourceVecType.getNumElements() < 2) {
Value c0 = createI32Constant(rewriter, loc, 0);
- Value elem0 = rewriter.create<LLVM::ExtractElementOp>(loc, source, c0);
+ Value elem0 = LLVM::ExtractElementOp::create(rewriter, loc, source, c0);
VectorType v2 = VectorType::get(2, sourceElemType);
- source = rewriter.create<LLVM::ZeroOp>(loc, v2);
- source = rewriter.create<LLVM::InsertElementOp>(loc, source, elem0, c0);
+ source = LLVM::ZeroOp::create(rewriter, loc, v2);
+ source = LLVM::InsertElementOp::create(rewriter, loc, source, elem0, c0);
}
Value sourceA, sourceB;
if (sourceElemType.isF32()) {
Value c0 = createI32Constant(rewriter, loc, 0);
Value c1 = createI32Constant(rewriter, loc, 1);
- sourceA = rewriter.create<LLVM::ExtractElementOp>(loc, source, c0);
- sourceB = rewriter.create<LLVM::ExtractElementOp>(loc, source, c1);
+ sourceA = LLVM::ExtractElementOp::create(rewriter, loc, source, c0);
+ sourceB = LLVM::ExtractElementOp::create(rewriter, loc, source, c1);
}
Value result;
if (sourceElemType.isF32() && isa<Float8E5M2Type>(resultElemType))
- result = rewriter.create<ROCDL::CvtScaleF32PkBf8F32Op>(
- loc, intResultType, existing, sourceA, sourceB, scale, op.getIndex());
+ result = ROCDL::CvtScaleF32PkBf8F32Op::create(rewriter, loc, intResultType,
+ existing, sourceA, sourceB,
+ scale, op.getIndex());
else if (sourceElemType.isF16() && isa<Float8E5M2Type>(resultElemType))
- result = rewriter.create<ROCDL::CvtScaleF32PkBf8F16Op>(
- loc, intResultType, existing, source, scale, op.getIndex());
+ result = ROCDL::CvtScaleF32PkBf8F16Op::create(
+ rewriter, loc, intResultType, existing, source, scale, op.getIndex());
else if (sourceElemType.isBF16() && isa<Float8E5M2Type>(resultElemType))
- result = rewriter.create<ROCDL::CvtScaleF32PkBf8Bf16Op>(
- loc, intResultType, existing, source, scale, op.getIndex());
+ result = ROCDL::CvtScaleF32PkBf8Bf16Op::create(
+ rewriter, loc, intResultType, existing, source, scale, op.getIndex());
else if (sourceElemType.isF32() && isa<Float8E4M3FNType>(resultElemType))
- result = rewriter.create<ROCDL::CvtScaleF32PkFp8F32Op>(
- loc, intResultType, existing, sourceA, sourceB, scale, op.getIndex());
+ result = ROCDL::CvtScaleF32PkFp8F32Op::create(rewriter, loc, intResultType,
+ existing, sourceA, sourceB,
+ scale, op.getIndex());
else if (sourceElemType.isF16() && isa<Float8E4M3FNType>(resultElemType))
- result = rewriter.create<ROCDL::CvtScaleF32PkFp8F16Op>(
- loc, intResultType, existing, source, scale, op.getIndex());
+ result = ROCDL::CvtScaleF32PkFp8F16Op::create(
+ rewriter, loc, intResultType, existing, source, scale, op.getIndex());
else if (sourceElemType.isBF16() && isa<Float8E4M3FNType>(resultElemType))
- result = rewriter.create<ROCDL::CvtScaleF32PkFp8Bf16Op>(
- loc, intResultType, existing, source, scale, op.getIndex());
+ result = ROCDL::CvtScaleF32PkFp8Bf16Op::create(
+ rewriter, loc, intResultType, existing, source, scale, op.getIndex());
else if (sourceElemType.isF32() && isa<Float4E2M1FNType>(resultElemType))
- result = rewriter.create<ROCDL::CvtScaleF32PkFp4F32Op>(
- loc, intResultType, existing, sourceA, sourceB, scale, op.getIndex());
+ result = ROCDL::CvtScaleF32PkFp4F32Op::create(rewriter, loc, intResultType,
+ existing, sourceA, sourceB,
+ scale, op.getIndex());
else if (sourceElemType.isF16() && isa<Float4E2M1FNType>(resultElemType))
- result = rewriter.create<ROCDL::CvtScaleF32PkFp4F16Op>(
- loc, intResultType, existing, source, scale, op.getIndex());
+ result = ROCDL::CvtScaleF32PkFp4F16Op::create(
+ rewriter, loc, intResultType, existing, source, scale, op.getIndex());
else if (sourceElemType.isBF16() && isa<Float4E2M1FNType>(resultElemType))
- result = rewriter.create<ROCDL::CvtScaleF32PkFp4Bf16Op>(
- loc, intResultType, existing, source, scale, op.getIndex());
+ result = ROCDL::CvtScaleF32PkFp4Bf16Op::create(
+ rewriter, loc, intResultType, existing, source, scale, op.getIndex());
else
return failure();
@@ -1526,20 +1532,20 @@ LogicalResult PackedTrunc2xFp8OpLowering::matchAndRewrite(
Value sourceA = adaptor.getSourceA();
Value sourceB = adaptor.getSourceB();
if (!sourceB)
- sourceB = rewriter.create<LLVM::UndefOp>(loc, sourceA.getType());
+ sourceB = LLVM::UndefOp::create(rewriter, loc, sourceA.getType());
Value existing = adaptor.getExisting();
if (existing)
- existing = rewriter.create<LLVM::BitcastOp>(loc, i32, existing);
+ existing = LLVM::BitcastOp::create(rewriter, loc, i32, existing);
else
- existing = rewriter.create<LLVM::UndefOp>(loc, i32);
+ existing = LLVM::UndefOp::create(rewriter, loc, i32);
Value result;
if (typeIsExpectedBf8ForChipset(chipset, resultElemType))
- result = rewriter.create<ROCDL::CvtPkBf8F32Op>(loc, i32, sourceA, sourceB,
- existing, op.getWordIndex());
+ result = ROCDL::CvtPkBf8F32Op::create(rewriter, loc, i32, sourceA, sourceB,
+ existing, op.getWordIndex());
else if (typeIsExpectedFp8ForChipset(chipset, resultElemType))
- result = rewriter.create<ROCDL::CvtPkFp8F32Op>(loc, i32, sourceA, sourceB,
- existing, op.getWordIndex());
+ result = ROCDL::CvtPkFp8F32Op::create(rewriter, loc, i32, sourceA, sourceB,
+ existing, op.getWordIndex());
result = rewriter.replaceOpWithNewOp<LLVM::BitcastOp>(
op, getTypeConverter()->convertType(resultType), result);
@@ -1563,17 +1569,17 @@ LogicalResult PackedStochRoundFp8OpLowering::matchAndRewrite(
Value stoch = adaptor.getStochiasticParam();
Value existing = adaptor.getExisting();
if (existing)
- existing = rewriter.create<LLVM::BitcastOp>(loc, i32, existing);
+ existing = LLVM::BitcastOp::create(rewriter, loc, i32, existing);
else
- existing = rewriter.create<LLVM::UndefOp>(loc, i32);
+ existing = LLVM::UndefOp::create(rewriter, loc, i32);
Value result;
if (typeIsExpectedBf8ForChipset(chipset, resultElemType))
- result = rewriter.create<ROCDL::CvtSrBf8F32Op>(
- loc, i32, source, stoch, existing, op.getStoreIndex());
+ result = ROCDL::CvtSrBf8F32Op::create(rewriter, loc, i32, source, stoch,
+ existing, op.getStoreIndex());
else if (typeIsExpectedFp8ForChipset(chipset, resultElemType))
- result = rewriter.create<ROCDL::CvtSrFp8F32Op>(
- loc, i32, source, stoch, existing, op.getStoreIndex());
+ result = ROCDL::CvtSrFp8F32Op::create(rewriter, loc, i32, source, stoch,
+ existing, op.getStoreIndex());
result = rewriter.replaceOpWithNewOp<LLVM::BitcastOp>(
op, getTypeConverter()->convertType(resultType), result);
@@ -1617,14 +1623,15 @@ struct AMDGPUDPPLowering : public ConvertOpToLLVMPattern<DPPOp> {
if (operandType.getIntOrFloatBitWidth() <= 16) {
if (llvm::isa<FloatType>(operandType)) {
operand =
- rewriter.create<LLVM::BitcastOp>(loc, llvmSrcIntType, operand);
+ LLVM::BitcastOp::create(rewriter, loc, llvmSrcIntType, operand);
}
auto llvmVecType = typeConverter->convertType(mlir::VectorType::get(
32 / operandType.getIntOrFloatBitWidth(), llvmSrcIntType));
- Value undefVec = rewriter.create<LLVM::UndefOp>(loc, llvmVecType);
- operand = rewriter.create<LLVM::InsertElementOp>(
- loc, undefVec, operand, createI32Constant(rewriter, loc, 0));
- operand = rewriter.create<LLVM::BitcastOp>(loc, llvmType, operand);
+ Value undefVec = LLVM::UndefOp::create(rewriter, loc, llvmVecType);
+ operand =
+ LLVM::InsertElementOp::create(rewriter, loc, undefVec, operand,
+ createI32Constant(rewriter, loc, 0));
+ operand = LLVM::BitcastOp::create(rewriter, loc, llvmType, operand);
}
return operand;
};
@@ -1711,14 +1718,15 @@ struct AMDGPUDPPLowering : public ConvertOpToLLVMPattern<DPPOp> {
bool boundCtrl = DppOp->getAttrOfType<BoolAttr>("bound_ctrl").getValue();
// create a ROCDL_DPPMovOp instruction with the appropriate attributes
- auto dppMovOp = rewriter.create<ROCDL::DPPUpdateOp>(
- loc, llvmType, old, src, DppCtrl, rowMask, bankMask, boundCtrl);
+ auto dppMovOp =
+ ROCDL::DPPUpdateOp::create(rewriter, loc, llvmType, old, src, DppCtrl,
+ rowMask, bankMask, boundCtrl);
Value result = dppMovOp.getRes();
if (srcType.getIntOrFloatBitWidth() < 32) {
- result = rewriter.create<LLVM::TruncOp>(loc, llvmSrcIntType, result);
+ result = LLVM::TruncOp::create(rewriter, loc, llvmSrcIntType, result);
if (!llvm::isa<IntegerType>(srcType)) {
- result = rewriter.create<LLVM::BitcastOp>(loc, srcType, result);
+ result = LLVM::BitcastOp::create(rewriter, loc, srcType, result);
}
}
@@ -1752,7 +1760,7 @@ struct AMDGPUSwizzleBitModeLowering
SmallVector<Value> swizzled;
for (Value v : decomposed) {
Value res =
- rewriter.create<ROCDL::DsSwizzleOp>(loc, v.getType(), v, maskValue);
+ ROCDL::DsSwizzleOp::create(rewriter, loc, v.getType(), v, maskValue);
swizzled.emplace_back(res);
}
diff --git a/mlir/lib/Conversion/AffineToStandard/AffineToStandard.cpp b/mlir/lib/Conversion/AffineToStandard/AffineToStandard.cpp
index 3b143ca1ef9ce..3b148f9021666 100644
--- a/mlir/lib/Conversion/AffineToStandard/AffineToStandard.cpp
+++ b/mlir/lib/Conversion/AffineToStandard/AffineToStandard.cpp
@@ -50,9 +50,9 @@ static Value buildMinMaxReductionSeq(Location loc,
Value value = *valueIt++;
for (; valueIt != values.end(); ++valueIt) {
if (predicate == arith::CmpIPredicate::sgt)
- value = builder.create<arith::MaxSIOp>(loc, value, *valueIt);
+ value = arith::MaxSIOp::create(builder, loc, value, *valueIt);
else
- value = builder.create<arith::MinSIOp>(loc, value, *valueIt);
+ value = arith::MinSIOp::create(builder, loc, value, *valueIt);
}
return value;
@@ -154,9 +154,9 @@ class AffineForLowering : public OpRewritePattern<AffineForOp> {
Value lowerBound = lowerAffineLowerBound(op, rewriter);
Value upperBound = lowerAffineUpperBound(op, rewriter);
Value step =
- rewriter.create<arith::ConstantIndexOp>(loc, op.getStepAsInt());
- auto scfForOp = rewriter.create<scf::ForOp>(loc, lowerBound, upperBound,
- step, op.getInits());
+ arith::ConstantIndexOp::create(rewriter, loc, op.getStepAsInt());
+ auto scfForOp = scf::ForOp::create(rewriter, loc, lowerBound, upperBound,
+ step, op.getInits());
rewriter.eraseBlock(scfForOp.getBody());
rewriter.inlineRegionBefore(op.getRegion(), scfForOp.getRegion(),
scfForOp.getRegion().end());
@@ -197,7 +197,7 @@ class AffineParallelLowering : public OpRewritePattern<AffineParallelOp> {
}
steps.reserve(op.getSteps().size());
for (int64_t step : op.getSteps())
- steps.push_back(rewriter.create<arith::ConstantIndexOp>(loc, step));
+ steps.push_back(arith::ConstantIndexOp::create(rewriter, loc, step));
// Get the terminator op.
auto affineParOpTerminator =
@@ -205,9 +205,9 @@ class AffineParallelLowering : public OpRewritePattern<AffineParallelOp> {
scf::ParallelOp parOp;
if (op.getResults().empty()) {
// Case with no reduction operations/return values.
- parOp = rewriter.create<scf::ParallelOp>(loc, lowerBoundTuple,
- upperBoundTuple, steps,
- /*bodyBuilderFn=*/nullptr);
+ parOp = scf::ParallelOp::create(rewriter, loc, lowerBoundTuple,
+ upperBoundTuple, steps,
+ /*bodyBuilderFn=*/nullptr);
rewriter.eraseBlock(parOp.getBody());
rewriter.inlineRegionBefore(op.getRegion(), parOp.getRegion(),
parOp.getRegion().end());
@@ -233,9 +233,9 @@ class AffineParallelLowering : public OpRewritePattern<AffineParallelOp> {
identityVals.push_back(
arith::getIdentityValue(reductionOpValue, resultType, rewriter, loc));
}
- parOp = rewriter.create<scf::ParallelOp>(
- loc, lowerBoundTuple, upperBoundTuple, steps, identityVals,
- /*bodyBuilderFn=*/nullptr);
+ parOp = scf::ParallelOp::create(rewriter, loc, lowerBoundTuple,
+ upperBoundTuple, steps, identityVals,
+ /*bodyBuilderFn=*/nullptr);
// Copy the body of the affine.parallel op.
rewriter.eraseBlock(parOp.getBody());
@@ -261,7 +261,7 @@ class AffineParallelLowering : public OpRewritePattern<AffineParallelOp> {
Value reductionResult = arith::getReductionOp(
reductionOpValue, rewriter, loc, reductionBody.getArgument(0),
reductionBody.getArgument(1));
- rewriter.create<scf::ReduceReturnOp>(loc, reductionResult);
+ scf::ReduceReturnOp::create(rewriter, loc, reductionResult);
}
rewriter.replaceOp(op, parOp.getResults());
return success();
@@ -278,7 +278,7 @@ class AffineIfLowering : public OpRewritePattern<AffineIfOp> {
// Now we just have to handle the condition logic.
auto integerSet = op.getIntegerSet();
- Value zeroConstant = rewriter.create<arith::ConstantIndexOp>(loc, 0);
+ Value zeroConstant = arith::ConstantIndexOp::create(rewriter, loc, 0);
SmallVector<Value, 8> operands(op.getOperands());
auto operandsRef = llvm::ArrayRef(operands);
@@ -298,18 +298,18 @@ class AffineIfLowering : public OpRewritePattern<AffineIfOp> {
auto pred =
isEquality ? arith::CmpIPredicate::eq : arith::CmpIPredicate::sge;
Value cmpVal =
- rewriter.create<arith::CmpIOp>(loc, pred, affResult, zeroConstant);
- cond = cond
- ? rewriter.create<arith::AndIOp>(loc, cond, cmpVal).getResult()
- : cmpVal;
+ arith::CmpIOp::create(rewriter, loc, pred, affResult, zeroConstant);
+ cond =
+ cond ? arith::AndIOp::create(rewriter, loc, cond, cmpVal).getResult()
+ : cmpVal;
}
cond = cond ? cond
- : rewriter.create<arith::ConstantIntOp>(loc, /*value=*/1,
- /*width=*/1);
+ : arith::ConstantIntOp::create(rewriter, loc, /*value=*/1,
+ /*width=*/1);
bool hasElseRegion = !op.getElseRegion().empty();
- auto ifOp = rewriter.create<scf::IfOp>(loc, op.getResultTypes(), cond,
- hasElseRegion);
+ auto ifOp = scf::IfOp::create(rewriter, loc, op.getResultTypes(), cond,
+ hasElseRegion);
rewriter.inlineRegionBefore(op.getThenRegion(),
&ifOp.getThenRegion().back());
rewriter.eraseBlock(&ifOp.getThenRegion().back());
diff --git a/mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp b/mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp
index 156c679c5039e..73a17b09721b2 100644
--- a/mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp
+++ b/mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp
@@ -115,9 +115,9 @@ static Value castF32To(Type desType, Value f32, Location loc,
if (elementType.isF32())
return f32;
if (elementType.getIntOrFloatBitWidth() < 32)
- return rewriter.create<arith::TruncFOp>(loc, desType, f32);
+ return arith::TruncFOp::create(rewriter, loc, desType, f32);
if (elementType.getIntOrFloatBitWidth() > 32)
- return rewriter.create<arith::ExtFOp>(loc, desType, f32);
+ return arith::ExtFOp::create(rewriter, loc, desType, f32);
llvm_unreachable("The only 32-bit float type is f32");
}
@@ -139,27 +139,27 @@ ExtFOnFloat8RewritePattern::matchAndRewrite(arith::ExtFOp op,
Type outElemType = getElementTypeOrSelf(op.getOut().getType());
VectorType extResType = VectorType::get(2, rewriter.getF32Type());
if (!inVecType) {
- Value asFloat = rewriter.create<amdgpu::ExtPackedFp8Op>(
- loc, rewriter.getF32Type(), in, 0);
+ Value asFloat = amdgpu::ExtPackedFp8Op::create(
+ rewriter, loc, rewriter.getF32Type(), in, 0);
Value result = castF32To(outElemType, asFloat, loc, rewriter);
rewriter.replaceOp(op, result);
return success();
}
int64_t numElements = inVecType.getNumElements();
- Value zero = rewriter.create<arith::ConstantOp>(
- loc, outElemType, rewriter.getFloatAttr(outElemType, 0.0));
+ Value zero = arith::ConstantOp::create(
+ rewriter, loc, outElemType, rewriter.getFloatAttr(outElemType, 0.0));
VectorType outType = cast<VectorType>(op.getOut().getType());
if (inVecType.getShape().empty()) {
Value zerodSplat =
rewriter.createOrFold<vector::BroadcastOp>(loc, outType, zero);
Value scalarIn =
- rewriter.create<vector::ExtractOp>(loc, in, ArrayRef<int64_t>{});
+ vector::ExtractOp::create(rewriter, loc, in, ArrayRef<int64_t>{});
Value scalarExt =
- rewriter.create<arith::ExtFOp>(loc, outElemType, scalarIn);
- Value result = rewriter.create<vector::InsertOp>(loc, scalarExt, zerodSplat,
- ArrayRef<int64_t>{});
+ arith::ExtFOp::create(rewriter, loc, outElemType, scalarIn);
+ Value result = vector::InsertOp::create(rewriter, loc, scalarExt,
+ zerodSplat, ArrayRef<int64_t>{});
rewriter.replaceOp(op, result);
return success();
}
@@ -171,32 +171,32 @@ ExtFOnFloat8RewritePattern::matchAndRewrite(arith::ExtFOp op,
if (inVecType.getRank() > 1) {
inVecType = VectorType::get(SmallVector<int64_t>{numElements},
inVecType.getElementType());
- in = rewriter.create<vector::ShapeCastOp>(loc, inVecType, in);
+ in = vector::ShapeCastOp::create(rewriter, loc, inVecType, in);
}
for (int64_t i = 0; i < numElements; i += 4) {
int64_t elemsThisOp = std::min(numElements, i + 4) - i;
- Value inSlice = rewriter.create<vector::ExtractStridedSliceOp>(
- loc, in, i, elemsThisOp, 1);
+ Value inSlice = vector::ExtractStridedSliceOp::create(rewriter, loc, in, i,
+ elemsThisOp, 1);
for (int64_t j = 0; j < elemsThisOp; j += 2) {
if (i + j + 1 < numElements) { // Convert two 8-bit elements
- Value asFloats = rewriter.create<amdgpu::ExtPackedFp8Op>(
- loc, extResType, inSlice, j / 2);
+ Value asFloats = amdgpu::ExtPackedFp8Op::create(
+ rewriter, loc, extResType, inSlice, j / 2);
Type desType = VectorType::get(2, outElemType);
Value asType = castF32To(desType, asFloats, loc, rewriter);
- result = rewriter.create<vector::InsertStridedSliceOp>(
- loc, asType, result, i + j, 1);
+ result = vector::InsertStridedSliceOp::create(rewriter, loc, asType,
+ result, i + j, 1);
} else { // Convert a 8-bit element
- Value asFloat = rewriter.create<amdgpu::ExtPackedFp8Op>(
- loc, rewriter.getF32Type(), inSlice, j / 2 * 2);
+ Value asFloat = amdgpu::ExtPackedFp8Op::create(
+ rewriter, loc, rewriter.getF32Type(), inSlice, j / 2 * 2);
Value asType = castF32To(outElemType, asFloat, loc, rewriter);
- result = rewriter.create<vector::InsertOp>(loc, asType, result, i + j);
+ result = vector::InsertOp::create(rewriter, loc, asType, result, i + j);
}
}
}
if (inVecType.getRank() != outType.getRank()) {
- result = rewriter.create<vector::ShapeCastOp>(loc, outType, result);
+ result = vector::ShapeCastOp::create(rewriter, loc, outType, result);
}
rewriter.replaceOp(op, result);
@@ -208,9 +208,9 @@ static Value castToF32(Value value, Location loc, PatternRewriter &rewriter) {
if (type.isF32())
return value;
if (type.getIntOrFloatBitWidth() < 32)
- return rewriter.create<arith::ExtFOp>(loc, rewriter.getF32Type(), value);
+ return arith::ExtFOp::create(rewriter, loc, rewriter.getF32Type(), value);
if (type.getIntOrFloatBitWidth() > 32)
- return rewriter.create<arith::TruncFOp>(loc, rewriter.getF32Type(), value);
+ return arith::TruncFOp::create(rewriter, loc, rewriter.getF32Type(), value);
llvm_unreachable("The only 32-bit float type is f32");
}
@@ -250,13 +250,15 @@ static Value clampInput(PatternRewriter &rewriter, Location loc,
loc, arith::CmpFPredicate::OEQ, source, negInf);
Value isNan = rewriter.createOrFold<arith::CmpFOp>(
loc, arith::CmpFPredicate::UNO, source, source);
- Value isNonFinite = rewriter.create<arith::OrIOp>(
- loc, rewriter.create<arith::OrIOp>(loc, isInf, isNegInf), isNan);
+ Value isNonFinite = arith::OrIOp::create(
+ rewriter, loc, arith::OrIOp::create(rewriter, loc, isInf, isNegInf),
+ isNan);
- Value clampedBelow = rewriter.create<arith::MaximumFOp>(loc, source, minCst);
- Value clamped = rewriter.create<arith::MinimumFOp>(loc, clampedBelow, maxCst);
+ Value clampedBelow = arith::MaximumFOp::create(rewriter, loc, source, minCst);
+ Value clamped =
+ arith::MinimumFOp::create(rewriter, loc, clampedBelow, maxCst);
Value res =
- rewriter.create<arith::SelectOp>(loc, isNonFinite, source, clamped);
+ arith::SelectOp::create(rewriter, loc, isNonFinite, source, clamped);
return res;
}
@@ -290,25 +292,25 @@ TruncFToFloat8RewritePattern::matchAndRewrite(arith::TruncFOp op,
VectorType truncResType = VectorType::get(4, outElemType);
if (!inVectorTy) {
Value asFloat = castToF32(in, loc, rewriter);
- Value asF8s = rewriter.create<amdgpu::PackedTrunc2xFp8Op>(
- loc, truncResType, asFloat, /*sourceB=*/nullptr, 0,
+ Value asF8s = amdgpu::PackedTrunc2xFp8Op::create(
+ rewriter, loc, truncResType, asFloat, /*sourceB=*/nullptr, 0,
/*existing=*/nullptr);
- Value result = rewriter.create<vector::ExtractOp>(loc, asF8s, 0);
+ Value result = vector::ExtractOp::create(rewriter, loc, asF8s, 0);
rewriter.replaceOp(op, result);
return success();
}
int64_t numElements = outVecType.getNumElements();
- Value zero = rewriter.create<arith::ConstantOp>(
- loc, outElemType, rewriter.getFloatAttr(outElemType, 0.0));
+ Value zero = arith::ConstantOp::create(
+ rewriter, loc, outElemType, rewriter.getFloatAttr(outElemType, 0.0));
if (outVecType.getShape().empty()) {
Value scalarIn =
- rewriter.create<vector::ExtractOp>(loc, in, ArrayRef<int64_t>{});
+ vector::ExtractOp::create(rewriter, loc, in, ArrayRef<int64_t>{});
// Recurse to send the 0-D vector case to the 1-D vector case
Value scalarTrunc =
- rewriter.create<arith::TruncFOp>(loc, outElemType, scalarIn);
- Value result = rewriter.create<vector::InsertOp>(loc, scalarTrunc, zero,
- ArrayRef<int64_t>{});
+ arith::TruncFOp::create(rewriter, loc, outElemType, scalarIn);
+ Value result = vector::InsertOp::create(rewriter, loc, scalarTrunc, zero,
+ ArrayRef<int64_t>{});
rewriter.replaceOp(op, result);
return success();
}
@@ -320,32 +322,32 @@ TruncFToFloat8RewritePattern::matchAndRewrite(arith::TruncFOp op,
if (inVectorTy.getRank() > 1) {
inVectorTy = VectorType::get(SmallVector<int64_t>{numElements},
inVectorTy.getElementType());
- in = rewriter.create<vector::ShapeCastOp>(loc, inVectorTy, in);
+ in = vector::ShapeCastOp::create(rewriter, loc, inVectorTy, in);
}
for (int64_t i = 0; i < numElements; i += 4) {
int64_t elemsThisOp = std::min(numElements, i + 4) - i;
Value thisResult = nullptr;
for (int64_t j = 0; j < elemsThisOp; j += 2) {
- Value elemA = rewriter.create<vector::ExtractOp>(loc, in, i + j);
+ Value elemA = vector::ExtractOp::create(rewriter, loc, in, i + j);
Value asFloatA = castToF32(elemA, loc, rewriter);
Value asFloatB = nullptr;
if (j + 1 < elemsThisOp) {
- Value elemB = rewriter.create<vector::ExtractOp>(loc, in, i + j + 1);
+ Value elemB = vector::ExtractOp::create(rewriter, loc, in, i + j + 1);
asFloatB = castToF32(elemB, loc, rewriter);
}
- thisResult = rewriter.create<amdgpu::PackedTrunc2xFp8Op>(
- loc, truncResType, asFloatA, asFloatB, j / 2, thisResult);
+ thisResult = amdgpu::PackedTrunc2xFp8Op::create(
+ rewriter, loc, truncResType, asFloatA, asFloatB, j / 2, thisResult);
}
if (elemsThisOp < 4)
- thisResult = rewriter.create<vector::ExtractStridedSliceOp>(
- loc, thisResult, 0, elemsThisOp, 1);
- result = rewriter.create<vector::InsertStridedSliceOp>(loc, thisResult,
- result, i, 1);
+ thisResult = vector::ExtractStridedSliceOp::create(
+ rewriter, loc, thisResult, 0, elemsThisOp, 1);
+ result = vector::InsertStridedSliceOp::create(rewriter, loc, thisResult,
+ result, i, 1);
}
if (inVectorTy.getRank() != outVecType.getRank()) {
- result = rewriter.create<vector::ShapeCastOp>(loc, outVecType, result);
+ result = vector::ShapeCastOp::create(rewriter, loc, outVecType, result);
}
rewriter.replaceOp(op, result);
@@ -373,10 +375,10 @@ LogicalResult TruncfToFloat16RewritePattern::matchAndRewrite(
// Handle the case where input type is not a vector type
if (!inVectorTy) {
- auto sourceB = rewriter.create<LLVM::PoisonOp>(loc, rewriter.getF32Type());
+ auto sourceB = LLVM::PoisonOp::create(rewriter, loc, rewriter.getF32Type());
Value asF16s =
- rewriter.create<ROCDL::CvtPkRtz>(loc, truncResType, in, sourceB);
- Value result = rewriter.create<vector::ExtractOp>(loc, asF16s, 0);
+ ROCDL::CvtPkRtz::create(rewriter, loc, truncResType, in, sourceB);
+ Value result = vector::ExtractOp::create(rewriter, loc, asF16s, 0);
rewriter.replaceOp(op, result);
return success();
}
@@ -389,7 +391,7 @@ LogicalResult TruncfToFloat16RewritePattern::matchAndRewrite(
if (inVectorTy.getRank() > 1) {
inVectorTy = VectorType::get(SmallVector<int64_t>{numElements},
inVectorTy.getElementType());
- in = rewriter.create<vector::ShapeCastOp>(loc, inVectorTy, in);
+ in = vector::ShapeCastOp::create(rewriter, loc, inVectorTy, in);
}
// Handle the vector case. We also handle the (uncommon) case where the vector
@@ -397,25 +399,25 @@ LogicalResult TruncfToFloat16RewritePattern::matchAndRewrite(
for (int64_t i = 0; i < numElements; i += 2) {
int64_t elemsThisOp = std::min(numElements, i + 2) - i;
Value thisResult = nullptr;
- Value elemA = rewriter.create<vector::ExtractOp>(loc, in, i);
- Value elemB = rewriter.create<LLVM::PoisonOp>(loc, rewriter.getF32Type());
+ Value elemA = vector::ExtractOp::create(rewriter, loc, in, i);
+ Value elemB = LLVM::PoisonOp::create(rewriter, loc, rewriter.getF32Type());
if (elemsThisOp == 2) {
- elemB = rewriter.create<vector::ExtractOp>(loc, in, i + 1);
+ elemB = vector::ExtractOp::create(rewriter, loc, in, i + 1);
}
thisResult =
- rewriter.create<ROCDL::CvtPkRtz>(loc, truncResType, elemA, elemB);
+ ROCDL::CvtPkRtz::create(rewriter, loc, truncResType, elemA, elemB);
// Place back the truncated result into the possibly larger vector. If we
// are operating on a size 2 vector, these operations should be folded away
- thisResult = rewriter.create<vector::ExtractStridedSliceOp>(
- loc, thisResult, 0, elemsThisOp, 1);
- result = rewriter.create<vector::InsertStridedSliceOp>(loc, thisResult,
- result, i, 1);
+ thisResult = vector::ExtractStridedSliceOp::create(
+ rewriter, loc, thisResult, 0, elemsThisOp, 1);
+ result = vector::InsertStridedSliceOp::create(rewriter, loc, thisResult,
+ result, i, 1);
}
if (inVectorTy.getRank() != outVecType.getRank()) {
- result = rewriter.create<vector::ShapeCastOp>(loc, outVecType, result);
+ result = vector::ShapeCastOp::create(rewriter, loc, outVecType, result);
}
rewriter.replaceOp(op, result);
@@ -472,18 +474,18 @@ ScalingExtFRewritePattern::matchAndRewrite(arith::ScalingExtFOp op,
Type scaleF32Type =
scaleVecType ? VectorType::get(scaleVecType.getShape(), f32) : f32;
if (scaleType.getIntOrFloatBitWidth() < 32)
- scale = rewriter.create<arith::ExtFOp>(loc, scaleF32Type, scale);
+ scale = arith::ExtFOp::create(rewriter, loc, scaleF32Type, scale);
else if (scaleType.getIntOrFloatBitWidth() > 32)
- scale = rewriter.create<arith::TruncFOp>(loc, scaleF32Type, scale);
+ scale = arith::TruncFOp::create(rewriter, loc, scaleF32Type, scale);
VectorType extScaleResultType = VectorType::get(opWidth, outType);
if (!outVecType) {
- Value inCast = rewriter.create<vector::BroadcastOp>(
- loc, VectorType::get(1, inType), in);
+ Value inCast = vector::BroadcastOp::create(rewriter, loc,
+ VectorType::get(1, inType), in);
// TODO: replace this with non-packed ScaledExtOp
- Value scaleExt = rewriter.create<amdgpu::ScaledExtPackedOp>(
- loc, extScaleResultType, inCast, scale, 0);
+ Value scaleExt = amdgpu::ScaledExtPackedOp::create(
+ rewriter, loc, extScaleResultType, inCast, scale, 0);
scaleExt = rewriter.replaceOpWithNewOp<vector::ExtractOp>(op, scaleExt, 0);
return success();
}
@@ -508,20 +510,20 @@ ScalingExtFRewritePattern::matchAndRewrite(arith::ScalingExtFOp op,
int64_t blockSize = computeProduct(ratio);
- Value zero = rewriter.create<arith::ConstantOp>(
- loc, outType, rewriter.getFloatAttr(outType, 0.0));
+ Value zero = arith::ConstantOp::create(rewriter, loc, outType,
+ rewriter.getFloatAttr(outType, 0.0));
Value result =
rewriter.createOrFold<vector::BroadcastOp>(loc, outVecType, zero);
for (SmallVector<int64_t> offsets : StaticTileOffsetRange(inShape, ratio)) {
SmallVector<int64_t> strides(offsets.size(), 1);
- Value block = rewriter.create<vector::ExtractStridedSliceOp>(
- loc, in, offsets, ratio, strides);
+ Value block = vector::ExtractStridedSliceOp::create(
+ rewriter, loc, in, offsets, ratio, strides);
VectorType block1DType = VectorType::get(blockSize, inType);
Value block1D =
- rewriter.create<vector::ShapeCastOp>(loc, block1DType, block);
+ vector::ShapeCastOp::create(rewriter, loc, block1DType, block);
Value uniformScale =
- rewriter.create<vector::ExtractOp>(loc, scale, offsets);
+ vector::ExtractOp::create(rewriter, loc, scale, offsets);
VectorType blockResultType = VectorType::get(blockSize, outType);
Value blockResult =
@@ -530,23 +532,23 @@ ScalingExtFRewritePattern::matchAndRewrite(arith::ScalingExtFOp op,
for (int64_t i = 0, sliceWidth = std::min(opWidth, blockSize - i);
i < blockSize;
i += sliceWidth, sliceWidth = std::min(opWidth, blockSize - i)) {
- Value slice = rewriter.create<vector::ExtractStridedSliceOp>(
- loc, block1D, i, sliceWidth, 1);
+ Value slice = vector::ExtractStridedSliceOp::create(
+ rewriter, loc, block1D, i, sliceWidth, 1);
// TODO: replace this with non-packed ScaledExtOp for sliceWidth == 1
- Value scaleExt = rewriter.create<amdgpu::ScaledExtPackedOp>(
- loc, extScaleResultType, slice, uniformScale, 0);
+ Value scaleExt = amdgpu::ScaledExtPackedOp::create(
+ rewriter, loc, extScaleResultType, slice, uniformScale, 0);
if (sliceWidth != opWidth)
- scaleExt = rewriter.create<vector::ExtractStridedSliceOp>(
- loc, scaleExt, 0, sliceWidth, 1);
- blockResult = rewriter.create<vector::InsertStridedSliceOp>(
- loc, scaleExt, blockResult, i, 1);
+ scaleExt = vector::ExtractStridedSliceOp::create(
+ rewriter, loc, scaleExt, 0, sliceWidth, 1);
+ blockResult = vector::InsertStridedSliceOp::create(
+ rewriter, loc, scaleExt, blockResult, i, 1);
}
VectorType resultType = VectorType::get(ratio, outType);
Value cast =
- rewriter.create<vector::ShapeCastOp>(loc, resultType, blockResult);
- result = rewriter.create<vector::InsertStridedSliceOp>(loc, cast, result,
- offsets, strides);
+ vector::ShapeCastOp::create(rewriter, loc, resultType, blockResult);
+ result = vector::InsertStridedSliceOp::create(rewriter, loc, cast, result,
+ offsets, strides);
}
rewriter.replaceOp(op, result);
@@ -578,21 +580,22 @@ ScalingTruncFRewritePattern::matchAndRewrite(arith::ScalingTruncFOp op,
Type scaleF32Type =
scaleVecType ? VectorType::get(scaleVecType.getShape(), f32) : f32;
if (scaleType.getIntOrFloatBitWidth() < 32)
- scale = rewriter.create<arith::ExtFOp>(loc, scaleF32Type, scale);
+ scale = arith::ExtFOp::create(rewriter, loc, scaleF32Type, scale);
else if (scaleType.getIntOrFloatBitWidth() > 32)
- scale = rewriter.create<arith::TruncFOp>(loc, scaleF32Type, scale);
+ scale = arith::TruncFOp::create(rewriter, loc, scaleF32Type, scale);
- Value zero = rewriter.create<arith::ConstantOp>(
- loc, outType, rewriter.getFloatAttr(outType, 0.0));
+ Value zero = arith::ConstantOp::create(rewriter, loc, outType,
+ rewriter.getFloatAttr(outType, 0.0));
unsigned numPackedElem = 32 / outType.getIntOrFloatBitWidth();
VectorType truncScaleResultType = VectorType::get(numPackedElem, outType);
if (!outVecType) {
Type inVecType = VectorType::get(1, inType);
- Value inCast = rewriter.create<vector::BroadcastOp>(loc, inVecType, in);
+ Value inCast = vector::BroadcastOp::create(rewriter, loc, inVecType, in);
// TODO: replace this with non-packed ScaledTruncOp
- Value scaleTrunc = rewriter.create<amdgpu::PackedScaledTruncOp>(
- loc, truncScaleResultType, inCast, scale, 0, /*existing=*/nullptr);
+ Value scaleTrunc = amdgpu::PackedScaledTruncOp::create(
+ rewriter, loc, truncScaleResultType, inCast, scale, 0,
+ /*existing=*/nullptr);
scaleTrunc =
rewriter.replaceOpWithNewOp<vector::ExtractOp>(op, scaleTrunc, 0);
return success();
@@ -623,13 +626,13 @@ ScalingTruncFRewritePattern::matchAndRewrite(arith::ScalingTruncFOp op,
for (SmallVector<int64_t> offsets : StaticTileOffsetRange(inShape, ratio)) {
SmallVector<int64_t> strides(offsets.size(), 1);
- Value block = rewriter.create<vector::ExtractStridedSliceOp>(
- loc, in, offsets, ratio, strides);
+ Value block = vector::ExtractStridedSliceOp::create(
+ rewriter, loc, in, offsets, ratio, strides);
VectorType block1DType = VectorType::get(blockSize, inType);
Value block1D =
- rewriter.create<vector::ShapeCastOp>(loc, block1DType, block);
+ vector::ShapeCastOp::create(rewriter, loc, block1DType, block);
Value uniformScale =
- rewriter.create<vector::ExtractOp>(loc, scale, offsets);
+ vector::ExtractOp::create(rewriter, loc, scale, offsets);
VectorType blockResultType = VectorType::get(blockSize, outType);
Value blockResult =
@@ -638,26 +641,26 @@ ScalingTruncFRewritePattern::matchAndRewrite(arith::ScalingTruncFOp op,
for (int64_t i = 0, sliceWidth = std::min(opWidth, blockSize - i);
i < blockSize;
i += sliceWidth, sliceWidth = std::min(opWidth, blockSize - i)) {
- Value slice = rewriter.create<vector::ExtractStridedSliceOp>(
- loc, block1D, i, sliceWidth, 1);
+ Value slice = vector::ExtractStridedSliceOp::create(
+ rewriter, loc, block1D, i, sliceWidth, 1);
// TODO: replace this with non-packed ScaledTruncOp for sliceWidth == 1
- Value scaleTrunc = rewriter.create<amdgpu::PackedScaledTruncOp>(
- loc, truncScaleResultType, slice, uniformScale, 0,
+ Value scaleTrunc = amdgpu::PackedScaledTruncOp::create(
+ rewriter, loc, truncScaleResultType, slice, uniformScale, 0,
/*existing=*/nullptr);
int64_t packedWidth =
cast<VectorType>(scaleTrunc.getType()).getNumElements();
if (packedWidth != opWidth)
- scaleTrunc = rewriter.create<vector::ExtractStridedSliceOp>(
- loc, scaleTrunc, 0, sliceWidth, 1);
- blockResult = rewriter.create<vector::InsertStridedSliceOp>(
- loc, scaleTrunc, blockResult, i, 1);
+ scaleTrunc = vector::ExtractStridedSliceOp::create(
+ rewriter, loc, scaleTrunc, 0, sliceWidth, 1);
+ blockResult = vector::InsertStridedSliceOp::create(
+ rewriter, loc, scaleTrunc, blockResult, i, 1);
}
VectorType resultType = VectorType::get(ratio, outType);
Value cast =
- rewriter.create<vector::ShapeCastOp>(loc, resultType, blockResult);
- result = rewriter.create<vector::InsertStridedSliceOp>(loc, cast, result,
- offsets, strides);
+ vector::ShapeCastOp::create(rewriter, loc, resultType, blockResult);
+ result = vector::InsertStridedSliceOp::create(rewriter, loc, cast, result,
+ offsets, strides);
}
rewriter.replaceOp(op, result);
diff --git a/mlir/lib/Conversion/ArithToArmSME/ArithToArmSME.cpp b/mlir/lib/Conversion/ArithToArmSME/ArithToArmSME.cpp
index cbe0b3fda3410..ba489436a1a4d 100644
--- a/mlir/lib/Conversion/ArithToArmSME/ArithToArmSME.cpp
+++ b/mlir/lib/Conversion/ArithToArmSME/ArithToArmSME.cpp
@@ -74,15 +74,15 @@ struct ConstantOpToArmSMELowering : public OpRewritePattern<arith::ConstantOp> {
VectorType tileSliceType = VectorType::Builder(tileType).dropDim(0);
auto denseAttr1D = DenseElementsAttr::get(
tileSliceType, denseAttr.getSplatValue<Attribute>());
- auto constantOp1D = rewriter.create<arith::ConstantOp>(loc, denseAttr1D);
+ auto constantOp1D = arith::ConstantOp::create(rewriter, loc, denseAttr1D);
- auto initTile = rewriter.create<arm_sme::GetTileOp>(loc, tileType);
+ auto initTile = arm_sme::GetTileOp::create(rewriter, loc, tileType);
auto makeLoopBody = [&](OpBuilder &b, Location loc, Value tileSliceIndex,
Value currentTile) {
// Create 'arm_sme.insert_tile_slice' to write vector to tile
// slice.
- auto nextTile = b.create<arm_sme::InsertTileSliceOp>(
- loc, tileType, constantOp1D, currentTile, tileSliceIndex);
+ auto nextTile = arm_sme::InsertTileSliceOp::create(
+ b, loc, tileType, constantOp1D, currentTile, tileSliceIndex);
return nextTile.getResult();
};
auto forOp = mlir::arm_sme::createLoopOverTileSlices(
diff --git a/mlir/lib/Conversion/ArithToEmitC/ArithToEmitC.cpp b/mlir/lib/Conversion/ArithToEmitC/ArithToEmitC.cpp
index a5c08a6378021..59b3fe2e4eaed 100644
--- a/mlir/lib/Conversion/ArithToEmitC/ArithToEmitC.cpp
+++ b/mlir/lib/Conversion/ArithToEmitC/ArithToEmitC.cpp
@@ -110,9 +110,9 @@ class CmpFOpConversion : public OpConversionPattern<arith::CmpFOp> {
emitc::CmpPredicate predicate;
switch (op.getPredicate()) {
case arith::CmpFPredicate::AlwaysFalse: {
- auto constant = rewriter.create<emitc::ConstantOp>(
- op.getLoc(), rewriter.getI1Type(),
- rewriter.getBoolAttr(/*value=*/false));
+ auto constant =
+ emitc::ConstantOp::create(rewriter, op.getLoc(), rewriter.getI1Type(),
+ rewriter.getBoolAttr(/*value=*/false));
rewriter.replaceOp(op, constant);
return success();
}
@@ -179,9 +179,9 @@ class CmpFOpConversion : public OpConversionPattern<arith::CmpFOp> {
return success();
}
case arith::CmpFPredicate::AlwaysTrue: {
- auto constant = rewriter.create<emitc::ConstantOp>(
- op.getLoc(), rewriter.getI1Type(),
- rewriter.getBoolAttr(/*value=*/true));
+ auto constant =
+ emitc::ConstantOp::create(rewriter, op.getLoc(), rewriter.getI1Type(),
+ rewriter.getBoolAttr(/*value=*/true));
rewriter.replaceOp(op, constant);
return success();
}
@@ -189,8 +189,8 @@ class CmpFOpConversion : public OpConversionPattern<arith::CmpFOp> {
// Compare the values naively
auto cmpResult =
- rewriter.create<emitc::CmpOp>(op.getLoc(), op.getType(), predicate,
- adaptor.getLhs(), adaptor.getRhs());
+ emitc::CmpOp::create(rewriter, op.getLoc(), op.getType(), predicate,
+ adaptor.getLhs(), adaptor.getRhs());
// Adjust the results for unordered/ordered semantics
if (unordered) {
@@ -213,16 +213,16 @@ class CmpFOpConversion : public OpConversionPattern<arith::CmpFOp> {
Value isNaN(ConversionPatternRewriter &rewriter, Location loc,
Value operand) const {
// A value is NaN exactly when it compares unequal to itself.
- return rewriter.create<emitc::CmpOp>(
- loc, rewriter.getI1Type(), emitc::CmpPredicate::ne, operand, operand);
+ return emitc::CmpOp::create(rewriter, loc, rewriter.getI1Type(),
+ emitc::CmpPredicate::ne, operand, operand);
}
/// Return a value that is true if \p operand is not NaN.
Value isNotNaN(ConversionPatternRewriter &rewriter, Location loc,
Value operand) const {
// A value is not NaN exactly when it compares equal to itself.
- return rewriter.create<emitc::CmpOp>(
- loc, rewriter.getI1Type(), emitc::CmpPredicate::eq, operand, operand);
+ return emitc::CmpOp::create(rewriter, loc, rewriter.getI1Type(),
+ emitc::CmpPredicate::eq, operand, operand);
}
/// Return a value that is true if the operands \p first and \p second are
@@ -231,8 +231,8 @@ class CmpFOpConversion : public OpConversionPattern<arith::CmpFOp> {
Location loc, Value first, Value second) const {
auto firstIsNaN = isNaN(rewriter, loc, first);
auto secondIsNaN = isNaN(rewriter, loc, second);
- return rewriter.create<emitc::LogicalOrOp>(loc, rewriter.getI1Type(),
- firstIsNaN, secondIsNaN);
+ return emitc::LogicalOrOp::create(rewriter, loc, rewriter.getI1Type(),
+ firstIsNaN, secondIsNaN);
}
/// Return a value that is true if the operands \p first and \p second are
@@ -241,8 +241,8 @@ class CmpFOpConversion : public OpConversionPattern<arith::CmpFOp> {
Value first, Value second) const {
auto firstIsNotNaN = isNotNaN(rewriter, loc, first);
auto secondIsNotNaN = isNotNaN(rewriter, loc, second);
- return rewriter.create<emitc::LogicalAndOp>(loc, rewriter.getI1Type(),
- firstIsNotNaN, secondIsNotNaN);
+ return emitc::LogicalAndOp::create(rewriter, loc, rewriter.getI1Type(),
+ firstIsNotNaN, secondIsNotNaN);
}
};
@@ -378,10 +378,10 @@ class CastConversion : public OpConversionPattern<ArithOp> {
Type attrType = (emitc::isPointerWideType(operandType))
? rewriter.getIndexType()
: operandType;
- auto constOne = rewriter.create<emitc::ConstantOp>(
- op.getLoc(), operandType, rewriter.getOneAttr(attrType));
- auto oneAndOperand = rewriter.create<emitc::BitwiseAndOp>(
- op.getLoc(), operandType, adaptor.getIn(), constOne);
+ auto constOne = emitc::ConstantOp::create(
+ rewriter, op.getLoc(), operandType, rewriter.getOneAttr(attrType));
+ auto oneAndOperand = emitc::BitwiseAndOp::create(
+ rewriter, op.getLoc(), operandType, adaptor.getIn(), constOne);
rewriter.replaceOpWithNewOp<emitc::CastOp>(op, opReturnType,
oneAndOperand);
return success();
@@ -466,9 +466,8 @@ class BinaryUIOpConversion final : public OpConversionPattern<ArithOp> {
Value lhsAdapted = adaptValueType(uiBinOp.getLhs(), rewriter, unsignedType);
Value rhsAdapted = adaptValueType(uiBinOp.getRhs(), rewriter, unsignedType);
- auto newDivOp =
- rewriter.create<EmitCOp>(uiBinOp.getLoc(), unsignedType,
- ArrayRef<Value>{lhsAdapted, rhsAdapted});
+ auto newDivOp = EmitCOp::create(rewriter, uiBinOp.getLoc(), unsignedType,
+ ArrayRef<Value>{lhsAdapted, rhsAdapted});
Value resultAdapted = adaptValueType(newDivOp, rewriter, newRetTy);
rewriter.replaceOp(uiBinOp, resultAdapted);
return success();
@@ -588,38 +587,40 @@ class ShiftOpConversion : public OpConversionPattern<ArithOp> {
// Add a runtime check for overflow
Value width;
if (emitc::isPointerWideType(type)) {
- Value eight = rewriter.create<emitc::ConstantOp>(
- op.getLoc(), rhsType, rewriter.getIndexAttr(8));
- emitc::CallOpaqueOp sizeOfCall = rewriter.create<emitc::CallOpaqueOp>(
- op.getLoc(), rhsType, "sizeof", ArrayRef<Value>{eight});
- width = rewriter.create<emitc::MulOp>(op.getLoc(), rhsType, eight,
- sizeOfCall.getResult(0));
+ Value eight = emitc::ConstantOp::create(rewriter, op.getLoc(), rhsType,
+ rewriter.getIndexAttr(8));
+ emitc::CallOpaqueOp sizeOfCall = emitc::CallOpaqueOp::create(
+ rewriter, op.getLoc(), rhsType, "sizeof", ArrayRef<Value>{eight});
+ width = emitc::MulOp::create(rewriter, op.getLoc(), rhsType, eight,
+ sizeOfCall.getResult(0));
} else {
- width = rewriter.create<emitc::ConstantOp>(
- op.getLoc(), rhsType,
+ width = emitc::ConstantOp::create(
+ rewriter, op.getLoc(), rhsType,
rewriter.getIntegerAttr(rhsType, type.getIntOrFloatBitWidth()));
}
- Value excessCheck = rewriter.create<emitc::CmpOp>(
- op.getLoc(), rewriter.getI1Type(), emitc::CmpPredicate::lt, rhs, width);
+ Value excessCheck =
+ emitc::CmpOp::create(rewriter, op.getLoc(), rewriter.getI1Type(),
+ emitc::CmpPredicate::lt, rhs, width);
// Any concrete value is a valid refinement of poison.
- Value poison = rewriter.create<emitc::ConstantOp>(
- op.getLoc(), arithmeticType,
+ Value poison = emitc::ConstantOp::create(
+ rewriter, op.getLoc(), arithmeticType,
(isa<IntegerType>(arithmeticType)
? rewriter.getIntegerAttr(arithmeticType, 0)
: rewriter.getIndexAttr(0)));
- emitc::ExpressionOp ternary = rewriter.create<emitc::ExpressionOp>(
- op.getLoc(), arithmeticType, /*do_not_inline=*/false);
+ emitc::ExpressionOp ternary = emitc::ExpressionOp::create(
+ rewriter, op.getLoc(), arithmeticType, /*do_not_inline=*/false);
Block &bodyBlock = ternary.getBodyRegion().emplaceBlock();
auto currentPoint = rewriter.getInsertionPoint();
rewriter.setInsertionPointToStart(&bodyBlock);
Value arithmeticResult =
- rewriter.create<EmitCOp>(op.getLoc(), arithmeticType, lhs, rhs);
- Value resultOrPoison = rewriter.create<emitc::ConditionalOp>(
- op.getLoc(), arithmeticType, excessCheck, arithmeticResult, poison);
- rewriter.create<emitc::YieldOp>(op.getLoc(), resultOrPoison);
+ EmitCOp::create(rewriter, op.getLoc(), arithmeticType, lhs, rhs);
+ Value resultOrPoison =
+ emitc::ConditionalOp::create(rewriter, op.getLoc(), arithmeticType,
+ excessCheck, arithmeticResult, poison);
+ emitc::YieldOp::create(rewriter, op.getLoc(), resultOrPoison);
rewriter.setInsertionPoint(op->getBlock(), currentPoint);
Value result = adaptValueType(ternary, rewriter, type);
@@ -700,11 +701,12 @@ class FtoICastOpConversion : public OpConversionPattern<CastOp> {
/*isSigned=*/false);
}
- Value result = rewriter.create<emitc::CastOp>(
- castOp.getLoc(), actualResultType, adaptor.getOperands());
+ Value result = emitc::CastOp::create(
+ rewriter, castOp.getLoc(), actualResultType, adaptor.getOperands());
if (isa<arith::FPToUIOp>(castOp)) {
- result = rewriter.create<emitc::CastOp>(castOp.getLoc(), dstType, result);
+ result =
+ emitc::CastOp::create(rewriter, castOp.getLoc(), dstType, result);
}
rewriter.replaceOp(castOp, result);
diff --git a/mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp b/mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp
index f7bf581adc9e3..18e857c81af8d 100644
--- a/mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp
+++ b/mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp
@@ -293,11 +293,11 @@ LogicalResult IndexCastOpLowering<OpTy, ExtCastTy>::matchAndRewrite(
[&](Type llvm1DVectorTy, ValueRange operands) -> Value {
typename OpTy::Adaptor adaptor(operands);
if (targetBits < sourceBits) {
- return rewriter.create<LLVM::TruncOp>(op.getLoc(), llvm1DVectorTy,
- adaptor.getIn());
+ return LLVM::TruncOp::create(rewriter, op.getLoc(), llvm1DVectorTy,
+ adaptor.getIn());
}
- return rewriter.create<ExtCastTy>(op.getLoc(), llvm1DVectorTy,
- adaptor.getIn());
+ return ExtCastTy::create(rewriter, op.getLoc(), llvm1DVectorTy,
+ adaptor.getIn());
},
rewriter);
}
@@ -324,12 +324,12 @@ LogicalResult AddUIExtendedOpLowering::matchAndRewrite(
Type newOverflowType = typeConverter->convertType(overflowResultType);
Type structType =
LLVM::LLVMStructType::getLiteral(ctx, {sumResultType, newOverflowType});
- Value addOverflow = rewriter.create<LLVM::UAddWithOverflowOp>(
- loc, structType, adaptor.getLhs(), adaptor.getRhs());
+ Value addOverflow = LLVM::UAddWithOverflowOp::create(
+ rewriter, loc, structType, adaptor.getLhs(), adaptor.getRhs());
Value sumExtracted =
- rewriter.create<LLVM::ExtractValueOp>(loc, addOverflow, 0);
+ LLVM::ExtractValueOp::create(rewriter, loc, addOverflow, 0);
Value overflowExtracted =
- rewriter.create<LLVM::ExtractValueOp>(loc, addOverflow, 1);
+ LLVM::ExtractValueOp::create(rewriter, loc, addOverflow, 1);
rewriter.replaceOp(op, {sumExtracted, overflowExtracted});
return success();
}
@@ -381,15 +381,15 @@ LogicalResult MulIExtendedOpLowering<ArithMulOp, IsSigned>::matchAndRewrite(
"LLVM dialect should support all signless integer types");
using LLVMExtOp = std::conditional_t<IsSigned, LLVM::SExtOp, LLVM::ZExtOp>;
- Value lhsExt = rewriter.create<LLVMExtOp>(loc, wideType, adaptor.getLhs());
- Value rhsExt = rewriter.create<LLVMExtOp>(loc, wideType, adaptor.getRhs());
- Value mulExt = rewriter.create<LLVM::MulOp>(loc, wideType, lhsExt, rhsExt);
+ Value lhsExt = LLVMExtOp::create(rewriter, loc, wideType, adaptor.getLhs());
+ Value rhsExt = LLVMExtOp::create(rewriter, loc, wideType, adaptor.getRhs());
+ Value mulExt = LLVM::MulOp::create(rewriter, loc, wideType, lhsExt, rhsExt);
// Split the 2*N-bit wide result into two N-bit values.
- Value low = rewriter.create<LLVM::TruncOp>(loc, resultType, mulExt);
- Value shiftVal = rewriter.create<LLVM::ConstantOp>(loc, shiftValAttr);
- Value highExt = rewriter.create<LLVM::LShrOp>(loc, mulExt, shiftVal);
- Value high = rewriter.create<LLVM::TruncOp>(loc, resultType, highExt);
+ Value low = LLVM::TruncOp::create(rewriter, loc, resultType, mulExt);
+ Value shiftVal = LLVM::ConstantOp::create(rewriter, loc, shiftValAttr);
+ Value highExt = LLVM::LShrOp::create(rewriter, loc, mulExt, shiftVal);
+ Value high = LLVM::TruncOp::create(rewriter, loc, resultType, highExt);
rewriter.replaceOp(op, {low, high});
return success();
@@ -435,8 +435,8 @@ CmpIOpLowering::matchAndRewrite(arith::CmpIOp op, OpAdaptor adaptor,
op.getOperation(), adaptor.getOperands(), *getTypeConverter(),
[&](Type llvm1DVectorTy, ValueRange operands) {
OpAdaptor adaptor(operands);
- return rewriter.create<LLVM::ICmpOp>(
- op.getLoc(), llvm1DVectorTy,
+ return LLVM::ICmpOp::create(
+ rewriter, op.getLoc(), llvm1DVectorTy,
convertCmpPredicate<LLVM::ICmpPredicate>(op.getPredicate()),
adaptor.getLhs(), adaptor.getRhs());
},
@@ -471,8 +471,8 @@ CmpFOpLowering::matchAndRewrite(arith::CmpFOp op, OpAdaptor adaptor,
op.getOperation(), adaptor.getOperands(), *getTypeConverter(),
[&](Type llvm1DVectorTy, ValueRange operands) {
OpAdaptor adaptor(operands);
- return rewriter.create<LLVM::FCmpOp>(
- op.getLoc(), llvm1DVectorTy,
+ return LLVM::FCmpOp::create(
+ rewriter, op.getLoc(), llvm1DVectorTy,
convertCmpPredicate<LLVM::FCmpPredicate>(op.getPredicate()),
adaptor.getLhs(), adaptor.getRhs(), fmf);
},
diff --git a/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp b/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp
index 434d7df853a5e..d43e6816641cb 100644
--- a/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp
+++ b/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp
@@ -117,12 +117,12 @@ static Value getScalarOrVectorConstInt(Type type, uint64_t value,
if (auto vectorType = dyn_cast<VectorType>(type)) {
Attribute element = IntegerAttr::get(vectorType.getElementType(), value);
auto attr = SplatElementsAttr::get(vectorType, element);
- return builder.create<spirv::ConstantOp>(loc, vectorType, attr);
+ return spirv::ConstantOp::create(builder, loc, vectorType, attr);
}
if (auto intType = dyn_cast<IntegerType>(type))
- return builder.create<spirv::ConstantOp>(
- loc, type, builder.getIntegerAttr(type, value));
+ return spirv::ConstantOp::create(builder, loc, type,
+ builder.getIntegerAttr(type, value));
return nullptr;
}
@@ -418,18 +418,19 @@ static Value emulateSignedRemainder(Location loc, Value lhs, Value rhs,
Type type = lhs.getType();
// Calculate the remainder with spirv.UMod.
- Value lhsAbs = builder.create<SignedAbsOp>(loc, type, lhs);
- Value rhsAbs = builder.create<SignedAbsOp>(loc, type, rhs);
- Value abs = builder.create<spirv::UModOp>(loc, lhsAbs, rhsAbs);
+ Value lhsAbs = SignedAbsOp::create(builder, loc, type, lhs);
+ Value rhsAbs = SignedAbsOp::create(builder, loc, type, rhs);
+ Value abs = spirv::UModOp::create(builder, loc, lhsAbs, rhsAbs);
// Fix the sign.
Value isPositive;
if (lhs == signOperand)
- isPositive = builder.create<spirv::IEqualOp>(loc, lhs, lhsAbs);
+ isPositive = spirv::IEqualOp::create(builder, loc, lhs, lhsAbs);
else
- isPositive = builder.create<spirv::IEqualOp>(loc, rhs, rhsAbs);
- Value absNegate = builder.create<spirv::SNegateOp>(loc, type, abs);
- return builder.create<spirv::SelectOp>(loc, type, isPositive, abs, absNegate);
+ isPositive = spirv::IEqualOp::create(builder, loc, rhs, rhsAbs);
+ Value absNegate = spirv::SNegateOp::create(builder, loc, type, abs);
+ return spirv::SelectOp::create(builder, loc, type, isPositive, abs,
+ absNegate);
}
/// Converts arith.remsi to GLSL SPIR-V ops.
@@ -601,13 +602,13 @@ struct ExtSII1Pattern final : public OpConversionPattern<arith::ExtSIOp> {
Value allOnes;
if (auto intTy = dyn_cast<IntegerType>(dstType)) {
unsigned componentBitwidth = intTy.getWidth();
- allOnes = rewriter.create<spirv::ConstantOp>(
- loc, intTy,
+ allOnes = spirv::ConstantOp::create(
+ rewriter, loc, intTy,
rewriter.getIntegerAttr(intTy, APInt::getAllOnes(componentBitwidth)));
} else if (auto vectorTy = dyn_cast<VectorType>(dstType)) {
unsigned componentBitwidth = vectorTy.getElementTypeBitWidth();
- allOnes = rewriter.create<spirv::ConstantOp>(
- loc, vectorTy,
+ allOnes = spirv::ConstantOp::create(
+ rewriter, loc, vectorTy,
SplatElementsAttr::get(vectorTy,
APInt::getAllOnes(componentBitwidth)));
} else {
@@ -653,8 +654,8 @@ struct ExtSIPattern final : public OpConversionPattern<arith::ExtSIOp> {
// First shift left to sequeeze out all leading bits beyond the original
// bitwidth. Here we need to use the original source and result type's
// bitwidth.
- auto shiftLOp = rewriter.create<spirv::ShiftLeftLogicalOp>(
- op.getLoc(), dstType, adaptor.getIn(), shiftSize);
+ auto shiftLOp = spirv::ShiftLeftLogicalOp::create(
+ rewriter, op.getLoc(), dstType, adaptor.getIn(), shiftSize);
// Then we perform arithmetic right shift to make sure we have the right
// sign bits for negative values.
@@ -757,9 +758,9 @@ struct TruncII1Pattern final : public OpConversionPattern<arith::TruncIOp> {
auto srcType = adaptor.getOperands().front().getType();
// Check if (x & 1) == 1.
Value mask = spirv::ConstantOp::getOne(srcType, loc, rewriter);
- Value maskedSrc = rewriter.create<spirv::BitwiseAndOp>(
- loc, srcType, adaptor.getOperands()[0], mask);
- Value isOne = rewriter.create<spirv::IEqualOp>(loc, maskedSrc, mask);
+ Value maskedSrc = spirv::BitwiseAndOp::create(
+ rewriter, loc, srcType, adaptor.getOperands()[0], mask);
+ Value isOne = spirv::IEqualOp::create(rewriter, loc, maskedSrc, mask);
Value zero = spirv::ConstantOp::getZero(dstType, loc, rewriter);
Value one = spirv::ConstantOp::getOne(dstType, loc, rewriter);
@@ -914,9 +915,9 @@ class CmpIOpBooleanPattern final : public OpConversionPattern<arith::CmpIOp> {
if (auto vectorType = dyn_cast<VectorType>(dstType))
type = VectorType::get(vectorType.getShape(), type);
Value extLhs =
- rewriter.create<arith::ExtUIOp>(op.getLoc(), type, adaptor.getLhs());
+ arith::ExtUIOp::create(rewriter, op.getLoc(), type, adaptor.getLhs());
Value extRhs =
- rewriter.create<arith::ExtUIOp>(op.getLoc(), type, adaptor.getRhs());
+ arith::ExtUIOp::create(rewriter, op.getLoc(), type, adaptor.getRhs());
rewriter.replaceOpWithNewOp<arith::CmpIOp>(op, op.getPredicate(), extLhs,
extRhs);
@@ -1067,12 +1068,12 @@ class CmpFOpNanNonePattern final : public OpConversionPattern<arith::CmpFOp> {
replace = spirv::ConstantOp::getZero(op.getType(), loc, rewriter);
}
} else {
- Value lhsIsNan = rewriter.create<spirv::IsNanOp>(loc, adaptor.getLhs());
- Value rhsIsNan = rewriter.create<spirv::IsNanOp>(loc, adaptor.getRhs());
+ Value lhsIsNan = spirv::IsNanOp::create(rewriter, loc, adaptor.getLhs());
+ Value rhsIsNan = spirv::IsNanOp::create(rewriter, loc, adaptor.getRhs());
- replace = rewriter.create<spirv::LogicalOrOp>(loc, lhsIsNan, rhsIsNan);
+ replace = spirv::LogicalOrOp::create(rewriter, loc, lhsIsNan, rhsIsNan);
if (op.getPredicate() == arith::CmpFPredicate::ORD)
- replace = rewriter.create<spirv::LogicalNotOp>(loc, replace);
+ replace = spirv::LogicalNotOp::create(rewriter, loc, replace);
}
rewriter.replaceOp(op, replace);
@@ -1094,17 +1095,17 @@ class AddUIExtendedOpPattern final
ConversionPatternRewriter &rewriter) const override {
Type dstElemTy = adaptor.getLhs().getType();
Location loc = op->getLoc();
- Value result = rewriter.create<spirv::IAddCarryOp>(loc, adaptor.getLhs(),
- adaptor.getRhs());
+ Value result = spirv::IAddCarryOp::create(rewriter, loc, adaptor.getLhs(),
+ adaptor.getRhs());
- Value sumResult = rewriter.create<spirv::CompositeExtractOp>(
- loc, result, llvm::ArrayRef(0));
- Value carryValue = rewriter.create<spirv::CompositeExtractOp>(
- loc, result, llvm::ArrayRef(1));
+ Value sumResult = spirv::CompositeExtractOp::create(rewriter, loc, result,
+ llvm::ArrayRef(0));
+ Value carryValue = spirv::CompositeExtractOp::create(rewriter, loc, result,
+ llvm::ArrayRef(1));
// Convert the carry value to boolean.
Value one = spirv::ConstantOp::getOne(dstElemTy, loc, rewriter);
- Value carryResult = rewriter.create<spirv::IEqualOp>(loc, carryValue, one);
+ Value carryResult = spirv::IEqualOp::create(rewriter, loc, carryValue, one);
rewriter.replaceOp(op, {sumResult, carryResult});
return success();
@@ -1125,12 +1126,12 @@ class MulIExtendedOpPattern final : public OpConversionPattern<ArithMulOp> {
ConversionPatternRewriter &rewriter) const override {
Location loc = op->getLoc();
Value result =
- rewriter.create<SPIRVMulOp>(loc, adaptor.getLhs(), adaptor.getRhs());
+ SPIRVMulOp::create(rewriter, loc, adaptor.getLhs(), adaptor.getRhs());
- Value low = rewriter.create<spirv::CompositeExtractOp>(loc, result,
- llvm::ArrayRef(0));
- Value high = rewriter.create<spirv::CompositeExtractOp>(loc, result,
- llvm::ArrayRef(1));
+ Value low = spirv::CompositeExtractOp::create(rewriter, loc, result,
+ llvm::ArrayRef(0));
+ Value high = spirv::CompositeExtractOp::create(rewriter, loc, result,
+ llvm::ArrayRef(1));
rewriter.replaceOp(op, {low, high});
return success();
@@ -1183,20 +1184,20 @@ class MinimumMaximumFOpPattern final : public OpConversionPattern<Op> {
Location loc = op.getLoc();
Value spirvOp =
- rewriter.create<SPIRVOp>(loc, dstType, adaptor.getOperands());
+ SPIRVOp::create(rewriter, loc, dstType, adaptor.getOperands());
if (bitEnumContainsAll(op.getFastmath(), arith::FastMathFlags::nnan)) {
rewriter.replaceOp(op, spirvOp);
return success();
}
- Value lhsIsNan = rewriter.create<spirv::IsNanOp>(loc, adaptor.getLhs());
- Value rhsIsNan = rewriter.create<spirv::IsNanOp>(loc, adaptor.getRhs());
+ Value lhsIsNan = spirv::IsNanOp::create(rewriter, loc, adaptor.getLhs());
+ Value rhsIsNan = spirv::IsNanOp::create(rewriter, loc, adaptor.getRhs());
- Value select1 = rewriter.create<spirv::SelectOp>(loc, dstType, lhsIsNan,
- adaptor.getLhs(), spirvOp);
- Value select2 = rewriter.create<spirv::SelectOp>(loc, dstType, rhsIsNan,
- adaptor.getRhs(), select1);
+ Value select1 = spirv::SelectOp::create(rewriter, loc, dstType, lhsIsNan,
+ adaptor.getLhs(), spirvOp);
+ Value select2 = spirv::SelectOp::create(rewriter, loc, dstType, rhsIsNan,
+ adaptor.getRhs(), select1);
rewriter.replaceOp(op, select2);
return success();
@@ -1237,7 +1238,7 @@ class MinNumMaxNumFOpPattern final : public OpConversionPattern<Op> {
Location loc = op.getLoc();
Value spirvOp =
- rewriter.create<SPIRVOp>(loc, dstType, adaptor.getOperands());
+ SPIRVOp::create(rewriter, loc, dstType, adaptor.getOperands());
if (!shouldInsertNanGuards<SPIRVOp>() ||
bitEnumContainsAll(op.getFastmath(), arith::FastMathFlags::nnan)) {
@@ -1245,13 +1246,13 @@ class MinNumMaxNumFOpPattern final : public OpConversionPattern<Op> {
return success();
}
- Value lhsIsNan = rewriter.create<spirv::IsNanOp>(loc, adaptor.getLhs());
- Value rhsIsNan = rewriter.create<spirv::IsNanOp>(loc, adaptor.getRhs());
+ Value lhsIsNan = spirv::IsNanOp::create(rewriter, loc, adaptor.getLhs());
+ Value rhsIsNan = spirv::IsNanOp::create(rewriter, loc, adaptor.getRhs());
- Value select1 = rewriter.create<spirv::SelectOp>(loc, dstType, lhsIsNan,
- adaptor.getRhs(), spirvOp);
- Value select2 = rewriter.create<spirv::SelectOp>(loc, dstType, rhsIsNan,
- adaptor.getLhs(), select1);
+ Value select1 = spirv::SelectOp::create(rewriter, loc, dstType, lhsIsNan,
+ adaptor.getRhs(), spirvOp);
+ Value select2 = spirv::SelectOp::create(rewriter, loc, dstType, rhsIsNan,
+ adaptor.getLhs(), select1);
rewriter.replaceOp(op, select2);
return success();
diff --git a/mlir/lib/Conversion/ArmNeon2dToIntr/ArmNeon2dToIntr.cpp b/mlir/lib/Conversion/ArmNeon2dToIntr/ArmNeon2dToIntr.cpp
index 9c6de938a7108..1510b0b16b07d 100644
--- a/mlir/lib/Conversion/ArmNeon2dToIntr/ArmNeon2dToIntr.cpp
+++ b/mlir/lib/Conversion/ArmNeon2dToIntr/ArmNeon2dToIntr.cpp
@@ -41,11 +41,11 @@ class Sdot2dLoweringPattern : public OpRewritePattern<Sdot2dOp> {
Value c2d = op.getC();
Location loc = op.getLoc();
Value b1d =
- rewriter.create<vector::ShapeCastOp>(loc, flattenedVectorType, b2d);
+ vector::ShapeCastOp::create(rewriter, loc, flattenedVectorType, b2d);
Value c1d =
- rewriter.create<vector::ShapeCastOp>(loc, flattenedVectorType, c2d);
- Value newOp = rewriter.create<SdotOp>(loc, op.getRes().getType(), op.getA(),
- b1d, c1d);
+ vector::ShapeCastOp::create(rewriter, loc, flattenedVectorType, c2d);
+ Value newOp = SdotOp::create(rewriter, loc, op.getRes().getType(),
+ op.getA(), b1d, c1d);
rewriter.replaceOp(op, {newOp});
return success();
}
diff --git a/mlir/lib/Conversion/ArmSMEToLLVM/ArmSMEToLLVM.cpp b/mlir/lib/Conversion/ArmSMEToLLVM/ArmSMEToLLVM.cpp
index 21ea444e31821..9bc3fa3473398 100644
--- a/mlir/lib/Conversion/ArmSMEToLLVM/ArmSMEToLLVM.cpp
+++ b/mlir/lib/Conversion/ArmSMEToLLVM/ArmSMEToLLVM.cpp
@@ -45,38 +45,38 @@ static Operation *createLoadTileSliceIntrinsic(
if (layout == arm_sme::TileSliceLayout::Horizontal) {
switch (type) {
case arm_sme::ArmSMETileType::ZAB:
- return rewriter.create<arm_sme::aarch64_sme_ld1b_horiz>(
- loc, maskOp, ptr, tileId, tileSliceI32);
+ return arm_sme::aarch64_sme_ld1b_horiz::create(rewriter, loc, maskOp, ptr,
+ tileId, tileSliceI32);
case arm_sme::ArmSMETileType::ZAH:
- return rewriter.create<arm_sme::aarch64_sme_ld1h_horiz>(
- loc, maskOp, ptr, tileId, tileSliceI32);
+ return arm_sme::aarch64_sme_ld1h_horiz::create(rewriter, loc, maskOp, ptr,
+ tileId, tileSliceI32);
case arm_sme::ArmSMETileType::ZAS:
- return rewriter.create<arm_sme::aarch64_sme_ld1w_horiz>(
- loc, maskOp, ptr, tileId, tileSliceI32);
+ return arm_sme::aarch64_sme_ld1w_horiz::create(rewriter, loc, maskOp, ptr,
+ tileId, tileSliceI32);
case arm_sme::ArmSMETileType::ZAD:
- return rewriter.create<arm_sme::aarch64_sme_ld1d_horiz>(
- loc, maskOp, ptr, tileId, tileSliceI32);
+ return arm_sme::aarch64_sme_ld1d_horiz::create(rewriter, loc, maskOp, ptr,
+ tileId, tileSliceI32);
case arm_sme::ArmSMETileType::ZAQ:
- return rewriter.create<arm_sme::aarch64_sme_ld1q_horiz>(
- loc, maskOp, ptr, tileId, tileSliceI32);
+ return arm_sme::aarch64_sme_ld1q_horiz::create(rewriter, loc, maskOp, ptr,
+ tileId, tileSliceI32);
}
} else {
switch (type) {
case arm_sme::ArmSMETileType::ZAB:
- return rewriter.create<arm_sme::aarch64_sme_ld1b_vert>(
- loc, maskOp, ptr, tileId, tileSliceI32);
+ return arm_sme::aarch64_sme_ld1b_vert::create(rewriter, loc, maskOp, ptr,
+ tileId, tileSliceI32);
case arm_sme::ArmSMETileType::ZAH:
- return rewriter.create<arm_sme::aarch64_sme_ld1h_vert>(
- loc, maskOp, ptr, tileId, tileSliceI32);
+ return arm_sme::aarch64_sme_ld1h_vert::create(rewriter, loc, maskOp, ptr,
+ tileId, tileSliceI32);
case arm_sme::ArmSMETileType::ZAS:
- return rewriter.create<arm_sme::aarch64_sme_ld1w_vert>(
- loc, maskOp, ptr, tileId, tileSliceI32);
+ return arm_sme::aarch64_sme_ld1w_vert::create(rewriter, loc, maskOp, ptr,
+ tileId, tileSliceI32);
case arm_sme::ArmSMETileType::ZAD:
- return rewriter.create<arm_sme::aarch64_sme_ld1d_vert>(
- loc, maskOp, ptr, tileId, tileSliceI32);
+ return arm_sme::aarch64_sme_ld1d_vert::create(rewriter, loc, maskOp, ptr,
+ tileId, tileSliceI32);
case arm_sme::ArmSMETileType::ZAQ:
- return rewriter.create<arm_sme::aarch64_sme_ld1q_vert>(
- loc, maskOp, ptr, tileId, tileSliceI32);
+ return arm_sme::aarch64_sme_ld1q_vert::create(rewriter, loc, maskOp, ptr,
+ tileId, tileSliceI32);
break;
}
}
@@ -91,38 +91,38 @@ static Operation *createStoreTileSliceIntrinsic(
if (layout == arm_sme::TileSliceLayout::Horizontal) {
switch (type) {
case arm_sme::ArmSMETileType::ZAB:
- return rewriter.create<arm_sme::aarch64_sme_st1b_horiz>(
- loc, maskOp, ptr, tileId, tileSliceI32);
+ return arm_sme::aarch64_sme_st1b_horiz::create(rewriter, loc, maskOp, ptr,
+ tileId, tileSliceI32);
case arm_sme::ArmSMETileType::ZAH:
- return rewriter.create<arm_sme::aarch64_sme_st1h_horiz>(
- loc, maskOp, ptr, tileId, tileSliceI32);
+ return arm_sme::aarch64_sme_st1h_horiz::create(rewriter, loc, maskOp, ptr,
+ tileId, tileSliceI32);
case arm_sme::ArmSMETileType::ZAS:
- return rewriter.create<arm_sme::aarch64_sme_st1w_horiz>(
- loc, maskOp, ptr, tileId, tileSliceI32);
+ return arm_sme::aarch64_sme_st1w_horiz::create(rewriter, loc, maskOp, ptr,
+ tileId, tileSliceI32);
case arm_sme::ArmSMETileType::ZAD:
- return rewriter.create<arm_sme::aarch64_sme_st1d_horiz>(
- loc, maskOp, ptr, tileId, tileSliceI32);
+ return arm_sme::aarch64_sme_st1d_horiz::create(rewriter, loc, maskOp, ptr,
+ tileId, tileSliceI32);
case arm_sme::ArmSMETileType::ZAQ:
- return rewriter.create<arm_sme::aarch64_sme_st1q_horiz>(
- loc, maskOp, ptr, tileId, tileSliceI32);
+ return arm_sme::aarch64_sme_st1q_horiz::create(rewriter, loc, maskOp, ptr,
+ tileId, tileSliceI32);
}
} else {
switch (type) {
case arm_sme::ArmSMETileType::ZAB:
- return rewriter.create<arm_sme::aarch64_sme_st1b_vert>(
- loc, maskOp, ptr, tileId, tileSliceI32);
+ return arm_sme::aarch64_sme_st1b_vert::create(rewriter, loc, maskOp, ptr,
+ tileId, tileSliceI32);
case arm_sme::ArmSMETileType::ZAH:
- return rewriter.create<arm_sme::aarch64_sme_st1h_vert>(
- loc, maskOp, ptr, tileId, tileSliceI32);
+ return arm_sme::aarch64_sme_st1h_vert::create(rewriter, loc, maskOp, ptr,
+ tileId, tileSliceI32);
case arm_sme::ArmSMETileType::ZAS:
- return rewriter.create<arm_sme::aarch64_sme_st1w_vert>(
- loc, maskOp, ptr, tileId, tileSliceI32);
+ return arm_sme::aarch64_sme_st1w_vert::create(rewriter, loc, maskOp, ptr,
+ tileId, tileSliceI32);
case arm_sme::ArmSMETileType::ZAD:
- return rewriter.create<arm_sme::aarch64_sme_st1d_vert>(
- loc, maskOp, ptr, tileId, tileSliceI32);
+ return arm_sme::aarch64_sme_st1d_vert::create(rewriter, loc, maskOp, ptr,
+ tileId, tileSliceI32);
case arm_sme::ArmSMETileType::ZAQ:
- return rewriter.create<arm_sme::aarch64_sme_st1q_vert>(
- loc, maskOp, ptr, tileId, tileSliceI32);
+ return arm_sme::aarch64_sme_st1q_vert::create(rewriter, loc, maskOp, ptr,
+ tileId, tileSliceI32);
}
}
llvm_unreachable("unknown type in createStoreTileSliceIntrinsic");
@@ -146,16 +146,16 @@ createAllocaForTile(RewriterBase &rewriter, Location loc,
// Move to the first operation in the function.
rewriter.setInsertionPointToStart(&func.getBlocks().front());
// Create an alloca matching the tile size of the `tileOp`.
- auto vscale = rewriter.create<vector::VectorScaleOp>(loc);
+ auto vscale = vector::VectorScaleOp::create(rewriter, loc);
auto tileElementType = tileOp.getTileType().getElementType();
auto memrefType = MemRefType::get(
{ShapedType::kDynamic, ShapedType::kDynamic}, tileElementType);
unsigned minElements = arm_sme::getSMETileSliceMinNumElts(tileElementType);
auto minElementsOp =
- rewriter.create<arith::ConstantIndexOp>(loc, minElements);
- auto vectorLen = rewriter.create<arith::MulIOp>(loc, vscale, minElementsOp);
- auto alloca = rewriter.create<memref::AllocaOp>(
- loc, memrefType, ValueRange{vectorLen, vectorLen});
+ arith::ConstantIndexOp::create(rewriter, loc, minElements);
+ auto vectorLen = arith::MulIOp::create(rewriter, loc, vscale, minElementsOp);
+ auto alloca = memref::AllocaOp::create(rewriter, loc, memrefType,
+ ValueRange{vectorLen, vectorLen});
return alloca;
}
@@ -293,10 +293,10 @@ struct ConvertArmSMESpillsAndFillsToLLVM : public ConvertToLLVMPattern {
Value tileMemory, Value sliceIndex) const {
auto llvmType = getTypeConverter()->convertType(tileMemory.getType());
auto descriptor =
- rewriter.create<UnrealizedConversionCastOp>(loc, llvmType, tileMemory);
- auto zero = rewriter.create<arith::ConstantIntOp>(loc, 0, /*width=*/64);
- auto sliceIndexI64 = rewriter.create<arith::IndexCastOp>(
- loc, rewriter.getI64Type(), sliceIndex);
+ UnrealizedConversionCastOp::create(rewriter, loc, llvmType, tileMemory);
+ auto zero = arith::ConstantIntOp::create(rewriter, loc, 0, /*width=*/64);
+ auto sliceIndexI64 = arith::IndexCastOp::create(
+ rewriter, loc, rewriter.getI64Type(), sliceIndex);
return getStridedElementPtr(
static_cast<ConversionPatternRewriter &>(rewriter), loc,
llvm::cast<MemRefType>(tileMemory.getType()), descriptor.getResult(0),
@@ -309,28 +309,29 @@ struct ConvertArmSMESpillsAndFillsToLLVM : public ConvertToLLVMPattern {
arm_sme::ArmSMETileType tileType, VectorType sliceType,
IntegerAttr tileId, Value sliceIndex) const {
// Cast the slice index to an i32.
- auto sliceIndexI32 = rewriter.create<arith::IndexCastOp>(
- loc, rewriter.getI32Type(), sliceIndex);
+ auto sliceIndexI32 = arith::IndexCastOp::create(
+ rewriter, loc, rewriter.getI32Type(), sliceIndex);
// Create an all-true predicate for the slice.
auto predicateType = sliceType.clone(rewriter.getI1Type());
- auto allTruePredicate = rewriter.create<arith::ConstantOp>(
- loc, DenseElementsAttr::get(predicateType, true));
+ auto allTruePredicate = arith::ConstantOp::create(
+ rewriter, loc, DenseElementsAttr::get(predicateType, true));
// Create padding vector (never used due to all-true predicate).
- auto padVector = rewriter.create<LLVM::PoisonOp>(loc, sliceType);
+ auto padVector = LLVM::PoisonOp::create(rewriter, loc, sliceType);
// Get a pointer to the current slice.
auto slicePtr =
getInMemoryTileSlicePtr(rewriter, loc, tileAlloca, sliceIndex);
// Read the value of the current slice from ZA.
- auto currentTileSlice = rewriter.create<arm_sme::aarch64_sme_read_horiz>(
- loc, sliceType, padVector, allTruePredicate, tileId, sliceIndexI32);
+ auto currentTileSlice = arm_sme::aarch64_sme_read_horiz::create(
+ rewriter, loc, sliceType, padVector, allTruePredicate, tileId,
+ sliceIndexI32);
// Load the new tile slice back from memory into ZA.
createLoadTileSliceIntrinsic(
rewriter, loc, tileType, arm_sme::TileSliceLayout::Horizontal,
allTruePredicate, slicePtr, tileId, sliceIndexI32);
// Store the current tile slice to memory.
- auto zero = rewriter.create<arith::ConstantIndexOp>(loc, 0);
- rewriter.create<vector::StoreOp>(loc, currentTileSlice, tileAlloca,
- ValueRange{sliceIndex, zero});
+ auto zero = arith::ConstantIndexOp::create(rewriter, loc, 0);
+ vector::StoreOp::create(rewriter, loc, currentTileSlice, tileAlloca,
+ ValueRange{sliceIndex, zero});
}
/// Emits a full in-place swap of the contents of a tile in ZA and a
@@ -341,12 +342,14 @@ struct ConvertArmSMESpillsAndFillsToLLVM : public ConvertToLLVMPattern {
RewriterBase::InsertionGuard guard(rewriter);
// Create an scf.for over all tile slices.
auto minNumElts =
- rewriter.create<arith::ConstantIndexOp>(loc, sliceType.getDimSize(0));
- auto lowerBound = rewriter.create<arith::ConstantIndexOp>(loc, 0);
- auto upperBound = rewriter.create<arith::MulIOp>(
- loc, minNumElts, rewriter.create<vector::VectorScaleOp>(loc));
- auto step = rewriter.create<arith::ConstantIndexOp>(loc, 1);
- auto forOp = rewriter.create<scf::ForOp>(loc, lowerBound, upperBound, step);
+ arith::ConstantIndexOp::create(rewriter, loc, sliceType.getDimSize(0));
+ auto lowerBound = arith::ConstantIndexOp::create(rewriter, loc, 0);
+ auto upperBound =
+ arith::MulIOp::create(rewriter, loc, minNumElts,
+ vector::VectorScaleOp::create(rewriter, loc));
+ auto step = arith::ConstantIndexOp::create(rewriter, loc, 1);
+ auto forOp =
+ scf::ForOp::create(rewriter, loc, lowerBound, upperBound, step);
// Emit a swap for each tile slice.
rewriter.setInsertionPointToStart(forOp.getBody());
auto sliceIndex = forOp.getInductionVar();
@@ -479,8 +482,8 @@ struct ZeroOpConversion : public ConvertArmSMEOpToLLVMPattern<arm_sme::ZeroOp> {
//
// This holds for all tile sizes.
int32_t zeroMask = baseMaskForSize << int32_t(tileId.getInt());
- rewriter.create<arm_sme::aarch64_sme_zero>(
- loc, rewriter.getI32IntegerAttr(zeroMask));
+ arm_sme::aarch64_sme_zero::create(rewriter, loc,
+ rewriter.getI32IntegerAttr(zeroMask));
// Create a placeholder op to preserve dataflow.
// Note: Place the `get_tile` op at the start of the block. This ensures
@@ -513,8 +516,8 @@ struct LoadTileSliceConversion
auto tileSlice = loadTileSliceOp.getTileSliceIndex();
// Cast tile slice to i32 for intrinsic.
- auto tileSliceI32 = rewriter.create<arith::IndexCastUIOp>(
- loc, rewriter.getI32Type(), tileSlice);
+ auto tileSliceI32 = arith::IndexCastUIOp::create(
+ rewriter, loc, rewriter.getI32Type(), tileSlice);
// Create all active predicate mask.
auto maskOp = loadTileSliceOp.getMask();
@@ -559,8 +562,8 @@ struct StoreTileSliceConversion
auto tileSlice = storeTileSliceOp.getTileSliceIndex();
// Cast tile slice to i32 for intrinsic.
- auto tileSliceI32 = rewriter.create<arith::IndexCastUIOp>(
- loc, rewriter.getI32Type(), tileSlice);
+ auto tileSliceI32 = arith::IndexCastUIOp::create(
+ rewriter, loc, rewriter.getI32Type(), tileSlice);
auto maskOp = storeTileSliceOp.getMask();
@@ -595,28 +598,28 @@ struct InsertTileSliceConversion
auto tileSlice = insertTileSliceOp.getTileSliceIndex();
// Cast tile slice from index to i32 for intrinsic.
- auto tileSliceI32 = rewriter.create<arith::IndexCastUIOp>(
- loc, rewriter.getI32Type(), tileSlice);
+ auto tileSliceI32 = arith::IndexCastUIOp::create(
+ rewriter, loc, rewriter.getI32Type(), tileSlice);
// Create all active predicate mask.
- auto one = rewriter.create<arith::ConstantOp>(
- loc, rewriter.getI1Type(),
+ auto one = arith::ConstantOp::create(
+ rewriter, loc, rewriter.getI1Type(),
rewriter.getIntegerAttr(rewriter.getI1Type(), 1));
auto predTy = VectorType::get(tileType.getShape()[0], rewriter.getI1Type(),
/*scalableDims=*/{true});
- auto allActiveMask = rewriter.create<vector::SplatOp>(loc, predTy, one);
+ auto allActiveMask = vector::SplatOp::create(rewriter, loc, predTy, one);
// Create 'arm_sme.intr.write.(horiz|vert)' to write vector to tile slice.
switch (insertTileSliceOp.getLayout()) {
case arm_sme::TileSliceLayout::Horizontal:
- rewriter.create<arm_sme::aarch64_sme_write_horiz>(
- loc, tileId, tileSliceI32, allActiveMask,
- insertTileSliceOp.getVector());
+ arm_sme::aarch64_sme_write_horiz::create(rewriter, loc, tileId,
+ tileSliceI32, allActiveMask,
+ insertTileSliceOp.getVector());
break;
case arm_sme::TileSliceLayout::Vertical:
- rewriter.create<arm_sme::aarch64_sme_write_vert>(
- loc, tileId, tileSliceI32, allActiveMask,
- insertTileSliceOp.getVector());
+ arm_sme::aarch64_sme_write_vert::create(rewriter, loc, tileId,
+ tileSliceI32, allActiveMask,
+ insertTileSliceOp.getVector());
break;
}
@@ -646,16 +649,16 @@ struct ExtractTileSliceConversion
// Create an 'all true' predicate for the tile slice.
auto predicateType = sliceType.cloneWith({}, rewriter.getI1Type());
- auto allTruePredicate = rewriter.create<arith::ConstantOp>(
- loc, DenseElementsAttr::get(predicateType, true));
+ auto allTruePredicate = arith::ConstantOp::create(
+ rewriter, loc, DenseElementsAttr::get(predicateType, true));
// Zero destination/fallback for tile slice extraction.
- auto zeroVector = rewriter.create<arith::ConstantOp>(
- loc, sliceType, rewriter.getZeroAttr(sliceType));
+ auto zeroVector = arith::ConstantOp::create(
+ rewriter, loc, sliceType, rewriter.getZeroAttr(sliceType));
// Cast tile slice from index to i32 for intrinsic.
- auto sliceIndexI32 = rewriter.create<arith::IndexCastOp>(
- loc, rewriter.getI32Type(), sliceIndex);
+ auto sliceIndexI32 = arith::IndexCastOp::create(
+ rewriter, loc, rewriter.getI32Type(), sliceIndex);
// Create 'arm_sme.intr.read.(horiz|vert)' to extract the tile slice.
switch (extractTileSlice.getLayout()) {
@@ -743,7 +746,7 @@ struct OuterProductOpConversion
Value acc = outerProductOp.getAcc();
if (!acc) {
// Initalize accumulator with zero.
- auto zero = rewriter.create<arm_sme::ZeroOp>(loc, resultVectorType);
+ auto zero = arm_sme::ZeroOp::create(rewriter, loc, resultVectorType);
zero.setTileId(tileId);
acc = zero;
}
@@ -754,16 +757,16 @@ struct OuterProductOpConversion
if (!lhsMask || !rhsMask) {
auto predTy =
outerProductOp.getLhsType().cloneWith({}, rewriter.getI1Type());
- Value allActiveMask = rewriter.create<arith::ConstantOp>(
- loc, DenseElementsAttr::get(predTy, true));
+ Value allActiveMask = arith::ConstantOp::create(
+ rewriter, loc, DenseElementsAttr::get(predTy, true));
lhsMask = allActiveMask;
rhsMask = allActiveMask;
}
// Create 'arm_sme.intr.mopa' outer product intrinsic.
- rewriter.create<arm_sme::aarch64_sme_mopa>(loc, tileId, lhsMask, rhsMask,
- outerProductOp.getLhs(),
- outerProductOp.getRhs());
+ arm_sme::aarch64_sme_mopa::create(rewriter, loc, tileId, lhsMask, rhsMask,
+ outerProductOp.getLhs(),
+ outerProductOp.getRhs());
// The outerproduct intrinsics have no result, replace
// 'arm_sme.outerproduct' with the input tile to preserve dataflow.
@@ -792,7 +795,7 @@ struct OuterProductWideningOpConversion
Value acc = op.getAcc();
if (!acc) {
// Initalize accumulator with zero.
- auto zero = rewriter.create<arm_sme::ZeroOp>(loc, op.getResultType());
+ auto zero = arm_sme::ZeroOp::create(rewriter, loc, op.getResultType());
zero.setTileId(tileId);
acc = zero;
}
@@ -801,14 +804,14 @@ struct OuterProductWideningOpConversion
Value rhsMask = op.getRhsMask();
if (!lhsMask || !rhsMask) {
auto predTy = op.getLhsType().cloneWith({}, rewriter.getI1Type());
- Value allActiveMask = rewriter.create<arith::ConstantOp>(
- loc, DenseElementsAttr::get(predTy, true));
+ Value allActiveMask = arith::ConstantOp::create(
+ rewriter, loc, DenseElementsAttr::get(predTy, true));
lhsMask = allActiveMask;
rhsMask = allActiveMask;
}
- rewriter.create<OuterProductWideningIntrOp>(
- loc, tileId, lhsMask, rhsMask, adaptor.getLhs(), adaptor.getRhs());
+ OuterProductWideningIntrOp::create(rewriter, loc, tileId, lhsMask, rhsMask,
+ adaptor.getLhs(), adaptor.getRhs());
// The outerproduct intrinsics have no result, replace
// 'arm_sme.outerproduct' with the input tile to preserve dataflow.
@@ -843,13 +846,13 @@ struct StreamingVLOpConversion
auto *intrOp = [&]() -> Operation * {
switch (streamingVlOp.getTypeSize()) {
case arm_sme::TypeSize::Byte:
- return rewriter.create<arm_sme::aarch64_sme_cntsb>(loc, i64Type);
+ return arm_sme::aarch64_sme_cntsb::create(rewriter, loc, i64Type);
case arm_sme::TypeSize::Half:
- return rewriter.create<arm_sme::aarch64_sme_cntsh>(loc, i64Type);
+ return arm_sme::aarch64_sme_cntsh::create(rewriter, loc, i64Type);
case arm_sme::TypeSize::Word:
- return rewriter.create<arm_sme::aarch64_sme_cntsw>(loc, i64Type);
+ return arm_sme::aarch64_sme_cntsw::create(rewriter, loc, i64Type);
case arm_sme::TypeSize::Double:
- return rewriter.create<arm_sme::aarch64_sme_cntsd>(loc, i64Type);
+ return arm_sme::aarch64_sme_cntsd::create(rewriter, loc, i64Type);
}
llvm_unreachable("unknown type size in StreamingVLOpConversion");
}();
@@ -872,8 +875,8 @@ static void mergeConsecutiveTileZerosInBlock(Block *block) {
if (zeroOpsToMerge.size() <= 1)
return;
IRRewriter rewriter(zeroOpsToMerge.front());
- rewriter.create<arm_sme::aarch64_sme_zero>(
- zeroOpsToMerge.front().getLoc(),
+ arm_sme::aarch64_sme_zero::create(
+ rewriter, zeroOpsToMerge.front().getLoc(),
rewriter.getI32IntegerAttr(mergedZeroMask));
for (auto zeroOp : zeroOpsToMerge)
rewriter.eraseOp(zeroOp);
diff --git a/mlir/lib/Conversion/ArmSMEToSCF/ArmSMEToSCF.cpp b/mlir/lib/Conversion/ArmSMEToSCF/ArmSMEToSCF.cpp
index 458628c29c6ac..9a37b30c14813 100644
--- a/mlir/lib/Conversion/ArmSMEToSCF/ArmSMEToSCF.cpp
+++ b/mlir/lib/Conversion/ArmSMEToSCF/ArmSMEToSCF.cpp
@@ -39,7 +39,7 @@ SmallVector<Value, 2> getMemrefIndices(ValueRange indices, unsigned rank,
auto tileSliceOffset = tileSliceIndex;
auto baseIndexPlusTileSliceOffset =
- rewriter.create<arith::AddIOp>(loc, indices[0], tileSliceOffset);
+ arith::AddIOp::create(rewriter, loc, indices[0], tileSliceOffset);
outIndices.push_back(baseIndexPlusTileSliceOffset);
outIndices.push_back(indices[1]);
@@ -59,10 +59,11 @@ FailureOr<scf::ForOp> createLoadStoreForOverTileSlices(
if (memrefIndices.size() != 2)
return rewriter.notifyMatchFailure(loc, "invalid number of indices");
- auto minTileSlices = rewriter.create<arith::ConstantIndexOp>(
- loc, arm_sme::getSMETileSliceMinNumElts(tileType.getElementType()));
+ auto minTileSlices = arith::ConstantIndexOp::create(
+ rewriter, loc,
+ arm_sme::getSMETileSliceMinNumElts(tileType.getElementType()));
auto vscale =
- rewriter.create<vector::VectorScaleOp>(loc, rewriter.getIndexType());
+ vector::VectorScaleOp::create(rewriter, loc, rewriter.getIndexType());
auto predicateType =
VectorType::get(tileType.getDimSize(1), rewriter.getI1Type(), true);
@@ -70,7 +71,7 @@ FailureOr<scf::ForOp> createLoadStoreForOverTileSlices(
// elements in a vector of SVL bits for a given element type (SVL_B,
// SVL_H, ..., SVL_Q).
auto numTileSlices =
- rewriter.create<arith::MulIOp>(loc, minTileSlices, vscale);
+ arith::MulIOp::create(rewriter, loc, minTileSlices, vscale);
Value predicate;
Value upperBound;
@@ -82,30 +83,30 @@ FailureOr<scf::ForOp> createLoadStoreForOverTileSlices(
// The upper bound of the loop must be clamped at `numTileSlices` as
// `vector.create_mask` allows operands to be greater than the size of a
// dimension.
- auto numRowI64 = rewriter.create<arith::IndexCastOp>(
- loc, rewriter.getI64Type(), maskDim0);
- auto numTileSlicesI64 = rewriter.create<arith::IndexCastOp>(
- loc, rewriter.getI64Type(), numTileSlices);
+ auto numRowI64 = arith::IndexCastOp::create(
+ rewriter, loc, rewriter.getI64Type(), maskDim0);
+ auto numTileSlicesI64 = arith::IndexCastOp::create(
+ rewriter, loc, rewriter.getI64Type(), numTileSlices);
auto upperBoundI64 =
- rewriter.create<arith::MinSIOp>(loc, numRowI64, numTileSlicesI64);
- upperBound = rewriter.create<arith::IndexCastOp>(
- loc, rewriter.getIndexType(), upperBoundI64);
+ arith::MinSIOp::create(rewriter, loc, numRowI64, numTileSlicesI64);
+ upperBound = arith::IndexCastOp::create(
+ rewriter, loc, rewriter.getIndexType(), upperBoundI64);
predicate =
- rewriter.create<vector::CreateMaskOp>(loc, predicateType, maskDim1);
+ vector::CreateMaskOp::create(rewriter, loc, predicateType, maskDim1);
} else {
upperBound = numTileSlices;
// No mask. Create an 'all true' predicate for the tile slice.
- predicate = rewriter.create<arith::ConstantOp>(
- loc, DenseElementsAttr::get(predicateType, true));
+ predicate = arith::ConstantOp::create(
+ rewriter, loc, DenseElementsAttr::get(predicateType, true));
}
bool hasCarriedArgs = bool(initTile);
- auto lowerBound = rewriter.create<arith::ConstantIndexOp>(loc, 0);
- auto step = rewriter.create<arith::ConstantIndexOp>(loc, 1);
- auto forOp = rewriter.create<scf::ForOp>(loc, lowerBound, upperBound, step,
- hasCarriedArgs ? ValueRange{initTile}
- : ValueRange{});
+ auto lowerBound = arith::ConstantIndexOp::create(rewriter, loc, 0);
+ auto step = arith::ConstantIndexOp::create(rewriter, loc, 1);
+ auto forOp =
+ scf::ForOp::create(rewriter, loc, lowerBound, upperBound, step,
+ hasCarriedArgs ? ValueRange{initTile} : ValueRange{});
rewriter.setInsertionPointToStart(forOp.getBody());
Value tileSliceIndex = forOp.getInductionVar();
@@ -118,7 +119,7 @@ FailureOr<scf::ForOp> createLoadStoreForOverTileSlices(
assert(bool(nextTile) == hasCarriedArgs);
if (nextTile)
- rewriter.create<scf::YieldOp>(loc, nextTile);
+ scf::YieldOp::create(rewriter, loc, nextTile);
return forOp;
}
@@ -194,9 +195,9 @@ struct TileLoadOpConversion : public OpRewritePattern<arm_sme::TileLoadOp> {
// Initialize tile with zero to satisfy padding. Inactive cols will be
// zeroed anyway since the loads use zeroing predication. For inactive
// rows however, no load will occur so these need to be zeroed.
- initTile = rewriter.create<arm_sme::ZeroOp>(loc, tileType);
+ initTile = arm_sme::ZeroOp::create(rewriter, loc, tileType);
} else {
- initTile = rewriter.create<arm_sme::GetTileOp>(loc, tileType);
+ initTile = arm_sme::GetTileOp::create(rewriter, loc, tileType);
}
// Create a loop to load the active tile slices from memory.
@@ -207,9 +208,10 @@ struct TileLoadOpConversion : public OpRewritePattern<arm_sme::TileLoadOp> {
Value currentTile) -> Value {
// Create 'arm_sme.load_tile_slice' to load tile slice from memory
// into tile.
- return rewriter.create<arm_sme::LoadTileSliceOp>(
- loc, tileType, tileLoadOp.getBase(), predicate, currentTile,
- memrefIndices, tileSliceIndex, tileLoadOp.getLayout());
+ return arm_sme::LoadTileSliceOp::create(
+ rewriter, loc, tileType, tileLoadOp.getBase(), predicate,
+ currentTile, memrefIndices, tileSliceIndex,
+ tileLoadOp.getLayout());
});
if (failed(forOp))
@@ -283,22 +285,22 @@ struct TileLoadOpWithMaskAndPadNonZeroConversion
auto numRows = createMaskOp.getOperands()[0];
auto numCols = createMaskOp.getOperands()[1];
- auto numColsI32 = rewriter.create<arith::IndexCastUIOp>(
- loc, rewriter.getI32Type(), numCols);
+ auto numColsI32 = arith::IndexCastUIOp::create(
+ rewriter, loc, rewriter.getI32Type(), numCols);
- auto initTile = rewriter.create<arm_sme::GetTileOp>(loc, tileType);
+ auto initTile = arm_sme::GetTileOp::create(rewriter, loc, tileType);
// Create a loop that loads each ZA tile slice from memory.
- auto step = rewriter.create<arith::ConstantIndexOp>(loc, 1);
- auto minTileSlices = rewriter.create<arith::ConstantIndexOp>(
- loc, arm_sme::getSMETileSliceMinNumElts(tileElementType));
+ auto step = arith::ConstantIndexOp::create(rewriter, loc, 1);
+ auto minTileSlices = arith::ConstantIndexOp::create(
+ rewriter, loc, arm_sme::getSMETileSliceMinNumElts(tileElementType));
auto vscale =
- rewriter.create<vector::VectorScaleOp>(loc, rewriter.getIndexType());
- auto lowerBound = rewriter.create<arith::ConstantIndexOp>(loc, 0);
+ vector::VectorScaleOp::create(rewriter, loc, rewriter.getIndexType());
+ auto lowerBound = arith::ConstantIndexOp::create(rewriter, loc, 0);
auto numTileSlices =
- rewriter.create<arith::MulIOp>(loc, minTileSlices, vscale);
- auto forOp = rewriter.create<scf::ForOp>(loc, lowerBound, numTileSlices,
- step, ValueRange{initTile});
+ arith::MulIOp::create(rewriter, loc, minTileSlices, vscale);
+ auto forOp = scf::ForOp::create(rewriter, loc, lowerBound, numTileSlices,
+ step, ValueRange{initTile});
rewriter.setInsertionPointToStart(forOp.getBody());
@@ -306,17 +308,18 @@ struct TileLoadOpWithMaskAndPadNonZeroConversion
auto currentTile = forOp.getRegionIterArg(0);
// Combine masks.
- auto rowIsActive = rewriter.create<arith::CmpIOp>(
- loc, arith::CmpIPredicate::ult, tileSliceIndex, numRows);
- auto rowIsActiveI32 = rewriter.create<arith::ExtSIOp>(
- loc, rewriter.getI32Type(), rowIsActive);
- auto mask = rewriter.create<arith::AndIOp>(loc, rowIsActiveI32, numColsI32);
- auto maskIndex =
- rewriter.create<arith::IndexCastOp>(loc, rewriter.getIndexType(), mask);
+ auto rowIsActive = arith::CmpIOp::create(
+ rewriter, loc, arith::CmpIPredicate::ult, tileSliceIndex, numRows);
+ auto rowIsActiveI32 = arith::ExtSIOp::create(
+ rewriter, loc, rewriter.getI32Type(), rowIsActive);
+ auto mask =
+ arith::AndIOp::create(rewriter, loc, rowIsActiveI32, numColsI32);
+ auto maskIndex = arith::IndexCastOp::create(rewriter, loc,
+ rewriter.getIndexType(), mask);
auto predicateType =
VectorType::get(tileType.getDimSize(1), rewriter.getI1Type(), true);
- auto maskOp1D = rewriter.create<vector::CreateMaskOp>(
- loc, predicateType, maskIndex.getResult());
+ auto maskOp1D = vector::CreateMaskOp::create(rewriter, loc, predicateType,
+ maskIndex.getResult());
auto memrefIndices = getMemrefIndices(
tileLoadOp.getIndices(), tileLoadOp.getMemRefType().getRank(),
@@ -324,17 +327,18 @@ struct TileLoadOpWithMaskAndPadNonZeroConversion
// Splat pad into 1-D vector matching type of tile slice.
VectorType tileSliceType = VectorType::Builder(tileType).dropDim(0);
- auto pad1DOp = rewriter.create<vector::SplatOp>(loc, tileSliceType, padOp);
+ auto pad1DOp = vector::SplatOp::create(rewriter, loc, tileSliceType, padOp);
- auto loadSlice = rewriter.create<vector::MaskedLoadOp>(
- loc, tileSliceType, tileLoadOp.getBase(), memrefIndices, maskOp1D,
- /*passthru=*/pad1DOp);
+ auto loadSlice = vector::MaskedLoadOp::create(rewriter, loc, tileSliceType,
+ tileLoadOp.getBase(),
+ memrefIndices, maskOp1D,
+ /*passthru=*/pad1DOp);
// Create 'arm_sme.insert_tile_slice' to insert slice into tile.
- auto insertSlice = rewriter.create<arm_sme::InsertTileSliceOp>(
- loc, tileType, loadSlice->getResult(0), currentTile, tileSliceIndex,
- tileLoadOp.getLayout());
- rewriter.create<scf::YieldOp>(loc, insertSlice.getResult());
+ auto insertSlice = arm_sme::InsertTileSliceOp::create(
+ rewriter, loc, tileType, loadSlice->getResult(0), currentTile,
+ tileSliceIndex, tileLoadOp.getLayout());
+ scf::YieldOp::create(rewriter, loc, insertSlice.getResult());
rewriter.setInsertionPointAfter(forOp);
diff --git a/mlir/lib/Conversion/AsyncToLLVM/AsyncToLLVM.cpp b/mlir/lib/Conversion/AsyncToLLVM/AsyncToLLVM.cpp
index 94f7caa315cf7..79e1683b4e2cf 100644
--- a/mlir/lib/Conversion/AsyncToLLVM/AsyncToLLVM.cpp
+++ b/mlir/lib/Conversion/AsyncToLLVM/AsyncToLLVM.cpp
@@ -203,7 +203,7 @@ static void addAsyncRuntimeApiDeclarations(ModuleOp module) {
auto addFuncDecl = [&](StringRef name, FunctionType type) {
if (module.lookupSymbol(name))
return;
- builder.create<func::FuncOp>(name, type).setPrivate();
+ func::FuncOp::create(builder, name, type).setPrivate();
};
MLIRContext *ctx = module.getContext();
@@ -254,15 +254,15 @@ static void addResumeFunction(ModuleOp module) {
auto voidTy = LLVM::LLVMVoidType::get(ctx);
Type ptrType = AsyncAPI::opaquePointerType(ctx);
- auto resumeOp = moduleBuilder.create<LLVM::LLVMFuncOp>(
- kResume, LLVM::LLVMFunctionType::get(voidTy, {ptrType}));
+ auto resumeOp = LLVM::LLVMFuncOp::create(
+ moduleBuilder, kResume, LLVM::LLVMFunctionType::get(voidTy, {ptrType}));
resumeOp.setPrivate();
auto *block = resumeOp.addEntryBlock(moduleBuilder);
auto blockBuilder = ImplicitLocOpBuilder::atBlockEnd(loc, block);
- blockBuilder.create<LLVM::CoroResumeOp>(resumeOp.getArgument(0));
- blockBuilder.create<LLVM::ReturnOp>(ValueRange());
+ LLVM::CoroResumeOp::create(blockBuilder, resumeOp.getArgument(0));
+ LLVM::ReturnOp::create(blockBuilder, ValueRange());
}
//===----------------------------------------------------------------------===//
@@ -282,7 +282,8 @@ class AsyncRuntimeTypeConverter : public TypeConverter {
// in patterns for other dialects.
auto addUnrealizedCast = [](OpBuilder &builder, Type type,
ValueRange inputs, Location loc) -> Value {
- auto cast = builder.create<UnrealizedConversionCastOp>(loc, type, inputs);
+ auto cast =
+ UnrealizedConversionCastOp::create(builder, loc, type, inputs);
return cast.getResult(0);
};
@@ -343,8 +344,8 @@ class CoroIdOpConversion : public AsyncOpConversionPattern<CoroIdOp> {
// Constants for initializing coroutine frame.
auto constZero =
- rewriter.create<LLVM::ConstantOp>(loc, rewriter.getI32Type(), 0);
- auto nullPtr = rewriter.create<LLVM::ZeroOp>(loc, ptrType);
+ LLVM::ConstantOp::create(rewriter, loc, rewriter.getI32Type(), 0);
+ auto nullPtr = LLVM::ZeroOp::create(rewriter, loc, ptrType);
// Get coroutine id: @llvm.coro.id.
rewriter.replaceOpWithNewOp<LLVM::CoroIdOp>(
@@ -372,33 +373,33 @@ class CoroBeginOpConversion : public AsyncOpConversionPattern<CoroBeginOp> {
// Get coroutine frame size: @llvm.coro.size.i64.
Value coroSize =
- rewriter.create<LLVM::CoroSizeOp>(loc, rewriter.getI64Type());
+ LLVM::CoroSizeOp::create(rewriter, loc, rewriter.getI64Type());
// Get coroutine frame alignment: @llvm.coro.align.i64.
Value coroAlign =
- rewriter.create<LLVM::CoroAlignOp>(loc, rewriter.getI64Type());
+ LLVM::CoroAlignOp::create(rewriter, loc, rewriter.getI64Type());
// Round up the size to be multiple of the alignment. Since aligned_alloc
// requires the size parameter be an integral multiple of the alignment
// parameter.
auto makeConstant = [&](uint64_t c) {
- return rewriter.create<LLVM::ConstantOp>(op->getLoc(),
- rewriter.getI64Type(), c);
+ return LLVM::ConstantOp::create(rewriter, op->getLoc(),
+ rewriter.getI64Type(), c);
};
- coroSize = rewriter.create<LLVM::AddOp>(op->getLoc(), coroSize, coroAlign);
+ coroSize = LLVM::AddOp::create(rewriter, op->getLoc(), coroSize, coroAlign);
coroSize =
- rewriter.create<LLVM::SubOp>(op->getLoc(), coroSize, makeConstant(1));
+ LLVM::SubOp::create(rewriter, op->getLoc(), coroSize, makeConstant(1));
Value negCoroAlign =
- rewriter.create<LLVM::SubOp>(op->getLoc(), makeConstant(0), coroAlign);
+ LLVM::SubOp::create(rewriter, op->getLoc(), makeConstant(0), coroAlign);
coroSize =
- rewriter.create<LLVM::AndOp>(op->getLoc(), coroSize, negCoroAlign);
+ LLVM::AndOp::create(rewriter, op->getLoc(), coroSize, negCoroAlign);
// Allocate memory for the coroutine frame.
auto allocFuncOp = LLVM::lookupOrCreateAlignedAllocFn(
rewriter, op->getParentOfType<ModuleOp>(), rewriter.getI64Type());
if (failed(allocFuncOp))
return failure();
- auto coroAlloc = rewriter.create<LLVM::CallOp>(
- loc, allocFuncOp.value(), ValueRange{coroAlign, coroSize});
+ auto coroAlloc = LLVM::CallOp::create(rewriter, loc, allocFuncOp.value(),
+ ValueRange{coroAlign, coroSize});
// Begin a coroutine: @llvm.coro.begin.
auto coroId = CoroBeginOpAdaptor(adaptor.getOperands()).getId();
@@ -427,7 +428,7 @@ class CoroFreeOpConversion : public AsyncOpConversionPattern<CoroFreeOp> {
// Get a pointer to the coroutine frame memory: @llvm.coro.free.
auto coroMem =
- rewriter.create<LLVM::CoroFreeOp>(loc, ptrType, adaptor.getOperands());
+ LLVM::CoroFreeOp::create(rewriter, loc, ptrType, adaptor.getOperands());
// Free the memory.
auto freeFuncOp =
@@ -455,15 +456,15 @@ class CoroEndOpConversion : public OpConversionPattern<CoroEndOp> {
matchAndRewrite(CoroEndOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
// We are not in the block that is part of the unwind sequence.
- auto constFalse = rewriter.create<LLVM::ConstantOp>(
- op->getLoc(), rewriter.getI1Type(), rewriter.getBoolAttr(false));
- auto noneToken = rewriter.create<LLVM::NoneTokenOp>(op->getLoc());
+ auto constFalse =
+ LLVM::ConstantOp::create(rewriter, op->getLoc(), rewriter.getI1Type(),
+ rewriter.getBoolAttr(false));
+ auto noneToken = LLVM::NoneTokenOp::create(rewriter, op->getLoc());
// Mark the end of a coroutine: @llvm.coro.end.
auto coroHdl = adaptor.getHandle();
- rewriter.create<LLVM::CoroEndOp>(
- op->getLoc(), rewriter.getI1Type(),
- ValueRange({coroHdl, constFalse, noneToken}));
+ LLVM::CoroEndOp::create(rewriter, op->getLoc(), rewriter.getI1Type(),
+ ValueRange({coroHdl, constFalse, noneToken}));
rewriter.eraseOp(op);
return success();
@@ -534,13 +535,13 @@ class CoroSuspendOpConversion : public OpConversionPattern<CoroSuspendOp> {
auto loc = op->getLoc();
// This is not a final suspension point.
- auto constFalse = rewriter.create<LLVM::ConstantOp>(
- loc, rewriter.getI1Type(), rewriter.getBoolAttr(false));
+ auto constFalse = LLVM::ConstantOp::create(
+ rewriter, loc, rewriter.getI1Type(), rewriter.getBoolAttr(false));
// Suspend a coroutine: @llvm.coro.suspend
auto coroState = adaptor.getState();
- auto coroSuspend = rewriter.create<LLVM::CoroSuspendOp>(
- loc, i8, ValueRange({coroState, constFalse}));
+ auto coroSuspend = LLVM::CoroSuspendOp::create(
+ rewriter, loc, i8, ValueRange({coroState, constFalse}));
// Cast return code to i32.
@@ -551,7 +552,7 @@ class CoroSuspendOpConversion : public OpConversionPattern<CoroSuspendOp> {
llvm::SmallVector<Block *, 2> caseDest = {op.getResumeDest(),
op.getCleanupDest()};
rewriter.replaceOpWithNewOp<LLVM::SwitchOp>(
- op, rewriter.create<LLVM::SExtOp>(loc, i32, coroSuspend.getResult()),
+ op, LLVM::SExtOp::create(rewriter, loc, i32, coroSuspend.getResult()),
/*defaultDestination=*/op.getSuspendDest(),
/*defaultOperands=*/ValueRange(),
/*caseValues=*/caseValues,
@@ -602,11 +603,11 @@ class RuntimeCreateOpLowering : public ConvertOpToLLVMPattern<RuntimeCreateOp> {
// %Size = getelementptr %T* null, int 1
// %SizeI = ptrtoint %T* %Size to i64
- auto nullPtr = rewriter.create<LLVM::ZeroOp>(loc, storagePtrType);
+ auto nullPtr = LLVM::ZeroOp::create(rewriter, loc, storagePtrType);
auto gep =
- rewriter.create<LLVM::GEPOp>(loc, storagePtrType, storedType,
- nullPtr, ArrayRef<LLVM::GEPArg>{1});
- return rewriter.create<LLVM::PtrToIntOp>(loc, i64, gep);
+ LLVM::GEPOp::create(rewriter, loc, storagePtrType, storedType,
+ nullPtr, ArrayRef<LLVM::GEPArg>{1});
+ return LLVM::PtrToIntOp::create(rewriter, loc, i64, gep);
};
rewriter.replaceOpWithNewOp<func::CallOp>(op, kCreateValue, resultType,
@@ -739,8 +740,8 @@ class RuntimeAwaitOpLowering : public OpConversionPattern<RuntimeAwaitOp> {
.Case<ValueType>([](Type) { return kAwaitValue; })
.Case<GroupType>([](Type) { return kAwaitGroup; });
- rewriter.create<func::CallOp>(op->getLoc(), apiFuncName, TypeRange(),
- adaptor.getOperands());
+ func::CallOp::create(rewriter, op->getLoc(), apiFuncName, TypeRange(),
+ adaptor.getOperands());
rewriter.eraseOp(op);
return success();
@@ -772,13 +773,12 @@ class RuntimeAwaitAndResumeOpLowering
// A pointer to coroutine resume intrinsic wrapper.
addResumeFunction(op->getParentOfType<ModuleOp>());
- auto resumePtr = rewriter.create<LLVM::AddressOfOp>(
- op->getLoc(), AsyncAPI::opaquePointerType(rewriter.getContext()),
- kResume);
+ auto resumePtr = LLVM::AddressOfOp::create(
+ rewriter, op->getLoc(),
+ AsyncAPI::opaquePointerType(rewriter.getContext()), kResume);
- rewriter.create<func::CallOp>(
- op->getLoc(), apiFuncName, TypeRange(),
- ValueRange({operand, handle, resumePtr.getRes()}));
+ func::CallOp::create(rewriter, op->getLoc(), apiFuncName, TypeRange(),
+ ValueRange({operand, handle, resumePtr.getRes()}));
rewriter.eraseOp(op);
return success();
@@ -801,9 +801,9 @@ class RuntimeResumeOpLowering
ConversionPatternRewriter &rewriter) const override {
// A pointer to coroutine resume intrinsic wrapper.
addResumeFunction(op->getParentOfType<ModuleOp>());
- auto resumePtr = rewriter.create<LLVM::AddressOfOp>(
- op->getLoc(), AsyncAPI::opaquePointerType(rewriter.getContext()),
- kResume);
+ auto resumePtr = LLVM::AddressOfOp::create(
+ rewriter, op->getLoc(),
+ AsyncAPI::opaquePointerType(rewriter.getContext()), kResume);
// Call async runtime API to execute a coroutine in the managed thread.
auto coroHdl = adaptor.getHandle();
@@ -832,8 +832,8 @@ class RuntimeStoreOpLowering : public ConvertOpToLLVMPattern<RuntimeStoreOp> {
// Get a pointer to the async value storage from the runtime.
auto ptrType = AsyncAPI::opaquePointerType(rewriter.getContext());
auto storage = adaptor.getStorage();
- auto storagePtr = rewriter.create<func::CallOp>(
- loc, kGetValueStorage, TypeRange(ptrType), storage);
+ auto storagePtr = func::CallOp::create(rewriter, loc, kGetValueStorage,
+ TypeRange(ptrType), storage);
// Cast from i8* to the LLVM pointer type.
auto valueType = op.getValue().getType();
@@ -845,7 +845,7 @@ class RuntimeStoreOpLowering : public ConvertOpToLLVMPattern<RuntimeStoreOp> {
Value castedStoragePtr = storagePtr.getResult(0);
// Store the yielded value into the async value storage.
auto value = adaptor.getValue();
- rewriter.create<LLVM::StoreOp>(loc, value, castedStoragePtr);
+ LLVM::StoreOp::create(rewriter, loc, value, castedStoragePtr);
// Erase the original runtime store operation.
rewriter.eraseOp(op);
@@ -872,8 +872,8 @@ class RuntimeLoadOpLowering : public ConvertOpToLLVMPattern<RuntimeLoadOp> {
// Get a pointer to the async value storage from the runtime.
auto ptrType = AsyncAPI::opaquePointerType(rewriter.getContext());
auto storage = adaptor.getStorage();
- auto storagePtr = rewriter.create<func::CallOp>(
- loc, kGetValueStorage, TypeRange(ptrType), storage);
+ auto storagePtr = func::CallOp::create(rewriter, loc, kGetValueStorage,
+ TypeRange(ptrType), storage);
// Cast from i8* to the LLVM pointer type.
auto valueType = op.getResult().getType();
@@ -960,9 +960,9 @@ class RefCountingOpLowering : public OpConversionPattern<RefCountingOp> {
LogicalResult
matchAndRewrite(RefCountingOp op, typename RefCountingOp::Adaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
- auto count = rewriter.create<arith::ConstantOp>(
- op->getLoc(), rewriter.getI64Type(),
- rewriter.getI64IntegerAttr(op.getCount()));
+ auto count =
+ arith::ConstantOp::create(rewriter, op->getLoc(), rewriter.getI64Type(),
+ rewriter.getI64IntegerAttr(op.getCount()));
auto operand = adaptor.getOperand();
rewriter.replaceOpWithNewOp<func::CallOp>(op, TypeRange(), apiFunctionName,
diff --git a/mlir/lib/Conversion/BufferizationToMemRef/BufferizationToMemRef.cpp b/mlir/lib/Conversion/BufferizationToMemRef/BufferizationToMemRef.cpp
index b9991f36cdaaf..30a7170cf5c6a 100644
--- a/mlir/lib/Conversion/BufferizationToMemRef/BufferizationToMemRef.cpp
+++ b/mlir/lib/Conversion/BufferizationToMemRef/BufferizationToMemRef.cpp
@@ -47,26 +47,26 @@ struct CloneOpConversion : public OpConversionPattern<bufferization::CloneOp> {
if (auto unrankedType = dyn_cast<UnrankedMemRefType>(type)) {
// Constants
- Value zero = rewriter.create<arith::ConstantIndexOp>(loc, 0);
- Value one = rewriter.create<arith::ConstantIndexOp>(loc, 1);
+ Value zero = arith::ConstantIndexOp::create(rewriter, loc, 0);
+ Value one = arith::ConstantIndexOp::create(rewriter, loc, 1);
// Dynamically evaluate the size and shape of the unranked memref
- Value rank = rewriter.create<memref::RankOp>(loc, op.getInput());
+ Value rank = memref::RankOp::create(rewriter, loc, op.getInput());
MemRefType allocType =
MemRefType::get({ShapedType::kDynamic}, rewriter.getIndexType());
- Value shape = rewriter.create<memref::AllocaOp>(loc, allocType, rank);
+ Value shape = memref::AllocaOp::create(rewriter, loc, allocType, rank);
// Create a loop to query dimension sizes, store them as a shape, and
// compute the total size of the memref
auto loopBody = [&](OpBuilder &builder, Location loc, Value i,
ValueRange args) {
auto acc = args.front();
- auto dim = rewriter.create<memref::DimOp>(loc, op.getInput(), i);
+ auto dim = memref::DimOp::create(rewriter, loc, op.getInput(), i);
- rewriter.create<memref::StoreOp>(loc, dim, shape, i);
- acc = rewriter.create<arith::MulIOp>(loc, acc, dim);
+ memref::StoreOp::create(rewriter, loc, dim, shape, i);
+ acc = arith::MulIOp::create(rewriter, loc, acc, dim);
- rewriter.create<scf::YieldOp>(loc, acc);
+ scf::YieldOp::create(rewriter, loc, acc);
};
auto size = rewriter
.create<scf::ForOp>(loc, zero, rank, one, ValueRange(one),
@@ -78,9 +78,9 @@ struct CloneOpConversion : public OpConversionPattern<bufferization::CloneOp> {
// Allocate new memref with 1D dynamic shape, then reshape into the
// shape of the original unranked memref
- alloc = rewriter.create<memref::AllocOp>(loc, memrefType, size);
+ alloc = memref::AllocOp::create(rewriter, loc, memrefType, size);
alloc =
- rewriter.create<memref::ReshapeOp>(loc, unrankedType, alloc, shape);
+ memref::ReshapeOp::create(rewriter, loc, unrankedType, alloc, shape);
} else {
MemRefType memrefType = cast<MemRefType>(type);
MemRefLayoutAttrInterface layout;
@@ -103,14 +103,15 @@ struct CloneOpConversion : public OpConversionPattern<bufferization::CloneOp> {
}
// Allocate a memref with identity layout.
- alloc = rewriter.create<memref::AllocOp>(loc, allocType, dynamicOperands);
+ alloc =
+ memref::AllocOp::create(rewriter, loc, allocType, dynamicOperands);
// Cast the allocation to the specified type if needed.
if (memrefType != allocType)
alloc =
- rewriter.create<memref::CastOp>(op->getLoc(), memrefType, alloc);
+ memref::CastOp::create(rewriter, op->getLoc(), memrefType, alloc);
}
- rewriter.create<memref::CopyOp>(loc, op.getInput(), alloc);
+ memref::CopyOp::create(rewriter, loc, op.getInput(), alloc);
rewriter.replaceOp(op, alloc);
return success();
}
diff --git a/mlir/lib/Conversion/ComplexCommon/DivisionConverter.cpp b/mlir/lib/Conversion/ComplexCommon/DivisionConverter.cpp
index 70b22386f1eea..14fbb9bf09545 100644
--- a/mlir/lib/Conversion/ComplexCommon/DivisionConverter.cpp
+++ b/mlir/lib/Conversion/ComplexCommon/DivisionConverter.cpp
@@ -23,41 +23,43 @@ void mlir::complex::convertDivToLLVMUsingAlgebraic(
ConversionPatternRewriter &rewriter, Location loc, Value lhsRe, Value lhsIm,
Value rhsRe, Value rhsIm, LLVM::FastmathFlagsAttr fmf, Value *resultRe,
Value *resultIm) {
- Value rhsSqNorm = rewriter.create<LLVM::FAddOp>(
- loc, rewriter.create<LLVM::FMulOp>(loc, rhsRe, rhsRe, fmf),
- rewriter.create<LLVM::FMulOp>(loc, rhsIm, rhsIm, fmf), fmf);
+ Value rhsSqNorm = LLVM::FAddOp::create(
+ rewriter, loc, LLVM::FMulOp::create(rewriter, loc, rhsRe, rhsRe, fmf),
+ LLVM::FMulOp::create(rewriter, loc, rhsIm, rhsIm, fmf), fmf);
- Value realNumerator = rewriter.create<LLVM::FAddOp>(
- loc, rewriter.create<LLVM::FMulOp>(loc, lhsRe, rhsRe, fmf),
- rewriter.create<LLVM::FMulOp>(loc, lhsIm, rhsIm, fmf), fmf);
+ Value realNumerator = LLVM::FAddOp::create(
+ rewriter, loc, LLVM::FMulOp::create(rewriter, loc, lhsRe, rhsRe, fmf),
+ LLVM::FMulOp::create(rewriter, loc, lhsIm, rhsIm, fmf), fmf);
- Value imagNumerator = rewriter.create<LLVM::FSubOp>(
- loc, rewriter.create<LLVM::FMulOp>(loc, lhsIm, rhsRe, fmf),
- rewriter.create<LLVM::FMulOp>(loc, lhsRe, rhsIm, fmf), fmf);
+ Value imagNumerator = LLVM::FSubOp::create(
+ rewriter, loc, LLVM::FMulOp::create(rewriter, loc, lhsIm, rhsRe, fmf),
+ LLVM::FMulOp::create(rewriter, loc, lhsRe, rhsIm, fmf), fmf);
- *resultRe = rewriter.create<LLVM::FDivOp>(loc, realNumerator, rhsSqNorm, fmf);
- *resultIm = rewriter.create<LLVM::FDivOp>(loc, imagNumerator, rhsSqNorm, fmf);
+ *resultRe =
+ LLVM::FDivOp::create(rewriter, loc, realNumerator, rhsSqNorm, fmf);
+ *resultIm =
+ LLVM::FDivOp::create(rewriter, loc, imagNumerator, rhsSqNorm, fmf);
}
void mlir::complex::convertDivToStandardUsingAlgebraic(
ConversionPatternRewriter &rewriter, Location loc, Value lhsRe, Value lhsIm,
Value rhsRe, Value rhsIm, arith::FastMathFlagsAttr fmf, Value *resultRe,
Value *resultIm) {
- Value rhsSqNorm = rewriter.create<arith::AddFOp>(
- loc, rewriter.create<arith::MulFOp>(loc, rhsRe, rhsRe, fmf),
- rewriter.create<arith::MulFOp>(loc, rhsIm, rhsIm, fmf), fmf);
+ Value rhsSqNorm = arith::AddFOp::create(
+ rewriter, loc, arith::MulFOp::create(rewriter, loc, rhsRe, rhsRe, fmf),
+ arith::MulFOp::create(rewriter, loc, rhsIm, rhsIm, fmf), fmf);
- Value realNumerator = rewriter.create<arith::AddFOp>(
- loc, rewriter.create<arith::MulFOp>(loc, lhsRe, rhsRe, fmf),
- rewriter.create<arith::MulFOp>(loc, lhsIm, rhsIm, fmf), fmf);
- Value imagNumerator = rewriter.create<arith::SubFOp>(
- loc, rewriter.create<arith::MulFOp>(loc, lhsIm, rhsRe, fmf),
- rewriter.create<arith::MulFOp>(loc, lhsRe, rhsIm, fmf), fmf);
+ Value realNumerator = arith::AddFOp::create(
+ rewriter, loc, arith::MulFOp::create(rewriter, loc, lhsRe, rhsRe, fmf),
+ arith::MulFOp::create(rewriter, loc, lhsIm, rhsIm, fmf), fmf);
+ Value imagNumerator = arith::SubFOp::create(
+ rewriter, loc, arith::MulFOp::create(rewriter, loc, lhsIm, rhsRe, fmf),
+ arith::MulFOp::create(rewriter, loc, lhsRe, rhsIm, fmf), fmf);
*resultRe =
- rewriter.create<arith::DivFOp>(loc, realNumerator, rhsSqNorm, fmf);
+ arith::DivFOp::create(rewriter, loc, realNumerator, rhsSqNorm, fmf);
*resultIm =
- rewriter.create<arith::DivFOp>(loc, imagNumerator, rhsSqNorm, fmf);
+ arith::DivFOp::create(rewriter, loc, imagNumerator, rhsSqNorm, fmf);
}
// Smith's algorithm to divide complex numbers. It is just a bit smarter
@@ -94,181 +96,185 @@ void mlir::complex::convertDivToLLVMUsingRangeReduction(
auto elementType = cast<FloatType>(rhsRe.getType());
Value rhsRealImagRatio =
- rewriter.create<LLVM::FDivOp>(loc, rhsRe, rhsIm, fmf);
- Value rhsRealImagDenom = rewriter.create<LLVM::FAddOp>(
- loc, rhsIm,
- rewriter.create<LLVM::FMulOp>(loc, rhsRealImagRatio, rhsRe, fmf), fmf);
- Value realNumerator1 = rewriter.create<LLVM::FAddOp>(
- loc, rewriter.create<LLVM::FMulOp>(loc, lhsRe, rhsRealImagRatio, fmf),
- lhsIm, fmf);
- Value resultReal1 =
- rewriter.create<LLVM::FDivOp>(loc, realNumerator1, rhsRealImagDenom, fmf);
- Value imagNumerator1 = rewriter.create<LLVM::FSubOp>(
- loc, rewriter.create<LLVM::FMulOp>(loc, lhsIm, rhsRealImagRatio, fmf),
- lhsRe, fmf);
- Value resultImag1 =
- rewriter.create<LLVM::FDivOp>(loc, imagNumerator1, rhsRealImagDenom, fmf);
+ LLVM::FDivOp::create(rewriter, loc, rhsRe, rhsIm, fmf);
+ Value rhsRealImagDenom = LLVM::FAddOp::create(
+ rewriter, loc, rhsIm,
+ LLVM::FMulOp::create(rewriter, loc, rhsRealImagRatio, rhsRe, fmf), fmf);
+ Value realNumerator1 = LLVM::FAddOp::create(
+ rewriter, loc,
+ LLVM::FMulOp::create(rewriter, loc, lhsRe, rhsRealImagRatio, fmf), lhsIm,
+ fmf);
+ Value resultReal1 = LLVM::FDivOp::create(rewriter, loc, realNumerator1,
+ rhsRealImagDenom, fmf);
+ Value imagNumerator1 = LLVM::FSubOp::create(
+ rewriter, loc,
+ LLVM::FMulOp::create(rewriter, loc, lhsIm, rhsRealImagRatio, fmf), lhsRe,
+ fmf);
+ Value resultImag1 = LLVM::FDivOp::create(rewriter, loc, imagNumerator1,
+ rhsRealImagDenom, fmf);
Value rhsImagRealRatio =
- rewriter.create<LLVM::FDivOp>(loc, rhsIm, rhsRe, fmf);
- Value rhsImagRealDenom = rewriter.create<LLVM::FAddOp>(
- loc, rhsRe,
- rewriter.create<LLVM::FMulOp>(loc, rhsImagRealRatio, rhsIm, fmf), fmf);
- Value realNumerator2 = rewriter.create<LLVM::FAddOp>(
- loc, lhsRe,
- rewriter.create<LLVM::FMulOp>(loc, lhsIm, rhsImagRealRatio, fmf), fmf);
- Value resultReal2 =
- rewriter.create<LLVM::FDivOp>(loc, realNumerator2, rhsImagRealDenom, fmf);
- Value imagNumerator2 = rewriter.create<LLVM::FSubOp>(
- loc, lhsIm,
- rewriter.create<LLVM::FMulOp>(loc, lhsRe, rhsImagRealRatio, fmf), fmf);
- Value resultImag2 =
- rewriter.create<LLVM::FDivOp>(loc, imagNumerator2, rhsImagRealDenom, fmf);
+ LLVM::FDivOp::create(rewriter, loc, rhsIm, rhsRe, fmf);
+ Value rhsImagRealDenom = LLVM::FAddOp::create(
+ rewriter, loc, rhsRe,
+ LLVM::FMulOp::create(rewriter, loc, rhsImagRealRatio, rhsIm, fmf), fmf);
+ Value realNumerator2 = LLVM::FAddOp::create(
+ rewriter, loc, lhsRe,
+ LLVM::FMulOp::create(rewriter, loc, lhsIm, rhsImagRealRatio, fmf), fmf);
+ Value resultReal2 = LLVM::FDivOp::create(rewriter, loc, realNumerator2,
+ rhsImagRealDenom, fmf);
+ Value imagNumerator2 = LLVM::FSubOp::create(
+ rewriter, loc, lhsIm,
+ LLVM::FMulOp::create(rewriter, loc, lhsRe, rhsImagRealRatio, fmf), fmf);
+ Value resultImag2 = LLVM::FDivOp::create(rewriter, loc, imagNumerator2,
+ rhsImagRealDenom, fmf);
// Consider corner cases.
// Case 1. Zero denominator, numerator contains at most one NaN value.
- Value zero = rewriter.create<LLVM::ConstantOp>(
- loc, elementType, rewriter.getZeroAttr(elementType));
- Value rhsRealAbs = rewriter.create<LLVM::FAbsOp>(loc, rhsRe, fmf);
- Value rhsRealIsZero = rewriter.create<LLVM::FCmpOp>(
- loc, LLVM::FCmpPredicate::oeq, rhsRealAbs, zero);
- Value rhsImagAbs = rewriter.create<LLVM::FAbsOp>(loc, rhsIm, fmf);
- Value rhsImagIsZero = rewriter.create<LLVM::FCmpOp>(
- loc, LLVM::FCmpPredicate::oeq, rhsImagAbs, zero);
- Value lhsRealIsNotNaN =
- rewriter.create<LLVM::FCmpOp>(loc, LLVM::FCmpPredicate::ord, lhsRe, zero);
- Value lhsImagIsNotNaN =
- rewriter.create<LLVM::FCmpOp>(loc, LLVM::FCmpPredicate::ord, lhsIm, zero);
+ Value zero = LLVM::ConstantOp::create(rewriter, loc, elementType,
+ rewriter.getZeroAttr(elementType));
+ Value rhsRealAbs = LLVM::FAbsOp::create(rewriter, loc, rhsRe, fmf);
+ Value rhsRealIsZero = LLVM::FCmpOp::create(
+ rewriter, loc, LLVM::FCmpPredicate::oeq, rhsRealAbs, zero);
+ Value rhsImagAbs = LLVM::FAbsOp::create(rewriter, loc, rhsIm, fmf);
+ Value rhsImagIsZero = LLVM::FCmpOp::create(
+ rewriter, loc, LLVM::FCmpPredicate::oeq, rhsImagAbs, zero);
+ Value lhsRealIsNotNaN = LLVM::FCmpOp::create(
+ rewriter, loc, LLVM::FCmpPredicate::ord, lhsRe, zero);
+ Value lhsImagIsNotNaN = LLVM::FCmpOp::create(
+ rewriter, loc, LLVM::FCmpPredicate::ord, lhsIm, zero);
Value lhsContainsNotNaNValue =
- rewriter.create<LLVM::OrOp>(loc, lhsRealIsNotNaN, lhsImagIsNotNaN);
- Value resultIsInfinity = rewriter.create<LLVM::AndOp>(
- loc, lhsContainsNotNaNValue,
- rewriter.create<LLVM::AndOp>(loc, rhsRealIsZero, rhsImagIsZero));
- Value inf = rewriter.create<LLVM::ConstantOp>(
- loc, elementType,
+ LLVM::OrOp::create(rewriter, loc, lhsRealIsNotNaN, lhsImagIsNotNaN);
+ Value resultIsInfinity = LLVM::AndOp::create(
+ rewriter, loc, lhsContainsNotNaNValue,
+ LLVM::AndOp::create(rewriter, loc, rhsRealIsZero, rhsImagIsZero));
+ Value inf = LLVM::ConstantOp::create(
+ rewriter, loc, elementType,
rewriter.getFloatAttr(elementType,
APFloat::getInf(elementType.getFloatSemantics())));
Value infWithSignOfrhsReal =
- rewriter.create<LLVM::CopySignOp>(loc, inf, rhsRe);
+ LLVM::CopySignOp::create(rewriter, loc, inf, rhsRe);
Value infinityResultReal =
- rewriter.create<LLVM::FMulOp>(loc, infWithSignOfrhsReal, lhsRe, fmf);
+ LLVM::FMulOp::create(rewriter, loc, infWithSignOfrhsReal, lhsRe, fmf);
Value infinityResultImag =
- rewriter.create<LLVM::FMulOp>(loc, infWithSignOfrhsReal, lhsIm, fmf);
+ LLVM::FMulOp::create(rewriter, loc, infWithSignOfrhsReal, lhsIm, fmf);
// Case 2. Infinite numerator, finite denominator.
- Value rhsRealFinite = rewriter.create<LLVM::FCmpOp>(
- loc, LLVM::FCmpPredicate::one, rhsRealAbs, inf);
- Value rhsImagFinite = rewriter.create<LLVM::FCmpOp>(
- loc, LLVM::FCmpPredicate::one, rhsImagAbs, inf);
+ Value rhsRealFinite = LLVM::FCmpOp::create(
+ rewriter, loc, LLVM::FCmpPredicate::one, rhsRealAbs, inf);
+ Value rhsImagFinite = LLVM::FCmpOp::create(
+ rewriter, loc, LLVM::FCmpPredicate::one, rhsImagAbs, inf);
Value rhsFinite =
- rewriter.create<LLVM::AndOp>(loc, rhsRealFinite, rhsImagFinite);
- Value lhsRealAbs = rewriter.create<LLVM::FAbsOp>(loc, lhsRe, fmf);
- Value lhsRealInfinite = rewriter.create<LLVM::FCmpOp>(
- loc, LLVM::FCmpPredicate::oeq, lhsRealAbs, inf);
- Value lhsImagAbs = rewriter.create<LLVM::FAbsOp>(loc, lhsIm, fmf);
- Value lhsImagInfinite = rewriter.create<LLVM::FCmpOp>(
- loc, LLVM::FCmpPredicate::oeq, lhsImagAbs, inf);
+ LLVM::AndOp::create(rewriter, loc, rhsRealFinite, rhsImagFinite);
+ Value lhsRealAbs = LLVM::FAbsOp::create(rewriter, loc, lhsRe, fmf);
+ Value lhsRealInfinite = LLVM::FCmpOp::create(
+ rewriter, loc, LLVM::FCmpPredicate::oeq, lhsRealAbs, inf);
+ Value lhsImagAbs = LLVM::FAbsOp::create(rewriter, loc, lhsIm, fmf);
+ Value lhsImagInfinite = LLVM::FCmpOp::create(
+ rewriter, loc, LLVM::FCmpPredicate::oeq, lhsImagAbs, inf);
Value lhsInfinite =
- rewriter.create<LLVM::OrOp>(loc, lhsRealInfinite, lhsImagInfinite);
+ LLVM::OrOp::create(rewriter, loc, lhsRealInfinite, lhsImagInfinite);
Value infNumFiniteDenom =
- rewriter.create<LLVM::AndOp>(loc, lhsInfinite, rhsFinite);
- Value one = rewriter.create<LLVM::ConstantOp>(
- loc, elementType, rewriter.getFloatAttr(elementType, 1));
- Value lhsRealIsInfWithSign = rewriter.create<LLVM::CopySignOp>(
- loc, rewriter.create<LLVM::SelectOp>(loc, lhsRealInfinite, one, zero),
- lhsRe);
- Value lhsImagIsInfWithSign = rewriter.create<LLVM::CopySignOp>(
- loc, rewriter.create<LLVM::SelectOp>(loc, lhsImagInfinite, one, zero),
- lhsIm);
+ LLVM::AndOp::create(rewriter, loc, lhsInfinite, rhsFinite);
+ Value one = LLVM::ConstantOp::create(rewriter, loc, elementType,
+ rewriter.getFloatAttr(elementType, 1));
+ Value lhsRealIsInfWithSign = LLVM::CopySignOp::create(
+ rewriter, loc,
+ LLVM::SelectOp::create(rewriter, loc, lhsRealInfinite, one, zero), lhsRe);
+ Value lhsImagIsInfWithSign = LLVM::CopySignOp::create(
+ rewriter, loc,
+ LLVM::SelectOp::create(rewriter, loc, lhsImagInfinite, one, zero), lhsIm);
Value lhsRealIsInfWithSignTimesrhsReal =
- rewriter.create<LLVM::FMulOp>(loc, lhsRealIsInfWithSign, rhsRe, fmf);
+ LLVM::FMulOp::create(rewriter, loc, lhsRealIsInfWithSign, rhsRe, fmf);
Value lhsImagIsInfWithSignTimesrhsImag =
- rewriter.create<LLVM::FMulOp>(loc, lhsImagIsInfWithSign, rhsIm, fmf);
- Value resultReal3 = rewriter.create<LLVM::FMulOp>(
- loc, inf,
- rewriter.create<LLVM::FAddOp>(loc, lhsRealIsInfWithSignTimesrhsReal,
- lhsImagIsInfWithSignTimesrhsImag, fmf),
+ LLVM::FMulOp::create(rewriter, loc, lhsImagIsInfWithSign, rhsIm, fmf);
+ Value resultReal3 = LLVM::FMulOp::create(
+ rewriter, loc, inf,
+ LLVM::FAddOp::create(rewriter, loc, lhsRealIsInfWithSignTimesrhsReal,
+ lhsImagIsInfWithSignTimesrhsImag, fmf),
fmf);
Value lhsRealIsInfWithSignTimesrhsImag =
- rewriter.create<LLVM::FMulOp>(loc, lhsRealIsInfWithSign, rhsIm, fmf);
+ LLVM::FMulOp::create(rewriter, loc, lhsRealIsInfWithSign, rhsIm, fmf);
Value lhsImagIsInfWithSignTimesrhsReal =
- rewriter.create<LLVM::FMulOp>(loc, lhsImagIsInfWithSign, rhsRe, fmf);
- Value resultImag3 = rewriter.create<LLVM::FMulOp>(
- loc, inf,
- rewriter.create<LLVM::FSubOp>(loc, lhsImagIsInfWithSignTimesrhsReal,
- lhsRealIsInfWithSignTimesrhsImag, fmf),
+ LLVM::FMulOp::create(rewriter, loc, lhsImagIsInfWithSign, rhsRe, fmf);
+ Value resultImag3 = LLVM::FMulOp::create(
+ rewriter, loc, inf,
+ LLVM::FSubOp::create(rewriter, loc, lhsImagIsInfWithSignTimesrhsReal,
+ lhsRealIsInfWithSignTimesrhsImag, fmf),
fmf);
// Case 3: Finite numerator, infinite denominator.
- Value lhsRealFinite = rewriter.create<LLVM::FCmpOp>(
- loc, LLVM::FCmpPredicate::one, lhsRealAbs, inf);
- Value lhsImagFinite = rewriter.create<LLVM::FCmpOp>(
- loc, LLVM::FCmpPredicate::one, lhsImagAbs, inf);
+ Value lhsRealFinite = LLVM::FCmpOp::create(
+ rewriter, loc, LLVM::FCmpPredicate::one, lhsRealAbs, inf);
+ Value lhsImagFinite = LLVM::FCmpOp::create(
+ rewriter, loc, LLVM::FCmpPredicate::one, lhsImagAbs, inf);
Value lhsFinite =
- rewriter.create<LLVM::AndOp>(loc, lhsRealFinite, lhsImagFinite);
- Value rhsRealInfinite = rewriter.create<LLVM::FCmpOp>(
- loc, LLVM::FCmpPredicate::oeq, rhsRealAbs, inf);
- Value rhsImagInfinite = rewriter.create<LLVM::FCmpOp>(
- loc, LLVM::FCmpPredicate::oeq, rhsImagAbs, inf);
+ LLVM::AndOp::create(rewriter, loc, lhsRealFinite, lhsImagFinite);
+ Value rhsRealInfinite = LLVM::FCmpOp::create(
+ rewriter, loc, LLVM::FCmpPredicate::oeq, rhsRealAbs, inf);
+ Value rhsImagInfinite = LLVM::FCmpOp::create(
+ rewriter, loc, LLVM::FCmpPredicate::oeq, rhsImagAbs, inf);
Value rhsInfinite =
- rewriter.create<LLVM::OrOp>(loc, rhsRealInfinite, rhsImagInfinite);
+ LLVM::OrOp::create(rewriter, loc, rhsRealInfinite, rhsImagInfinite);
Value finiteNumInfiniteDenom =
- rewriter.create<LLVM::AndOp>(loc, lhsFinite, rhsInfinite);
- Value rhsRealIsInfWithSign = rewriter.create<LLVM::CopySignOp>(
- loc, rewriter.create<LLVM::SelectOp>(loc, rhsRealInfinite, one, zero),
- rhsRe);
- Value rhsImagIsInfWithSign = rewriter.create<LLVM::CopySignOp>(
- loc, rewriter.create<LLVM::SelectOp>(loc, rhsImagInfinite, one, zero),
- rhsIm);
+ LLVM::AndOp::create(rewriter, loc, lhsFinite, rhsInfinite);
+ Value rhsRealIsInfWithSign = LLVM::CopySignOp::create(
+ rewriter, loc,
+ LLVM::SelectOp::create(rewriter, loc, rhsRealInfinite, one, zero), rhsRe);
+ Value rhsImagIsInfWithSign = LLVM::CopySignOp::create(
+ rewriter, loc,
+ LLVM::SelectOp::create(rewriter, loc, rhsImagInfinite, one, zero), rhsIm);
Value rhsRealIsInfWithSignTimeslhsReal =
- rewriter.create<LLVM::FMulOp>(loc, lhsRe, rhsRealIsInfWithSign, fmf);
+ LLVM::FMulOp::create(rewriter, loc, lhsRe, rhsRealIsInfWithSign, fmf);
Value rhsImagIsInfWithSignTimeslhsImag =
- rewriter.create<LLVM::FMulOp>(loc, lhsIm, rhsImagIsInfWithSign, fmf);
- Value resultReal4 = rewriter.create<LLVM::FMulOp>(
- loc, zero,
- rewriter.create<LLVM::FAddOp>(loc, rhsRealIsInfWithSignTimeslhsReal,
- rhsImagIsInfWithSignTimeslhsImag, fmf),
+ LLVM::FMulOp::create(rewriter, loc, lhsIm, rhsImagIsInfWithSign, fmf);
+ Value resultReal4 = LLVM::FMulOp::create(
+ rewriter, loc, zero,
+ LLVM::FAddOp::create(rewriter, loc, rhsRealIsInfWithSignTimeslhsReal,
+ rhsImagIsInfWithSignTimeslhsImag, fmf),
fmf);
Value rhsRealIsInfWithSignTimeslhsImag =
- rewriter.create<LLVM::FMulOp>(loc, lhsIm, rhsRealIsInfWithSign, fmf);
+ LLVM::FMulOp::create(rewriter, loc, lhsIm, rhsRealIsInfWithSign, fmf);
Value rhsImagIsInfWithSignTimeslhsReal =
- rewriter.create<LLVM::FMulOp>(loc, lhsRe, rhsImagIsInfWithSign, fmf);
- Value resultImag4 = rewriter.create<LLVM::FMulOp>(
- loc, zero,
- rewriter.create<LLVM::FSubOp>(loc, rhsRealIsInfWithSignTimeslhsImag,
- rhsImagIsInfWithSignTimeslhsReal, fmf),
+ LLVM::FMulOp::create(rewriter, loc, lhsRe, rhsImagIsInfWithSign, fmf);
+ Value resultImag4 = LLVM::FMulOp::create(
+ rewriter, loc, zero,
+ LLVM::FSubOp::create(rewriter, loc, rhsRealIsInfWithSignTimeslhsImag,
+ rhsImagIsInfWithSignTimeslhsReal, fmf),
fmf);
- Value realAbsSmallerThanImagAbs = rewriter.create<LLVM::FCmpOp>(
- loc, LLVM::FCmpPredicate::olt, rhsRealAbs, rhsImagAbs);
- Value resultReal5 = rewriter.create<LLVM::SelectOp>(
- loc, realAbsSmallerThanImagAbs, resultReal1, resultReal2);
- Value resultImag5 = rewriter.create<LLVM::SelectOp>(
- loc, realAbsSmallerThanImagAbs, resultImag1, resultImag2);
- Value resultRealSpecialCase3 = rewriter.create<LLVM::SelectOp>(
- loc, finiteNumInfiniteDenom, resultReal4, resultReal5);
- Value resultImagSpecialCase3 = rewriter.create<LLVM::SelectOp>(
- loc, finiteNumInfiniteDenom, resultImag4, resultImag5);
- Value resultRealSpecialCase2 = rewriter.create<LLVM::SelectOp>(
- loc, infNumFiniteDenom, resultReal3, resultRealSpecialCase3);
- Value resultImagSpecialCase2 = rewriter.create<LLVM::SelectOp>(
- loc, infNumFiniteDenom, resultImag3, resultImagSpecialCase3);
- Value resultRealSpecialCase1 = rewriter.create<LLVM::SelectOp>(
- loc, resultIsInfinity, infinityResultReal, resultRealSpecialCase2);
- Value resultImagSpecialCase1 = rewriter.create<LLVM::SelectOp>(
- loc, resultIsInfinity, infinityResultImag, resultImagSpecialCase2);
+ Value realAbsSmallerThanImagAbs = LLVM::FCmpOp::create(
+ rewriter, loc, LLVM::FCmpPredicate::olt, rhsRealAbs, rhsImagAbs);
+ Value resultReal5 = LLVM::SelectOp::create(
+ rewriter, loc, realAbsSmallerThanImagAbs, resultReal1, resultReal2);
+ Value resultImag5 = LLVM::SelectOp::create(
+ rewriter, loc, realAbsSmallerThanImagAbs, resultImag1, resultImag2);
+ Value resultRealSpecialCase3 = LLVM::SelectOp::create(
+ rewriter, loc, finiteNumInfiniteDenom, resultReal4, resultReal5);
+ Value resultImagSpecialCase3 = LLVM::SelectOp::create(
+ rewriter, loc, finiteNumInfiniteDenom, resultImag4, resultImag5);
+ Value resultRealSpecialCase2 = LLVM::SelectOp::create(
+ rewriter, loc, infNumFiniteDenom, resultReal3, resultRealSpecialCase3);
+ Value resultImagSpecialCase2 = LLVM::SelectOp::create(
+ rewriter, loc, infNumFiniteDenom, resultImag3, resultImagSpecialCase3);
+ Value resultRealSpecialCase1 =
+ LLVM::SelectOp::create(rewriter, loc, resultIsInfinity,
+ infinityResultReal, resultRealSpecialCase2);
+ Value resultImagSpecialCase1 =
+ LLVM::SelectOp::create(rewriter, loc, resultIsInfinity,
+ infinityResultImag, resultImagSpecialCase2);
- Value resultRealIsNaN = rewriter.create<LLVM::FCmpOp>(
- loc, LLVM::FCmpPredicate::uno, resultReal5, zero);
- Value resultImagIsNaN = rewriter.create<LLVM::FCmpOp>(
- loc, LLVM::FCmpPredicate::uno, resultImag5, zero);
+ Value resultRealIsNaN = LLVM::FCmpOp::create(
+ rewriter, loc, LLVM::FCmpPredicate::uno, resultReal5, zero);
+ Value resultImagIsNaN = LLVM::FCmpOp::create(
+ rewriter, loc, LLVM::FCmpPredicate::uno, resultImag5, zero);
Value resultIsNaN =
- rewriter.create<LLVM::AndOp>(loc, resultRealIsNaN, resultImagIsNaN);
+ LLVM::AndOp::create(rewriter, loc, resultRealIsNaN, resultImagIsNaN);
- *resultRe = rewriter.create<LLVM::SelectOp>(
- loc, resultIsNaN, resultRealSpecialCase1, resultReal5);
- *resultIm = rewriter.create<LLVM::SelectOp>(
- loc, resultIsNaN, resultImagSpecialCase1, resultImag5);
+ *resultRe = LLVM::SelectOp::create(rewriter, loc, resultIsNaN,
+ resultRealSpecialCase1, resultReal5);
+ *resultIm = LLVM::SelectOp::create(rewriter, loc, resultIsNaN,
+ resultImagSpecialCase1, resultImag5);
}
void mlir::complex::convertDivToStandardUsingRangeReduction(
@@ -278,179 +284,187 @@ void mlir::complex::convertDivToStandardUsingRangeReduction(
auto elementType = cast<FloatType>(rhsRe.getType());
Value rhsRealImagRatio =
- rewriter.create<arith::DivFOp>(loc, rhsRe, rhsIm, fmf);
- Value rhsRealImagDenom = rewriter.create<arith::AddFOp>(
- loc, rhsIm,
- rewriter.create<arith::MulFOp>(loc, rhsRealImagRatio, rhsRe, fmf), fmf);
- Value realNumerator1 = rewriter.create<arith::AddFOp>(
- loc, rewriter.create<arith::MulFOp>(loc, lhsRe, rhsRealImagRatio, fmf),
- lhsIm, fmf);
- Value resultReal1 = rewriter.create<arith::DivFOp>(loc, realNumerator1,
- rhsRealImagDenom, fmf);
- Value imagNumerator1 = rewriter.create<arith::SubFOp>(
- loc, rewriter.create<arith::MulFOp>(loc, lhsIm, rhsRealImagRatio, fmf),
- lhsRe, fmf);
- Value resultImag1 = rewriter.create<arith::DivFOp>(loc, imagNumerator1,
- rhsRealImagDenom, fmf);
+ arith::DivFOp::create(rewriter, loc, rhsRe, rhsIm, fmf);
+ Value rhsRealImagDenom = arith::AddFOp::create(
+ rewriter, loc, rhsIm,
+ arith::MulFOp::create(rewriter, loc, rhsRealImagRatio, rhsRe, fmf), fmf);
+ Value realNumerator1 = arith::AddFOp::create(
+ rewriter, loc,
+ arith::MulFOp::create(rewriter, loc, lhsRe, rhsRealImagRatio, fmf), lhsIm,
+ fmf);
+ Value resultReal1 = arith::DivFOp::create(rewriter, loc, realNumerator1,
+ rhsRealImagDenom, fmf);
+ Value imagNumerator1 = arith::SubFOp::create(
+ rewriter, loc,
+ arith::MulFOp::create(rewriter, loc, lhsIm, rhsRealImagRatio, fmf), lhsRe,
+ fmf);
+ Value resultImag1 = arith::DivFOp::create(rewriter, loc, imagNumerator1,
+ rhsRealImagDenom, fmf);
Value rhsImagRealRatio =
- rewriter.create<arith::DivFOp>(loc, rhsIm, rhsRe, fmf);
- Value rhsImagRealDenom = rewriter.create<arith::AddFOp>(
- loc, rhsRe,
- rewriter.create<arith::MulFOp>(loc, rhsImagRealRatio, rhsIm, fmf), fmf);
- Value realNumerator2 = rewriter.create<arith::AddFOp>(
- loc, lhsRe,
- rewriter.create<arith::MulFOp>(loc, lhsIm, rhsImagRealRatio, fmf), fmf);
- Value resultReal2 = rewriter.create<arith::DivFOp>(loc, realNumerator2,
- rhsImagRealDenom, fmf);
- Value imagNumerator2 = rewriter.create<arith::SubFOp>(
- loc, lhsIm,
- rewriter.create<arith::MulFOp>(loc, lhsRe, rhsImagRealRatio, fmf), fmf);
- Value resultImag2 = rewriter.create<arith::DivFOp>(loc, imagNumerator2,
- rhsImagRealDenom, fmf);
+ arith::DivFOp::create(rewriter, loc, rhsIm, rhsRe, fmf);
+ Value rhsImagRealDenom = arith::AddFOp::create(
+ rewriter, loc, rhsRe,
+ arith::MulFOp::create(rewriter, loc, rhsImagRealRatio, rhsIm, fmf), fmf);
+ Value realNumerator2 = arith::AddFOp::create(
+ rewriter, loc, lhsRe,
+ arith::MulFOp::create(rewriter, loc, lhsIm, rhsImagRealRatio, fmf), fmf);
+ Value resultReal2 = arith::DivFOp::create(rewriter, loc, realNumerator2,
+ rhsImagRealDenom, fmf);
+ Value imagNumerator2 = arith::SubFOp::create(
+ rewriter, loc, lhsIm,
+ arith::MulFOp::create(rewriter, loc, lhsRe, rhsImagRealRatio, fmf), fmf);
+ Value resultImag2 = arith::DivFOp::create(rewriter, loc, imagNumerator2,
+ rhsImagRealDenom, fmf);
// Consider corner cases.
// Case 1. Zero denominator, numerator contains at most one NaN value.
- Value zero = rewriter.create<arith::ConstantOp>(
- loc, elementType, rewriter.getZeroAttr(elementType));
- Value rhsRealAbs = rewriter.create<math::AbsFOp>(loc, rhsRe, fmf);
- Value rhsRealIsZero = rewriter.create<arith::CmpFOp>(
- loc, arith::CmpFPredicate::OEQ, rhsRealAbs, zero);
- Value rhsImagAbs = rewriter.create<math::AbsFOp>(loc, rhsIm, fmf);
- Value rhsImagIsZero = rewriter.create<arith::CmpFOp>(
- loc, arith::CmpFPredicate::OEQ, rhsImagAbs, zero);
- Value lhsRealIsNotNaN = rewriter.create<arith::CmpFOp>(
- loc, arith::CmpFPredicate::ORD, lhsRe, zero);
- Value lhsImagIsNotNaN = rewriter.create<arith::CmpFOp>(
- loc, arith::CmpFPredicate::ORD, lhsIm, zero);
+ Value zero = arith::ConstantOp::create(rewriter, loc, elementType,
+ rewriter.getZeroAttr(elementType));
+ Value rhsRealAbs = math::AbsFOp::create(rewriter, loc, rhsRe, fmf);
+ Value rhsRealIsZero = arith::CmpFOp::create(
+ rewriter, loc, arith::CmpFPredicate::OEQ, rhsRealAbs, zero);
+ Value rhsImagAbs = math::AbsFOp::create(rewriter, loc, rhsIm, fmf);
+ Value rhsImagIsZero = arith::CmpFOp::create(
+ rewriter, loc, arith::CmpFPredicate::OEQ, rhsImagAbs, zero);
+ Value lhsRealIsNotNaN = arith::CmpFOp::create(
+ rewriter, loc, arith::CmpFPredicate::ORD, lhsRe, zero);
+ Value lhsImagIsNotNaN = arith::CmpFOp::create(
+ rewriter, loc, arith::CmpFPredicate::ORD, lhsIm, zero);
Value lhsContainsNotNaNValue =
- rewriter.create<arith::OrIOp>(loc, lhsRealIsNotNaN, lhsImagIsNotNaN);
- Value resultIsInfinity = rewriter.create<arith::AndIOp>(
- loc, lhsContainsNotNaNValue,
- rewriter.create<arith::AndIOp>(loc, rhsRealIsZero, rhsImagIsZero));
- Value inf = rewriter.create<arith::ConstantOp>(
- loc, elementType,
+ arith::OrIOp::create(rewriter, loc, lhsRealIsNotNaN, lhsImagIsNotNaN);
+ Value resultIsInfinity = arith::AndIOp::create(
+ rewriter, loc, lhsContainsNotNaNValue,
+ arith::AndIOp::create(rewriter, loc, rhsRealIsZero, rhsImagIsZero));
+ Value inf = arith::ConstantOp::create(
+ rewriter, loc, elementType,
rewriter.getFloatAttr(elementType,
APFloat::getInf(elementType.getFloatSemantics())));
Value infWithSignOfRhsReal =
- rewriter.create<math::CopySignOp>(loc, inf, rhsRe);
+ math::CopySignOp::create(rewriter, loc, inf, rhsRe);
Value infinityResultReal =
- rewriter.create<arith::MulFOp>(loc, infWithSignOfRhsReal, lhsRe, fmf);
+ arith::MulFOp::create(rewriter, loc, infWithSignOfRhsReal, lhsRe, fmf);
Value infinityResultImag =
- rewriter.create<arith::MulFOp>(loc, infWithSignOfRhsReal, lhsIm, fmf);
+ arith::MulFOp::create(rewriter, loc, infWithSignOfRhsReal, lhsIm, fmf);
// Case 2. Infinite numerator, finite denominator.
- Value rhsRealFinite = rewriter.create<arith::CmpFOp>(
- loc, arith::CmpFPredicate::ONE, rhsRealAbs, inf);
- Value rhsImagFinite = rewriter.create<arith::CmpFOp>(
- loc, arith::CmpFPredicate::ONE, rhsImagAbs, inf);
+ Value rhsRealFinite = arith::CmpFOp::create(
+ rewriter, loc, arith::CmpFPredicate::ONE, rhsRealAbs, inf);
+ Value rhsImagFinite = arith::CmpFOp::create(
+ rewriter, loc, arith::CmpFPredicate::ONE, rhsImagAbs, inf);
Value rhsFinite =
- rewriter.create<arith::AndIOp>(loc, rhsRealFinite, rhsImagFinite);
- Value lhsRealAbs = rewriter.create<math::AbsFOp>(loc, lhsRe, fmf);
- Value lhsRealInfinite = rewriter.create<arith::CmpFOp>(
- loc, arith::CmpFPredicate::OEQ, lhsRealAbs, inf);
- Value lhsImagAbs = rewriter.create<math::AbsFOp>(loc, lhsIm, fmf);
- Value lhsImagInfinite = rewriter.create<arith::CmpFOp>(
- loc, arith::CmpFPredicate::OEQ, lhsImagAbs, inf);
+ arith::AndIOp::create(rewriter, loc, rhsRealFinite, rhsImagFinite);
+ Value lhsRealAbs = math::AbsFOp::create(rewriter, loc, lhsRe, fmf);
+ Value lhsRealInfinite = arith::CmpFOp::create(
+ rewriter, loc, arith::CmpFPredicate::OEQ, lhsRealAbs, inf);
+ Value lhsImagAbs = math::AbsFOp::create(rewriter, loc, lhsIm, fmf);
+ Value lhsImagInfinite = arith::CmpFOp::create(
+ rewriter, loc, arith::CmpFPredicate::OEQ, lhsImagAbs, inf);
Value lhsInfinite =
- rewriter.create<arith::OrIOp>(loc, lhsRealInfinite, lhsImagInfinite);
+ arith::OrIOp::create(rewriter, loc, lhsRealInfinite, lhsImagInfinite);
Value infNumFiniteDenom =
- rewriter.create<arith::AndIOp>(loc, lhsInfinite, rhsFinite);
- Value one = rewriter.create<arith::ConstantOp>(
- loc, elementType, rewriter.getFloatAttr(elementType, 1));
- Value lhsRealIsInfWithSign = rewriter.create<math::CopySignOp>(
- loc, rewriter.create<arith::SelectOp>(loc, lhsRealInfinite, one, zero),
+ arith::AndIOp::create(rewriter, loc, lhsInfinite, rhsFinite);
+ Value one = arith::ConstantOp::create(rewriter, loc, elementType,
+ rewriter.getFloatAttr(elementType, 1));
+ Value lhsRealIsInfWithSign = math::CopySignOp::create(
+ rewriter, loc,
+ arith::SelectOp::create(rewriter, loc, lhsRealInfinite, one, zero),
lhsRe);
- Value lhsImagIsInfWithSign = rewriter.create<math::CopySignOp>(
- loc, rewriter.create<arith::SelectOp>(loc, lhsImagInfinite, one, zero),
+ Value lhsImagIsInfWithSign = math::CopySignOp::create(
+ rewriter, loc,
+ arith::SelectOp::create(rewriter, loc, lhsImagInfinite, one, zero),
lhsIm);
Value lhsRealIsInfWithSignTimesRhsReal =
- rewriter.create<arith::MulFOp>(loc, lhsRealIsInfWithSign, rhsRe, fmf);
+ arith::MulFOp::create(rewriter, loc, lhsRealIsInfWithSign, rhsRe, fmf);
Value lhsImagIsInfWithSignTimesRhsImag =
- rewriter.create<arith::MulFOp>(loc, lhsImagIsInfWithSign, rhsIm, fmf);
- Value resultReal3 = rewriter.create<arith::MulFOp>(
- loc, inf,
- rewriter.create<arith::AddFOp>(loc, lhsRealIsInfWithSignTimesRhsReal,
- lhsImagIsInfWithSignTimesRhsImag, fmf),
+ arith::MulFOp::create(rewriter, loc, lhsImagIsInfWithSign, rhsIm, fmf);
+ Value resultReal3 = arith::MulFOp::create(
+ rewriter, loc, inf,
+ arith::AddFOp::create(rewriter, loc, lhsRealIsInfWithSignTimesRhsReal,
+ lhsImagIsInfWithSignTimesRhsImag, fmf),
fmf);
Value lhsRealIsInfWithSignTimesRhsImag =
- rewriter.create<arith::MulFOp>(loc, lhsRealIsInfWithSign, rhsIm, fmf);
+ arith::MulFOp::create(rewriter, loc, lhsRealIsInfWithSign, rhsIm, fmf);
Value lhsImagIsInfWithSignTimesRhsReal =
- rewriter.create<arith::MulFOp>(loc, lhsImagIsInfWithSign, rhsRe, fmf);
- Value resultImag3 = rewriter.create<arith::MulFOp>(
- loc, inf,
- rewriter.create<arith::SubFOp>(loc, lhsImagIsInfWithSignTimesRhsReal,
- lhsRealIsInfWithSignTimesRhsImag, fmf),
+ arith::MulFOp::create(rewriter, loc, lhsImagIsInfWithSign, rhsRe, fmf);
+ Value resultImag3 = arith::MulFOp::create(
+ rewriter, loc, inf,
+ arith::SubFOp::create(rewriter, loc, lhsImagIsInfWithSignTimesRhsReal,
+ lhsRealIsInfWithSignTimesRhsImag, fmf),
fmf);
// Case 3: Finite numerator, infinite denominator.
- Value lhsRealFinite = rewriter.create<arith::CmpFOp>(
- loc, arith::CmpFPredicate::ONE, lhsRealAbs, inf);
- Value lhsImagFinite = rewriter.create<arith::CmpFOp>(
- loc, arith::CmpFPredicate::ONE, lhsImagAbs, inf);
+ Value lhsRealFinite = arith::CmpFOp::create(
+ rewriter, loc, arith::CmpFPredicate::ONE, lhsRealAbs, inf);
+ Value lhsImagFinite = arith::CmpFOp::create(
+ rewriter, loc, arith::CmpFPredicate::ONE, lhsImagAbs, inf);
Value lhsFinite =
- rewriter.create<arith::AndIOp>(loc, lhsRealFinite, lhsImagFinite);
- Value rhsRealInfinite = rewriter.create<arith::CmpFOp>(
- loc, arith::CmpFPredicate::OEQ, rhsRealAbs, inf);
- Value rhsImagInfinite = rewriter.create<arith::CmpFOp>(
- loc, arith::CmpFPredicate::OEQ, rhsImagAbs, inf);
+ arith::AndIOp::create(rewriter, loc, lhsRealFinite, lhsImagFinite);
+ Value rhsRealInfinite = arith::CmpFOp::create(
+ rewriter, loc, arith::CmpFPredicate::OEQ, rhsRealAbs, inf);
+ Value rhsImagInfinite = arith::CmpFOp::create(
+ rewriter, loc, arith::CmpFPredicate::OEQ, rhsImagAbs, inf);
Value rhsInfinite =
- rewriter.create<arith::OrIOp>(loc, rhsRealInfinite, rhsImagInfinite);
+ arith::OrIOp::create(rewriter, loc, rhsRealInfinite, rhsImagInfinite);
Value finiteNumInfiniteDenom =
- rewriter.create<arith::AndIOp>(loc, lhsFinite, rhsInfinite);
- Value rhsRealIsInfWithSign = rewriter.create<math::CopySignOp>(
- loc, rewriter.create<arith::SelectOp>(loc, rhsRealInfinite, one, zero),
+ arith::AndIOp::create(rewriter, loc, lhsFinite, rhsInfinite);
+ Value rhsRealIsInfWithSign = math::CopySignOp::create(
+ rewriter, loc,
+ arith::SelectOp::create(rewriter, loc, rhsRealInfinite, one, zero),
rhsRe);
- Value rhsImagIsInfWithSign = rewriter.create<math::CopySignOp>(
- loc, rewriter.create<arith::SelectOp>(loc, rhsImagInfinite, one, zero),
+ Value rhsImagIsInfWithSign = math::CopySignOp::create(
+ rewriter, loc,
+ arith::SelectOp::create(rewriter, loc, rhsImagInfinite, one, zero),
rhsIm);
Value rhsRealIsInfWithSignTimesLhsReal =
- rewriter.create<arith::MulFOp>(loc, lhsRe, rhsRealIsInfWithSign, fmf);
+ arith::MulFOp::create(rewriter, loc, lhsRe, rhsRealIsInfWithSign, fmf);
Value rhsImagIsInfWithSignTimesLhsImag =
- rewriter.create<arith::MulFOp>(loc, lhsIm, rhsImagIsInfWithSign, fmf);
- Value resultReal4 = rewriter.create<arith::MulFOp>(
- loc, zero,
- rewriter.create<arith::AddFOp>(loc, rhsRealIsInfWithSignTimesLhsReal,
- rhsImagIsInfWithSignTimesLhsImag, fmf),
+ arith::MulFOp::create(rewriter, loc, lhsIm, rhsImagIsInfWithSign, fmf);
+ Value resultReal4 = arith::MulFOp::create(
+ rewriter, loc, zero,
+ arith::AddFOp::create(rewriter, loc, rhsRealIsInfWithSignTimesLhsReal,
+ rhsImagIsInfWithSignTimesLhsImag, fmf),
fmf);
Value rhsRealIsInfWithSignTimesLhsImag =
- rewriter.create<arith::MulFOp>(loc, lhsIm, rhsRealIsInfWithSign, fmf);
+ arith::MulFOp::create(rewriter, loc, lhsIm, rhsRealIsInfWithSign, fmf);
Value rhsImagIsInfWithSignTimesLhsReal =
- rewriter.create<arith::MulFOp>(loc, lhsRe, rhsImagIsInfWithSign, fmf);
- Value resultImag4 = rewriter.create<arith::MulFOp>(
- loc, zero,
- rewriter.create<arith::SubFOp>(loc, rhsRealIsInfWithSignTimesLhsImag,
- rhsImagIsInfWithSignTimesLhsReal, fmf),
+ arith::MulFOp::create(rewriter, loc, lhsRe, rhsImagIsInfWithSign, fmf);
+ Value resultImag4 = arith::MulFOp::create(
+ rewriter, loc, zero,
+ arith::SubFOp::create(rewriter, loc, rhsRealIsInfWithSignTimesLhsImag,
+ rhsImagIsInfWithSignTimesLhsReal, fmf),
fmf);
- Value realAbsSmallerThanImagAbs = rewriter.create<arith::CmpFOp>(
- loc, arith::CmpFPredicate::OLT, rhsRealAbs, rhsImagAbs);
- Value resultReal5 = rewriter.create<arith::SelectOp>(
- loc, realAbsSmallerThanImagAbs, resultReal1, resultReal2);
- Value resultImag5 = rewriter.create<arith::SelectOp>(
- loc, realAbsSmallerThanImagAbs, resultImag1, resultImag2);
- Value resultRealSpecialCase3 = rewriter.create<arith::SelectOp>(
- loc, finiteNumInfiniteDenom, resultReal4, resultReal5);
- Value resultImagSpecialCase3 = rewriter.create<arith::SelectOp>(
- loc, finiteNumInfiniteDenom, resultImag4, resultImag5);
- Value resultRealSpecialCase2 = rewriter.create<arith::SelectOp>(
- loc, infNumFiniteDenom, resultReal3, resultRealSpecialCase3);
- Value resultImagSpecialCase2 = rewriter.create<arith::SelectOp>(
- loc, infNumFiniteDenom, resultImag3, resultImagSpecialCase3);
- Value resultRealSpecialCase1 = rewriter.create<arith::SelectOp>(
- loc, resultIsInfinity, infinityResultReal, resultRealSpecialCase2);
- Value resultImagSpecialCase1 = rewriter.create<arith::SelectOp>(
- loc, resultIsInfinity, infinityResultImag, resultImagSpecialCase2);
+ Value realAbsSmallerThanImagAbs = arith::CmpFOp::create(
+ rewriter, loc, arith::CmpFPredicate::OLT, rhsRealAbs, rhsImagAbs);
+ Value resultReal5 = arith::SelectOp::create(
+ rewriter, loc, realAbsSmallerThanImagAbs, resultReal1, resultReal2);
+ Value resultImag5 = arith::SelectOp::create(
+ rewriter, loc, realAbsSmallerThanImagAbs, resultImag1, resultImag2);
+ Value resultRealSpecialCase3 = arith::SelectOp::create(
+ rewriter, loc, finiteNumInfiniteDenom, resultReal4, resultReal5);
+ Value resultImagSpecialCase3 = arith::SelectOp::create(
+ rewriter, loc, finiteNumInfiniteDenom, resultImag4, resultImag5);
+ Value resultRealSpecialCase2 = arith::SelectOp::create(
+ rewriter, loc, infNumFiniteDenom, resultReal3, resultRealSpecialCase3);
+ Value resultImagSpecialCase2 = arith::SelectOp::create(
+ rewriter, loc, infNumFiniteDenom, resultImag3, resultImagSpecialCase3);
+ Value resultRealSpecialCase1 =
+ arith::SelectOp::create(rewriter, loc, resultIsInfinity,
+ infinityResultReal, resultRealSpecialCase2);
+ Value resultImagSpecialCase1 =
+ arith::SelectOp::create(rewriter, loc, resultIsInfinity,
+ infinityResultImag, resultImagSpecialCase2);
- Value resultRealIsNaN = rewriter.create<arith::CmpFOp>(
- loc, arith::CmpFPredicate::UNO, resultReal5, zero);
- Value resultImagIsNaN = rewriter.create<arith::CmpFOp>(
- loc, arith::CmpFPredicate::UNO, resultImag5, zero);
+ Value resultRealIsNaN = arith::CmpFOp::create(
+ rewriter, loc, arith::CmpFPredicate::UNO, resultReal5, zero);
+ Value resultImagIsNaN = arith::CmpFOp::create(
+ rewriter, loc, arith::CmpFPredicate::UNO, resultImag5, zero);
Value resultIsNaN =
- rewriter.create<arith::AndIOp>(loc, resultRealIsNaN, resultImagIsNaN);
+ arith::AndIOp::create(rewriter, loc, resultRealIsNaN, resultImagIsNaN);
- *resultRe = rewriter.create<arith::SelectOp>(
- loc, resultIsNaN, resultRealSpecialCase1, resultReal5);
- *resultIm = rewriter.create<arith::SelectOp>(
- loc, resultIsNaN, resultImagSpecialCase1, resultImag5);
+ *resultRe = arith::SelectOp::create(rewriter, loc, resultIsNaN,
+ resultRealSpecialCase1, resultReal5);
+ *resultIm = arith::SelectOp::create(rewriter, loc, resultIsNaN,
+ resultImagSpecialCase1, resultImag5);
}
diff --git a/mlir/lib/Conversion/ComplexToLLVM/ComplexToLLVM.cpp b/mlir/lib/Conversion/ComplexToLLVM/ComplexToLLVM.cpp
index e5e862315941d..86d02e6c6209f 100644
--- a/mlir/lib/Conversion/ComplexToLLVM/ComplexToLLVM.cpp
+++ b/mlir/lib/Conversion/ComplexToLLVM/ComplexToLLVM.cpp
@@ -35,7 +35,7 @@ static constexpr unsigned kImaginaryPosInComplexNumberStruct = 1;
ComplexStructBuilder ComplexStructBuilder::poison(OpBuilder &builder,
Location loc, Type type) {
- Value val = builder.create<LLVM::PoisonOp>(loc, type);
+ Value val = LLVM::PoisonOp::create(builder, loc, type);
return ComplexStructBuilder(val);
}
@@ -79,9 +79,9 @@ struct AbsOpConversion : public ConvertOpToLLVMPattern<complex::AbsOp> {
LLVM::FastmathFlagsAttr fmf = LLVM::FastmathFlagsAttr::get(
op.getContext(),
convertArithFastMathFlagsToLLVM(complexFMFAttr.getValue()));
- Value sqNorm = rewriter.create<LLVM::FAddOp>(
- loc, rewriter.create<LLVM::FMulOp>(loc, real, real, fmf),
- rewriter.create<LLVM::FMulOp>(loc, imag, imag, fmf), fmf);
+ Value sqNorm = LLVM::FAddOp::create(
+ rewriter, loc, LLVM::FMulOp::create(rewriter, loc, real, real, fmf),
+ LLVM::FMulOp::create(rewriter, loc, imag, imag, fmf), fmf);
rewriter.replaceOpWithNewOp<LLVM::SqrtOp>(op, sqNorm);
return success();
@@ -191,10 +191,10 @@ struct AddOpConversion : public ConvertOpToLLVMPattern<complex::AddOp> {
LLVM::FastmathFlagsAttr fmf = LLVM::FastmathFlagsAttr::get(
op.getContext(),
convertArithFastMathFlagsToLLVM(complexFMFAttr.getValue()));
- Value real =
- rewriter.create<LLVM::FAddOp>(loc, arg.lhs.real(), arg.rhs.real(), fmf);
- Value imag =
- rewriter.create<LLVM::FAddOp>(loc, arg.lhs.imag(), arg.rhs.imag(), fmf);
+ Value real = LLVM::FAddOp::create(rewriter, loc, arg.lhs.real(),
+ arg.rhs.real(), fmf);
+ Value imag = LLVM::FAddOp::create(rewriter, loc, arg.lhs.imag(),
+ arg.rhs.imag(), fmf);
result.setReal(rewriter, loc, real);
result.setImaginary(rewriter, loc, imag);
@@ -278,13 +278,13 @@ struct MulOpConversion : public ConvertOpToLLVMPattern<complex::MulOp> {
Value lhsRe = arg.lhs.real();
Value lhsIm = arg.lhs.imag();
- Value real = rewriter.create<LLVM::FSubOp>(
- loc, rewriter.create<LLVM::FMulOp>(loc, rhsRe, lhsRe, fmf),
- rewriter.create<LLVM::FMulOp>(loc, rhsIm, lhsIm, fmf), fmf);
+ Value real = LLVM::FSubOp::create(
+ rewriter, loc, LLVM::FMulOp::create(rewriter, loc, rhsRe, lhsRe, fmf),
+ LLVM::FMulOp::create(rewriter, loc, rhsIm, lhsIm, fmf), fmf);
- Value imag = rewriter.create<LLVM::FAddOp>(
- loc, rewriter.create<LLVM::FMulOp>(loc, lhsIm, rhsRe, fmf),
- rewriter.create<LLVM::FMulOp>(loc, lhsRe, rhsIm, fmf), fmf);
+ Value imag = LLVM::FAddOp::create(
+ rewriter, loc, LLVM::FMulOp::create(rewriter, loc, lhsIm, rhsRe, fmf),
+ LLVM::FMulOp::create(rewriter, loc, lhsRe, rhsIm, fmf), fmf);
result.setReal(rewriter, loc, real);
result.setImaginary(rewriter, loc, imag);
@@ -313,10 +313,10 @@ struct SubOpConversion : public ConvertOpToLLVMPattern<complex::SubOp> {
LLVM::FastmathFlagsAttr fmf = LLVM::FastmathFlagsAttr::get(
op.getContext(),
convertArithFastMathFlagsToLLVM(complexFMFAttr.getValue()));
- Value real =
- rewriter.create<LLVM::FSubOp>(loc, arg.lhs.real(), arg.rhs.real(), fmf);
- Value imag =
- rewriter.create<LLVM::FSubOp>(loc, arg.lhs.imag(), arg.rhs.imag(), fmf);
+ Value real = LLVM::FSubOp::create(rewriter, loc, arg.lhs.real(),
+ arg.rhs.real(), fmf);
+ Value imag = LLVM::FSubOp::create(rewriter, loc, arg.lhs.imag(),
+ arg.rhs.imag(), fmf);
result.setReal(rewriter, loc, real);
result.setImaginary(rewriter, loc, imag);
diff --git a/mlir/lib/Conversion/ComplexToLibm/ComplexToLibm.cpp b/mlir/lib/Conversion/ComplexToLibm/ComplexToLibm.cpp
index 56269d189873a..f83cac751ff05 100644
--- a/mlir/lib/Conversion/ComplexToLibm/ComplexToLibm.cpp
+++ b/mlir/lib/Conversion/ComplexToLibm/ComplexToLibm.cpp
@@ -84,8 +84,8 @@ LogicalResult ScalarOpToLibmCall<Op, TypeResolver>::matchAndRewrite(
rewriter.setInsertionPointToStart(&module->getRegion(0).front());
auto opFunctionTy = FunctionType::get(
rewriter.getContext(), op->getOperandTypes(), op->getResultTypes());
- opFunc = rewriter.create<func::FuncOp>(rewriter.getUnknownLoc(), name,
- opFunctionTy);
+ opFunc = func::FuncOp::create(rewriter, rewriter.getUnknownLoc(), name,
+ opFunctionTy);
opFunc.setPrivate();
}
assert(isa<FunctionOpInterface>(SymbolTable::lookupSymbolIn(module, name)));
diff --git a/mlir/lib/Conversion/ComplexToROCDLLibraryCalls/ComplexToROCDLLibraryCalls.cpp b/mlir/lib/Conversion/ComplexToROCDLLibraryCalls/ComplexToROCDLLibraryCalls.cpp
index 99d5424aef79a..6f0fc2965e6fd 100644
--- a/mlir/lib/Conversion/ComplexToROCDLLibraryCalls/ComplexToROCDLLibraryCalls.cpp
+++ b/mlir/lib/Conversion/ComplexToROCDLLibraryCalls/ComplexToROCDLLibraryCalls.cpp
@@ -44,8 +44,8 @@ struct ComplexOpToROCDLLibraryCalls : public OpRewritePattern<Op> {
rewriter.setInsertionPointToStart(&symTable->getRegion(0).front());
auto funcTy = FunctionType::get(
rewriter.getContext(), op->getOperandTypes(), op->getResultTypes());
- opFunc = rewriter.create<func::FuncOp>(rewriter.getUnknownLoc(), funcName,
- funcTy);
+ opFunc = func::FuncOp::create(rewriter, rewriter.getUnknownLoc(),
+ funcName, funcTy);
opFunc.setPrivate();
}
rewriter.replaceOpWithNewOp<func::CallOp>(op, funcName, op.getType(),
diff --git a/mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp b/mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp
index 0c832c452718b..eeff8a93e7a72 100644
--- a/mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp
+++ b/mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp
@@ -31,44 +31,45 @@ enum class AbsFn { abs, sqrt, rsqrt };
// Returns the absolute value, its square root or its reciprocal square root.
Value computeAbs(Value real, Value imag, arith::FastMathFlags fmf,
ImplicitLocOpBuilder &b, AbsFn fn = AbsFn::abs) {
- Value one = b.create<arith::ConstantOp>(real.getType(),
- b.getFloatAttr(real.getType(), 1.0));
+ Value one = arith::ConstantOp::create(b, real.getType(),
+ b.getFloatAttr(real.getType(), 1.0));
- Value absReal = b.create<math::AbsFOp>(real, fmf);
- Value absImag = b.create<math::AbsFOp>(imag, fmf);
+ Value absReal = math::AbsFOp::create(b, real, fmf);
+ Value absImag = math::AbsFOp::create(b, imag, fmf);
- Value max = b.create<arith::MaximumFOp>(absReal, absImag, fmf);
- Value min = b.create<arith::MinimumFOp>(absReal, absImag, fmf);
+ Value max = arith::MaximumFOp::create(b, absReal, absImag, fmf);
+ Value min = arith::MinimumFOp::create(b, absReal, absImag, fmf);
// The lowering below requires NaNs and infinities to work correctly.
arith::FastMathFlags fmfWithNaNInf = arith::bitEnumClear(
fmf, arith::FastMathFlags::nnan | arith::FastMathFlags::ninf);
- Value ratio = b.create<arith::DivFOp>(min, max, fmfWithNaNInf);
- Value ratioSq = b.create<arith::MulFOp>(ratio, ratio, fmfWithNaNInf);
- Value ratioSqPlusOne = b.create<arith::AddFOp>(ratioSq, one, fmfWithNaNInf);
+ Value ratio = arith::DivFOp::create(b, min, max, fmfWithNaNInf);
+ Value ratioSq = arith::MulFOp::create(b, ratio, ratio, fmfWithNaNInf);
+ Value ratioSqPlusOne = arith::AddFOp::create(b, ratioSq, one, fmfWithNaNInf);
Value result;
if (fn == AbsFn::rsqrt) {
- ratioSqPlusOne = b.create<math::RsqrtOp>(ratioSqPlusOne, fmfWithNaNInf);
- min = b.create<math::RsqrtOp>(min, fmfWithNaNInf);
- max = b.create<math::RsqrtOp>(max, fmfWithNaNInf);
+ ratioSqPlusOne = math::RsqrtOp::create(b, ratioSqPlusOne, fmfWithNaNInf);
+ min = math::RsqrtOp::create(b, min, fmfWithNaNInf);
+ max = math::RsqrtOp::create(b, max, fmfWithNaNInf);
}
if (fn == AbsFn::sqrt) {
- Value quarter = b.create<arith::ConstantOp>(
- real.getType(), b.getFloatAttr(real.getType(), 0.25));
+ Value quarter = arith::ConstantOp::create(
+ b, real.getType(), b.getFloatAttr(real.getType(), 0.25));
// sqrt(sqrt(a*b)) would avoid the pow, but will overflow more easily.
- Value sqrt = b.create<math::SqrtOp>(max, fmfWithNaNInf);
- Value p025 = b.create<math::PowFOp>(ratioSqPlusOne, quarter, fmfWithNaNInf);
- result = b.create<arith::MulFOp>(sqrt, p025, fmfWithNaNInf);
+ Value sqrt = math::SqrtOp::create(b, max, fmfWithNaNInf);
+ Value p025 =
+ math::PowFOp::create(b, ratioSqPlusOne, quarter, fmfWithNaNInf);
+ result = arith::MulFOp::create(b, sqrt, p025, fmfWithNaNInf);
} else {
- Value sqrt = b.create<math::SqrtOp>(ratioSqPlusOne, fmfWithNaNInf);
- result = b.create<arith::MulFOp>(max, sqrt, fmfWithNaNInf);
+ Value sqrt = math::SqrtOp::create(b, ratioSqPlusOne, fmfWithNaNInf);
+ result = arith::MulFOp::create(b, max, sqrt, fmfWithNaNInf);
}
- Value isNaN = b.create<arith::CmpFOp>(arith::CmpFPredicate::UNO, result,
- result, fmfWithNaNInf);
- return b.create<arith::SelectOp>(isNaN, min, result);
+ Value isNaN = arith::CmpFOp::create(b, arith::CmpFPredicate::UNO, result,
+ result, fmfWithNaNInf);
+ return arith::SelectOp::create(b, isNaN, min, result);
}
struct AbsOpConversion : public OpConversionPattern<complex::AbsOp> {
@@ -81,8 +82,8 @@ struct AbsOpConversion : public OpConversionPattern<complex::AbsOp> {
arith::FastMathFlags fmf = op.getFastMathFlagsAttr().getValue();
- Value real = b.create<complex::ReOp>(adaptor.getComplex());
- Value imag = b.create<complex::ImOp>(adaptor.getComplex());
+ Value real = complex::ReOp::create(b, adaptor.getComplex());
+ Value imag = complex::ImOp::create(b, adaptor.getComplex());
rewriter.replaceOp(op, computeAbs(real, imag, fmf, b));
return success();
@@ -105,28 +106,28 @@ struct Atan2OpConversion : public OpConversionPattern<complex::Atan2Op> {
Value lhs = adaptor.getLhs();
Value rhs = adaptor.getRhs();
- Value rhsSquared = b.create<complex::MulOp>(type, rhs, rhs, fmf);
- Value lhsSquared = b.create<complex::MulOp>(type, lhs, lhs, fmf);
+ Value rhsSquared = complex::MulOp::create(b, type, rhs, rhs, fmf);
+ Value lhsSquared = complex::MulOp::create(b, type, lhs, lhs, fmf);
Value rhsSquaredPlusLhsSquared =
- b.create<complex::AddOp>(type, rhsSquared, lhsSquared, fmf);
+ complex::AddOp::create(b, type, rhsSquared, lhsSquared, fmf);
Value sqrtOfRhsSquaredPlusLhsSquared =
- b.create<complex::SqrtOp>(type, rhsSquaredPlusLhsSquared, fmf);
+ complex::SqrtOp::create(b, type, rhsSquaredPlusLhsSquared, fmf);
Value zero =
- b.create<arith::ConstantOp>(elementType, b.getZeroAttr(elementType));
- Value one = b.create<arith::ConstantOp>(elementType,
- b.getFloatAttr(elementType, 1));
- Value i = b.create<complex::CreateOp>(type, zero, one);
- Value iTimesLhs = b.create<complex::MulOp>(i, lhs, fmf);
- Value rhsPlusILhs = b.create<complex::AddOp>(rhs, iTimesLhs, fmf);
+ arith::ConstantOp::create(b, elementType, b.getZeroAttr(elementType));
+ Value one = arith::ConstantOp::create(b, elementType,
+ b.getFloatAttr(elementType, 1));
+ Value i = complex::CreateOp::create(b, type, zero, one);
+ Value iTimesLhs = complex::MulOp::create(b, i, lhs, fmf);
+ Value rhsPlusILhs = complex::AddOp::create(b, rhs, iTimesLhs, fmf);
- Value divResult = b.create<complex::DivOp>(
- rhsPlusILhs, sqrtOfRhsSquaredPlusLhsSquared, fmf);
- Value logResult = b.create<complex::LogOp>(divResult, fmf);
+ Value divResult = complex::DivOp::create(
+ b, rhsPlusILhs, sqrtOfRhsSquaredPlusLhsSquared, fmf);
+ Value logResult = complex::LogOp::create(b, divResult, fmf);
- Value negativeOne = b.create<arith::ConstantOp>(
- elementType, b.getFloatAttr(elementType, -1));
- Value negativeI = b.create<complex::CreateOp>(type, zero, negativeOne);
+ Value negativeOne = arith::ConstantOp::create(
+ b, elementType, b.getFloatAttr(elementType, -1));
+ Value negativeI = complex::CreateOp::create(b, type, zero, negativeOne);
rewriter.replaceOpWithNewOp<complex::MulOp>(op, negativeI, logResult, fmf);
return success();
@@ -146,14 +147,18 @@ struct ComparisonOpConversion : public OpConversionPattern<ComparisonOp> {
auto loc = op.getLoc();
auto type = cast<ComplexType>(adaptor.getLhs().getType()).getElementType();
- Value realLhs = rewriter.create<complex::ReOp>(loc, type, adaptor.getLhs());
- Value imagLhs = rewriter.create<complex::ImOp>(loc, type, adaptor.getLhs());
- Value realRhs = rewriter.create<complex::ReOp>(loc, type, adaptor.getRhs());
- Value imagRhs = rewriter.create<complex::ImOp>(loc, type, adaptor.getRhs());
+ Value realLhs =
+ complex::ReOp::create(rewriter, loc, type, adaptor.getLhs());
+ Value imagLhs =
+ complex::ImOp::create(rewriter, loc, type, adaptor.getLhs());
+ Value realRhs =
+ complex::ReOp::create(rewriter, loc, type, adaptor.getRhs());
+ Value imagRhs =
+ complex::ImOp::create(rewriter, loc, type, adaptor.getRhs());
Value realComparison =
- rewriter.create<arith::CmpFOp>(loc, p, realLhs, realRhs);
+ arith::CmpFOp::create(rewriter, loc, p, realLhs, realRhs);
Value imagComparison =
- rewriter.create<arith::CmpFOp>(loc, p, imagLhs, imagRhs);
+ arith::CmpFOp::create(rewriter, loc, p, imagLhs, imagRhs);
rewriter.replaceOpWithNewOp<ResultCombiner>(op, realComparison,
imagComparison);
@@ -176,14 +181,14 @@ struct BinaryComplexOpConversion : public OpConversionPattern<BinaryComplexOp> {
mlir::ImplicitLocOpBuilder b(op.getLoc(), rewriter);
arith::FastMathFlagsAttr fmf = op.getFastMathFlagsAttr();
- Value realLhs = b.create<complex::ReOp>(elementType, adaptor.getLhs());
- Value realRhs = b.create<complex::ReOp>(elementType, adaptor.getRhs());
- Value resultReal = b.create<BinaryStandardOp>(elementType, realLhs, realRhs,
- fmf.getValue());
- Value imagLhs = b.create<complex::ImOp>(elementType, adaptor.getLhs());
- Value imagRhs = b.create<complex::ImOp>(elementType, adaptor.getRhs());
- Value resultImag = b.create<BinaryStandardOp>(elementType, imagLhs, imagRhs,
- fmf.getValue());
+ Value realLhs = complex::ReOp::create(b, elementType, adaptor.getLhs());
+ Value realRhs = complex::ReOp::create(b, elementType, adaptor.getRhs());
+ Value resultReal = BinaryStandardOp::create(b, elementType, realLhs,
+ realRhs, fmf.getValue());
+ Value imagLhs = complex::ImOp::create(b, elementType, adaptor.getLhs());
+ Value imagRhs = complex::ImOp::create(b, elementType, adaptor.getRhs());
+ Value resultImag = BinaryStandardOp::create(b, elementType, imagLhs,
+ imagRhs, fmf.getValue());
rewriter.replaceOpWithNewOp<complex::CreateOp>(op, type, resultReal,
resultImag);
return success();
@@ -205,20 +210,20 @@ struct TrigonometricOpConversion : public OpConversionPattern<TrigonometricOp> {
arith::FastMathFlagsAttr fmf = op.getFastMathFlagsAttr();
Value real =
- rewriter.create<complex::ReOp>(loc, elementType, adaptor.getComplex());
+ complex::ReOp::create(rewriter, loc, elementType, adaptor.getComplex());
Value imag =
- rewriter.create<complex::ImOp>(loc, elementType, adaptor.getComplex());
+ complex::ImOp::create(rewriter, loc, elementType, adaptor.getComplex());
// Trigonometric ops use a set of common building blocks to convert to real
// ops. Here we create these building blocks and call into an op-specific
// implementation in the subclass to combine them.
- Value half = rewriter.create<arith::ConstantOp>(
- loc, elementType, rewriter.getFloatAttr(elementType, 0.5));
- Value exp = rewriter.create<math::ExpOp>(loc, imag, fmf);
- Value scaledExp = rewriter.create<arith::MulFOp>(loc, half, exp, fmf);
- Value reciprocalExp = rewriter.create<arith::DivFOp>(loc, half, exp, fmf);
- Value sin = rewriter.create<math::SinOp>(loc, real, fmf);
- Value cos = rewriter.create<math::CosOp>(loc, real, fmf);
+ Value half = arith::ConstantOp::create(
+ rewriter, loc, elementType, rewriter.getFloatAttr(elementType, 0.5));
+ Value exp = math::ExpOp::create(rewriter, loc, imag, fmf);
+ Value scaledExp = arith::MulFOp::create(rewriter, loc, half, exp, fmf);
+ Value reciprocalExp = arith::DivFOp::create(rewriter, loc, half, exp, fmf);
+ Value sin = math::SinOp::create(rewriter, loc, real, fmf);
+ Value cos = math::CosOp::create(rewriter, loc, real, fmf);
auto resultPair =
combine(loc, scaledExp, reciprocalExp, sin, cos, rewriter, fmf);
@@ -251,11 +256,11 @@ struct CosOpConversion : public TrigonometricOpConversion<complex::CosOp> {
// Re(cos(x + iy)) = (0.5/t + 0.5*t) * cos x
// Im(cos(x + iy)) = (0.5/t - 0.5*t) * sin x
Value sum =
- rewriter.create<arith::AddFOp>(loc, reciprocalExp, scaledExp, fmf);
- Value resultReal = rewriter.create<arith::MulFOp>(loc, sum, cos, fmf);
+ arith::AddFOp::create(rewriter, loc, reciprocalExp, scaledExp, fmf);
+ Value resultReal = arith::MulFOp::create(rewriter, loc, sum, cos, fmf);
Value diff =
- rewriter.create<arith::SubFOp>(loc, reciprocalExp, scaledExp, fmf);
- Value resultImag = rewriter.create<arith::MulFOp>(loc, diff, sin, fmf);
+ arith::SubFOp::create(rewriter, loc, reciprocalExp, scaledExp, fmf);
+ Value resultImag = arith::MulFOp::create(rewriter, loc, diff, sin, fmf);
return {resultReal, resultImag};
}
};
@@ -275,13 +280,13 @@ struct DivOpConversion : public OpConversionPattern<complex::DivOp> {
arith::FastMathFlagsAttr fmf = op.getFastMathFlagsAttr();
Value lhsReal =
- rewriter.create<complex::ReOp>(loc, elementType, adaptor.getLhs());
+ complex::ReOp::create(rewriter, loc, elementType, adaptor.getLhs());
Value lhsImag =
- rewriter.create<complex::ImOp>(loc, elementType, adaptor.getLhs());
+ complex::ImOp::create(rewriter, loc, elementType, adaptor.getLhs());
Value rhsReal =
- rewriter.create<complex::ReOp>(loc, elementType, adaptor.getRhs());
+ complex::ReOp::create(rewriter, loc, elementType, adaptor.getRhs());
Value rhsImag =
- rewriter.create<complex::ImOp>(loc, elementType, adaptor.getRhs());
+ complex::ImOp::create(rewriter, loc, elementType, adaptor.getRhs());
Value resultReal, resultImag;
@@ -318,16 +323,16 @@ struct ExpOpConversion : public OpConversionPattern<complex::ExpOp> {
arith::FastMathFlagsAttr fmf = op.getFastMathFlagsAttr();
Value real =
- rewriter.create<complex::ReOp>(loc, elementType, adaptor.getComplex());
+ complex::ReOp::create(rewriter, loc, elementType, adaptor.getComplex());
Value imag =
- rewriter.create<complex::ImOp>(loc, elementType, adaptor.getComplex());
- Value expReal = rewriter.create<math::ExpOp>(loc, real, fmf.getValue());
- Value cosImag = rewriter.create<math::CosOp>(loc, imag, fmf.getValue());
+ complex::ImOp::create(rewriter, loc, elementType, adaptor.getComplex());
+ Value expReal = math::ExpOp::create(rewriter, loc, real, fmf.getValue());
+ Value cosImag = math::CosOp::create(rewriter, loc, imag, fmf.getValue());
Value resultReal =
- rewriter.create<arith::MulFOp>(loc, expReal, cosImag, fmf.getValue());
- Value sinImag = rewriter.create<math::SinOp>(loc, imag, fmf.getValue());
+ arith::MulFOp::create(rewriter, loc, expReal, cosImag, fmf.getValue());
+ Value sinImag = math::SinOp::create(rewriter, loc, imag, fmf.getValue());
Value resultImag =
- rewriter.create<arith::MulFOp>(loc, expReal, sinImag, fmf.getValue());
+ arith::MulFOp::create(rewriter, loc, expReal, sinImag, fmf.getValue());
rewriter.replaceOpWithNewOp<complex::CreateOp>(op, type, resultReal,
resultImag);
@@ -340,11 +345,11 @@ Value evaluatePolynomial(ImplicitLocOpBuilder &b, Value arg,
arith::FastMathFlagsAttr fmf) {
auto argType = mlir::cast<FloatType>(arg.getType());
Value poly =
- b.create<arith::ConstantOp>(b.getFloatAttr(argType, coefficients[0]));
+ arith::ConstantOp::create(b, b.getFloatAttr(argType, coefficients[0]));
for (unsigned i = 1; i < coefficients.size(); ++i) {
- poly = b.create<math::FmaOp>(
- poly, arg,
- b.create<arith::ConstantOp>(b.getFloatAttr(argType, coefficients[i])),
+ poly = math::FmaOp::create(
+ b, poly, arg,
+ arith::ConstantOp::create(b, b.getFloatAttr(argType, coefficients[i])),
fmf);
}
return poly;
@@ -365,26 +370,26 @@ struct Expm1OpConversion : public OpConversionPattern<complex::Expm1Op> {
arith::FastMathFlagsAttr fmf = op.getFastMathFlagsAttr();
ImplicitLocOpBuilder b(op.getLoc(), rewriter);
- Value real = b.create<complex::ReOp>(adaptor.getComplex());
- Value imag = b.create<complex::ImOp>(adaptor.getComplex());
+ Value real = complex::ReOp::create(b, adaptor.getComplex());
+ Value imag = complex::ImOp::create(b, adaptor.getComplex());
- Value zero = b.create<arith::ConstantOp>(b.getFloatAttr(elemType, 0.0));
- Value one = b.create<arith::ConstantOp>(b.getFloatAttr(elemType, 1.0));
+ Value zero = arith::ConstantOp::create(b, b.getFloatAttr(elemType, 0.0));
+ Value one = arith::ConstantOp::create(b, b.getFloatAttr(elemType, 1.0));
- Value expm1Real = b.create<math::ExpM1Op>(real, fmf);
- Value expReal = b.create<arith::AddFOp>(expm1Real, one, fmf);
+ Value expm1Real = math::ExpM1Op::create(b, real, fmf);
+ Value expReal = arith::AddFOp::create(b, expm1Real, one, fmf);
- Value sinImag = b.create<math::SinOp>(imag, fmf);
+ Value sinImag = math::SinOp::create(b, imag, fmf);
Value cosm1Imag = emitCosm1(imag, fmf, b);
- Value cosImag = b.create<arith::AddFOp>(cosm1Imag, one, fmf);
+ Value cosImag = arith::AddFOp::create(b, cosm1Imag, one, fmf);
- Value realResult = b.create<arith::AddFOp>(
- b.create<arith::MulFOp>(expm1Real, cosImag, fmf), cosm1Imag, fmf);
+ Value realResult = arith::AddFOp::create(
+ b, arith::MulFOp::create(b, expm1Real, cosImag, fmf), cosm1Imag, fmf);
- Value imagIsZero = b.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, imag,
- zero, fmf.getValue());
- Value imagResult = b.create<arith::SelectOp>(
- imagIsZero, zero, b.create<arith::MulFOp>(expReal, sinImag, fmf));
+ Value imagIsZero = arith::CmpFOp::create(b, arith::CmpFPredicate::OEQ, imag,
+ zero, fmf.getValue());
+ Value imagResult = arith::SelectOp::create(
+ b, imagIsZero, zero, arith::MulFOp::create(b, expReal, sinImag, fmf));
rewriter.replaceOpWithNewOp<complex::CreateOp>(op, type, realResult,
imagResult);
@@ -395,8 +400,8 @@ struct Expm1OpConversion : public OpConversionPattern<complex::Expm1Op> {
Value emitCosm1(Value arg, arith::FastMathFlagsAttr fmf,
ImplicitLocOpBuilder &b) const {
auto argType = mlir::cast<FloatType>(arg.getType());
- auto negHalf = b.create<arith::ConstantOp>(b.getFloatAttr(argType, -0.5));
- auto negOne = b.create<arith::ConstantOp>(b.getFloatAttr(argType, -1.0));
+ auto negHalf = arith::ConstantOp::create(b, b.getFloatAttr(argType, -0.5));
+ auto negOne = arith::ConstantOp::create(b, b.getFloatAttr(argType, -1.0));
// Algorithm copied from cephes cosm1.
SmallVector<double, 7> kCoeffs{
@@ -405,23 +410,23 @@ struct Expm1OpConversion : public OpConversionPattern<complex::Expm1Op> {
2.4801587301570552304991E-5, -1.3888888888888872993737E-3,
4.1666666666666666609054E-2,
};
- Value cos = b.create<math::CosOp>(arg, fmf);
- Value forLargeArg = b.create<arith::AddFOp>(cos, negOne, fmf);
+ Value cos = math::CosOp::create(b, arg, fmf);
+ Value forLargeArg = arith::AddFOp::create(b, cos, negOne, fmf);
- Value argPow2 = b.create<arith::MulFOp>(arg, arg, fmf);
- Value argPow4 = b.create<arith::MulFOp>(argPow2, argPow2, fmf);
+ Value argPow2 = arith::MulFOp::create(b, arg, arg, fmf);
+ Value argPow4 = arith::MulFOp::create(b, argPow2, argPow2, fmf);
Value poly = evaluatePolynomial(b, argPow2, kCoeffs, fmf);
auto forSmallArg =
- b.create<arith::AddFOp>(b.create<arith::MulFOp>(argPow4, poly, fmf),
- b.create<arith::MulFOp>(negHalf, argPow2, fmf));
+ arith::AddFOp::create(b, arith::MulFOp::create(b, argPow4, poly, fmf),
+ arith::MulFOp::create(b, negHalf, argPow2, fmf));
// (pi/4)^2 is approximately 0.61685
Value piOver4Pow2 =
- b.create<arith::ConstantOp>(b.getFloatAttr(argType, 0.61685));
- Value cond = b.create<arith::CmpFOp>(arith::CmpFPredicate::OGE, argPow2,
- piOver4Pow2, fmf.getValue());
- return b.create<arith::SelectOp>(cond, forLargeArg, forSmallArg);
+ arith::ConstantOp::create(b, b.getFloatAttr(argType, 0.61685));
+ Value cond = arith::CmpFOp::create(b, arith::CmpFPredicate::OGE, argPow2,
+ piOver4Pow2, fmf.getValue());
+ return arith::SelectOp::create(b, cond, forLargeArg, forSmallArg);
}
};
@@ -436,13 +441,13 @@ struct LogOpConversion : public OpConversionPattern<complex::LogOp> {
arith::FastMathFlagsAttr fmf = op.getFastMathFlagsAttr();
mlir::ImplicitLocOpBuilder b(op.getLoc(), rewriter);
- Value abs = b.create<complex::AbsOp>(elementType, adaptor.getComplex(),
- fmf.getValue());
- Value resultReal = b.create<math::LogOp>(elementType, abs, fmf.getValue());
- Value real = b.create<complex::ReOp>(elementType, adaptor.getComplex());
- Value imag = b.create<complex::ImOp>(elementType, adaptor.getComplex());
+ Value abs = complex::AbsOp::create(b, elementType, adaptor.getComplex(),
+ fmf.getValue());
+ Value resultReal = math::LogOp::create(b, elementType, abs, fmf.getValue());
+ Value real = complex::ReOp::create(b, elementType, adaptor.getComplex());
+ Value imag = complex::ImOp::create(b, elementType, adaptor.getComplex());
Value resultImag =
- b.create<math::Atan2Op>(elementType, imag, real, fmf.getValue());
+ math::Atan2Op::create(b, elementType, imag, real, fmf.getValue());
rewriter.replaceOpWithNewOp<complex::CreateOp>(op, type, resultReal,
resultImag);
return success();
@@ -460,40 +465,42 @@ struct Log1pOpConversion : public OpConversionPattern<complex::Log1pOp> {
arith::FastMathFlags fmf = op.getFastMathFlagsAttr().getValue();
mlir::ImplicitLocOpBuilder b(op.getLoc(), rewriter);
- Value real = b.create<complex::ReOp>(adaptor.getComplex());
- Value imag = b.create<complex::ImOp>(adaptor.getComplex());
+ Value real = complex::ReOp::create(b, adaptor.getComplex());
+ Value imag = complex::ImOp::create(b, adaptor.getComplex());
- Value half = b.create<arith::ConstantOp>(elementType,
- b.getFloatAttr(elementType, 0.5));
- Value one = b.create<arith::ConstantOp>(elementType,
- b.getFloatAttr(elementType, 1));
- Value realPlusOne = b.create<arith::AddFOp>(real, one, fmf);
- Value absRealPlusOne = b.create<math::AbsFOp>(realPlusOne, fmf);
- Value absImag = b.create<math::AbsFOp>(imag, fmf);
+ Value half = arith::ConstantOp::create(b, elementType,
+ b.getFloatAttr(elementType, 0.5));
+ Value one = arith::ConstantOp::create(b, elementType,
+ b.getFloatAttr(elementType, 1));
+ Value realPlusOne = arith::AddFOp::create(b, real, one, fmf);
+ Value absRealPlusOne = math::AbsFOp::create(b, realPlusOne, fmf);
+ Value absImag = math::AbsFOp::create(b, imag, fmf);
- Value maxAbs = b.create<arith::MaximumFOp>(absRealPlusOne, absImag, fmf);
- Value minAbs = b.create<arith::MinimumFOp>(absRealPlusOne, absImag, fmf);
+ Value maxAbs = arith::MaximumFOp::create(b, absRealPlusOne, absImag, fmf);
+ Value minAbs = arith::MinimumFOp::create(b, absRealPlusOne, absImag, fmf);
- Value useReal = b.create<arith::CmpFOp>(arith::CmpFPredicate::OGT,
- realPlusOne, absImag, fmf);
- Value maxMinusOne = b.create<arith::SubFOp>(maxAbs, one, fmf);
+ Value useReal = arith::CmpFOp::create(b, arith::CmpFPredicate::OGT,
+ realPlusOne, absImag, fmf);
+ Value maxMinusOne = arith::SubFOp::create(b, maxAbs, one, fmf);
Value maxAbsOfRealPlusOneAndImagMinusOne =
- b.create<arith::SelectOp>(useReal, real, maxMinusOne);
+ arith::SelectOp::create(b, useReal, real, maxMinusOne);
arith::FastMathFlags fmfWithNaNInf = arith::bitEnumClear(
fmf, arith::FastMathFlags::nnan | arith::FastMathFlags::ninf);
- Value minMaxRatio = b.create<arith::DivFOp>(minAbs, maxAbs, fmfWithNaNInf);
+ Value minMaxRatio = arith::DivFOp::create(b, minAbs, maxAbs, fmfWithNaNInf);
Value logOfMaxAbsOfRealPlusOneAndImag =
- b.create<math::Log1pOp>(maxAbsOfRealPlusOneAndImagMinusOne, fmf);
- Value logOfSqrtPart = b.create<math::Log1pOp>(
- b.create<arith::MulFOp>(minMaxRatio, minMaxRatio, fmfWithNaNInf),
+ math::Log1pOp::create(b, maxAbsOfRealPlusOneAndImagMinusOne, fmf);
+ Value logOfSqrtPart = math::Log1pOp::create(
+ b, arith::MulFOp::create(b, minMaxRatio, minMaxRatio, fmfWithNaNInf),
fmfWithNaNInf);
- Value r = b.create<arith::AddFOp>(
- b.create<arith::MulFOp>(half, logOfSqrtPart, fmfWithNaNInf),
+ Value r = arith::AddFOp::create(
+ b, arith::MulFOp::create(b, half, logOfSqrtPart, fmfWithNaNInf),
logOfMaxAbsOfRealPlusOneAndImag, fmfWithNaNInf);
- Value resultReal = b.create<arith::SelectOp>(
- b.create<arith::CmpFOp>(arith::CmpFPredicate::UNO, r, r, fmfWithNaNInf),
+ Value resultReal = arith::SelectOp::create(
+ b,
+ arith::CmpFOp::create(b, arith::CmpFPredicate::UNO, r, r,
+ fmfWithNaNInf),
minAbs, r);
- Value resultImag = b.create<math::Atan2Op>(imag, realPlusOne, fmf);
+ Value resultImag = math::Atan2Op::create(b, imag, realPlusOne, fmf);
rewriter.replaceOpWithNewOp<complex::CreateOp>(op, type, resultReal,
resultImag);
return success();
@@ -511,22 +518,22 @@ struct MulOpConversion : public OpConversionPattern<complex::MulOp> {
auto elementType = cast<FloatType>(type.getElementType());
arith::FastMathFlagsAttr fmf = op.getFastMathFlagsAttr();
auto fmfValue = fmf.getValue();
- Value lhsReal = b.create<complex::ReOp>(elementType, adaptor.getLhs());
- Value lhsImag = b.create<complex::ImOp>(elementType, adaptor.getLhs());
- Value rhsReal = b.create<complex::ReOp>(elementType, adaptor.getRhs());
- Value rhsImag = b.create<complex::ImOp>(elementType, adaptor.getRhs());
+ Value lhsReal = complex::ReOp::create(b, elementType, adaptor.getLhs());
+ Value lhsImag = complex::ImOp::create(b, elementType, adaptor.getLhs());
+ Value rhsReal = complex::ReOp::create(b, elementType, adaptor.getRhs());
+ Value rhsImag = complex::ImOp::create(b, elementType, adaptor.getRhs());
Value lhsRealTimesRhsReal =
- b.create<arith::MulFOp>(lhsReal, rhsReal, fmfValue);
+ arith::MulFOp::create(b, lhsReal, rhsReal, fmfValue);
Value lhsImagTimesRhsImag =
- b.create<arith::MulFOp>(lhsImag, rhsImag, fmfValue);
- Value real = b.create<arith::SubFOp>(lhsRealTimesRhsReal,
- lhsImagTimesRhsImag, fmfValue);
+ arith::MulFOp::create(b, lhsImag, rhsImag, fmfValue);
+ Value real = arith::SubFOp::create(b, lhsRealTimesRhsReal,
+ lhsImagTimesRhsImag, fmfValue);
Value lhsImagTimesRhsReal =
- b.create<arith::MulFOp>(lhsImag, rhsReal, fmfValue);
+ arith::MulFOp::create(b, lhsImag, rhsReal, fmfValue);
Value lhsRealTimesRhsImag =
- b.create<arith::MulFOp>(lhsReal, rhsImag, fmfValue);
- Value imag = b.create<arith::AddFOp>(lhsImagTimesRhsReal,
- lhsRealTimesRhsImag, fmfValue);
+ arith::MulFOp::create(b, lhsReal, rhsImag, fmfValue);
+ Value imag = arith::AddFOp::create(b, lhsImagTimesRhsReal,
+ lhsRealTimesRhsImag, fmfValue);
rewriter.replaceOpWithNewOp<complex::CreateOp>(op, type, real, imag);
return success();
}
@@ -543,11 +550,11 @@ struct NegOpConversion : public OpConversionPattern<complex::NegOp> {
auto elementType = cast<FloatType>(type.getElementType());
Value real =
- rewriter.create<complex::ReOp>(loc, elementType, adaptor.getComplex());
+ complex::ReOp::create(rewriter, loc, elementType, adaptor.getComplex());
Value imag =
- rewriter.create<complex::ImOp>(loc, elementType, adaptor.getComplex());
- Value negReal = rewriter.create<arith::NegFOp>(loc, real);
- Value negImag = rewriter.create<arith::NegFOp>(loc, imag);
+ complex::ImOp::create(rewriter, loc, elementType, adaptor.getComplex());
+ Value negReal = arith::NegFOp::create(rewriter, loc, real);
+ Value negImag = arith::NegFOp::create(rewriter, loc, imag);
rewriter.replaceOpWithNewOp<complex::CreateOp>(op, type, negReal, negImag);
return success();
}
@@ -570,11 +577,11 @@ struct SinOpConversion : public TrigonometricOpConversion<complex::SinOp> {
// Re(sin(x + iy)) = (0.5*t + 0.5/t) * sin x
// Im(cos(x + iy)) = (0.5*t - 0.5/t) * cos x
Value sum =
- rewriter.create<arith::AddFOp>(loc, scaledExp, reciprocalExp, fmf);
- Value resultReal = rewriter.create<arith::MulFOp>(loc, sum, sin, fmf);
+ arith::AddFOp::create(rewriter, loc, scaledExp, reciprocalExp, fmf);
+ Value resultReal = arith::MulFOp::create(rewriter, loc, sum, sin, fmf);
Value diff =
- rewriter.create<arith::SubFOp>(loc, scaledExp, reciprocalExp, fmf);
- Value resultImag = rewriter.create<arith::MulFOp>(loc, diff, cos, fmf);
+ arith::SubFOp::create(rewriter, loc, scaledExp, reciprocalExp, fmf);
+ Value resultImag = arith::MulFOp::create(rewriter, loc, diff, cos, fmf);
return {resultReal, resultImag};
}
};
@@ -593,64 +600,65 @@ struct SqrtOpConversion : public OpConversionPattern<complex::SqrtOp> {
arith::FastMathFlags fmf = op.getFastMathFlagsAttr().getValue();
auto cst = [&](APFloat v) {
- return b.create<arith::ConstantOp>(elementType,
- b.getFloatAttr(elementType, v));
+ return arith::ConstantOp::create(b, elementType,
+ b.getFloatAttr(elementType, v));
};
const auto &floatSemantics = elementType.getFloatSemantics();
Value zero = cst(APFloat::getZero(floatSemantics));
- Value half = b.create<arith::ConstantOp>(elementType,
- b.getFloatAttr(elementType, 0.5));
+ Value half = arith::ConstantOp::create(b, elementType,
+ b.getFloatAttr(elementType, 0.5));
- Value real = b.create<complex::ReOp>(elementType, adaptor.getComplex());
- Value imag = b.create<complex::ImOp>(elementType, adaptor.getComplex());
+ Value real = complex::ReOp::create(b, elementType, adaptor.getComplex());
+ Value imag = complex::ImOp::create(b, elementType, adaptor.getComplex());
Value absSqrt = computeAbs(real, imag, fmf, b, AbsFn::sqrt);
- Value argArg = b.create<math::Atan2Op>(imag, real, fmf);
- Value sqrtArg = b.create<arith::MulFOp>(argArg, half, fmf);
- Value cos = b.create<math::CosOp>(sqrtArg, fmf);
- Value sin = b.create<math::SinOp>(sqrtArg, fmf);
+ Value argArg = math::Atan2Op::create(b, imag, real, fmf);
+ Value sqrtArg = arith::MulFOp::create(b, argArg, half, fmf);
+ Value cos = math::CosOp::create(b, sqrtArg, fmf);
+ Value sin = math::SinOp::create(b, sqrtArg, fmf);
// sin(atan2(0, inf)) = 0, sqrt(abs(inf)) = inf, but we can't multiply
// 0 * inf.
Value sinIsZero =
- b.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, sin, zero, fmf);
+ arith::CmpFOp::create(b, arith::CmpFPredicate::OEQ, sin, zero, fmf);
- Value resultReal = b.create<arith::MulFOp>(absSqrt, cos, fmf);
- Value resultImag = b.create<arith::SelectOp>(
- sinIsZero, zero, b.create<arith::MulFOp>(absSqrt, sin, fmf));
+ Value resultReal = arith::MulFOp::create(b, absSqrt, cos, fmf);
+ Value resultImag = arith::SelectOp::create(
+ b, sinIsZero, zero, arith::MulFOp::create(b, absSqrt, sin, fmf));
if (!arith::bitEnumContainsAll(fmf, arith::FastMathFlags::nnan |
arith::FastMathFlags::ninf)) {
Value inf = cst(APFloat::getInf(floatSemantics));
Value negInf = cst(APFloat::getInf(floatSemantics, true));
Value nan = cst(APFloat::getNaN(floatSemantics));
- Value absImag = b.create<math::AbsFOp>(elementType, imag, fmf);
+ Value absImag = math::AbsFOp::create(b, elementType, imag, fmf);
- Value absImagIsInf =
- b.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, absImag, inf, fmf);
- Value absImagIsNotInf =
- b.create<arith::CmpFOp>(arith::CmpFPredicate::ONE, absImag, inf, fmf);
+ Value absImagIsInf = arith::CmpFOp::create(b, arith::CmpFPredicate::OEQ,
+ absImag, inf, fmf);
+ Value absImagIsNotInf = arith::CmpFOp::create(
+ b, arith::CmpFPredicate::ONE, absImag, inf, fmf);
Value realIsInf =
- b.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, real, inf, fmf);
- Value realIsNegInf =
- b.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, real, negInf, fmf);
+ arith::CmpFOp::create(b, arith::CmpFPredicate::OEQ, real, inf, fmf);
+ Value realIsNegInf = arith::CmpFOp::create(b, arith::CmpFPredicate::OEQ,
+ real, negInf, fmf);
- resultReal = b.create<arith::SelectOp>(
- b.create<arith::AndIOp>(realIsNegInf, absImagIsNotInf), zero,
+ resultReal = arith::SelectOp::create(
+ b, arith::AndIOp::create(b, realIsNegInf, absImagIsNotInf), zero,
resultReal);
- resultReal = b.create<arith::SelectOp>(
- b.create<arith::OrIOp>(absImagIsInf, realIsInf), inf, resultReal);
+ resultReal = arith::SelectOp::create(
+ b, arith::OrIOp::create(b, absImagIsInf, realIsInf), inf, resultReal);
- Value imagSignInf = b.create<math::CopySignOp>(inf, imag, fmf);
- resultImag = b.create<arith::SelectOp>(
- b.create<arith::CmpFOp>(arith::CmpFPredicate::UNO, absSqrt, absSqrt),
+ Value imagSignInf = math::CopySignOp::create(b, inf, imag, fmf);
+ resultImag = arith::SelectOp::create(
+ b,
+ arith::CmpFOp::create(b, arith::CmpFPredicate::UNO, absSqrt, absSqrt),
nan, resultImag);
- resultImag = b.create<arith::SelectOp>(
- b.create<arith::OrIOp>(absImagIsInf, realIsNegInf), imagSignInf,
+ resultImag = arith::SelectOp::create(
+ b, arith::OrIOp::create(b, absImagIsInf, realIsNegInf), imagSignInf,
resultImag);
}
Value resultIsZero =
- b.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, absSqrt, zero, fmf);
- resultReal = b.create<arith::SelectOp>(resultIsZero, zero, resultReal);
- resultImag = b.create<arith::SelectOp>(resultIsZero, zero, resultImag);
+ arith::CmpFOp::create(b, arith::CmpFPredicate::OEQ, absSqrt, zero, fmf);
+ resultReal = arith::SelectOp::create(b, resultIsZero, zero, resultReal);
+ resultImag = arith::SelectOp::create(b, resultIsZero, zero, resultImag);
rewriter.replaceOpWithNewOp<complex::CreateOp>(op, type, resultReal,
resultImag);
@@ -669,19 +677,20 @@ struct SignOpConversion : public OpConversionPattern<complex::SignOp> {
mlir::ImplicitLocOpBuilder b(op.getLoc(), rewriter);
arith::FastMathFlagsAttr fmf = op.getFastMathFlagsAttr();
- Value real = b.create<complex::ReOp>(elementType, adaptor.getComplex());
- Value imag = b.create<complex::ImOp>(elementType, adaptor.getComplex());
+ Value real = complex::ReOp::create(b, elementType, adaptor.getComplex());
+ Value imag = complex::ImOp::create(b, elementType, adaptor.getComplex());
Value zero =
- b.create<arith::ConstantOp>(elementType, b.getZeroAttr(elementType));
+ arith::ConstantOp::create(b, elementType, b.getZeroAttr(elementType));
Value realIsZero =
- b.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, real, zero);
+ arith::CmpFOp::create(b, arith::CmpFPredicate::OEQ, real, zero);
Value imagIsZero =
- b.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, imag, zero);
- Value isZero = b.create<arith::AndIOp>(realIsZero, imagIsZero);
- auto abs = b.create<complex::AbsOp>(elementType, adaptor.getComplex(), fmf);
- Value realSign = b.create<arith::DivFOp>(real, abs, fmf);
- Value imagSign = b.create<arith::DivFOp>(imag, abs, fmf);
- Value sign = b.create<complex::CreateOp>(type, realSign, imagSign);
+ arith::CmpFOp::create(b, arith::CmpFPredicate::OEQ, imag, zero);
+ Value isZero = arith::AndIOp::create(b, realIsZero, imagIsZero);
+ auto abs =
+ complex::AbsOp::create(b, elementType, adaptor.getComplex(), fmf);
+ Value realSign = arith::DivFOp::create(b, real, abs, fmf);
+ Value imagSign = arith::DivFOp::create(b, imag, abs, fmf);
+ Value sign = complex::CreateOp::create(b, type, realSign, imagSign);
rewriter.replaceOpWithNewOp<arith::SelectOp>(op, isZero,
adaptor.getComplex(), sign);
return success();
@@ -703,84 +712,84 @@ struct TanTanhOpConversion : public OpConversionPattern<Op> {
const auto &floatSemantics = elementType.getFloatSemantics();
Value real =
- b.create<complex::ReOp>(loc, elementType, adaptor.getComplex());
+ complex::ReOp::create(b, loc, elementType, adaptor.getComplex());
Value imag =
- b.create<complex::ImOp>(loc, elementType, adaptor.getComplex());
- Value negOne = b.create<arith::ConstantOp>(
- elementType, b.getFloatAttr(elementType, -1.0));
+ complex::ImOp::create(b, loc, elementType, adaptor.getComplex());
+ Value negOne = arith::ConstantOp::create(b, elementType,
+ b.getFloatAttr(elementType, -1.0));
if constexpr (std::is_same_v<Op, complex::TanOp>) {
// tan(x+yi) = -i*tanh(-y + xi)
std::swap(real, imag);
- real = b.create<arith::MulFOp>(real, negOne, fmf);
+ real = arith::MulFOp::create(b, real, negOne, fmf);
}
auto cst = [&](APFloat v) {
- return b.create<arith::ConstantOp>(elementType,
- b.getFloatAttr(elementType, v));
+ return arith::ConstantOp::create(b, elementType,
+ b.getFloatAttr(elementType, v));
};
Value inf = cst(APFloat::getInf(floatSemantics));
- Value four = b.create<arith::ConstantOp>(elementType,
- b.getFloatAttr(elementType, 4.0));
- Value twoReal = b.create<arith::AddFOp>(real, real, fmf);
- Value negTwoReal = b.create<arith::MulFOp>(negOne, twoReal, fmf);
-
- Value expTwoRealMinusOne = b.create<math::ExpM1Op>(twoReal, fmf);
- Value expNegTwoRealMinusOne = b.create<math::ExpM1Op>(negTwoReal, fmf);
- Value realNum =
- b.create<arith::SubFOp>(expTwoRealMinusOne, expNegTwoRealMinusOne, fmf);
-
- Value cosImag = b.create<math::CosOp>(imag, fmf);
- Value cosImagSq = b.create<arith::MulFOp>(cosImag, cosImag, fmf);
- Value twoCosTwoImagPlusOne = b.create<arith::MulFOp>(cosImagSq, four, fmf);
- Value sinImag = b.create<math::SinOp>(imag, fmf);
-
- Value imagNum = b.create<arith::MulFOp>(
- four, b.create<arith::MulFOp>(cosImag, sinImag, fmf), fmf);
-
- Value expSumMinusTwo =
- b.create<arith::AddFOp>(expTwoRealMinusOne, expNegTwoRealMinusOne, fmf);
+ Value four = arith::ConstantOp::create(b, elementType,
+ b.getFloatAttr(elementType, 4.0));
+ Value twoReal = arith::AddFOp::create(b, real, real, fmf);
+ Value negTwoReal = arith::MulFOp::create(b, negOne, twoReal, fmf);
+
+ Value expTwoRealMinusOne = math::ExpM1Op::create(b, twoReal, fmf);
+ Value expNegTwoRealMinusOne = math::ExpM1Op::create(b, negTwoReal, fmf);
+ Value realNum = arith::SubFOp::create(b, expTwoRealMinusOne,
+ expNegTwoRealMinusOne, fmf);
+
+ Value cosImag = math::CosOp::create(b, imag, fmf);
+ Value cosImagSq = arith::MulFOp::create(b, cosImag, cosImag, fmf);
+ Value twoCosTwoImagPlusOne = arith::MulFOp::create(b, cosImagSq, four, fmf);
+ Value sinImag = math::SinOp::create(b, imag, fmf);
+
+ Value imagNum = arith::MulFOp::create(
+ b, four, arith::MulFOp::create(b, cosImag, sinImag, fmf), fmf);
+
+ Value expSumMinusTwo = arith::AddFOp::create(b, expTwoRealMinusOne,
+ expNegTwoRealMinusOne, fmf);
Value denom =
- b.create<arith::AddFOp>(expSumMinusTwo, twoCosTwoImagPlusOne, fmf);
+ arith::AddFOp::create(b, expSumMinusTwo, twoCosTwoImagPlusOne, fmf);
- Value isInf = b.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ,
- expSumMinusTwo, inf, fmf);
- Value realLimit = b.create<math::CopySignOp>(negOne, real, fmf);
+ Value isInf = arith::CmpFOp::create(b, arith::CmpFPredicate::OEQ,
+ expSumMinusTwo, inf, fmf);
+ Value realLimit = math::CopySignOp::create(b, negOne, real, fmf);
- Value resultReal = b.create<arith::SelectOp>(
- isInf, realLimit, b.create<arith::DivFOp>(realNum, denom, fmf));
- Value resultImag = b.create<arith::DivFOp>(imagNum, denom, fmf);
+ Value resultReal = arith::SelectOp::create(
+ b, isInf, realLimit, arith::DivFOp::create(b, realNum, denom, fmf));
+ Value resultImag = arith::DivFOp::create(b, imagNum, denom, fmf);
if (!arith::bitEnumContainsAll(fmf, arith::FastMathFlags::nnan |
arith::FastMathFlags::ninf)) {
- Value absReal = b.create<math::AbsFOp>(real, fmf);
- Value zero = b.create<arith::ConstantOp>(
- elementType, b.getFloatAttr(elementType, 0.0));
+ Value absReal = math::AbsFOp::create(b, real, fmf);
+ Value zero = arith::ConstantOp::create(b, elementType,
+ b.getFloatAttr(elementType, 0.0));
Value nan = cst(APFloat::getNaN(floatSemantics));
- Value absRealIsInf =
- b.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, absReal, inf, fmf);
+ Value absRealIsInf = arith::CmpFOp::create(b, arith::CmpFPredicate::OEQ,
+ absReal, inf, fmf);
Value imagIsZero =
- b.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, imag, zero, fmf);
- Value absRealIsNotInf = b.create<arith::XOrIOp>(
- absRealIsInf, b.create<arith::ConstantIntOp>(true, /*width=*/1));
+ arith::CmpFOp::create(b, arith::CmpFPredicate::OEQ, imag, zero, fmf);
+ Value absRealIsNotInf = arith::XOrIOp::create(
+ b, absRealIsInf, arith::ConstantIntOp::create(b, true, /*width=*/1));
- Value imagNumIsNaN = b.create<arith::CmpFOp>(arith::CmpFPredicate::UNO,
- imagNum, imagNum, fmf);
+ Value imagNumIsNaN = arith::CmpFOp::create(b, arith::CmpFPredicate::UNO,
+ imagNum, imagNum, fmf);
Value resultRealIsNaN =
- b.create<arith::AndIOp>(imagNumIsNaN, absRealIsNotInf);
- Value resultImagIsZero = b.create<arith::OrIOp>(
- imagIsZero, b.create<arith::AndIOp>(absRealIsInf, imagNumIsNaN));
+ arith::AndIOp::create(b, imagNumIsNaN, absRealIsNotInf);
+ Value resultImagIsZero = arith::OrIOp::create(
+ b, imagIsZero, arith::AndIOp::create(b, absRealIsInf, imagNumIsNaN));
- resultReal = b.create<arith::SelectOp>(resultRealIsNaN, nan, resultReal);
+ resultReal = arith::SelectOp::create(b, resultRealIsNaN, nan, resultReal);
resultImag =
- b.create<arith::SelectOp>(resultImagIsZero, zero, resultImag);
+ arith::SelectOp::create(b, resultImagIsZero, zero, resultImag);
}
if constexpr (std::is_same_v<Op, complex::TanOp>) {
// tan(x+yi) = -i*tanh(-y + xi)
std::swap(resultReal, resultImag);
- resultImag = b.create<arith::MulFOp>(resultImag, negOne, fmf);
+ resultImag = arith::MulFOp::create(b, resultImag, negOne, fmf);
}
rewriter.replaceOpWithNewOp<complex::CreateOp>(op, type, resultReal,
@@ -799,10 +808,10 @@ struct ConjOpConversion : public OpConversionPattern<complex::ConjOp> {
auto type = cast<ComplexType>(adaptor.getComplex().getType());
auto elementType = cast<FloatType>(type.getElementType());
Value real =
- rewriter.create<complex::ReOp>(loc, elementType, adaptor.getComplex());
+ complex::ReOp::create(rewriter, loc, elementType, adaptor.getComplex());
Value imag =
- rewriter.create<complex::ImOp>(loc, elementType, adaptor.getComplex());
- Value negImag = rewriter.create<arith::NegFOp>(loc, elementType, imag);
+ complex::ImOp::create(rewriter, loc, elementType, adaptor.getComplex());
+ Value negImag = arith::NegFOp::create(rewriter, loc, elementType, imag);
rewriter.replaceOpWithNewOp<complex::CreateOp>(op, type, real, negImag);
@@ -818,97 +827,102 @@ static Value powOpConversionImpl(mlir::ImplicitLocOpBuilder &builder,
arith::FastMathFlags fmf) {
auto elementType = cast<FloatType>(type.getElementType());
- Value a = builder.create<complex::ReOp>(lhs);
- Value b = builder.create<complex::ImOp>(lhs);
+ Value a = complex::ReOp::create(builder, lhs);
+ Value b = complex::ImOp::create(builder, lhs);
- Value abs = builder.create<complex::AbsOp>(lhs, fmf);
- Value absToC = builder.create<math::PowFOp>(abs, c, fmf);
+ Value abs = complex::AbsOp::create(builder, lhs, fmf);
+ Value absToC = math::PowFOp::create(builder, abs, c, fmf);
- Value negD = builder.create<arith::NegFOp>(d, fmf);
- Value argLhs = builder.create<math::Atan2Op>(b, a, fmf);
- Value negDArgLhs = builder.create<arith::MulFOp>(negD, argLhs, fmf);
- Value expNegDArgLhs = builder.create<math::ExpOp>(negDArgLhs, fmf);
+ Value negD = arith::NegFOp::create(builder, d, fmf);
+ Value argLhs = math::Atan2Op::create(builder, b, a, fmf);
+ Value negDArgLhs = arith::MulFOp::create(builder, negD, argLhs, fmf);
+ Value expNegDArgLhs = math::ExpOp::create(builder, negDArgLhs, fmf);
- Value coeff = builder.create<arith::MulFOp>(absToC, expNegDArgLhs, fmf);
- Value lnAbs = builder.create<math::LogOp>(abs, fmf);
- Value cArgLhs = builder.create<arith::MulFOp>(c, argLhs, fmf);
- Value dLnAbs = builder.create<arith::MulFOp>(d, lnAbs, fmf);
- Value q = builder.create<arith::AddFOp>(cArgLhs, dLnAbs, fmf);
- Value cosQ = builder.create<math::CosOp>(q, fmf);
- Value sinQ = builder.create<math::SinOp>(q, fmf);
+ Value coeff = arith::MulFOp::create(builder, absToC, expNegDArgLhs, fmf);
+ Value lnAbs = math::LogOp::create(builder, abs, fmf);
+ Value cArgLhs = arith::MulFOp::create(builder, c, argLhs, fmf);
+ Value dLnAbs = arith::MulFOp::create(builder, d, lnAbs, fmf);
+ Value q = arith::AddFOp::create(builder, cArgLhs, dLnAbs, fmf);
+ Value cosQ = math::CosOp::create(builder, q, fmf);
+ Value sinQ = math::SinOp::create(builder, q, fmf);
- Value inf = builder.create<arith::ConstantOp>(
- elementType,
+ Value inf = arith::ConstantOp::create(
+ builder, elementType,
builder.getFloatAttr(elementType,
APFloat::getInf(elementType.getFloatSemantics())));
- Value zero = builder.create<arith::ConstantOp>(
- elementType, builder.getFloatAttr(elementType, 0.0));
- Value one = builder.create<arith::ConstantOp>(
- elementType, builder.getFloatAttr(elementType, 1.0));
- Value complexOne = builder.create<complex::CreateOp>(type, one, zero);
- Value complexZero = builder.create<complex::CreateOp>(type, zero, zero);
- Value complexInf = builder.create<complex::CreateOp>(type, inf, zero);
+ Value zero = arith::ConstantOp::create(
+ builder, elementType, builder.getFloatAttr(elementType, 0.0));
+ Value one = arith::ConstantOp::create(builder, elementType,
+ builder.getFloatAttr(elementType, 1.0));
+ Value complexOne = complex::CreateOp::create(builder, type, one, zero);
+ Value complexZero = complex::CreateOp::create(builder, type, zero, zero);
+ Value complexInf = complex::CreateOp::create(builder, type, inf, zero);
// Case 0:
// d^c is 0 if d is 0 and c > 0. 0^0 is defined to be 1.0, see
// Branch Cuts for Complex Elementary Functions or Much Ado About
// Nothing's Sign Bit, W. Kahan, Section 10.
Value absEqZero =
- builder.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, abs, zero, fmf);
+ arith::CmpFOp::create(builder, arith::CmpFPredicate::OEQ, abs, zero, fmf);
Value dEqZero =
- builder.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, d, zero, fmf);
+ arith::CmpFOp::create(builder, arith::CmpFPredicate::OEQ, d, zero, fmf);
Value cEqZero =
- builder.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, c, zero, fmf);
+ arith::CmpFOp::create(builder, arith::CmpFPredicate::OEQ, c, zero, fmf);
Value bEqZero =
- builder.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, b, zero, fmf);
+ arith::CmpFOp::create(builder, arith::CmpFPredicate::OEQ, b, zero, fmf);
Value zeroLeC =
- builder.create<arith::CmpFOp>(arith::CmpFPredicate::OLE, zero, c, fmf);
- Value coeffCosQ = builder.create<arith::MulFOp>(coeff, cosQ, fmf);
- Value coeffSinQ = builder.create<arith::MulFOp>(coeff, sinQ, fmf);
+ arith::CmpFOp::create(builder, arith::CmpFPredicate::OLE, zero, c, fmf);
+ Value coeffCosQ = arith::MulFOp::create(builder, coeff, cosQ, fmf);
+ Value coeffSinQ = arith::MulFOp::create(builder, coeff, sinQ, fmf);
Value complexOneOrZero =
- builder.create<arith::SelectOp>(cEqZero, complexOne, complexZero);
+ arith::SelectOp::create(builder, cEqZero, complexOne, complexZero);
Value coeffCosSin =
- builder.create<complex::CreateOp>(type, coeffCosQ, coeffSinQ);
- Value cutoff0 = builder.create<arith::SelectOp>(
- builder.create<arith::AndIOp>(
- builder.create<arith::AndIOp>(absEqZero, dEqZero), zeroLeC),
+ complex::CreateOp::create(builder, type, coeffCosQ, coeffSinQ);
+ Value cutoff0 = arith::SelectOp::create(
+ builder,
+ arith::AndIOp::create(
+ builder, arith::AndIOp::create(builder, absEqZero, dEqZero), zeroLeC),
complexOneOrZero, coeffCosSin);
// Case 1:
// x^0 is defined to be 1 for any x, see
// Branch Cuts for Complex Elementary Functions or Much Ado About
// Nothing's Sign Bit, W. Kahan, Section 10.
- Value rhsEqZero = builder.create<arith::AndIOp>(cEqZero, dEqZero);
+ Value rhsEqZero = arith::AndIOp::create(builder, cEqZero, dEqZero);
Value cutoff1 =
- builder.create<arith::SelectOp>(rhsEqZero, complexOne, cutoff0);
+ arith::SelectOp::create(builder, rhsEqZero, complexOne, cutoff0);
// Case 2:
// 1^(c + d*i) = 1 + 0*i
- Value lhsEqOne = builder.create<arith::AndIOp>(
- builder.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, a, one, fmf),
+ Value lhsEqOne = arith::AndIOp::create(
+ builder,
+ arith::CmpFOp::create(builder, arith::CmpFPredicate::OEQ, a, one, fmf),
bEqZero);
Value cutoff2 =
- builder.create<arith::SelectOp>(lhsEqOne, complexOne, cutoff1);
+ arith::SelectOp::create(builder, lhsEqOne, complexOne, cutoff1);
// Case 3:
// inf^(c + 0*i) = inf + 0*i, c > 0
- Value lhsEqInf = builder.create<arith::AndIOp>(
- builder.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, a, inf, fmf),
+ Value lhsEqInf = arith::AndIOp::create(
+ builder,
+ arith::CmpFOp::create(builder, arith::CmpFPredicate::OEQ, a, inf, fmf),
bEqZero);
- Value rhsGt0 = builder.create<arith::AndIOp>(
- dEqZero,
- builder.create<arith::CmpFOp>(arith::CmpFPredicate::OGT, c, zero, fmf));
- Value cutoff3 = builder.create<arith::SelectOp>(
- builder.create<arith::AndIOp>(lhsEqInf, rhsGt0), complexInf, cutoff2);
+ Value rhsGt0 = arith::AndIOp::create(
+ builder, dEqZero,
+ arith::CmpFOp::create(builder, arith::CmpFPredicate::OGT, c, zero, fmf));
+ Value cutoff3 = arith::SelectOp::create(
+ builder, arith::AndIOp::create(builder, lhsEqInf, rhsGt0), complexInf,
+ cutoff2);
// Case 4:
// inf^(c + 0*i) = 0 + 0*i, c < 0
- Value rhsLt0 = builder.create<arith::AndIOp>(
- dEqZero,
- builder.create<arith::CmpFOp>(arith::CmpFPredicate::OLT, c, zero, fmf));
- Value cutoff4 = builder.create<arith::SelectOp>(
- builder.create<arith::AndIOp>(lhsEqInf, rhsLt0), complexZero, cutoff3);
+ Value rhsLt0 = arith::AndIOp::create(
+ builder, dEqZero,
+ arith::CmpFOp::create(builder, arith::CmpFPredicate::OLT, c, zero, fmf));
+ Value cutoff4 = arith::SelectOp::create(
+ builder, arith::AndIOp::create(builder, lhsEqInf, rhsLt0), complexZero,
+ cutoff3);
return cutoff4;
}
@@ -923,8 +937,8 @@ struct PowOpConversion : public OpConversionPattern<complex::PowOp> {
auto type = cast<ComplexType>(adaptor.getLhs().getType());
auto elementType = cast<FloatType>(type.getElementType());
- Value c = builder.create<complex::ReOp>(elementType, adaptor.getRhs());
- Value d = builder.create<complex::ImOp>(elementType, adaptor.getRhs());
+ Value c = complex::ReOp::create(builder, elementType, adaptor.getRhs());
+ Value d = complex::ImOp::create(builder, elementType, adaptor.getRhs());
rewriter.replaceOp(op, {powOpConversionImpl(builder, type, adaptor.getLhs(),
c, d, op.getFastmath())});
@@ -945,64 +959,64 @@ struct RsqrtOpConversion : public OpConversionPattern<complex::RsqrtOp> {
arith::FastMathFlags fmf = op.getFastMathFlagsAttr().getValue();
auto cst = [&](APFloat v) {
- return b.create<arith::ConstantOp>(elementType,
- b.getFloatAttr(elementType, v));
+ return arith::ConstantOp::create(b, elementType,
+ b.getFloatAttr(elementType, v));
};
const auto &floatSemantics = elementType.getFloatSemantics();
Value zero = cst(APFloat::getZero(floatSemantics));
Value inf = cst(APFloat::getInf(floatSemantics));
- Value negHalf = b.create<arith::ConstantOp>(
- elementType, b.getFloatAttr(elementType, -0.5));
+ Value negHalf = arith::ConstantOp::create(
+ b, elementType, b.getFloatAttr(elementType, -0.5));
Value nan = cst(APFloat::getNaN(floatSemantics));
- Value real = b.create<complex::ReOp>(elementType, adaptor.getComplex());
- Value imag = b.create<complex::ImOp>(elementType, adaptor.getComplex());
+ Value real = complex::ReOp::create(b, elementType, adaptor.getComplex());
+ Value imag = complex::ImOp::create(b, elementType, adaptor.getComplex());
Value absRsqrt = computeAbs(real, imag, fmf, b, AbsFn::rsqrt);
- Value argArg = b.create<math::Atan2Op>(imag, real, fmf);
- Value rsqrtArg = b.create<arith::MulFOp>(argArg, negHalf, fmf);
- Value cos = b.create<math::CosOp>(rsqrtArg, fmf);
- Value sin = b.create<math::SinOp>(rsqrtArg, fmf);
+ Value argArg = math::Atan2Op::create(b, imag, real, fmf);
+ Value rsqrtArg = arith::MulFOp::create(b, argArg, negHalf, fmf);
+ Value cos = math::CosOp::create(b, rsqrtArg, fmf);
+ Value sin = math::SinOp::create(b, rsqrtArg, fmf);
- Value resultReal = b.create<arith::MulFOp>(absRsqrt, cos, fmf);
- Value resultImag = b.create<arith::MulFOp>(absRsqrt, sin, fmf);
+ Value resultReal = arith::MulFOp::create(b, absRsqrt, cos, fmf);
+ Value resultImag = arith::MulFOp::create(b, absRsqrt, sin, fmf);
if (!arith::bitEnumContainsAll(fmf, arith::FastMathFlags::nnan |
arith::FastMathFlags::ninf)) {
- Value negOne = b.create<arith::ConstantOp>(
- elementType, b.getFloatAttr(elementType, -1));
+ Value negOne = arith::ConstantOp::create(b, elementType,
+ b.getFloatAttr(elementType, -1));
- Value realSignedZero = b.create<math::CopySignOp>(zero, real, fmf);
- Value imagSignedZero = b.create<math::CopySignOp>(zero, imag, fmf);
+ Value realSignedZero = math::CopySignOp::create(b, zero, real, fmf);
+ Value imagSignedZero = math::CopySignOp::create(b, zero, imag, fmf);
Value negImagSignedZero =
- b.create<arith::MulFOp>(negOne, imagSignedZero, fmf);
+ arith::MulFOp::create(b, negOne, imagSignedZero, fmf);
- Value absReal = b.create<math::AbsFOp>(real, fmf);
- Value absImag = b.create<math::AbsFOp>(imag, fmf);
+ Value absReal = math::AbsFOp::create(b, real, fmf);
+ Value absImag = math::AbsFOp::create(b, imag, fmf);
- Value absImagIsInf =
- b.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, absImag, inf, fmf);
+ Value absImagIsInf = arith::CmpFOp::create(b, arith::CmpFPredicate::OEQ,
+ absImag, inf, fmf);
Value realIsNan =
- b.create<arith::CmpFOp>(arith::CmpFPredicate::UNO, real, real, fmf);
- Value realIsInf =
- b.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, absReal, inf, fmf);
- Value inIsNanInf = b.create<arith::AndIOp>(absImagIsInf, realIsNan);
+ arith::CmpFOp::create(b, arith::CmpFPredicate::UNO, real, real, fmf);
+ Value realIsInf = arith::CmpFOp::create(b, arith::CmpFPredicate::OEQ,
+ absReal, inf, fmf);
+ Value inIsNanInf = arith::AndIOp::create(b, absImagIsInf, realIsNan);
- Value resultIsZero = b.create<arith::OrIOp>(inIsNanInf, realIsInf);
+ Value resultIsZero = arith::OrIOp::create(b, inIsNanInf, realIsInf);
resultReal =
- b.create<arith::SelectOp>(resultIsZero, realSignedZero, resultReal);
- resultImag = b.create<arith::SelectOp>(resultIsZero, negImagSignedZero,
- resultImag);
+ arith::SelectOp::create(b, resultIsZero, realSignedZero, resultReal);
+ resultImag = arith::SelectOp::create(b, resultIsZero, negImagSignedZero,
+ resultImag);
}
Value isRealZero =
- b.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, real, zero, fmf);
+ arith::CmpFOp::create(b, arith::CmpFPredicate::OEQ, real, zero, fmf);
Value isImagZero =
- b.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, imag, zero, fmf);
- Value isZero = b.create<arith::AndIOp>(isRealZero, isImagZero);
+ arith::CmpFOp::create(b, arith::CmpFPredicate::OEQ, imag, zero, fmf);
+ Value isZero = arith::AndIOp::create(b, isRealZero, isImagZero);
- resultReal = b.create<arith::SelectOp>(isZero, inf, resultReal);
- resultImag = b.create<arith::SelectOp>(isZero, nan, resultImag);
+ resultReal = arith::SelectOp::create(b, isZero, inf, resultReal);
+ resultImag = arith::SelectOp::create(b, isZero, nan, resultImag);
rewriter.replaceOpWithNewOp<complex::CreateOp>(op, type, resultReal,
resultImag);
@@ -1021,9 +1035,9 @@ struct AngleOpConversion : public OpConversionPattern<complex::AngleOp> {
arith::FastMathFlagsAttr fmf = op.getFastMathFlagsAttr();
Value real =
- rewriter.create<complex::ReOp>(loc, type, adaptor.getComplex());
+ complex::ReOp::create(rewriter, loc, type, adaptor.getComplex());
Value imag =
- rewriter.create<complex::ImOp>(loc, type, adaptor.getComplex());
+ complex::ImOp::create(rewriter, loc, type, adaptor.getComplex());
rewriter.replaceOpWithNewOp<math::Atan2Op>(op, imag, real, fmf);
diff --git a/mlir/lib/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.cpp b/mlir/lib/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.cpp
index 13a084407e53f..ff6d369176393 100644
--- a/mlir/lib/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.cpp
+++ b/mlir/lib/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.cpp
@@ -73,13 +73,13 @@ struct AssertOpLowering : public ConvertOpToLLVMPattern<cf::AssertOp> {
OpBuilder::InsertionGuard guard(rewriter);
rewriter.setInsertionPointToStart(module.getBody());
auto abortFuncTy = LLVM::LLVMFunctionType::get(getVoidType(), {});
- abortFunc = rewriter.create<LLVM::LLVMFuncOp>(rewriter.getUnknownLoc(),
- "abort", abortFuncTy);
+ abortFunc = LLVM::LLVMFuncOp::create(rewriter, rewriter.getUnknownLoc(),
+ "abort", abortFuncTy);
}
- rewriter.create<LLVM::CallOp>(loc, abortFunc, ValueRange());
- rewriter.create<LLVM::UnreachableOp>(loc);
+ LLVM::CallOp::create(rewriter, loc, abortFunc, ValueRange());
+ LLVM::UnreachableOp::create(rewriter, loc);
} else {
- rewriter.create<LLVM::BrOp>(loc, ValueRange(), continuationBlock);
+ LLVM::BrOp::create(rewriter, loc, ValueRange(), continuationBlock);
}
// Generate assertion test.
diff --git a/mlir/lib/Conversion/ControlFlowToSCF/ControlFlowToSCF.cpp b/mlir/lib/Conversion/ControlFlowToSCF/ControlFlowToSCF.cpp
index 9831dcaaaccc8..c8311eb5a6433 100644
--- a/mlir/lib/Conversion/ControlFlowToSCF/ControlFlowToSCF.cpp
+++ b/mlir/lib/Conversion/ControlFlowToSCF/ControlFlowToSCF.cpp
@@ -33,8 +33,8 @@ ControlFlowToSCFTransformation::createStructuredBranchRegionOp(
MutableArrayRef<Region> regions) {
if (auto condBrOp = dyn_cast<cf::CondBranchOp>(controlFlowCondOp)) {
assert(regions.size() == 2);
- auto ifOp = builder.create<scf::IfOp>(controlFlowCondOp->getLoc(),
- resultTypes, condBrOp.getCondition());
+ auto ifOp = scf::IfOp::create(builder, controlFlowCondOp->getLoc(),
+ resultTypes, condBrOp.getCondition());
ifOp.getThenRegion().takeBody(regions[0]);
ifOp.getElseRegion().takeBody(regions[1]);
return ifOp.getOperation();
@@ -43,8 +43,8 @@ ControlFlowToSCFTransformation::createStructuredBranchRegionOp(
if (auto switchOp = dyn_cast<cf::SwitchOp>(controlFlowCondOp)) {
// `getCFGSwitchValue` returns an i32 that we need to convert to index
// fist.
- auto cast = builder.create<arith::IndexCastUIOp>(
- controlFlowCondOp->getLoc(), builder.getIndexType(),
+ auto cast = arith::IndexCastUIOp::create(
+ builder, controlFlowCondOp->getLoc(), builder.getIndexType(),
switchOp.getFlag());
SmallVector<int64_t> cases;
if (auto caseValues = switchOp.getCaseValues())
@@ -55,8 +55,9 @@ ControlFlowToSCFTransformation::createStructuredBranchRegionOp(
assert(regions.size() == cases.size() + 1);
- auto indexSwitchOp = builder.create<scf::IndexSwitchOp>(
- controlFlowCondOp->getLoc(), resultTypes, cast, cases, cases.size());
+ auto indexSwitchOp =
+ scf::IndexSwitchOp::create(builder, controlFlowCondOp->getLoc(),
+ resultTypes, cast, cases, cases.size());
indexSwitchOp.getDefaultRegion().takeBody(regions[0]);
for (auto &&[targetRegion, sourceRegion] :
@@ -75,7 +76,7 @@ LogicalResult
ControlFlowToSCFTransformation::createStructuredBranchRegionTerminatorOp(
Location loc, OpBuilder &builder, Operation *branchRegionOp,
Operation *replacedControlFlowOp, ValueRange results) {
- builder.create<scf::YieldOp>(loc, results);
+ scf::YieldOp::create(builder, loc, results);
return success();
}
@@ -84,23 +85,24 @@ ControlFlowToSCFTransformation::createStructuredDoWhileLoopOp(
OpBuilder &builder, Operation *replacedOp, ValueRange loopVariablesInit,
Value condition, ValueRange loopVariablesNextIter, Region &&loopBody) {
Location loc = replacedOp->getLoc();
- auto whileOp = builder.create<scf::WhileOp>(loc, loopVariablesInit.getTypes(),
- loopVariablesInit);
+ auto whileOp = scf::WhileOp::create(
+ builder, loc, loopVariablesInit.getTypes(), loopVariablesInit);
whileOp.getBefore().takeBody(loopBody);
builder.setInsertionPointToEnd(&whileOp.getBefore().back());
// `getCFGSwitchValue` returns a i32. We therefore need to truncate the
// condition to i1 first. It is guaranteed to be either 0 or 1 already.
- builder.create<scf::ConditionOp>(
- loc, builder.create<arith::TruncIOp>(loc, builder.getI1Type(), condition),
+ scf::ConditionOp::create(
+ builder, loc,
+ arith::TruncIOp::create(builder, loc, builder.getI1Type(), condition),
loopVariablesNextIter);
Block *afterBlock = builder.createBlock(&whileOp.getAfter());
afterBlock->addArguments(
loopVariablesInit.getTypes(),
SmallVector<Location>(loopVariablesInit.size(), loc));
- builder.create<scf::YieldOp>(loc, afterBlock->getArguments());
+ scf::YieldOp::create(builder, loc, afterBlock->getArguments());
return whileOp.getOperation();
}
@@ -108,8 +110,8 @@ ControlFlowToSCFTransformation::createStructuredDoWhileLoopOp(
Value ControlFlowToSCFTransformation::getCFGSwitchValue(Location loc,
OpBuilder &builder,
unsigned int value) {
- return builder.create<arith::ConstantOp>(loc,
- builder.getI32IntegerAttr(value));
+ return arith::ConstantOp::create(builder, loc,
+ builder.getI32IntegerAttr(value));
}
void ControlFlowToSCFTransformation::createCFGSwitchOp(
@@ -117,15 +119,15 @@ void ControlFlowToSCFTransformation::createCFGSwitchOp(
ArrayRef<unsigned int> caseValues, BlockRange caseDestinations,
ArrayRef<ValueRange> caseArguments, Block *defaultDest,
ValueRange defaultArgs) {
- builder.create<cf::SwitchOp>(loc, flag, defaultDest, defaultArgs,
- llvm::to_vector_of<int32_t>(caseValues),
- caseDestinations, caseArguments);
+ cf::SwitchOp::create(builder, loc, flag, defaultDest, defaultArgs,
+ llvm::to_vector_of<int32_t>(caseValues),
+ caseDestinations, caseArguments);
}
Value ControlFlowToSCFTransformation::getUndefValue(Location loc,
OpBuilder &builder,
Type type) {
- return builder.create<ub::PoisonOp>(loc, type, nullptr);
+ return ub::PoisonOp::create(builder, loc, type, nullptr);
}
FailureOr<Operation *>
diff --git a/mlir/lib/Conversion/FuncToEmitC/FuncToEmitC.cpp b/mlir/lib/Conversion/FuncToEmitC/FuncToEmitC.cpp
index f8dc06f41ab87..197caeb4ffbfa 100644
--- a/mlir/lib/Conversion/FuncToEmitC/FuncToEmitC.cpp
+++ b/mlir/lib/Conversion/FuncToEmitC/FuncToEmitC.cpp
@@ -99,8 +99,8 @@ class FuncOpConversion final : public OpConversionPattern<func::FuncOp> {
}
// Create the converted `emitc.func` op.
- emitc::FuncOp newFuncOp = rewriter.create<emitc::FuncOp>(
- funcOp.getLoc(), funcOp.getName(),
+ emitc::FuncOp newFuncOp = emitc::FuncOp::create(
+ rewriter, funcOp.getLoc(), funcOp.getName(),
FunctionType::get(rewriter.getContext(),
signatureConverter.getConvertedTypes(),
resultType ? TypeRange(resultType) : TypeRange()));
diff --git a/mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp b/mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp
index 36235636d6ba2..67bb1c14c99a2 100644
--- a/mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp
+++ b/mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp
@@ -115,8 +115,8 @@ static void wrapForExternalCallers(OpBuilder &rewriter, Location loc,
SmallVector<NamedAttribute> attributes;
filterFuncAttributes(funcOp, attributes);
- auto wrapperFuncOp = rewriter.create<LLVM::LLVMFuncOp>(
- loc, llvm::formatv("_mlir_ciface_{0}", funcOp.getName()).str(),
+ auto wrapperFuncOp = LLVM::LLVMFuncOp::create(
+ rewriter, loc, llvm::formatv("_mlir_ciface_{0}", funcOp.getName()).str(),
wrapperFuncType, LLVM::Linkage::External, /*dsoLocal=*/false,
/*cconv=*/LLVM::CConv::C, /*comdat=*/nullptr, attributes);
propagateArgResAttrs(rewriter, !!resultStructType, funcOp, wrapperFuncOp);
@@ -129,14 +129,14 @@ static void wrapForExternalCallers(OpBuilder &rewriter, Location loc,
for (auto [index, argType] : llvm::enumerate(type.getInputs())) {
Value arg = wrapperFuncOp.getArgument(index + argOffset);
if (auto memrefType = dyn_cast<MemRefType>(argType)) {
- Value loaded = rewriter.create<LLVM::LoadOp>(
- loc, typeConverter.convertType(memrefType), arg);
+ Value loaded = LLVM::LoadOp::create(
+ rewriter, loc, typeConverter.convertType(memrefType), arg);
MemRefDescriptor::unpack(rewriter, loc, loaded, memrefType, args);
continue;
}
if (isa<UnrankedMemRefType>(argType)) {
- Value loaded = rewriter.create<LLVM::LoadOp>(
- loc, typeConverter.convertType(argType), arg);
+ Value loaded = LLVM::LoadOp::create(
+ rewriter, loc, typeConverter.convertType(argType), arg);
UnrankedMemRefDescriptor::unpack(rewriter, loc, loaded, args);
continue;
}
@@ -144,14 +144,14 @@ static void wrapForExternalCallers(OpBuilder &rewriter, Location loc,
args.push_back(arg);
}
- auto call = rewriter.create<LLVM::CallOp>(loc, newFuncOp, args);
+ auto call = LLVM::CallOp::create(rewriter, loc, newFuncOp, args);
if (resultStructType) {
- rewriter.create<LLVM::StoreOp>(loc, call.getResult(),
- wrapperFuncOp.getArgument(0));
- rewriter.create<LLVM::ReturnOp>(loc, ValueRange{});
+ LLVM::StoreOp::create(rewriter, loc, call.getResult(),
+ wrapperFuncOp.getArgument(0));
+ LLVM::ReturnOp::create(rewriter, loc, ValueRange{});
} else {
- rewriter.create<LLVM::ReturnOp>(loc, call.getResults());
+ LLVM::ReturnOp::create(rewriter, loc, call.getResults());
}
}
@@ -182,8 +182,8 @@ static void wrapExternalFunction(OpBuilder &builder, Location loc,
filterFuncAttributes(funcOp, attributes);
// Create the auxiliary function.
- auto wrapperFunc = builder.create<LLVM::LLVMFuncOp>(
- loc, llvm::formatv("_mlir_ciface_{0}", funcOp.getName()).str(),
+ auto wrapperFunc = LLVM::LLVMFuncOp::create(
+ builder, loc, llvm::formatv("_mlir_ciface_{0}", funcOp.getName()).str(),
wrapperType, LLVM::Linkage::External, /*dsoLocal=*/false,
/*cconv=*/LLVM::CConv::C, /*comdat=*/nullptr, attributes);
propagateArgResAttrs(builder, !!resultStructType, funcOp, wrapperFunc);
@@ -201,11 +201,11 @@ static void wrapExternalFunction(OpBuilder &builder, Location loc,
if (resultStructType) {
// Allocate the struct on the stack and pass the pointer.
Type resultType = cast<LLVM::LLVMFunctionType>(wrapperType).getParamType(0);
- Value one = builder.create<LLVM::ConstantOp>(
- loc, typeConverter.convertType(builder.getIndexType()),
+ Value one = LLVM::ConstantOp::create(
+ builder, loc, typeConverter.convertType(builder.getIndexType()),
builder.getIntegerAttr(builder.getIndexType(), 1));
Value result =
- builder.create<LLVM::AllocaOp>(loc, resultType, resultStructType, one);
+ LLVM::AllocaOp::create(builder, loc, resultType, resultStructType, one);
args.push_back(result);
}
@@ -229,12 +229,12 @@ static void wrapExternalFunction(OpBuilder &builder, Location loc,
wrapperArgsRange.take_front(numToDrop));
auto ptrTy = LLVM::LLVMPointerType::get(builder.getContext());
- Value one = builder.create<LLVM::ConstantOp>(
- loc, typeConverter.convertType(builder.getIndexType()),
+ Value one = LLVM::ConstantOp::create(
+ builder, loc, typeConverter.convertType(builder.getIndexType()),
builder.getIntegerAttr(builder.getIndexType(), 1));
- Value allocated = builder.create<LLVM::AllocaOp>(
- loc, ptrTy, packed.getType(), one, /*alignment=*/0);
- builder.create<LLVM::StoreOp>(loc, packed, allocated);
+ Value allocated = LLVM::AllocaOp::create(
+ builder, loc, ptrTy, packed.getType(), one, /*alignment=*/0);
+ LLVM::StoreOp::create(builder, loc, packed, allocated);
arg = allocated;
} else {
arg = wrapperArgsRange[0];
@@ -245,14 +245,14 @@ static void wrapExternalFunction(OpBuilder &builder, Location loc,
}
assert(wrapperArgsRange.empty() && "did not map some of the arguments");
- auto call = builder.create<LLVM::CallOp>(loc, wrapperFunc, args);
+ auto call = LLVM::CallOp::create(builder, loc, wrapperFunc, args);
if (resultStructType) {
Value result =
- builder.create<LLVM::LoadOp>(loc, resultStructType, args.front());
- builder.create<LLVM::ReturnOp>(loc, result);
+ LLVM::LoadOp::create(builder, loc, resultStructType, args.front());
+ LLVM::ReturnOp::create(builder, loc, result);
} else {
- builder.create<LLVM::ReturnOp>(loc, call.getResults());
+ LLVM::ReturnOp::create(builder, loc, call.getResults());
}
}
@@ -283,7 +283,7 @@ static void restoreByValRefArgumentType(
Type resTy = typeConverter.convertType(
cast<TypeAttr>(byValRefAttr->getValue()).getValue());
- Value valueArg = rewriter.create<LLVM::LoadOp>(arg.getLoc(), resTy, arg);
+ Value valueArg = LLVM::LoadOp::create(rewriter, arg.getLoc(), resTy, arg);
rewriter.replaceUsesOfBlockArgument(arg, valueArg);
}
}
@@ -357,8 +357,8 @@ FailureOr<LLVM::LLVMFuncOp> mlir::convertFuncOpToLLVMFuncOp(
symbolTable.remove(funcOp);
}
- auto newFuncOp = rewriter.create<LLVM::LLVMFuncOp>(
- funcOp.getLoc(), funcOp.getName(), llvmType, linkage,
+ auto newFuncOp = LLVM::LLVMFuncOp::create(
+ rewriter, funcOp.getLoc(), funcOp.getName(), llvmType, linkage,
/*dsoLocal=*/false, /*cconv=*/LLVM::CConv::C, /*comdat=*/nullptr,
attributes);
@@ -509,7 +509,7 @@ struct ConstantOpLowering : public ConvertOpToLLVMPattern<func::ConstantOp> {
return rewriter.notifyMatchFailure(op, "failed to convert result type");
auto newOp =
- rewriter.create<LLVM::AddressOfOp>(op.getLoc(), type, op.getValue());
+ LLVM::AddressOfOp::create(rewriter, op.getLoc(), type, op.getValue());
for (const NamedAttribute &attr : op->getAttrs()) {
if (attr.getName().strref() == "value")
continue;
@@ -556,9 +556,10 @@ struct CallOpInterfaceLowering : public ConvertOpToLLVMPattern<CallOpType> {
auto promoted = this->getTypeConverter()->promoteOperands(
callOp.getLoc(), /*opOperands=*/callOp->getOperands(),
adaptor.getOperands(), rewriter, useBarePtrCallConv);
- auto newOp = rewriter.create<LLVM::CallOp>(
- callOp.getLoc(), packedResult ? TypeRange(packedResult) : TypeRange(),
- promoted, callOp->getAttrs());
+ auto newOp = LLVM::CallOp::create(rewriter, callOp.getLoc(),
+ packedResult ? TypeRange(packedResult)
+ : TypeRange(),
+ promoted, callOp->getAttrs());
newOp.getProperties().operandSegmentSizes = {
static_cast<int32_t>(promoted.size()), 0};
@@ -573,8 +574,8 @@ struct CallOpInterfaceLowering : public ConvertOpToLLVMPattern<CallOpType> {
// Extract individual results from the structure and return them as list.
results.reserve(numResults);
for (unsigned i = 0; i < numResults; ++i) {
- results.push_back(rewriter.create<LLVM::ExtractValueOp>(
- callOp.getLoc(), newOp->getResult(0), i));
+ results.push_back(LLVM::ExtractValueOp::create(
+ rewriter, callOp.getLoc(), newOp->getResult(0), i));
}
}
@@ -726,9 +727,9 @@ struct ReturnOpLowering : public ConvertOpToLLVMPattern<func::ReturnOp> {
return rewriter.notifyMatchFailure(op, "could not convert result types");
}
- Value packed = rewriter.create<LLVM::PoisonOp>(loc, packedType);
+ Value packed = LLVM::PoisonOp::create(rewriter, loc, packedType);
for (auto [idx, operand] : llvm::enumerate(updatedOperands)) {
- packed = rewriter.create<LLVM::InsertValueOp>(loc, packed, operand, idx);
+ packed = LLVM::InsertValueOp::create(rewriter, loc, packed, operand, idx);
}
rewriter.replaceOpWithNewOp<LLVM::ReturnOp>(op, TypeRange(), packed,
op->getAttrs());
diff --git a/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp b/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp
index 01ca5e99a9aff..1037e296c8128 100644
--- a/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp
+++ b/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp
@@ -28,7 +28,7 @@ LLVM::LLVMFuncOp mlir::getOrDefineFunction(gpu::GPUModuleOp moduleOp,
if (!(ret = moduleOp.template lookupSymbol<LLVM::LLVMFuncOp>(name))) {
OpBuilder::InsertionGuard guard(b);
b.setInsertionPointToStart(moduleOp.getBody());
- ret = b.create<LLVM::LLVMFuncOp>(loc, name, type, LLVM::Linkage::External);
+ ret = LLVM::LLVMFuncOp::create(b, loc, name, type, LLVM::Linkage::External);
}
return ret;
}
@@ -68,9 +68,9 @@ mlir::getOrCreateStringConstant(OpBuilder &b, Location loc,
OpBuilder::InsertionGuard guard(b);
b.setInsertionPointToStart(moduleOp.getBody());
SmallString<16> name = getUniqueSymbolName(moduleOp, namePrefix);
- return b.create<LLVM::GlobalOp>(loc, globalType,
- /*isConstant=*/true, LLVM::Linkage::Internal,
- name, attr, alignment, addrSpace);
+ return LLVM::GlobalOp::create(b, loc, globalType,
+ /*isConstant=*/true, LLVM::Linkage::Internal,
+ name, attr, alignment, addrSpace);
}
LogicalResult
@@ -151,8 +151,8 @@ GPUFuncOpLowering::matchAndRewrite(gpu::GPUFuncOp gpuFuncOp, OpAdaptor adaptor,
gpuFuncOp.getWorkgroupAttributionAttr(
idx, LLVM::LLVMDialect::getAlignAttrName())))
alignment = alignAttr.getInt();
- auto globalOp = rewriter.create<LLVM::GlobalOp>(
- gpuFuncOp.getLoc(), arrayType, /*isConstant=*/false,
+ auto globalOp = LLVM::GlobalOp::create(
+ rewriter, gpuFuncOp.getLoc(), arrayType, /*isConstant=*/false,
LLVM::Linkage::Internal, name, /*value=*/Attribute(), alignment,
workgroupAddrSpace);
workgroupBuffers.push_back(globalOp);
@@ -220,8 +220,8 @@ GPUFuncOpLowering::matchAndRewrite(gpu::GPUFuncOp gpuFuncOp, OpAdaptor adaptor,
LLVM::CConv callingConvention = gpuFuncOp.isKernel()
? kernelCallingConvention
: nonKernelCallingConvention;
- auto llvmFuncOp = rewriter.create<LLVM::LLVMFuncOp>(
- gpuFuncOp.getLoc(), gpuFuncOp.getName(), funcType,
+ auto llvmFuncOp = LLVM::LLVMFuncOp::create(
+ rewriter, gpuFuncOp.getLoc(), gpuFuncOp.getName(), funcType,
LLVM::Linkage::External, /*dsoLocal=*/false, callingConvention,
/*comdat=*/nullptr, attributes);
@@ -266,11 +266,11 @@ GPUFuncOpLowering::matchAndRewrite(gpu::GPUFuncOp gpuFuncOp, OpAdaptor adaptor,
for (const auto [idx, global] : llvm::enumerate(workgroupBuffers)) {
auto ptrType = LLVM::LLVMPointerType::get(rewriter.getContext(),
global.getAddrSpace());
- Value address = rewriter.create<LLVM::AddressOfOp>(
- loc, ptrType, global.getSymNameAttr());
+ Value address = LLVM::AddressOfOp::create(rewriter, loc, ptrType,
+ global.getSymNameAttr());
Value memory =
- rewriter.create<LLVM::GEPOp>(loc, ptrType, global.getType(),
- address, ArrayRef<LLVM::GEPArg>{0, 0});
+ LLVM::GEPOp::create(rewriter, loc, ptrType, global.getType(),
+ address, ArrayRef<LLVM::GEPArg>{0, 0});
// Build a memref descriptor pointing to the buffer to plug with the
// existing memref infrastructure. This may use more registers than
@@ -298,15 +298,16 @@ GPUFuncOpLowering::matchAndRewrite(gpu::GPUFuncOp gpuFuncOp, OpAdaptor adaptor,
Type elementType = typeConverter->convertType(type.getElementType());
auto ptrType =
LLVM::LLVMPointerType::get(rewriter.getContext(), allocaAddrSpace);
- Value numElements = rewriter.create<LLVM::ConstantOp>(
- gpuFuncOp.getLoc(), int64Ty, type.getNumElements());
+ Value numElements = LLVM::ConstantOp::create(
+ rewriter, gpuFuncOp.getLoc(), int64Ty, type.getNumElements());
uint64_t alignment = 0;
if (auto alignAttr =
dyn_cast_or_null<IntegerAttr>(gpuFuncOp.getPrivateAttributionAttr(
idx, LLVM::LLVMDialect::getAlignAttrName())))
alignment = alignAttr.getInt();
- Value allocated = rewriter.create<LLVM::AllocaOp>(
- gpuFuncOp.getLoc(), ptrType, elementType, numElements, alignment);
+ Value allocated =
+ LLVM::AllocaOp::create(rewriter, gpuFuncOp.getLoc(), ptrType,
+ elementType, numElements, alignment);
Value descr = MemRefDescriptor::fromStaticShape(
rewriter, loc, *getTypeConverter(), type, allocated);
signatureConversion.remapInput(
@@ -418,8 +419,9 @@ LogicalResult GPUPrintfOpToHIPLowering::matchAndRewrite(
{llvmI64, ptrType, /*length (bytes)*/ llvmI64, /*isLast*/ llvmI32}));
/// Start the printf hostcall
- Value zeroI64 = rewriter.create<LLVM::ConstantOp>(loc, llvmI64, 0);
- auto printfBeginCall = rewriter.create<LLVM::CallOp>(loc, ocklBegin, zeroI64);
+ Value zeroI64 = LLVM::ConstantOp::create(rewriter, loc, llvmI64, 0);
+ auto printfBeginCall =
+ LLVM::CallOp::create(rewriter, loc, ocklBegin, zeroI64);
Value printfDesc = printfBeginCall.getResult();
// Create the global op or find an existing one.
@@ -427,21 +429,21 @@ LogicalResult GPUPrintfOpToHIPLowering::matchAndRewrite(
rewriter, loc, moduleOp, llvmI8, "printfFormat_", adaptor.getFormat());
// Get a pointer to the format string's first element and pass it to printf()
- Value globalPtr = rewriter.create<LLVM::AddressOfOp>(
- loc,
+ Value globalPtr = LLVM::AddressOfOp::create(
+ rewriter, loc,
LLVM::LLVMPointerType::get(rewriter.getContext(), global.getAddrSpace()),
global.getSymNameAttr());
Value stringStart =
- rewriter.create<LLVM::GEPOp>(loc, ptrType, global.getGlobalType(),
- globalPtr, ArrayRef<LLVM::GEPArg>{0, 0});
- Value stringLen = rewriter.create<LLVM::ConstantOp>(
- loc, llvmI64, cast<StringAttr>(global.getValueAttr()).size());
+ LLVM::GEPOp::create(rewriter, loc, ptrType, global.getGlobalType(),
+ globalPtr, ArrayRef<LLVM::GEPArg>{0, 0});
+ Value stringLen = LLVM::ConstantOp::create(
+ rewriter, loc, llvmI64, cast<StringAttr>(global.getValueAttr()).size());
- Value oneI32 = rewriter.create<LLVM::ConstantOp>(loc, llvmI32, 1);
- Value zeroI32 = rewriter.create<LLVM::ConstantOp>(loc, llvmI32, 0);
+ Value oneI32 = LLVM::ConstantOp::create(rewriter, loc, llvmI32, 1);
+ Value zeroI32 = LLVM::ConstantOp::create(rewriter, loc, llvmI32, 0);
- auto appendFormatCall = rewriter.create<LLVM::CallOp>(
- loc, ocklAppendStringN,
+ auto appendFormatCall = LLVM::CallOp::create(
+ rewriter, loc, ocklAppendStringN,
ValueRange{printfDesc, stringStart, stringLen,
adaptor.getArgs().empty() ? oneI32 : zeroI32});
printfDesc = appendFormatCall.getResult();
@@ -456,17 +458,18 @@ LogicalResult GPUPrintfOpToHIPLowering::matchAndRewrite(
SmallVector<mlir::Value, 2 + argsPerAppend + 1> arguments;
arguments.push_back(printfDesc);
arguments.push_back(
- rewriter.create<LLVM::ConstantOp>(loc, llvmI32, numArgsThisCall));
+ LLVM::ConstantOp::create(rewriter, loc, llvmI32, numArgsThisCall));
for (size_t i = group; i < bound; ++i) {
Value arg = adaptor.getArgs()[i];
if (auto floatType = dyn_cast<FloatType>(arg.getType())) {
if (!floatType.isF64())
- arg = rewriter.create<LLVM::FPExtOp>(
- loc, typeConverter->convertType(rewriter.getF64Type()), arg);
- arg = rewriter.create<LLVM::BitcastOp>(loc, llvmI64, arg);
+ arg = LLVM::FPExtOp::create(
+ rewriter, loc, typeConverter->convertType(rewriter.getF64Type()),
+ arg);
+ arg = LLVM::BitcastOp::create(rewriter, loc, llvmI64, arg);
}
if (arg.getType().getIntOrFloatBitWidth() != 64)
- arg = rewriter.create<LLVM::ZExtOp>(loc, llvmI64, arg);
+ arg = LLVM::ZExtOp::create(rewriter, loc, llvmI64, arg);
arguments.push_back(arg);
}
@@ -477,7 +480,7 @@ LogicalResult GPUPrintfOpToHIPLowering::matchAndRewrite(
auto isLast = (bound == nArgs) ? oneI32 : zeroI32;
arguments.push_back(isLast);
- auto call = rewriter.create<LLVM::CallOp>(loc, ocklAppendArgs, arguments);
+ auto call = LLVM::CallOp::create(rewriter, loc, ocklAppendArgs, arguments);
printfDesc = call.getResult();
}
rewriter.eraseOp(gpuPrintfOp);
@@ -510,13 +513,13 @@ LogicalResult GPUPrintfOpToLLVMCallLowering::matchAndRewrite(
/*alignment=*/0, addressSpace);
// Get a pointer to the format string's first element
- Value globalPtr = rewriter.create<LLVM::AddressOfOp>(
- loc,
+ Value globalPtr = LLVM::AddressOfOp::create(
+ rewriter, loc,
LLVM::LLVMPointerType::get(rewriter.getContext(), global.getAddrSpace()),
global.getSymNameAttr());
Value stringStart =
- rewriter.create<LLVM::GEPOp>(loc, ptrType, global.getGlobalType(),
- globalPtr, ArrayRef<LLVM::GEPArg>{0, 0});
+ LLVM::GEPOp::create(rewriter, loc, ptrType, global.getGlobalType(),
+ globalPtr, ArrayRef<LLVM::GEPArg>{0, 0});
// Construct arguments and function call
auto argsRange = adaptor.getArgs();
@@ -525,7 +528,7 @@ LogicalResult GPUPrintfOpToLLVMCallLowering::matchAndRewrite(
printfArgs.push_back(stringStart);
printfArgs.append(argsRange.begin(), argsRange.end());
- rewriter.create<LLVM::CallOp>(loc, printfDecl, printfArgs);
+ LLVM::CallOp::create(rewriter, loc, printfDecl, printfArgs);
rewriter.eraseOp(gpuPrintfOp);
return success();
}
@@ -559,10 +562,10 @@ LogicalResult GPUPrintfOpToVPrintfLowering::matchAndRewrite(
"printfFormat_", adaptor.getFormat());
// Get a pointer to the format string's first element
- Value globalPtr = rewriter.create<LLVM::AddressOfOp>(loc, global);
+ Value globalPtr = LLVM::AddressOfOp::create(rewriter, loc, global);
Value stringStart =
- rewriter.create<LLVM::GEPOp>(loc, ptrType, global.getGlobalType(),
- globalPtr, ArrayRef<LLVM::GEPArg>{0, 0});
+ LLVM::GEPOp::create(rewriter, loc, ptrType, global.getGlobalType(),
+ globalPtr, ArrayRef<LLVM::GEPArg>{0, 0});
SmallVector<Type> types;
SmallVector<Value> args;
// Promote and pack the arguments into a stack allocation.
@@ -572,27 +575,27 @@ LogicalResult GPUPrintfOpToVPrintfLowering::matchAndRewrite(
assert(type.isIntOrFloat());
if (isa<FloatType>(type)) {
type = rewriter.getF64Type();
- promotedArg = rewriter.create<LLVM::FPExtOp>(loc, type, arg);
+ promotedArg = LLVM::FPExtOp::create(rewriter, loc, type, arg);
}
types.push_back(type);
args.push_back(promotedArg);
}
Type structType =
LLVM::LLVMStructType::getLiteral(gpuPrintfOp.getContext(), types);
- Value one = rewriter.create<LLVM::ConstantOp>(loc, rewriter.getI64Type(),
- rewriter.getIndexAttr(1));
+ Value one = LLVM::ConstantOp::create(rewriter, loc, rewriter.getI64Type(),
+ rewriter.getIndexAttr(1));
Value tempAlloc =
- rewriter.create<LLVM::AllocaOp>(loc, ptrType, structType, one,
- /*alignment=*/0);
+ LLVM::AllocaOp::create(rewriter, loc, ptrType, structType, one,
+ /*alignment=*/0);
for (auto [index, arg] : llvm::enumerate(args)) {
- Value ptr = rewriter.create<LLVM::GEPOp>(
- loc, ptrType, structType, tempAlloc,
+ Value ptr = LLVM::GEPOp::create(
+ rewriter, loc, ptrType, structType, tempAlloc,
ArrayRef<LLVM::GEPArg>{0, static_cast<int32_t>(index)});
- rewriter.create<LLVM::StoreOp>(loc, arg, ptr);
+ LLVM::StoreOp::create(rewriter, loc, arg, ptr);
}
std::array<Value, 2> printfArgs = {stringStart, tempAlloc};
- rewriter.create<LLVM::CallOp>(loc, vprintfDecl, printfArgs);
+ LLVM::CallOp::create(rewriter, loc, vprintfDecl, printfArgs);
rewriter.eraseOp(gpuPrintfOp);
return success();
}
@@ -607,23 +610,23 @@ static Value scalarizeVectorOpHelper(Operation *op, ValueRange operands,
TypeRange operandTypes(operands);
VectorType vectorType = cast<VectorType>(llvm1DVectorTy);
Location loc = op->getLoc();
- Value result = rewriter.create<LLVM::PoisonOp>(loc, vectorType);
+ Value result = LLVM::PoisonOp::create(rewriter, loc, vectorType);
Type indexType = converter.convertType(rewriter.getIndexType());
StringAttr name = op->getName().getIdentifier();
Type elementType = vectorType.getElementType();
for (int64_t i = 0; i < vectorType.getNumElements(); ++i) {
- Value index = rewriter.create<LLVM::ConstantOp>(loc, indexType, i);
+ Value index = LLVM::ConstantOp::create(rewriter, loc, indexType, i);
auto extractElement = [&](Value operand) -> Value {
if (!isa<VectorType>(operand.getType()))
return operand;
- return rewriter.create<LLVM::ExtractElementOp>(loc, operand, index);
+ return LLVM::ExtractElementOp::create(rewriter, loc, operand, index);
};
auto scalarOperands = llvm::map_to_vector(operands, extractElement);
Operation *scalarOp =
rewriter.create(loc, name, scalarOperands, elementType, op->getAttrs());
- result = rewriter.create<LLVM::InsertElementOp>(
- loc, result, scalarOp->getResult(0), index);
+ result = LLVM::InsertElementOp::create(rewriter, loc, result,
+ scalarOp->getResult(0), index);
}
return result;
}
@@ -705,10 +708,10 @@ LLVM::GlobalOp getDynamicSharedMemorySymbol(
auto zeroSizedArrayType = LLVM::LLVMArrayType::get(
typeConverter->convertType(memrefType.getElementType()), 0);
- return rewriter.create<LLVM::GlobalOp>(
- op->getLoc(), zeroSizedArrayType, /*isConstant=*/false,
- LLVM::Linkage::Internal, symName, /*value=*/Attribute(), alignmentByte,
- addressSpace.value());
+ return LLVM::GlobalOp::create(rewriter, op->getLoc(), zeroSizedArrayType,
+ /*isConstant=*/false, LLVM::Linkage::Internal,
+ symName, /*value=*/Attribute(), alignmentByte,
+ addressSpace.value());
}
LogicalResult GPUDynamicSharedMemoryOpLowering::matchAndRewrite(
@@ -732,13 +735,13 @@ LogicalResult GPUDynamicSharedMemoryOpLowering::matchAndRewrite(
// Step 3. Get address of the global symbol
OpBuilder::InsertionGuard guard(rewriter);
rewriter.setInsertionPoint(op);
- auto basePtr = rewriter.create<LLVM::AddressOfOp>(loc, shmemOp);
+ auto basePtr = LLVM::AddressOfOp::create(rewriter, loc, shmemOp);
Type baseType = basePtr->getResultTypes().front();
// Step 4. Generate GEP using offsets
SmallVector<LLVM::GEPArg> gepArgs = {0};
- Value shmemPtr = rewriter.create<LLVM::GEPOp>(loc, baseType, elementType,
- basePtr, gepArgs);
+ Value shmemPtr = LLVM::GEPOp::create(rewriter, loc, baseType, elementType,
+ basePtr, gepArgs);
// Step 5. Create a memref descriptor
SmallVector<Value> shape, strides;
Value sizeBytes;
@@ -799,9 +802,9 @@ LogicalResult GPUReturnOpLowering::matchAndRewrite(
return rewriter.notifyMatchFailure(op, "could not convert result types");
}
- Value packed = rewriter.create<LLVM::PoisonOp>(loc, packedType);
+ Value packed = LLVM::PoisonOp::create(rewriter, loc, packedType);
for (auto [idx, operand] : llvm::enumerate(updatedOperands)) {
- packed = rewriter.create<LLVM::InsertValueOp>(loc, packed, operand, idx);
+ packed = LLVM::InsertValueOp::create(rewriter, loc, packed, operand, idx);
}
rewriter.replaceOpWithNewOp<LLVM::ReturnOp>(op, TypeRange(), packed,
op->getAttrs());
diff --git a/mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp b/mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp
index 167cabbc57db9..63eb6c58e87a7 100644
--- a/mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp
+++ b/mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp
@@ -79,8 +79,8 @@ class ConvertOpToGpuRuntimeCallPattern : public ConvertOpToLLVMPattern<OpTy> {
uint64_t rank = type.getRank();
Value numElements = desc.size(rewriter, loc, /*pos=*/0);
for (unsigned i = 1; i < rank; i++)
- numElements = rewriter.create<LLVM::MulOp>(
- loc, numElements, desc.size(rewriter, loc, /*pos=*/i));
+ numElements = LLVM::MulOp::create(rewriter, loc, numElements,
+ desc.size(rewriter, loc, /*pos=*/i));
return numElements;
}
@@ -582,7 +582,7 @@ LLVM::CallOp FunctionCallBuilder::create(Location loc, OpBuilder &builder,
return OpBuilder::atBlockEnd(module.getBody())
.create<LLVM::LLVMFuncOp>(loc, functionName, functionType);
}();
- return builder.create<LLVM::CallOp>(loc, function, arguments);
+ return LLVM::CallOp::create(builder, loc, function, arguments);
}
// Corresponding to cusparseIndexType_t defined in cusparse.h.
@@ -780,13 +780,13 @@ LogicalResult ConvertAllocOpToGpuRuntimeCallPattern::matchAndRewrite(
// Allocate the underlying buffer and store a pointer to it in the MemRef
// descriptor.
- auto nullPtr = rewriter.create<mlir::LLVM::ZeroOp>(loc, llvmPointerType);
+ auto nullPtr = mlir::LLVM::ZeroOp::create(rewriter, loc, llvmPointerType);
Value stream = adaptor.getAsyncDependencies().empty()
? nullPtr
: adaptor.getAsyncDependencies().front();
- auto isHostShared = rewriter.create<mlir::LLVM::ConstantOp>(
- loc, llvmInt8Type, rewriter.getI8IntegerAttr(isShared));
+ auto isHostShared = mlir::LLVM::ConstantOp::create(
+ rewriter, loc, llvmInt8Type, rewriter.getI8IntegerAttr(isShared));
Value allocatedPtr =
allocCallBuilder.create(loc, rewriter, {sizeBytes, stream, isHostShared})
@@ -1012,8 +1012,8 @@ LogicalResult LegalizeLaunchFuncOpPattern::matchAndRewrite(
uint64_t staticSize = static_cast<uint64_t>(bitwidth / 8) *
static_cast<uint64_t>(memrefTy.getNumElements());
- Value sizeArg = rewriter.create<LLVM::ConstantOp>(
- loc, getIndexType(), rewriter.getIndexAttr(staticSize));
+ Value sizeArg = LLVM::ConstantOp::create(
+ rewriter, loc, getIndexType(), rewriter.getIndexAttr(staticSize));
llvmArgumentsWithSizes.push_back(llvmArg); // Presumably a bare pointer.
llvmArgumentsWithSizes.push_back(sizeArg);
}
@@ -1025,8 +1025,8 @@ LogicalResult LegalizeLaunchFuncOpPattern::matchAndRewrite(
gpu::KernelDim3{adaptor.getClusterSizeX(), adaptor.getClusterSizeY(),
adaptor.getClusterSizeZ()};
}
- rewriter.create<gpu::LaunchFuncOp>(
- launchOp.getLoc(), launchOp.getKernelAttr(),
+ gpu::LaunchFuncOp::create(
+ rewriter, launchOp.getLoc(), launchOp.getKernelAttr(),
gpu::KernelDim3{adaptor.getGridSizeX(), adaptor.getGridSizeY(),
adaptor.getGridSizeZ()},
gpu::KernelDim3{adaptor.getBlockSizeX(), adaptor.getBlockSizeY(),
@@ -1048,8 +1048,8 @@ static Value bitAndAddrspaceCast(Location loc,
const LLVMTypeConverter &typeConverter) {
auto sourceTy = cast<LLVM::LLVMPointerType>(sourcePtr.getType());
if (destinationType.getAddressSpace() != sourceTy.getAddressSpace())
- sourcePtr = rewriter.create<LLVM::AddrSpaceCastOp>(
- loc,
+ sourcePtr = LLVM::AddrSpaceCastOp::create(
+ rewriter, loc,
LLVM::LLVMPointerType::get(rewriter.getContext(),
destinationType.getAddressSpace()),
sourcePtr);
@@ -1072,13 +1072,13 @@ LogicalResult ConvertMemcpyOpToGpuRuntimeCallPattern::matchAndRewrite(
Value numElements = getNumElements(rewriter, loc, memRefType, srcDesc);
Type elementPtrType = getElementPtrType(memRefType);
- Value nullPtr = rewriter.create<LLVM::ZeroOp>(loc, elementPtrType);
- Value gepPtr = rewriter.create<LLVM::GEPOp>(
- loc, elementPtrType,
+ Value nullPtr = LLVM::ZeroOp::create(rewriter, loc, elementPtrType);
+ Value gepPtr = LLVM::GEPOp::create(
+ rewriter, loc, elementPtrType,
typeConverter->convertType(memRefType.getElementType()), nullPtr,
numElements);
auto sizeBytes =
- rewriter.create<LLVM::PtrToIntOp>(loc, getIndexType(), gepPtr);
+ LLVM::PtrToIntOp::create(rewriter, loc, getIndexType(), gepPtr);
auto src = bitAndAddrspaceCast(loc, rewriter, llvmPointerType,
srcDesc.alignedPtr(rewriter, loc),
@@ -1123,7 +1123,7 @@ LogicalResult ConvertMemsetOpToGpuRuntimeCallPattern::matchAndRewrite(
Value numElements = getNumElements(rewriter, loc, memRefType, dstDesc);
auto value =
- rewriter.create<LLVM::BitcastOp>(loc, bitCastType, adaptor.getValue());
+ LLVM::BitcastOp::create(rewriter, loc, bitCastType, adaptor.getValue());
auto dst = bitAndAddrspaceCast(loc, rewriter, llvmPointerType,
dstDesc.alignedPtr(rewriter, loc),
*getTypeConverter());
@@ -1150,15 +1150,15 @@ LogicalResult ConvertSetDefaultDeviceOpToGpuRuntimeCallPattern::matchAndRewrite(
template <typename T>
static Value genConstInt32From(OpBuilder &builder, Location loc, T tValue) {
Type llvmInt32Type = builder.getIntegerType(32);
- return builder.create<LLVM::ConstantOp>(loc, llvmInt32Type,
- static_cast<int32_t>(tValue));
+ return LLVM::ConstantOp::create(builder, loc, llvmInt32Type,
+ static_cast<int32_t>(tValue));
}
template <typename T>
static Value genConstFloat32From(OpBuilder &builder, Location loc, T tValue) {
Type llvmFloat32Type = builder.getF32Type();
- return builder.create<LLVM::ConstantOp>(
- loc, llvmFloat32Type,
+ return LLVM::ConstantOp::create(
+ builder, loc, llvmFloat32Type,
builder.getF32FloatAttr(static_cast<float>(tValue)));
}
@@ -1189,11 +1189,11 @@ LogicalResult ConvertCreateDnTensorOpToGpuRuntimeCallPattern::matchAndRewrite(
// the dnmat is used with spmat with 2:4 sparsity
if (dims.size() == 2) {
if (isSpMMCusparseLtOp(op.getDnTensor())) {
- auto handleSz = rewriter.create<LLVM::ConstantOp>(
- loc, getIndexType(), rewriter.getIndexAttr(11032));
- handle = rewriter.create<LLVM::AllocaOp>(
- loc, llvmPointerType, llvmInt8Type, handleSz, /*alignment=*/16);
- handle = rewriter.create<LLVM::BitcastOp>(loc, llvmPointerType, handle);
+ auto handleSz = LLVM::ConstantOp::create(rewriter, loc, getIndexType(),
+ rewriter.getIndexAttr(11032));
+ handle = LLVM::AllocaOp::create(rewriter, loc, llvmPointerType,
+ llvmInt8Type, handleSz, /*alignment=*/16);
+ handle = LLVM::BitcastOp::create(rewriter, loc, llvmPointerType, handle);
createLtDnMatCallBuilder
.create(loc, rewriter,
@@ -1351,11 +1351,11 @@ LogicalResult ConvertCreate2To4SpMatOpToGpuRuntimeCallPattern::matchAndRewrite(
auto dtp = genConstInt32From(rewriter, loc, getCuSparseDataTypeFrom(dType));
// CUDA runner asserts the size is 44104 bytes.
- auto handleSz = rewriter.create<LLVM::ConstantOp>(
- loc, getIndexType(), rewriter.getIndexAttr(44104));
- Value handle = rewriter.create<LLVM::AllocaOp>(
- loc, llvmPointerType, llvmInt8Type, handleSz, /*alignment=*/16);
- handle = rewriter.create<LLVM::BitcastOp>(loc, llvmPointerType, handle);
+ auto handleSz = LLVM::ConstantOp::create(rewriter, loc, getIndexType(),
+ rewriter.getIndexAttr(44104));
+ Value handle = LLVM::AllocaOp::create(
+ rewriter, loc, llvmPointerType, llvmInt8Type, handleSz, /*alignment=*/16);
+ handle = LLVM::BitcastOp::create(rewriter, loc, llvmPointerType, handle);
create2To4SpMatCallBuilder
.create(loc, rewriter,
@@ -1441,10 +1441,11 @@ LogicalResult ConvertSpMMBufferSizeOpToGpuRuntimeCallPattern::matchAndRewrite(
genConstInt32From(rewriter, loc, get2To4PruneFlag(op.getSpmatA()));
auto computeType = genConstInt32From(
rewriter, loc, getCuSparseLtDataTypeFrom(adaptor.getComputeType()));
- auto three = rewriter.create<LLVM::ConstantOp>(loc, getIndexType(),
- rewriter.getIndexAttr(3));
- auto bufferSize = rewriter.create<LLVM::AllocaOp>(
- loc, llvmPointerType, llvmPointerType, three, /*alignment=*/16);
+ auto three = LLVM::ConstantOp::create(rewriter, loc, getIndexType(),
+ rewriter.getIndexAttr(3));
+ auto bufferSize =
+ LLVM::AllocaOp::create(rewriter, loc, llvmPointerType, llvmPointerType,
+ three, /*alignment=*/16);
createCuSparseLtSpMMBufferSizeBuilder
.create(loc, rewriter,
{bufferSize, modeA, modeB, adaptor.getSpmatA(),
@@ -1452,20 +1453,20 @@ LogicalResult ConvertSpMMBufferSizeOpToGpuRuntimeCallPattern::matchAndRewrite(
pruneFlag, stream})
.getResult();
- auto bufferSizePtr1 = rewriter.create<LLVM::GEPOp>(
- loc, llvmPointerType, llvmPointerType, bufferSize,
- ValueRange{rewriter.create<LLVM::ConstantOp>(
- loc, getIndexType(), rewriter.getIndexAttr(1))});
- auto bufferSizePtr2 = rewriter.create<LLVM::GEPOp>(
- loc, llvmPointerType, llvmPointerType, bufferSize,
- ValueRange{rewriter.create<LLVM::ConstantOp>(
- loc, getIndexType(), rewriter.getIndexAttr(2))});
+ auto bufferSizePtr1 = LLVM::GEPOp::create(
+ rewriter, loc, llvmPointerType, llvmPointerType, bufferSize,
+ ValueRange{LLVM::ConstantOp::create(rewriter, loc, getIndexType(),
+ rewriter.getIndexAttr(1))});
+ auto bufferSizePtr2 = LLVM::GEPOp::create(
+ rewriter, loc, llvmPointerType, llvmPointerType, bufferSize,
+ ValueRange{LLVM::ConstantOp::create(rewriter, loc, getIndexType(),
+ rewriter.getIndexAttr(2))});
auto bufferSize0 =
- rewriter.create<LLVM::LoadOp>(loc, llvmInt64Type, bufferSize);
+ LLVM::LoadOp::create(rewriter, loc, llvmInt64Type, bufferSize);
auto bufferSize1 =
- rewriter.create<LLVM::LoadOp>(loc, llvmInt64Type, bufferSizePtr1);
+ LLVM::LoadOp::create(rewriter, loc, llvmInt64Type, bufferSizePtr1);
auto bufferSize2 =
- rewriter.create<LLVM::LoadOp>(loc, llvmInt64Type, bufferSizePtr2);
+ LLVM::LoadOp::create(rewriter, loc, llvmInt64Type, bufferSizePtr2);
rewriter.replaceOp(op, {bufferSize0, bufferSize1, bufferSize2, stream});
} else {
@@ -1669,28 +1670,28 @@ LogicalResult ConvertSpMatGetSizeOpToGpuRuntimeCallPattern::matchAndRewrite(
Location loc = op.getLoc();
auto stream = adaptor.getAsyncDependencies().front();
- auto three = rewriter.create<LLVM::ConstantOp>(loc, getIndexType(),
- rewriter.getIndexAttr(3));
- auto buffer = rewriter.create<LLVM::AllocaOp>(
- loc, llvmPointerType, llvmInt64Type, three, /*alignment=*/16);
-
- auto rowsPtr = rewriter.create<LLVM::GEPOp>(
- loc, llvmPointerType, llvmPointerType, buffer,
- ValueRange{rewriter.create<LLVM::ConstantOp>(loc, getIndexType(),
- rewriter.getIndexAttr(0))});
- auto colsPtr = rewriter.create<LLVM::GEPOp>(
- loc, llvmPointerType, llvmPointerType, buffer,
- ValueRange{rewriter.create<LLVM::ConstantOp>(loc, getIndexType(),
- rewriter.getIndexAttr(1))});
- auto nnzsPtr = rewriter.create<LLVM::GEPOp>(
- loc, llvmPointerType, llvmPointerType, buffer,
- ValueRange{rewriter.create<LLVM::ConstantOp>(loc, getIndexType(),
- rewriter.getIndexAttr(2))});
+ auto three = LLVM::ConstantOp::create(rewriter, loc, getIndexType(),
+ rewriter.getIndexAttr(3));
+ auto buffer = LLVM::AllocaOp::create(rewriter, loc, llvmPointerType,
+ llvmInt64Type, three, /*alignment=*/16);
+
+ auto rowsPtr = LLVM::GEPOp::create(
+ rewriter, loc, llvmPointerType, llvmPointerType, buffer,
+ ValueRange{LLVM::ConstantOp::create(rewriter, loc, getIndexType(),
+ rewriter.getIndexAttr(0))});
+ auto colsPtr = LLVM::GEPOp::create(
+ rewriter, loc, llvmPointerType, llvmPointerType, buffer,
+ ValueRange{LLVM::ConstantOp::create(rewriter, loc, getIndexType(),
+ rewriter.getIndexAttr(1))});
+ auto nnzsPtr = LLVM::GEPOp::create(
+ rewriter, loc, llvmPointerType, llvmPointerType, buffer,
+ ValueRange{LLVM::ConstantOp::create(rewriter, loc, getIndexType(),
+ rewriter.getIndexAttr(2))});
createSpMatGetSizeBuilder.create(
loc, rewriter, {adaptor.getSpmat(), rowsPtr, colsPtr, nnzsPtr, stream});
- auto rows = rewriter.create<LLVM::LoadOp>(loc, llvmInt64Type, rowsPtr);
- auto cols = rewriter.create<LLVM::LoadOp>(loc, llvmInt64Type, colsPtr);
- auto nnzs = rewriter.create<LLVM::LoadOp>(loc, llvmInt64Type, nnzsPtr);
+ auto rows = LLVM::LoadOp::create(rewriter, loc, llvmInt64Type, rowsPtr);
+ auto cols = LLVM::LoadOp::create(rewriter, loc, llvmInt64Type, colsPtr);
+ auto nnzs = LLVM::LoadOp::create(rewriter, loc, llvmInt64Type, nnzsPtr);
rewriter.replaceOp(op, {rows, cols, nnzs, stream});
return success();
diff --git a/mlir/lib/Conversion/GPUCommon/IndexIntrinsicsOpLowering.h b/mlir/lib/Conversion/GPUCommon/IndexIntrinsicsOpLowering.h
index aab2409ed6328..91c43e8bd1117 100644
--- a/mlir/lib/Conversion/GPUCommon/IndexIntrinsicsOpLowering.h
+++ b/mlir/lib/Conversion/GPUCommon/IndexIntrinsicsOpLowering.h
@@ -59,13 +59,13 @@ struct OpLowering : public ConvertOpToLLVMPattern<Op> {
Operation *newOp;
switch (op.getDimension()) {
case gpu::Dimension::x:
- newOp = rewriter.create<XOp>(loc, IntegerType::get(context, 32));
+ newOp = XOp::create(rewriter, loc, IntegerType::get(context, 32));
break;
case gpu::Dimension::y:
- newOp = rewriter.create<YOp>(loc, IntegerType::get(context, 32));
+ newOp = YOp::create(rewriter, loc, IntegerType::get(context, 32));
break;
case gpu::Dimension::z:
- newOp = rewriter.create<ZOp>(loc, IntegerType::get(context, 32));
+ newOp = ZOp::create(rewriter, loc, IntegerType::get(context, 32));
break;
}
@@ -124,11 +124,13 @@ struct OpLowering : public ConvertOpToLLVMPattern<Op> {
rewriter.getContext(), 32, min, max));
}
if (indexBitwidth > 32) {
- newOp = rewriter.create<LLVM::SExtOp>(
- loc, IntegerType::get(context, indexBitwidth), newOp->getResult(0));
+ newOp = LLVM::SExtOp::create(rewriter, loc,
+ IntegerType::get(context, indexBitwidth),
+ newOp->getResult(0));
} else if (indexBitwidth < 32) {
- newOp = rewriter.create<LLVM::TruncOp>(
- loc, IntegerType::get(context, indexBitwidth), newOp->getResult(0));
+ newOp = LLVM::TruncOp::create(rewriter, loc,
+ IntegerType::get(context, indexBitwidth),
+ newOp->getResult(0));
}
rewriter.replaceOp(op, newOp->getResults());
diff --git a/mlir/lib/Conversion/GPUCommon/OpToFuncCallLowering.h b/mlir/lib/Conversion/GPUCommon/OpToFuncCallLowering.h
index 64cf09e600b88..9f36e5c369d06 100644
--- a/mlir/lib/Conversion/GPUCommon/OpToFuncCallLowering.h
+++ b/mlir/lib/Conversion/GPUCommon/OpToFuncCallLowering.h
@@ -103,7 +103,7 @@ struct OpToFuncCallLowering : public ConvertOpToLLVMPattern<SourceOp> {
LLVMFuncOp funcOp = appendOrGetFuncOp(funcName, funcType, op);
auto callOp =
- rewriter.create<LLVM::CallOp>(op->getLoc(), funcOp, castedOperands);
+ LLVM::CallOp::create(rewriter, op->getLoc(), funcOp, castedOperands);
if (resultType == adaptor.getOperands().front().getType()) {
rewriter.replaceOp(op, {callOp.getResult()});
@@ -115,19 +115,20 @@ struct OpToFuncCallLowering : public ConvertOpToLLVMPattern<SourceOp> {
// there is no guarantee of a specific value being used to indicate true,
// compare for inequality with zero (rather than truncate or shift).
if (isResultBool) {
- Value zero = rewriter.create<LLVM::ConstantOp>(
- op->getLoc(), rewriter.getIntegerType(32),
- rewriter.getI32IntegerAttr(0));
- Value truncated = rewriter.create<LLVM::ICmpOp>(
- op->getLoc(), LLVM::ICmpPredicate::ne, callOp.getResult(), zero);
+ Value zero = LLVM::ConstantOp::create(rewriter, op->getLoc(),
+ rewriter.getIntegerType(32),
+ rewriter.getI32IntegerAttr(0));
+ Value truncated =
+ LLVM::ICmpOp::create(rewriter, op->getLoc(), LLVM::ICmpPredicate::ne,
+ callOp.getResult(), zero);
rewriter.replaceOp(op, {truncated});
return success();
}
assert(callOp.getResult().getType().isF32() &&
"only f32 types are supposed to be truncated back");
- Value truncated = rewriter.create<LLVM::FPTruncOp>(
- op->getLoc(), adaptor.getOperands().front().getType(),
+ Value truncated = LLVM::FPTruncOp::create(
+ rewriter, op->getLoc(), adaptor.getOperands().front().getType(),
callOp.getResult());
rewriter.replaceOp(op, {truncated});
return success();
@@ -142,8 +143,9 @@ struct OpToFuncCallLowering : public ConvertOpToLLVMPattern<SourceOp> {
if (!f16Func.empty() && isa<Float16Type>(type))
return operand;
- return rewriter.create<LLVM::FPExtOp>(
- operand.getLoc(), Float32Type::get(rewriter.getContext()), operand);
+ return LLVM::FPExtOp::create(rewriter, operand.getLoc(),
+ Float32Type::get(rewriter.getContext()),
+ operand);
}
Type getFunctionType(Type resultType, ValueRange operands) const {
@@ -169,7 +171,7 @@ struct OpToFuncCallLowering : public ConvertOpToLLVMPattern<SourceOp> {
// location as debug info metadata inside of a function cannot be used
// outside of that function.
auto globalloc = op->getLoc()->findInstanceOfOrUnknown<FileLineColLoc>();
- return b.create<LLVMFuncOp>(globalloc, funcName, funcType);
+ return LLVMFuncOp::create(b, globalloc, funcName, funcType);
}
StringRef getFunctionName(Type type, SourceOp op) const {
diff --git a/mlir/lib/Conversion/GPUToLLVMSPV/GPUToLLVMSPV.cpp b/mlir/lib/Conversion/GPUToLLVMSPV/GPUToLLVMSPV.cpp
index 8b6b553f6eed0..c2363a1a40294 100644
--- a/mlir/lib/Conversion/GPUToLLVMSPV/GPUToLLVMSPV.cpp
+++ b/mlir/lib/Conversion/GPUToLLVMSPV/GPUToLLVMSPV.cpp
@@ -54,8 +54,8 @@ static LLVM::LLVMFuncOp lookupOrCreateSPIRVFn(Operation *symbolTable,
SymbolTable::lookupSymbolIn(symbolTable, name));
if (!func) {
OpBuilder b(symbolTable->getRegion(0));
- func = b.create<LLVM::LLVMFuncOp>(
- symbolTable->getLoc(), name,
+ func = LLVM::LLVMFuncOp::create(
+ b, symbolTable->getLoc(), name,
LLVM::LLVMFunctionType::get(resultType, paramTypes));
func.setCConv(LLVM::cconv::CConv::SPIR_FUNC);
func.setNoUnwind(true);
@@ -79,7 +79,7 @@ static LLVM::CallOp createSPIRVBuiltinCall(Location loc,
ConversionPatternRewriter &rewriter,
LLVM::LLVMFuncOp func,
ValueRange args) {
- auto call = rewriter.create<LLVM::CallOp>(loc, func, args);
+ auto call = LLVM::CallOp::create(rewriter, loc, func, args);
call.setCConv(func.getCConv());
call.setConvergentAttr(func.getConvergentAttr());
call.setNoUnwindAttr(func.getNoUnwindAttr());
@@ -121,7 +121,7 @@ struct GPUBarrierConversion final : ConvertOpToLLVMPattern<gpu::BarrierOp> {
constexpr int64_t localMemFenceFlag = 1;
Location loc = op->getLoc();
Value flag =
- rewriter.create<LLVM::ConstantOp>(loc, flagTy, localMemFenceFlag);
+ LLVM::ConstantOp::create(rewriter, loc, flagTy, localMemFenceFlag);
rewriter.replaceOp(op, createSPIRVBuiltinCall(loc, rewriter, func, flag));
return success();
}
@@ -162,8 +162,8 @@ struct LaunchConfigConversion : ConvertToLLVMPattern {
Location loc = op->getLoc();
gpu::Dimension dim = getDimension(op);
- Value dimVal = rewriter.create<LLVM::ConstantOp>(loc, dimTy,
- static_cast<int64_t>(dim));
+ Value dimVal = LLVM::ConstantOp::create(rewriter, loc, dimTy,
+ static_cast<int64_t>(dim));
rewriter.replaceOp(op, createSPIRVBuiltinCall(loc, rewriter, func, dimVal));
return success();
}
@@ -291,13 +291,13 @@ struct GPUShuffleConversion final : ConvertOpToLLVMPattern<gpu::ShuffleOp> {
ConversionPatternRewriter &rewriter) {
return TypeSwitch<Type, Value>(oldVal.getType())
.Case([&](BFloat16Type) {
- return rewriter.create<LLVM::BitcastOp>(loc, rewriter.getI16Type(),
- oldVal);
+ return LLVM::BitcastOp::create(rewriter, loc, rewriter.getI16Type(),
+ oldVal);
})
.Case([&](IntegerType intTy) -> Value {
if (intTy.getWidth() == 1)
- return rewriter.create<LLVM::ZExtOp>(loc, rewriter.getI8Type(),
- oldVal);
+ return LLVM::ZExtOp::create(rewriter, loc, rewriter.getI8Type(),
+ oldVal);
return oldVal;
})
.Default(oldVal);
@@ -308,11 +308,11 @@ struct GPUShuffleConversion final : ConvertOpToLLVMPattern<gpu::ShuffleOp> {
ConversionPatternRewriter &rewriter) {
return TypeSwitch<Type, Value>(newTy)
.Case([&](BFloat16Type) {
- return rewriter.create<LLVM::BitcastOp>(loc, newTy, oldVal);
+ return LLVM::BitcastOp::create(rewriter, loc, newTy, oldVal);
})
.Case([&](IntegerType intTy) -> Value {
if (intTy.getWidth() == 1)
- return rewriter.create<LLVM::TruncOp>(loc, newTy, oldVal);
+ return LLVM::TruncOp::create(rewriter, loc, newTy, oldVal);
return oldVal;
})
.Default(oldVal);
@@ -349,7 +349,7 @@ struct GPUShuffleConversion final : ConvertOpToLLVMPattern<gpu::ShuffleOp> {
bitcastOrTruncAfterShuffle(result, op.getType(0), loc, rewriter);
Value trueVal =
- rewriter.create<LLVM::ConstantOp>(loc, rewriter.getI1Type(), true);
+ LLVM::ConstantOp::create(rewriter, loc, rewriter.getI1Type(), true);
rewriter.replaceOp(op, {resultOrConversion, trueVal});
return success();
}
@@ -426,7 +426,7 @@ struct GPUSubgroupOpConversion final : ConvertOpToLLVMPattern<SubgroupOp> {
if (indexTy.getIntOrFloatBitWidth() < resultTy.getIntOrFloatBitWidth()) {
return failure();
}
- result = rewriter.create<LLVM::ZExtOp>(loc, indexTy, result);
+ result = LLVM::ZExtOp::create(rewriter, loc, indexTy, result);
}
rewriter.replaceOp(op, result);
diff --git a/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp b/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp
index 1ef6edea93c58..317bfc2970cf5 100644
--- a/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp
+++ b/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp
@@ -118,10 +118,10 @@ struct GPUSubgroupReduceOpLowering
Location loc = op->getLoc();
auto int32Type = IntegerType::get(rewriter.getContext(), 32);
- Value offset = rewriter.create<LLVM::ConstantOp>(loc, int32Type, -1);
+ Value offset = LLVM::ConstantOp::create(rewriter, loc, int32Type, -1);
- auto reduxOp = rewriter.create<NVVM::ReduxOp>(loc, int32Type, op.getValue(),
- mode.value(), offset);
+ auto reduxOp = NVVM::ReduxOp::create(rewriter, loc, int32Type,
+ op.getValue(), mode.value(), offset);
rewriter.replaceOp(op, reduxOp->getResult(0));
return success();
@@ -158,22 +158,22 @@ struct GPUShuffleOpLowering : public ConvertOpToLLVMPattern<gpu::ShuffleOp> {
auto int32Type = IntegerType::get(rewriter.getContext(), 32);
auto predTy = IntegerType::get(rewriter.getContext(), 1);
- Value one = rewriter.create<LLVM::ConstantOp>(loc, int32Type, 1);
- Value minusOne = rewriter.create<LLVM::ConstantOp>(loc, int32Type, -1);
- Value thirtyTwo = rewriter.create<LLVM::ConstantOp>(loc, int32Type, 32);
- Value numLeadInactiveLane = rewriter.create<LLVM::SubOp>(
- loc, int32Type, thirtyTwo, adaptor.getWidth());
+ Value one = LLVM::ConstantOp::create(rewriter, loc, int32Type, 1);
+ Value minusOne = LLVM::ConstantOp::create(rewriter, loc, int32Type, -1);
+ Value thirtyTwo = LLVM::ConstantOp::create(rewriter, loc, int32Type, 32);
+ Value numLeadInactiveLane = LLVM::SubOp::create(
+ rewriter, loc, int32Type, thirtyTwo, adaptor.getWidth());
// Bit mask of active lanes: `(-1) >> (32 - activeWidth)`.
- Value activeMask = rewriter.create<LLVM::LShrOp>(loc, int32Type, minusOne,
- numLeadInactiveLane);
+ Value activeMask = LLVM::LShrOp::create(rewriter, loc, int32Type, minusOne,
+ numLeadInactiveLane);
Value maskAndClamp;
if (op.getMode() == gpu::ShuffleMode::UP) {
// Clamp lane: `32 - activeWidth`
maskAndClamp = numLeadInactiveLane;
} else {
// Clamp lane: `activeWidth - 1`
- maskAndClamp =
- rewriter.create<LLVM::SubOp>(loc, int32Type, adaptor.getWidth(), one);
+ maskAndClamp = LLVM::SubOp::create(rewriter, loc, int32Type,
+ adaptor.getWidth(), one);
}
bool predIsUsed = !op->getResult(1).use_empty();
@@ -184,13 +184,14 @@ struct GPUShuffleOpLowering : public ConvertOpToLLVMPattern<gpu::ShuffleOp> {
resultTy = LLVM::LLVMStructType::getLiteral(rewriter.getContext(),
{valueTy, predTy});
}
- Value shfl = rewriter.create<NVVM::ShflOp>(
- loc, resultTy, activeMask, adaptor.getValue(), adaptor.getOffset(),
- maskAndClamp, convertShflKind(op.getMode()), returnValueAndIsValidAttr);
+ Value shfl = NVVM::ShflOp::create(
+ rewriter, loc, resultTy, activeMask, adaptor.getValue(),
+ adaptor.getOffset(), maskAndClamp, convertShflKind(op.getMode()),
+ returnValueAndIsValidAttr);
if (predIsUsed) {
- Value shflValue = rewriter.create<LLVM::ExtractValueOp>(loc, shfl, 0);
+ Value shflValue = LLVM::ExtractValueOp::create(rewriter, loc, shfl, 0);
Value isActiveSrcLane =
- rewriter.create<LLVM::ExtractValueOp>(loc, shfl, 1);
+ LLVM::ExtractValueOp::create(rewriter, loc, shfl, 1);
rewriter.replaceOp(op, {shflValue, isActiveSrcLane});
} else {
rewriter.replaceOp(op, {shfl, nullptr});
@@ -215,16 +216,16 @@ struct GPULaneIdOpToNVVM : ConvertOpToLLVMPattern<gpu::LaneIdOp> {
bounds = rewriter.getAttr<LLVM::ConstantRangeAttr>(
/*bitWidth=*/32, /*lower=*/0, /*upper=*/kWarpSize);
Value newOp =
- rewriter.create<NVVM::LaneIdOp>(loc, rewriter.getI32Type(), bounds);
+ NVVM::LaneIdOp::create(rewriter, loc, rewriter.getI32Type(), bounds);
// Truncate or extend the result depending on the index bitwidth specified
// by the LLVMTypeConverter options.
const unsigned indexBitwidth = getTypeConverter()->getIndexTypeBitwidth();
if (indexBitwidth > 32) {
- newOp = rewriter.create<LLVM::SExtOp>(
- loc, IntegerType::get(context, indexBitwidth), newOp);
+ newOp = LLVM::SExtOp::create(
+ rewriter, loc, IntegerType::get(context, indexBitwidth), newOp);
} else if (indexBitwidth < 32) {
- newOp = rewriter.create<LLVM::TruncOp>(
- loc, IntegerType::get(context, indexBitwidth), newOp);
+ newOp = LLVM::TruncOp::create(
+ rewriter, loc, IntegerType::get(context, indexBitwidth), newOp);
}
rewriter.replaceOp(op, {newOp});
return success();
@@ -271,10 +272,10 @@ struct AssertOpToAssertfailLowering
Block *afterBlock =
rewriter.splitBlock(assertBlock, ++assertOp->getIterator());
rewriter.setInsertionPointToEnd(beforeBlock);
- rewriter.create<cf::CondBranchOp>(loc, adaptor.getArg(), afterBlock,
- assertBlock);
+ cf::CondBranchOp::create(rewriter, loc, adaptor.getArg(), afterBlock,
+ assertBlock);
rewriter.setInsertionPointToEnd(assertBlock);
- rewriter.create<cf::BranchOp>(loc, afterBlock);
+ cf::BranchOp::create(rewriter, loc, afterBlock);
// Continue cf.assert lowering.
rewriter.setInsertionPoint(assertOp);
@@ -301,12 +302,12 @@ struct AssertOpToAssertfailLowering
// Create constants.
auto getGlobal = [&](LLVM::GlobalOp global) {
// Get a pointer to the format string's first element.
- Value globalPtr = rewriter.create<LLVM::AddressOfOp>(
- loc, LLVM::LLVMPointerType::get(ctx, global.getAddrSpace()),
+ Value globalPtr = LLVM::AddressOfOp::create(
+ rewriter, loc, LLVM::LLVMPointerType::get(ctx, global.getAddrSpace()),
global.getSymNameAttr());
Value start =
- rewriter.create<LLVM::GEPOp>(loc, ptrType, global.getGlobalType(),
- globalPtr, ArrayRef<LLVM::GEPArg>{0, 0});
+ LLVM::GEPOp::create(rewriter, loc, ptrType, global.getGlobalType(),
+ globalPtr, ArrayRef<LLVM::GEPArg>{0, 0});
return start;
};
Value assertMessage = getGlobal(getOrCreateStringConstant(
@@ -316,8 +317,8 @@ struct AssertOpToAssertfailLowering
Value assertFunc = getGlobal(getOrCreateStringConstant(
rewriter, loc, moduleOp, i8Type, "assert_func_", funcName));
Value assertLine =
- rewriter.create<LLVM::ConstantOp>(loc, i32Type, fileLine);
- Value c1 = rewriter.create<LLVM::ConstantOp>(loc, i64Type, 1);
+ LLVM::ConstantOp::create(rewriter, loc, i32Type, fileLine);
+ Value c1 = LLVM::ConstantOp::create(rewriter, loc, i64Type, 1);
// Insert function call to __assertfail.
SmallVector<Value> arguments{assertMessage, assertFile, assertLine,
diff --git a/mlir/lib/Conversion/GPUToNVVM/WmmaOpsToNvvm.cpp b/mlir/lib/Conversion/GPUToNVVM/WmmaOpsToNvvm.cpp
index 45fd933d58857..99c059cb03299 100644
--- a/mlir/lib/Conversion/GPUToNVVM/WmmaOpsToNvvm.cpp
+++ b/mlir/lib/Conversion/GPUToNVVM/WmmaOpsToNvvm.cpp
@@ -126,8 +126,8 @@ struct WmmaLoadOpToNVVMLowering
cast<MemRefType>(subgroupMmaLoadMatrixOp.getSrcMemref().getType()),
adaptor.getSrcMemref(), adaptor.getIndices());
- Value leadingDim = rewriter.create<LLVM::ConstantOp>(
- loc, rewriter.getI32Type(),
+ Value leadingDim = LLVM::ConstantOp::create(
+ rewriter, loc, rewriter.getI32Type(),
subgroupMmaLoadMatrixOp.getLeadDimensionAttr());
rewriter.replaceOpWithNewOp<NVVM::WMMALoadOp>(
op, resType, dataPtr, leadingDim, m, n, k, layout, eltype, frag);
@@ -173,7 +173,7 @@ struct WmmaStoreOpToNVVMLowering
auto matrixType = cast<LLVM::LLVMStructType>(adaptor.getSrc().getType());
for (unsigned i = 0, e = matrixType.getBody().size(); i < e; ++i) {
Value toUse =
- rewriter.create<LLVM::ExtractValueOp>(loc, adaptor.getSrc(), i);
+ LLVM::ExtractValueOp::create(rewriter, loc, adaptor.getSrc(), i);
storeOpOperands.push_back(toUse);
}
@@ -181,8 +181,8 @@ struct WmmaStoreOpToNVVMLowering
rewriter, loc,
cast<MemRefType>(subgroupMmaStoreMatrixOp.getDstMemref().getType()),
adaptor.getDstMemref(), adaptor.getIndices());
- Value leadingDim = rewriter.create<LLVM::ConstantOp>(
- loc, rewriter.getI32Type(),
+ Value leadingDim = LLVM::ConstantOp::create(
+ rewriter, loc, rewriter.getI32Type(),
subgroupMmaStoreMatrixOp.getLeadDimensionAttr());
rewriter.replaceOpWithNewOp<NVVM::WMMAStoreOp>(
op, dataPtr, m, n, k, layout, eltype, storeOpOperands, leadingDim);
@@ -216,7 +216,7 @@ struct WmmaMmaOpToNVVMLowering
auto unpackOp = [&](Value operand) {
auto structType = cast<LLVM::LLVMStructType>(operand.getType());
for (size_t i = 0, e = structType.getBody().size(); i < e; ++i) {
- Value toUse = rewriter.create<LLVM::ExtractValueOp>(loc, operand, i);
+ Value toUse = LLVM::ExtractValueOp::create(rewriter, loc, operand, i);
unpackedOps.push_back(toUse);
}
};
@@ -280,19 +280,19 @@ struct WmmaConstantOpToNVVMLowering
cast<gpu::MMAMatrixType>(subgroupMmaConstantOp.getType()));
// If the element type is a vector create a vector from the operand.
if (auto vecType = dyn_cast<VectorType>(type.getBody()[0])) {
- Value vecCst = rewriter.create<LLVM::PoisonOp>(loc, vecType);
+ Value vecCst = LLVM::PoisonOp::create(rewriter, loc, vecType);
for (int64_t vecEl = 0; vecEl < vecType.getNumElements(); vecEl++) {
- Value idx = rewriter.create<LLVM::ConstantOp>(
- loc, rewriter.getI32Type(), vecEl);
- vecCst = rewriter.create<LLVM::InsertElementOp>(loc, vecType, vecCst,
- cst, idx);
+ Value idx = LLVM::ConstantOp::create(rewriter, loc,
+ rewriter.getI32Type(), vecEl);
+ vecCst = LLVM::InsertElementOp::create(rewriter, loc, vecType, vecCst,
+ cst, idx);
}
cst = vecCst;
}
- Value matrixStruct = rewriter.create<LLVM::PoisonOp>(loc, type);
+ Value matrixStruct = LLVM::PoisonOp::create(rewriter, loc, type);
for (size_t i : llvm::seq(size_t(0), type.getBody().size())) {
matrixStruct =
- rewriter.create<LLVM::InsertValueOp>(loc, matrixStruct, cst, i);
+ LLVM::InsertValueOp::create(rewriter, loc, matrixStruct, cst, i);
}
rewriter.replaceOp(subgroupMmaConstantOp, matrixStruct);
return success();
@@ -305,17 +305,17 @@ static Value createMinMaxF(OpBuilder &builder, Location loc, Value lhs,
Type i1Type = builder.getI1Type();
if (auto vecType = dyn_cast<VectorType>(lhs.getType()))
i1Type = VectorType::get(vecType.getShape(), i1Type);
- Value cmp = builder.create<LLVM::FCmpOp>(
- loc, i1Type, isMin ? LLVM::FCmpPredicate::olt : LLVM::FCmpPredicate::ogt,
- lhs, rhs);
- Value sel = builder.create<LLVM::SelectOp>(loc, cmp, lhs, rhs);
- Value isNan = builder.create<LLVM::FCmpOp>(
- loc, i1Type, LLVM::FCmpPredicate::uno, lhs, rhs);
- Value nan = builder.create<LLVM::ConstantOp>(
- loc, lhs.getType(),
+ Value cmp = LLVM::FCmpOp::create(
+ builder, loc, i1Type,
+ isMin ? LLVM::FCmpPredicate::olt : LLVM::FCmpPredicate::ogt, lhs, rhs);
+ Value sel = LLVM::SelectOp::create(builder, loc, cmp, lhs, rhs);
+ Value isNan = LLVM::FCmpOp::create(builder, loc, i1Type,
+ LLVM::FCmpPredicate::uno, lhs, rhs);
+ Value nan = LLVM::ConstantOp::create(
+ builder, loc, lhs.getType(),
builder.getFloatAttr(floatType,
APFloat::getQNaN(floatType.getFloatSemantics())));
- return builder.create<LLVM::SelectOp>(loc, isNan, nan, sel);
+ return LLVM::SelectOp::create(builder, loc, isNan, nan, sel);
}
static Value createScalarOp(OpBuilder &builder, Location loc,
@@ -323,11 +323,11 @@ static Value createScalarOp(OpBuilder &builder, Location loc,
ArrayRef<Value> operands) {
switch (op) {
case gpu::MMAElementwiseOp::ADDF:
- return builder.create<LLVM::FAddOp>(loc, operands[0].getType(), operands);
+ return LLVM::FAddOp::create(builder, loc, operands[0].getType(), operands);
case gpu::MMAElementwiseOp::MULF:
- return builder.create<LLVM::FMulOp>(loc, operands[0].getType(), operands);
+ return LLVM::FMulOp::create(builder, loc, operands[0].getType(), operands);
case gpu::MMAElementwiseOp::DIVF:
- return builder.create<LLVM::FDivOp>(loc, operands[0].getType(), operands);
+ return LLVM::FDivOp::create(builder, loc, operands[0].getType(), operands);
case gpu::MMAElementwiseOp::MAXF:
return createMinMaxF(builder, loc, operands[0], operands[1],
/*isMin=*/false);
@@ -356,18 +356,18 @@ struct WmmaElementwiseOpToNVVMLowering
size_t numOperands = adaptor.getOperands().size();
LLVM::LLVMStructType destType = convertMMAToLLVMType(
cast<gpu::MMAMatrixType>(subgroupMmaElementwiseOp.getType()));
- Value matrixStruct = rewriter.create<LLVM::PoisonOp>(loc, destType);
+ Value matrixStruct = LLVM::PoisonOp::create(rewriter, loc, destType);
for (size_t i = 0, e = destType.getBody().size(); i < e; ++i) {
SmallVector<Value> extractedOperands;
for (size_t opIdx = 0; opIdx < numOperands; opIdx++) {
- extractedOperands.push_back(rewriter.create<LLVM::ExtractValueOp>(
- loc, adaptor.getOperands()[opIdx], i));
+ extractedOperands.push_back(LLVM::ExtractValueOp::create(
+ rewriter, loc, adaptor.getOperands()[opIdx], i));
}
Value element =
createScalarOp(rewriter, loc, subgroupMmaElementwiseOp.getOpType(),
extractedOperands);
matrixStruct =
- rewriter.create<LLVM::InsertValueOp>(loc, matrixStruct, element, i);
+ LLVM::InsertValueOp::create(rewriter, loc, matrixStruct, element, i);
}
rewriter.replaceOp(subgroupMmaElementwiseOp, matrixStruct);
return success();
diff --git a/mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp b/mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp
index 456bfaba980ca..d22364e1ef441 100644
--- a/mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp
+++ b/mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp
@@ -61,10 +61,10 @@ static Value truncOrExtToLLVMType(ConversionPatternRewriter &rewriter,
IntegerType::get(rewriter.getContext(), converter.getIndexTypeBitwidth());
// TODO: use <=> in C++20.
if (indexBitwidth > intWidth) {
- return rewriter.create<LLVM::SExtOp>(loc, indexBitwidthType, value);
+ return LLVM::SExtOp::create(rewriter, loc, indexBitwidthType, value);
}
if (indexBitwidth < intWidth) {
- return rewriter.create<LLVM::TruncOp>(loc, indexBitwidthType, value);
+ return LLVM::TruncOp::create(rewriter, loc, indexBitwidthType, value);
}
return value;
}
@@ -82,12 +82,12 @@ static bool canBeCalledWithBarePointers(gpu::GPUFuncOp func) {
static Value getLaneId(ConversionPatternRewriter &rewriter, Location loc,
const unsigned indexBitwidth) {
auto int32Type = IntegerType::get(rewriter.getContext(), 32);
- Value zero = rewriter.create<arith::ConstantIntOp>(loc, 0, 32);
- Value minus1 = rewriter.create<arith::ConstantIntOp>(loc, -1, 32);
- Value mbcntLo = rewriter.create<ROCDL::MbcntLoOp>(loc, int32Type,
- ValueRange{minus1, zero});
- Value laneId = rewriter.create<ROCDL::MbcntHiOp>(loc, int32Type,
- ValueRange{minus1, mbcntLo});
+ Value zero = arith::ConstantIntOp::create(rewriter, loc, 0, 32);
+ Value minus1 = arith::ConstantIntOp::create(rewriter, loc, -1, 32);
+ Value mbcntLo = ROCDL::MbcntLoOp::create(rewriter, loc, int32Type,
+ ValueRange{minus1, zero});
+ Value laneId = ROCDL::MbcntHiOp::create(rewriter, loc, int32Type,
+ ValueRange{minus1, mbcntLo});
return laneId;
}
static constexpr StringLiteral amdgcnDataLayout =
@@ -110,21 +110,21 @@ struct GPULaneIdOpToROCDL : ConvertOpToLLVMPattern<gpu::LaneIdOp> {
// followed by: %lid = call @llvm.amdgcn.mbcnt.hi(-1, %mlo)
Type intTy = IntegerType::get(context, 32);
- Value zero = rewriter.create<arith::ConstantIntOp>(loc, 0, 32);
- Value minus1 = rewriter.create<arith::ConstantIntOp>(loc, -1, 32);
- Value mbcntLo =
- rewriter.create<ROCDL::MbcntLoOp>(loc, intTy, ValueRange{minus1, zero});
- Value laneId = rewriter.create<ROCDL::MbcntHiOp>(
- loc, intTy, ValueRange{minus1, mbcntLo});
+ Value zero = arith::ConstantIntOp::create(rewriter, loc, 0, 32);
+ Value minus1 = arith::ConstantIntOp::create(rewriter, loc, -1, 32);
+ Value mbcntLo = ROCDL::MbcntLoOp::create(rewriter, loc, intTy,
+ ValueRange{minus1, zero});
+ Value laneId = ROCDL::MbcntHiOp::create(rewriter, loc, intTy,
+ ValueRange{minus1, mbcntLo});
// Truncate or extend the result depending on the index bitwidth specified
// by the LLVMTypeConverter options.
const unsigned indexBitwidth = getTypeConverter()->getIndexTypeBitwidth();
if (indexBitwidth > 32) {
- laneId = rewriter.create<LLVM::SExtOp>(
- loc, IntegerType::get(context, indexBitwidth), laneId);
+ laneId = LLVM::SExtOp::create(
+ rewriter, loc, IntegerType::get(context, indexBitwidth), laneId);
} else if (indexBitwidth < 32) {
- laneId = rewriter.create<LLVM::TruncOp>(
- loc, IntegerType::get(context, indexBitwidth), laneId);
+ laneId = LLVM::TruncOp::create(
+ rewriter, loc, IntegerType::get(context, indexBitwidth), laneId);
}
rewriter.replaceOp(op, {laneId});
return success();
@@ -149,8 +149,8 @@ struct GPUSubgroupSizeOpToROCDL : ConvertOpToLLVMPattern<gpu::SubgroupSizeOp> {
/*bitWidth=*/32, /*lower=*/isBeforeGfx10 ? 64 : 32,
/*upper=*/op.getUpperBoundAttr().getInt() + 1);
}
- Value wavefrontOp = rewriter.create<ROCDL::WavefrontSizeOp>(
- op.getLoc(), rewriter.getI32Type(), bounds);
+ Value wavefrontOp = ROCDL::WavefrontSizeOp::create(
+ rewriter, op.getLoc(), rewriter.getI32Type(), bounds);
wavefrontOp = truncOrExtToLLVMType(rewriter, op.getLoc(), wavefrontOp,
*getTypeConverter());
rewriter.replaceOp(op, {wavefrontOp});
@@ -190,44 +190,44 @@ struct GPUShuffleOpLowering : public ConvertOpToLLVMPattern<gpu::ShuffleOp> {
auto int32Type = IntegerType::get(rewriter.getContext(), 32);
Value width = adaptor.getWidth();
- Value zero = rewriter.create<LLVM::ConstantOp>(loc, int32Type, 0);
- Value negwidth = rewriter.create<LLVM::SubOp>(loc, int32Type, zero, width);
- Value add = rewriter.create<LLVM::AddOp>(loc, int32Type, srcLaneId, width);
+ Value zero = LLVM::ConstantOp::create(rewriter, loc, int32Type, 0);
+ Value negwidth = LLVM::SubOp::create(rewriter, loc, int32Type, zero, width);
+ Value add = LLVM::AddOp::create(rewriter, loc, int32Type, srcLaneId, width);
Value widthOrZeroIfOutside =
- rewriter.create<LLVM::AndOp>(loc, int32Type, add, negwidth);
+ LLVM::AndOp::create(rewriter, loc, int32Type, add, negwidth);
Value dstLane;
switch (op.getMode()) {
case gpu::ShuffleMode::UP:
- dstLane = rewriter.create<LLVM::SubOp>(loc, int32Type, srcLaneId,
- adaptor.getOffset());
+ dstLane = LLVM::SubOp::create(rewriter, loc, int32Type, srcLaneId,
+ adaptor.getOffset());
break;
case gpu::ShuffleMode::DOWN:
- dstLane = rewriter.create<LLVM::AddOp>(loc, int32Type, srcLaneId,
- adaptor.getOffset());
+ dstLane = LLVM::AddOp::create(rewriter, loc, int32Type, srcLaneId,
+ adaptor.getOffset());
break;
case gpu::ShuffleMode::XOR:
- dstLane = rewriter.create<LLVM::XOrOp>(loc, int32Type, srcLaneId,
- adaptor.getOffset());
+ dstLane = LLVM::XOrOp::create(rewriter, loc, int32Type, srcLaneId,
+ adaptor.getOffset());
break;
case gpu::ShuffleMode::IDX:
dstLane = adaptor.getOffset();
break;
}
- Value isActiveSrcLane = rewriter.create<LLVM::ICmpOp>(
- loc, LLVM::ICmpPredicate::slt, dstLane, widthOrZeroIfOutside);
- Value selectDstLane = rewriter.create<LLVM::SelectOp>(loc, isActiveSrcLane,
- dstLane, srcLaneId);
- Value two = rewriter.create<LLVM::ConstantOp>(loc, int32Type, 2);
+ Value isActiveSrcLane = LLVM::ICmpOp::create(
+ rewriter, loc, LLVM::ICmpPredicate::slt, dstLane, widthOrZeroIfOutside);
+ Value selectDstLane = LLVM::SelectOp::create(rewriter, loc, isActiveSrcLane,
+ dstLane, srcLaneId);
+ Value two = LLVM::ConstantOp::create(rewriter, loc, int32Type, 2);
Value dwordAlignedDstLane =
- rewriter.create<LLVM::ShlOp>(loc, int32Type, selectDstLane, two);
+ LLVM::ShlOp::create(rewriter, loc, int32Type, selectDstLane, two);
SmallVector<Value> decomposed =
LLVM::decomposeValue(rewriter, loc, initShflValue, int32Type);
SmallVector<Value> swizzled;
for (Value v : decomposed) {
- Value res = rewriter.create<ROCDL::DsBpermuteOp>(loc, int32Type,
- dwordAlignedDstLane, v);
+ Value res = ROCDL::DsBpermuteOp::create(rewriter, loc, int32Type,
+ dwordAlignedDstLane, v);
swizzled.emplace_back(res);
}
Value shflValue =
diff --git a/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp b/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp
index b99ed261ecfa3..a19194eb181fb 100644
--- a/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp
+++ b/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp
@@ -169,11 +169,11 @@ LogicalResult LaunchConfigConversion<SourceOp, builtin>::matchAndRewrite(
Value vector =
spirv::getBuiltinVariableValue(op, builtin, builtinType, rewriter);
- Value dim = rewriter.create<spirv::CompositeExtractOp>(
- op.getLoc(), builtinType, vector,
+ Value dim = spirv::CompositeExtractOp::create(
+ rewriter, op.getLoc(), builtinType, vector,
rewriter.getI32ArrayAttr({static_cast<int32_t>(op.getDimension())}));
if (forShader && builtinType != indexType)
- dim = rewriter.create<spirv::UConvertOp>(op.getLoc(), indexType, dim);
+ dim = spirv::UConvertOp::create(rewriter, op.getLoc(), indexType, dim);
rewriter.replaceOp(op, dim);
return success();
}
@@ -198,8 +198,8 @@ SingleDimLaunchConfigConversion<SourceOp, builtin>::matchAndRewrite(
Value builtinValue =
spirv::getBuiltinVariableValue(op, builtin, i32Type, rewriter);
if (i32Type != indexType)
- builtinValue = rewriter.create<spirv::UConvertOp>(op.getLoc(), indexType,
- builtinValue);
+ builtinValue = spirv::UConvertOp::create(rewriter, op.getLoc(), indexType,
+ builtinValue);
rewriter.replaceOp(op, builtinValue);
return success();
}
@@ -257,8 +257,8 @@ lowerAsEntryFunction(gpu::GPUFuncOp funcOp, const TypeConverter &typeConverter,
signatureConverter.addInputs(argType.index(), convertedType);
}
}
- auto newFuncOp = rewriter.create<spirv::FuncOp>(
- funcOp.getLoc(), funcOp.getName(),
+ auto newFuncOp = spirv::FuncOp::create(
+ rewriter, funcOp.getLoc(), funcOp.getName(),
rewriter.getFunctionType(signatureConverter.getConvertedTypes(), {}));
for (const auto &namedAttr : funcOp->getAttrs()) {
if (namedAttr.getName() == funcOp.getFunctionTypeAttrName() ||
@@ -367,8 +367,8 @@ LogicalResult GPUModuleConversion::matchAndRewrite(
// Add a keyword to the module name to avoid symbolic conflict.
std::string spvModuleName = (kSPIRVModule + moduleOp.getName()).str();
- auto spvModule = rewriter.create<spirv::ModuleOp>(
- moduleOp.getLoc(), addressingModel, *memoryModel, std::nullopt,
+ auto spvModule = spirv::ModuleOp::create(
+ rewriter, moduleOp.getLoc(), addressingModel, *memoryModel, std::nullopt,
StringRef(spvModuleName));
// Move the region from the module op into the SPIR-V module.
@@ -452,42 +452,42 @@ LogicalResult GPUShuffleConversion::matchAndRewrite(
switch (shuffleOp.getMode()) {
case gpu::ShuffleMode::XOR: {
- result = rewriter.create<spirv::GroupNonUniformShuffleXorOp>(
- loc, scope, adaptor.getValue(), adaptor.getOffset());
+ result = spirv::GroupNonUniformShuffleXorOp::create(
+ rewriter, loc, scope, adaptor.getValue(), adaptor.getOffset());
validVal = spirv::ConstantOp::getOne(rewriter.getI1Type(),
shuffleOp.getLoc(), rewriter);
break;
}
case gpu::ShuffleMode::IDX: {
- result = rewriter.create<spirv::GroupNonUniformShuffleOp>(
- loc, scope, adaptor.getValue(), adaptor.getOffset());
+ result = spirv::GroupNonUniformShuffleOp::create(
+ rewriter, loc, scope, adaptor.getValue(), adaptor.getOffset());
validVal = spirv::ConstantOp::getOne(rewriter.getI1Type(),
shuffleOp.getLoc(), rewriter);
break;
}
case gpu::ShuffleMode::DOWN: {
- result = rewriter.create<spirv::GroupNonUniformShuffleDownOp>(
- loc, scope, adaptor.getValue(), adaptor.getOffset());
+ result = spirv::GroupNonUniformShuffleDownOp::create(
+ rewriter, loc, scope, adaptor.getValue(), adaptor.getOffset());
- Value laneId = rewriter.create<gpu::LaneIdOp>(loc, widthAttr);
+ Value laneId = gpu::LaneIdOp::create(rewriter, loc, widthAttr);
Value resultLaneId =
- rewriter.create<arith::AddIOp>(loc, laneId, adaptor.getOffset());
- validVal = rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::ult,
- resultLaneId, adaptor.getWidth());
+ arith::AddIOp::create(rewriter, loc, laneId, adaptor.getOffset());
+ validVal = arith::CmpIOp::create(rewriter, loc, arith::CmpIPredicate::ult,
+ resultLaneId, adaptor.getWidth());
break;
}
case gpu::ShuffleMode::UP: {
- result = rewriter.create<spirv::GroupNonUniformShuffleUpOp>(
- loc, scope, adaptor.getValue(), adaptor.getOffset());
+ result = spirv::GroupNonUniformShuffleUpOp::create(
+ rewriter, loc, scope, adaptor.getValue(), adaptor.getOffset());
- Value laneId = rewriter.create<gpu::LaneIdOp>(loc, widthAttr);
+ Value laneId = gpu::LaneIdOp::create(rewriter, loc, widthAttr);
Value resultLaneId =
- rewriter.create<arith::SubIOp>(loc, laneId, adaptor.getOffset());
+ arith::SubIOp::create(rewriter, loc, laneId, adaptor.getOffset());
auto i32Type = rewriter.getIntegerType(32);
- validVal = rewriter.create<arith::CmpIOp>(
- loc, arith::CmpIPredicate::sge, resultLaneId,
- rewriter.create<arith::ConstantOp>(
- loc, i32Type, rewriter.getIntegerAttr(i32Type, 0)));
+ validVal = arith::CmpIOp::create(
+ rewriter, loc, arith::CmpIPredicate::sge, resultLaneId,
+ arith::ConstantOp::create(rewriter, loc, i32Type,
+ rewriter.getIntegerAttr(i32Type, 0)));
break;
}
}
@@ -516,15 +516,16 @@ LogicalResult GPURotateConversion::matchAndRewrite(
Location loc = rotateOp.getLoc();
auto scope = rewriter.getAttr<spirv::ScopeAttr>(spirv::Scope::Subgroup);
- Value rotateResult = rewriter.create<spirv::GroupNonUniformRotateKHROp>(
- loc, scope, adaptor.getValue(), adaptor.getOffset(), adaptor.getWidth());
+ Value rotateResult = spirv::GroupNonUniformRotateKHROp::create(
+ rewriter, loc, scope, adaptor.getValue(), adaptor.getOffset(),
+ adaptor.getWidth());
Value validVal;
if (widthAttr.getValue().getZExtValue() == subgroupSize) {
validVal = spirv::ConstantOp::getOne(rewriter.getI1Type(), loc, rewriter);
} else {
- Value laneId = rewriter.create<gpu::LaneIdOp>(loc, widthAttr);
- validVal = rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::ult,
- laneId, adaptor.getWidth());
+ Value laneId = gpu::LaneIdOp::create(rewriter, loc, widthAttr);
+ validVal = arith::CmpIOp::create(rewriter, loc, arith::CmpIPredicate::ult,
+ laneId, adaptor.getWidth());
}
rewriter.replaceOp(rotateOp, {rotateResult, validVal});
@@ -548,14 +549,14 @@ static Value createGroupReduceOpImpl(OpBuilder &builder, Location loc,
? spirv::GroupOperation::ClusteredReduce
: spirv::GroupOperation::Reduce);
if (isUniform) {
- return builder.create<UniformOp>(loc, type, scope, groupOp, arg)
+ return UniformOp::create(builder, loc, type, scope, groupOp, arg)
.getResult();
}
Value clusterSizeValue;
if (clusterSize.has_value())
- clusterSizeValue = builder.create<spirv::ConstantOp>(
- loc, builder.getI32Type(),
+ clusterSizeValue = spirv::ConstantOp::create(
+ builder, loc, builder.getI32Type(),
builder.getIntegerAttr(builder.getI32Type(), *clusterSize));
return builder
@@ -740,8 +741,8 @@ LogicalResult GPUPrintfConversion::matchAndRewrite(
std::string specCstName =
makeVarName(moduleOp, llvm::Twine(globalVarName) + "_sc");
- return rewriter.create<spirv::SpecConstantOp>(
- loc, rewriter.getStringAttr(specCstName), attr);
+ return spirv::SpecConstantOp::create(
+ rewriter, loc, rewriter.getStringAttr(specCstName), attr);
};
{
Operation *parent =
@@ -774,8 +775,8 @@ LogicalResult GPUPrintfConversion::matchAndRewrite(
std::string specCstCompositeName =
(llvm::Twine(globalVarName) + "_scc").str();
- specCstComposite = rewriter.create<spirv::SpecConstantCompositeOp>(
- loc, TypeAttr::get(globalType),
+ specCstComposite = spirv::SpecConstantCompositeOp::create(
+ rewriter, loc, TypeAttr::get(globalType),
rewriter.getStringAttr(specCstCompositeName),
rewriter.getArrayAttr(constituents));
@@ -785,23 +786,24 @@ LogicalResult GPUPrintfConversion::matchAndRewrite(
// Define a GlobalVarOp initialized using specialized constants
// that is used to specify the printf format string
// to be passed to the SPIRV CLPrintfOp.
- globalVar = rewriter.create<spirv::GlobalVariableOp>(
- loc, ptrType, globalVarName, FlatSymbolRefAttr::get(specCstComposite));
+ globalVar = spirv::GlobalVariableOp::create(
+ rewriter, loc, ptrType, globalVarName,
+ FlatSymbolRefAttr::get(specCstComposite));
globalVar->setAttr("Constant", rewriter.getUnitAttr());
}
// Get SSA value of Global variable and create pointer to i8 to point to
// the format string.
- Value globalPtr = rewriter.create<spirv::AddressOfOp>(loc, globalVar);
- Value fmtStr = rewriter.create<spirv::BitcastOp>(
- loc,
+ Value globalPtr = spirv::AddressOfOp::create(rewriter, loc, globalVar);
+ Value fmtStr = spirv::BitcastOp::create(
+ rewriter, loc,
spirv::PointerType::get(i8Type, spirv::StorageClass::UniformConstant),
globalPtr);
// Get printf arguments.
auto printfArgs = llvm::to_vector_of<Value, 4>(adaptor.getArgs());
- rewriter.create<spirv::CLPrintfOp>(loc, i32Type, fmtStr, printfArgs);
+ spirv::CLPrintfOp::create(rewriter, loc, i32Type, fmtStr, printfArgs);
// Need to erase the gpu.printf op as gpu.printf does not use result vs
// spirv::CLPrintfOp has i32 resultType so cannot replace with new SPIR-V
diff --git a/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRVPass.cpp b/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRVPass.cpp
index 0b2c06a08db2d..a344f88326089 100644
--- a/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRVPass.cpp
+++ b/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRVPass.cpp
@@ -144,11 +144,12 @@ void GPUToSPIRVPass::runOnOperation() {
if (targetEnvSupportsKernelCapability(moduleOp)) {
moduleOp.walk([&](gpu::GPUFuncOp funcOp) {
builder.setInsertionPoint(funcOp);
- auto newFuncOp = builder.create<func::FuncOp>(
- funcOp.getLoc(), funcOp.getName(), funcOp.getFunctionType());
+ auto newFuncOp =
+ func::FuncOp::create(builder, funcOp.getLoc(), funcOp.getName(),
+ funcOp.getFunctionType());
auto entryBlock = newFuncOp.addEntryBlock();
builder.setInsertionPointToEnd(entryBlock);
- builder.create<func::ReturnOp>(funcOp.getLoc());
+ func::ReturnOp::create(builder, funcOp.getLoc());
newFuncOp->setAttr(gpu::GPUDialect::getKernelFuncAttrName(),
builder.getUnitAttr());
funcOp.erase();
diff --git a/mlir/lib/Conversion/GPUToSPIRV/WmmaOpsToSPIRV.cpp b/mlir/lib/Conversion/GPUToSPIRV/WmmaOpsToSPIRV.cpp
index 7bb86b5ce1ddd..51dc50048024f 100644
--- a/mlir/lib/Conversion/GPUToSPIRV/WmmaOpsToSPIRV.cpp
+++ b/mlir/lib/Conversion/GPUToSPIRV/WmmaOpsToSPIRV.cpp
@@ -283,8 +283,8 @@ struct WmmaLoadOpToSPIRVLowering final
int64_t stride = op.getLeadDimension().getSExtValue();
IntegerType i32Type = rewriter.getI32Type();
- auto strideValue = rewriter.create<spirv::ConstantOp>(
- loc, i32Type, IntegerAttr::get(i32Type, stride));
+ auto strideValue = spirv::ConstantOp::create(
+ rewriter, loc, i32Type, IntegerAttr::get(i32Type, stride));
bool isColMajor = op.getTranspose().value_or(false);
auto layout = isColMajor ? spirv::CooperativeMatrixLayoutKHR::ColumnMajor
@@ -315,8 +315,8 @@ struct WmmaStoreOpToSPIRVLowering final
int64_t stride = op.getLeadDimension().getSExtValue();
IntegerType i32Type = rewriter.getI32Type();
- auto strideValue = rewriter.create<spirv::ConstantOp>(
- loc, i32Type, IntegerAttr::get(i32Type, stride));
+ auto strideValue = spirv::ConstantOp::create(
+ rewriter, loc, i32Type, IntegerAttr::get(i32Type, stride));
bool isColMajor = op.getTranspose().value_or(false);
auto layout = isColMajor ? spirv::CooperativeMatrixLayoutKHR::ColumnMajor
diff --git a/mlir/lib/Conversion/IndexToLLVM/IndexToLLVM.cpp b/mlir/lib/Conversion/IndexToLLVM/IndexToLLVM.cpp
index 0473bb59fa6aa..99d2f6ca78c38 100644
--- a/mlir/lib/Conversion/IndexToLLVM/IndexToLLVM.cpp
+++ b/mlir/lib/Conversion/IndexToLLVM/IndexToLLVM.cpp
@@ -36,34 +36,34 @@ struct ConvertIndexCeilDivS : mlir::ConvertOpToLLVMPattern<CeilDivSOp> {
Location loc = op.getLoc();
Value n = adaptor.getLhs();
Value m = adaptor.getRhs();
- Value zero = rewriter.create<LLVM::ConstantOp>(loc, n.getType(), 0);
- Value posOne = rewriter.create<LLVM::ConstantOp>(loc, n.getType(), 1);
- Value negOne = rewriter.create<LLVM::ConstantOp>(loc, n.getType(), -1);
+ Value zero = LLVM::ConstantOp::create(rewriter, loc, n.getType(), 0);
+ Value posOne = LLVM::ConstantOp::create(rewriter, loc, n.getType(), 1);
+ Value negOne = LLVM::ConstantOp::create(rewriter, loc, n.getType(), -1);
// Compute `x`.
Value mPos =
- rewriter.create<LLVM::ICmpOp>(loc, LLVM::ICmpPredicate::sgt, m, zero);
- Value x = rewriter.create<LLVM::SelectOp>(loc, mPos, negOne, posOne);
+ LLVM::ICmpOp::create(rewriter, loc, LLVM::ICmpPredicate::sgt, m, zero);
+ Value x = LLVM::SelectOp::create(rewriter, loc, mPos, negOne, posOne);
// Compute the positive result.
- Value nPlusX = rewriter.create<LLVM::AddOp>(loc, n, x);
- Value nPlusXDivM = rewriter.create<LLVM::SDivOp>(loc, nPlusX, m);
- Value posRes = rewriter.create<LLVM::AddOp>(loc, nPlusXDivM, posOne);
+ Value nPlusX = LLVM::AddOp::create(rewriter, loc, n, x);
+ Value nPlusXDivM = LLVM::SDivOp::create(rewriter, loc, nPlusX, m);
+ Value posRes = LLVM::AddOp::create(rewriter, loc, nPlusXDivM, posOne);
// Compute the negative result.
- Value negN = rewriter.create<LLVM::SubOp>(loc, zero, n);
- Value negNDivM = rewriter.create<LLVM::SDivOp>(loc, negN, m);
- Value negRes = rewriter.create<LLVM::SubOp>(loc, zero, negNDivM);
+ Value negN = LLVM::SubOp::create(rewriter, loc, zero, n);
+ Value negNDivM = LLVM::SDivOp::create(rewriter, loc, negN, m);
+ Value negRes = LLVM::SubOp::create(rewriter, loc, zero, negNDivM);
// Pick the positive result if `n` and `m` have the same sign and `n` is
// non-zero, i.e. `(n > 0) == (m > 0) && n != 0`.
Value nPos =
- rewriter.create<LLVM::ICmpOp>(loc, LLVM::ICmpPredicate::sgt, n, zero);
- Value sameSign =
- rewriter.create<LLVM::ICmpOp>(loc, LLVM::ICmpPredicate::eq, nPos, mPos);
+ LLVM::ICmpOp::create(rewriter, loc, LLVM::ICmpPredicate::sgt, n, zero);
+ Value sameSign = LLVM::ICmpOp::create(rewriter, loc,
+ LLVM::ICmpPredicate::eq, nPos, mPos);
Value nNonZero =
- rewriter.create<LLVM::ICmpOp>(loc, LLVM::ICmpPredicate::ne, n, zero);
- Value cmp = rewriter.create<LLVM::AndOp>(loc, sameSign, nNonZero);
+ LLVM::ICmpOp::create(rewriter, loc, LLVM::ICmpPredicate::ne, n, zero);
+ Value cmp = LLVM::AndOp::create(rewriter, loc, sameSign, nNonZero);
rewriter.replaceOpWithNewOp<LLVM::SelectOp>(op, cmp, posRes, negRes);
return success();
}
@@ -83,17 +83,17 @@ struct ConvertIndexCeilDivU : mlir::ConvertOpToLLVMPattern<CeilDivUOp> {
Location loc = op.getLoc();
Value n = adaptor.getLhs();
Value m = adaptor.getRhs();
- Value zero = rewriter.create<LLVM::ConstantOp>(loc, n.getType(), 0);
- Value one = rewriter.create<LLVM::ConstantOp>(loc, n.getType(), 1);
+ Value zero = LLVM::ConstantOp::create(rewriter, loc, n.getType(), 0);
+ Value one = LLVM::ConstantOp::create(rewriter, loc, n.getType(), 1);
// Compute the non-zero result.
- Value minusOne = rewriter.create<LLVM::SubOp>(loc, n, one);
- Value quotient = rewriter.create<LLVM::UDivOp>(loc, minusOne, m);
- Value plusOne = rewriter.create<LLVM::AddOp>(loc, quotient, one);
+ Value minusOne = LLVM::SubOp::create(rewriter, loc, n, one);
+ Value quotient = LLVM::UDivOp::create(rewriter, loc, minusOne, m);
+ Value plusOne = LLVM::AddOp::create(rewriter, loc, quotient, one);
// Pick the result.
Value cmp =
- rewriter.create<LLVM::ICmpOp>(loc, LLVM::ICmpPredicate::eq, n, zero);
+ LLVM::ICmpOp::create(rewriter, loc, LLVM::ICmpPredicate::eq, n, zero);
rewriter.replaceOpWithNewOp<LLVM::SelectOp>(op, cmp, zero, plusOne);
return success();
}
@@ -114,32 +114,32 @@ struct ConvertIndexFloorDivS : mlir::ConvertOpToLLVMPattern<FloorDivSOp> {
Location loc = op.getLoc();
Value n = adaptor.getLhs();
Value m = adaptor.getRhs();
- Value zero = rewriter.create<LLVM::ConstantOp>(loc, n.getType(), 0);
- Value posOne = rewriter.create<LLVM::ConstantOp>(loc, n.getType(), 1);
- Value negOne = rewriter.create<LLVM::ConstantOp>(loc, n.getType(), -1);
+ Value zero = LLVM::ConstantOp::create(rewriter, loc, n.getType(), 0);
+ Value posOne = LLVM::ConstantOp::create(rewriter, loc, n.getType(), 1);
+ Value negOne = LLVM::ConstantOp::create(rewriter, loc, n.getType(), -1);
// Compute `x`.
Value mNeg =
- rewriter.create<LLVM::ICmpOp>(loc, LLVM::ICmpPredicate::slt, m, zero);
- Value x = rewriter.create<LLVM::SelectOp>(loc, mNeg, posOne, negOne);
+ LLVM::ICmpOp::create(rewriter, loc, LLVM::ICmpPredicate::slt, m, zero);
+ Value x = LLVM::SelectOp::create(rewriter, loc, mNeg, posOne, negOne);
// Compute the negative result.
- Value xMinusN = rewriter.create<LLVM::SubOp>(loc, x, n);
- Value xMinusNDivM = rewriter.create<LLVM::SDivOp>(loc, xMinusN, m);
- Value negRes = rewriter.create<LLVM::SubOp>(loc, negOne, xMinusNDivM);
+ Value xMinusN = LLVM::SubOp::create(rewriter, loc, x, n);
+ Value xMinusNDivM = LLVM::SDivOp::create(rewriter, loc, xMinusN, m);
+ Value negRes = LLVM::SubOp::create(rewriter, loc, negOne, xMinusNDivM);
// Compute the positive result.
- Value posRes = rewriter.create<LLVM::SDivOp>(loc, n, m);
+ Value posRes = LLVM::SDivOp::create(rewriter, loc, n, m);
// Pick the negative result if `n` and `m` have different signs and `n` is
// non-zero, i.e. `(n < 0) != (m < 0) && n != 0`.
Value nNeg =
- rewriter.create<LLVM::ICmpOp>(loc, LLVM::ICmpPredicate::slt, n, zero);
- Value diffSign =
- rewriter.create<LLVM::ICmpOp>(loc, LLVM::ICmpPredicate::ne, nNeg, mNeg);
+ LLVM::ICmpOp::create(rewriter, loc, LLVM::ICmpPredicate::slt, n, zero);
+ Value diffSign = LLVM::ICmpOp::create(rewriter, loc,
+ LLVM::ICmpPredicate::ne, nNeg, mNeg);
Value nNonZero =
- rewriter.create<LLVM::ICmpOp>(loc, LLVM::ICmpPredicate::ne, n, zero);
- Value cmp = rewriter.create<LLVM::AndOp>(loc, diffSign, nNonZero);
+ LLVM::ICmpOp::create(rewriter, loc, LLVM::ICmpPredicate::ne, n, zero);
+ Value cmp = LLVM::AndOp::create(rewriter, loc, diffSign, nNonZero);
rewriter.replaceOpWithNewOp<LLVM::SelectOp>(op, cmp, negRes, posRes);
return success();
}
diff --git a/mlir/lib/Conversion/IndexToSPIRV/IndexToSPIRV.cpp b/mlir/lib/Conversion/IndexToSPIRV/IndexToSPIRV.cpp
index 4821962f989e6..36cfe9dd6e2db 100644
--- a/mlir/lib/Conversion/IndexToSPIRV/IndexToSPIRV.cpp
+++ b/mlir/lib/Conversion/IndexToSPIRV/IndexToSPIRV.cpp
@@ -111,33 +111,33 @@ struct ConvertIndexCeilDivSPattern final : OpConversionPattern<CeilDivSOp> {
Value m = adaptor.getRhs();
// Define the constants
- Value zero = rewriter.create<spirv::ConstantOp>(
- loc, n_type, IntegerAttr::get(n_type, 0));
- Value posOne = rewriter.create<spirv::ConstantOp>(
- loc, n_type, IntegerAttr::get(n_type, 1));
- Value negOne = rewriter.create<spirv::ConstantOp>(
- loc, n_type, IntegerAttr::get(n_type, -1));
+ Value zero = spirv::ConstantOp::create(rewriter, loc, n_type,
+ IntegerAttr::get(n_type, 0));
+ Value posOne = spirv::ConstantOp::create(rewriter, loc, n_type,
+ IntegerAttr::get(n_type, 1));
+ Value negOne = spirv::ConstantOp::create(rewriter, loc, n_type,
+ IntegerAttr::get(n_type, -1));
// Compute `x`.
- Value mPos = rewriter.create<spirv::SGreaterThanOp>(loc, m, zero);
- Value x = rewriter.create<spirv::SelectOp>(loc, mPos, negOne, posOne);
+ Value mPos = spirv::SGreaterThanOp::create(rewriter, loc, m, zero);
+ Value x = spirv::SelectOp::create(rewriter, loc, mPos, negOne, posOne);
// Compute the positive result.
- Value nPlusX = rewriter.create<spirv::IAddOp>(loc, n, x);
- Value nPlusXDivM = rewriter.create<spirv::SDivOp>(loc, nPlusX, m);
- Value posRes = rewriter.create<spirv::IAddOp>(loc, nPlusXDivM, posOne);
+ Value nPlusX = spirv::IAddOp::create(rewriter, loc, n, x);
+ Value nPlusXDivM = spirv::SDivOp::create(rewriter, loc, nPlusX, m);
+ Value posRes = spirv::IAddOp::create(rewriter, loc, nPlusXDivM, posOne);
// Compute the negative result.
- Value negN = rewriter.create<spirv::ISubOp>(loc, zero, n);
- Value negNDivM = rewriter.create<spirv::SDivOp>(loc, negN, m);
- Value negRes = rewriter.create<spirv::ISubOp>(loc, zero, negNDivM);
+ Value negN = spirv::ISubOp::create(rewriter, loc, zero, n);
+ Value negNDivM = spirv::SDivOp::create(rewriter, loc, negN, m);
+ Value negRes = spirv::ISubOp::create(rewriter, loc, zero, negNDivM);
// Pick the positive result if `n` and `m` have the same sign and `n` is
// non-zero, i.e. `(n > 0) == (m > 0) && n != 0`.
- Value nPos = rewriter.create<spirv::SGreaterThanOp>(loc, n, zero);
- Value sameSign = rewriter.create<spirv::LogicalEqualOp>(loc, nPos, mPos);
- Value nNonZero = rewriter.create<spirv::INotEqualOp>(loc, n, zero);
- Value cmp = rewriter.create<spirv::LogicalAndOp>(loc, sameSign, nNonZero);
+ Value nPos = spirv::SGreaterThanOp::create(rewriter, loc, n, zero);
+ Value sameSign = spirv::LogicalEqualOp::create(rewriter, loc, nPos, mPos);
+ Value nNonZero = spirv::INotEqualOp::create(rewriter, loc, n, zero);
+ Value cmp = spirv::LogicalAndOp::create(rewriter, loc, sameSign, nNonZero);
rewriter.replaceOpWithNewOp<spirv::SelectOp>(op, cmp, posRes, negRes);
return success();
}
@@ -161,18 +161,18 @@ struct ConvertIndexCeilDivUPattern final : OpConversionPattern<CeilDivUOp> {
Value m = adaptor.getRhs();
// Define the constants
- Value zero = rewriter.create<spirv::ConstantOp>(
- loc, n_type, IntegerAttr::get(n_type, 0));
- Value one = rewriter.create<spirv::ConstantOp>(loc, n_type,
- IntegerAttr::get(n_type, 1));
+ Value zero = spirv::ConstantOp::create(rewriter, loc, n_type,
+ IntegerAttr::get(n_type, 0));
+ Value one = spirv::ConstantOp::create(rewriter, loc, n_type,
+ IntegerAttr::get(n_type, 1));
// Compute the non-zero result.
- Value minusOne = rewriter.create<spirv::ISubOp>(loc, n, one);
- Value quotient = rewriter.create<spirv::UDivOp>(loc, minusOne, m);
- Value plusOne = rewriter.create<spirv::IAddOp>(loc, quotient, one);
+ Value minusOne = spirv::ISubOp::create(rewriter, loc, n, one);
+ Value quotient = spirv::UDivOp::create(rewriter, loc, minusOne, m);
+ Value plusOne = spirv::IAddOp::create(rewriter, loc, quotient, one);
// Pick the result
- Value cmp = rewriter.create<spirv::IEqualOp>(loc, n, zero);
+ Value cmp = spirv::IEqualOp::create(rewriter, loc, n, zero);
rewriter.replaceOpWithNewOp<spirv::SelectOp>(op, cmp, zero, plusOne);
return success();
}
@@ -197,32 +197,33 @@ struct ConvertIndexFloorDivSPattern final : OpConversionPattern<FloorDivSOp> {
Value m = adaptor.getRhs();
// Define the constants
- Value zero = rewriter.create<spirv::ConstantOp>(
- loc, n_type, IntegerAttr::get(n_type, 0));
- Value posOne = rewriter.create<spirv::ConstantOp>(
- loc, n_type, IntegerAttr::get(n_type, 1));
- Value negOne = rewriter.create<spirv::ConstantOp>(
- loc, n_type, IntegerAttr::get(n_type, -1));
+ Value zero = spirv::ConstantOp::create(rewriter, loc, n_type,
+ IntegerAttr::get(n_type, 0));
+ Value posOne = spirv::ConstantOp::create(rewriter, loc, n_type,
+ IntegerAttr::get(n_type, 1));
+ Value negOne = spirv::ConstantOp::create(rewriter, loc, n_type,
+ IntegerAttr::get(n_type, -1));
// Compute `x`.
- Value mNeg = rewriter.create<spirv::SLessThanOp>(loc, m, zero);
- Value x = rewriter.create<spirv::SelectOp>(loc, mNeg, posOne, negOne);
+ Value mNeg = spirv::SLessThanOp::create(rewriter, loc, m, zero);
+ Value x = spirv::SelectOp::create(rewriter, loc, mNeg, posOne, negOne);
// Compute the negative result
- Value xMinusN = rewriter.create<spirv::ISubOp>(loc, x, n);
- Value xMinusNDivM = rewriter.create<spirv::SDivOp>(loc, xMinusN, m);
- Value negRes = rewriter.create<spirv::ISubOp>(loc, negOne, xMinusNDivM);
+ Value xMinusN = spirv::ISubOp::create(rewriter, loc, x, n);
+ Value xMinusNDivM = spirv::SDivOp::create(rewriter, loc, xMinusN, m);
+ Value negRes = spirv::ISubOp::create(rewriter, loc, negOne, xMinusNDivM);
// Compute the positive result.
- Value posRes = rewriter.create<spirv::SDivOp>(loc, n, m);
+ Value posRes = spirv::SDivOp::create(rewriter, loc, n, m);
// Pick the negative result if `n` and `m` have different signs and `n` is
// non-zero, i.e. `(n < 0) != (m < 0) && n != 0`.
- Value nNeg = rewriter.create<spirv::SLessThanOp>(loc, n, zero);
- Value diffSign = rewriter.create<spirv::LogicalNotEqualOp>(loc, nNeg, mNeg);
- Value nNonZero = rewriter.create<spirv::INotEqualOp>(loc, n, zero);
+ Value nNeg = spirv::SLessThanOp::create(rewriter, loc, n, zero);
+ Value diffSign =
+ spirv::LogicalNotEqualOp::create(rewriter, loc, nNeg, mNeg);
+ Value nNonZero = spirv::INotEqualOp::create(rewriter, loc, n, zero);
- Value cmp = rewriter.create<spirv::LogicalAndOp>(loc, diffSign, nNonZero);
+ Value cmp = spirv::LogicalAndOp::create(rewriter, loc, diffSign, nNonZero);
rewriter.replaceOpWithNewOp<spirv::SelectOp>(op, cmp, posRes, negRes);
return success();
}
diff --git a/mlir/lib/Conversion/LLVMCommon/MemRefBuilder.cpp b/mlir/lib/Conversion/LLVMCommon/MemRefBuilder.cpp
index e34d5f74d232f..fce7a3f324b86 100644
--- a/mlir/lib/Conversion/LLVMCommon/MemRefBuilder.cpp
+++ b/mlir/lib/Conversion/LLVMCommon/MemRefBuilder.cpp
@@ -32,7 +32,7 @@ MemRefDescriptor::MemRefDescriptor(Value descriptor)
MemRefDescriptor MemRefDescriptor::poison(OpBuilder &builder, Location loc,
Type descriptorType) {
- Value descriptor = builder.create<LLVM::PoisonOp>(loc, descriptorType);
+ Value descriptor = LLVM::PoisonOp::create(builder, loc, descriptorType);
return MemRefDescriptor(descriptor);
}
@@ -99,21 +99,21 @@ void MemRefDescriptor::setAlignedPtr(OpBuilder &builder, Location loc,
// integer attribute.
static Value createIndexAttrConstant(OpBuilder &builder, Location loc,
Type resultType, int64_t value) {
- return builder.create<LLVM::ConstantOp>(loc, resultType,
- builder.getIndexAttr(value));
+ return LLVM::ConstantOp::create(builder, loc, resultType,
+ builder.getIndexAttr(value));
}
/// Builds IR extracting the offset from the descriptor.
Value MemRefDescriptor::offset(OpBuilder &builder, Location loc) {
- return builder.create<LLVM::ExtractValueOp>(loc, value,
- kOffsetPosInMemRefDescriptor);
+ return LLVM::ExtractValueOp::create(builder, loc, value,
+ kOffsetPosInMemRefDescriptor);
}
/// Builds IR inserting the offset into the descriptor.
void MemRefDescriptor::setOffset(OpBuilder &builder, Location loc,
Value offset) {
- value = builder.create<LLVM::InsertValueOp>(loc, value, offset,
- kOffsetPosInMemRefDescriptor);
+ value = LLVM::InsertValueOp::create(builder, loc, value, offset,
+ kOffsetPosInMemRefDescriptor);
}
/// Builds IR inserting the offset into the descriptor.
@@ -125,8 +125,9 @@ void MemRefDescriptor::setConstantOffset(OpBuilder &builder, Location loc,
/// Builds IR extracting the pos-th size from the descriptor.
Value MemRefDescriptor::size(OpBuilder &builder, Location loc, unsigned pos) {
- return builder.create<LLVM::ExtractValueOp>(
- loc, value, ArrayRef<int64_t>({kSizePosInMemRefDescriptor, pos}));
+ return LLVM::ExtractValueOp::create(
+ builder, loc, value,
+ ArrayRef<int64_t>({kSizePosInMemRefDescriptor, pos}));
}
Value MemRefDescriptor::size(OpBuilder &builder, Location loc, Value pos,
@@ -137,23 +138,25 @@ Value MemRefDescriptor::size(OpBuilder &builder, Location loc, Value pos,
// Copy size values to stack-allocated memory.
auto one = createIndexAttrConstant(builder, loc, indexType, 1);
- auto sizes = builder.create<LLVM::ExtractValueOp>(
- loc, value, llvm::ArrayRef<int64_t>({kSizePosInMemRefDescriptor}));
- auto sizesPtr = builder.create<LLVM::AllocaOp>(loc, ptrTy, arrayTy, one,
- /*alignment=*/0);
- builder.create<LLVM::StoreOp>(loc, sizes, sizesPtr);
+ auto sizes = LLVM::ExtractValueOp::create(
+ builder, loc, value,
+ llvm::ArrayRef<int64_t>({kSizePosInMemRefDescriptor}));
+ auto sizesPtr = LLVM::AllocaOp::create(builder, loc, ptrTy, arrayTy, one,
+ /*alignment=*/0);
+ LLVM::StoreOp::create(builder, loc, sizes, sizesPtr);
// Load an return size value of interest.
- auto resultPtr = builder.create<LLVM::GEPOp>(loc, ptrTy, arrayTy, sizesPtr,
- ArrayRef<LLVM::GEPArg>{0, pos});
- return builder.create<LLVM::LoadOp>(loc, indexType, resultPtr);
+ auto resultPtr = LLVM::GEPOp::create(builder, loc, ptrTy, arrayTy, sizesPtr,
+ ArrayRef<LLVM::GEPArg>{0, pos});
+ return LLVM::LoadOp::create(builder, loc, indexType, resultPtr);
}
/// Builds IR inserting the pos-th size into the descriptor
void MemRefDescriptor::setSize(OpBuilder &builder, Location loc, unsigned pos,
Value size) {
- value = builder.create<LLVM::InsertValueOp>(
- loc, value, size, ArrayRef<int64_t>({kSizePosInMemRefDescriptor, pos}));
+ value = LLVM::InsertValueOp::create(
+ builder, loc, value, size,
+ ArrayRef<int64_t>({kSizePosInMemRefDescriptor, pos}));
}
void MemRefDescriptor::setConstantSize(OpBuilder &builder, Location loc,
@@ -164,15 +167,16 @@ void MemRefDescriptor::setConstantSize(OpBuilder &builder, Location loc,
/// Builds IR extracting the pos-th stride from the descriptor.
Value MemRefDescriptor::stride(OpBuilder &builder, Location loc, unsigned pos) {
- return builder.create<LLVM::ExtractValueOp>(
- loc, value, ArrayRef<int64_t>({kStridePosInMemRefDescriptor, pos}));
+ return LLVM::ExtractValueOp::create(
+ builder, loc, value,
+ ArrayRef<int64_t>({kStridePosInMemRefDescriptor, pos}));
}
/// Builds IR inserting the pos-th stride into the descriptor
void MemRefDescriptor::setStride(OpBuilder &builder, Location loc, unsigned pos,
Value stride) {
- value = builder.create<LLVM::InsertValueOp>(
- loc, value, stride,
+ value = LLVM::InsertValueOp::create(
+ builder, loc, value, stride,
ArrayRef<int64_t>({kStridePosInMemRefDescriptor, pos}));
}
@@ -207,8 +211,8 @@ Value MemRefDescriptor::bufferPtr(OpBuilder &builder, Location loc,
? offset(builder, loc)
: createIndexAttrConstant(builder, loc, indexType, offsetCst);
Type elementType = converter.convertType(type.getElementType());
- ptr = builder.create<LLVM::GEPOp>(loc, ptr.getType(), elementType, ptr,
- offsetVal);
+ ptr = LLVM::GEPOp::create(builder, loc, ptr.getType(), elementType, ptr,
+ offsetVal);
return ptr;
}
@@ -303,7 +307,7 @@ UnrankedMemRefDescriptor::UnrankedMemRefDescriptor(Value descriptor)
UnrankedMemRefDescriptor UnrankedMemRefDescriptor::poison(OpBuilder &builder,
Location loc,
Type descriptorType) {
- Value descriptor = builder.create<LLVM::PoisonOp>(loc, descriptorType);
+ Value descriptor = LLVM::PoisonOp::create(builder, loc, descriptorType);
return UnrankedMemRefDescriptor(descriptor);
}
Value UnrankedMemRefDescriptor::rank(OpBuilder &builder, Location loc) const {
@@ -380,19 +384,19 @@ void UnrankedMemRefDescriptor::computeSizes(
builder, loc, indexType,
llvm::divideCeil(typeConverter.getPointerBitwidth(addressSpace), 8));
Value doublePointerSize =
- builder.create<LLVM::MulOp>(loc, indexType, two, pointerSize);
+ LLVM::MulOp::create(builder, loc, indexType, two, pointerSize);
// (1 + 2 * rank) * sizeof(index)
Value rank = desc.rank(builder, loc);
- Value doubleRank = builder.create<LLVM::MulOp>(loc, indexType, two, rank);
+ Value doubleRank = LLVM::MulOp::create(builder, loc, indexType, two, rank);
Value doubleRankIncremented =
- builder.create<LLVM::AddOp>(loc, indexType, doubleRank, one);
- Value rankIndexSize = builder.create<LLVM::MulOp>(
- loc, indexType, doubleRankIncremented, indexSize);
+ LLVM::AddOp::create(builder, loc, indexType, doubleRank, one);
+ Value rankIndexSize = LLVM::MulOp::create(builder, loc, indexType,
+ doubleRankIncremented, indexSize);
// Total allocation size.
- Value allocationSize = builder.create<LLVM::AddOp>(
- loc, indexType, doublePointerSize, rankIndexSize);
+ Value allocationSize = LLVM::AddOp::create(
+ builder, loc, indexType, doublePointerSize, rankIndexSize);
sizes.push_back(allocationSize);
}
}
@@ -400,13 +404,13 @@ void UnrankedMemRefDescriptor::computeSizes(
Value UnrankedMemRefDescriptor::allocatedPtr(
OpBuilder &builder, Location loc, Value memRefDescPtr,
LLVM::LLVMPointerType elemPtrType) {
- return builder.create<LLVM::LoadOp>(loc, elemPtrType, memRefDescPtr);
+ return LLVM::LoadOp::create(builder, loc, elemPtrType, memRefDescPtr);
}
void UnrankedMemRefDescriptor::setAllocatedPtr(
OpBuilder &builder, Location loc, Value memRefDescPtr,
LLVM::LLVMPointerType elemPtrType, Value allocatedPtr) {
- builder.create<LLVM::StoreOp>(loc, allocatedPtr, memRefDescPtr);
+ LLVM::StoreOp::create(builder, loc, allocatedPtr, memRefDescPtr);
}
static std::pair<Value, Type>
@@ -423,9 +427,9 @@ Value UnrankedMemRefDescriptor::alignedPtr(
castToElemPtrPtr(builder, loc, memRefDescPtr, elemPtrType);
Value alignedGep =
- builder.create<LLVM::GEPOp>(loc, elemPtrPtrType, elemPtrType,
- elementPtrPtr, ArrayRef<LLVM::GEPArg>{1});
- return builder.create<LLVM::LoadOp>(loc, elemPtrType, alignedGep);
+ LLVM::GEPOp::create(builder, loc, elemPtrPtrType, elemPtrType,
+ elementPtrPtr, ArrayRef<LLVM::GEPArg>{1});
+ return LLVM::LoadOp::create(builder, loc, elemPtrType, alignedGep);
}
void UnrankedMemRefDescriptor::setAlignedPtr(
@@ -435,9 +439,9 @@ void UnrankedMemRefDescriptor::setAlignedPtr(
castToElemPtrPtr(builder, loc, memRefDescPtr, elemPtrType);
Value alignedGep =
- builder.create<LLVM::GEPOp>(loc, elemPtrPtrType, elemPtrType,
- elementPtrPtr, ArrayRef<LLVM::GEPArg>{1});
- builder.create<LLVM::StoreOp>(loc, alignedPtr, alignedGep);
+ LLVM::GEPOp::create(builder, loc, elemPtrPtrType, elemPtrType,
+ elementPtrPtr, ArrayRef<LLVM::GEPArg>{1});
+ LLVM::StoreOp::create(builder, loc, alignedPtr, alignedGep);
}
Value UnrankedMemRefDescriptor::offsetBasePtr(
@@ -446,8 +450,8 @@ Value UnrankedMemRefDescriptor::offsetBasePtr(
auto [elementPtrPtr, elemPtrPtrType] =
castToElemPtrPtr(builder, loc, memRefDescPtr, elemPtrType);
- return builder.create<LLVM::GEPOp>(loc, elemPtrPtrType, elemPtrType,
- elementPtrPtr, ArrayRef<LLVM::GEPArg>{2});
+ return LLVM::GEPOp::create(builder, loc, elemPtrPtrType, elemPtrType,
+ elementPtrPtr, ArrayRef<LLVM::GEPArg>{2});
}
Value UnrankedMemRefDescriptor::offset(OpBuilder &builder, Location loc,
@@ -456,8 +460,8 @@ Value UnrankedMemRefDescriptor::offset(OpBuilder &builder, Location loc,
LLVM::LLVMPointerType elemPtrType) {
Value offsetPtr =
offsetBasePtr(builder, loc, typeConverter, memRefDescPtr, elemPtrType);
- return builder.create<LLVM::LoadOp>(loc, typeConverter.getIndexType(),
- offsetPtr);
+ return LLVM::LoadOp::create(builder, loc, typeConverter.getIndexType(),
+ offsetPtr);
}
void UnrankedMemRefDescriptor::setOffset(OpBuilder &builder, Location loc,
@@ -467,7 +471,7 @@ void UnrankedMemRefDescriptor::setOffset(OpBuilder &builder, Location loc,
Value offset) {
Value offsetPtr =
offsetBasePtr(builder, loc, typeConverter, memRefDescPtr, elemPtrType);
- builder.create<LLVM::StoreOp>(loc, offset, offsetPtr);
+ LLVM::StoreOp::create(builder, loc, offset, offsetPtr);
}
Value UnrankedMemRefDescriptor::sizeBasePtr(
@@ -477,8 +481,8 @@ Value UnrankedMemRefDescriptor::sizeBasePtr(
Type structTy = LLVM::LLVMStructType::getLiteral(
indexTy.getContext(), {elemPtrType, elemPtrType, indexTy, indexTy});
auto resultType = LLVM::LLVMPointerType::get(builder.getContext());
- return builder.create<LLVM::GEPOp>(loc, resultType, structTy, memRefDescPtr,
- ArrayRef<LLVM::GEPArg>{0, 3});
+ return LLVM::GEPOp::create(builder, loc, resultType, structTy, memRefDescPtr,
+ ArrayRef<LLVM::GEPArg>{0, 3});
}
Value UnrankedMemRefDescriptor::size(OpBuilder &builder, Location loc,
@@ -489,8 +493,8 @@ Value UnrankedMemRefDescriptor::size(OpBuilder &builder, Location loc,
auto ptrType = LLVM::LLVMPointerType::get(builder.getContext());
Value sizeStoreGep =
- builder.create<LLVM::GEPOp>(loc, ptrType, indexTy, sizeBasePtr, index);
- return builder.create<LLVM::LoadOp>(loc, indexTy, sizeStoreGep);
+ LLVM::GEPOp::create(builder, loc, ptrType, indexTy, sizeBasePtr, index);
+ return LLVM::LoadOp::create(builder, loc, indexTy, sizeStoreGep);
}
void UnrankedMemRefDescriptor::setSize(OpBuilder &builder, Location loc,
@@ -501,8 +505,8 @@ void UnrankedMemRefDescriptor::setSize(OpBuilder &builder, Location loc,
auto ptrType = LLVM::LLVMPointerType::get(builder.getContext());
Value sizeStoreGep =
- builder.create<LLVM::GEPOp>(loc, ptrType, indexTy, sizeBasePtr, index);
- builder.create<LLVM::StoreOp>(loc, size, sizeStoreGep);
+ LLVM::GEPOp::create(builder, loc, ptrType, indexTy, sizeBasePtr, index);
+ LLVM::StoreOp::create(builder, loc, size, sizeStoreGep);
}
Value UnrankedMemRefDescriptor::strideBasePtr(
@@ -511,7 +515,7 @@ Value UnrankedMemRefDescriptor::strideBasePtr(
Type indexTy = typeConverter.getIndexType();
auto ptrType = LLVM::LLVMPointerType::get(builder.getContext());
- return builder.create<LLVM::GEPOp>(loc, ptrType, indexTy, sizeBasePtr, rank);
+ return LLVM::GEPOp::create(builder, loc, ptrType, indexTy, sizeBasePtr, rank);
}
Value UnrankedMemRefDescriptor::stride(OpBuilder &builder, Location loc,
@@ -522,8 +526,8 @@ Value UnrankedMemRefDescriptor::stride(OpBuilder &builder, Location loc,
auto ptrType = LLVM::LLVMPointerType::get(builder.getContext());
Value strideStoreGep =
- builder.create<LLVM::GEPOp>(loc, ptrType, indexTy, strideBasePtr, index);
- return builder.create<LLVM::LoadOp>(loc, indexTy, strideStoreGep);
+ LLVM::GEPOp::create(builder, loc, ptrType, indexTy, strideBasePtr, index);
+ return LLVM::LoadOp::create(builder, loc, indexTy, strideStoreGep);
}
void UnrankedMemRefDescriptor::setStride(OpBuilder &builder, Location loc,
@@ -534,6 +538,6 @@ void UnrankedMemRefDescriptor::setStride(OpBuilder &builder, Location loc,
auto ptrType = LLVM::LLVMPointerType::get(builder.getContext());
Value strideStoreGep =
- builder.create<LLVM::GEPOp>(loc, ptrType, indexTy, strideBasePtr, index);
- builder.create<LLVM::StoreOp>(loc, stride, strideStoreGep);
+ LLVM::GEPOp::create(builder, loc, ptrType, indexTy, strideBasePtr, index);
+ LLVM::StoreOp::create(builder, loc, stride, strideStoreGep);
}
diff --git a/mlir/lib/Conversion/LLVMCommon/Pattern.cpp b/mlir/lib/Conversion/LLVMCommon/Pattern.cpp
index c5f72f7e10b8c..ecd5b6367fba4 100644
--- a/mlir/lib/Conversion/LLVMCommon/Pattern.cpp
+++ b/mlir/lib/Conversion/LLVMCommon/Pattern.cpp
@@ -57,8 +57,8 @@ Value ConvertToLLVMPattern::createIndexAttrConstant(OpBuilder &builder,
Location loc,
Type resultType,
int64_t value) {
- return builder.create<LLVM::ConstantOp>(loc, resultType,
- builder.getIndexAttr(value));
+ return LLVM::ConstantOp::create(builder, loc, resultType,
+ builder.getIndexAttr(value));
}
Value ConvertToLLVMPattern::getStridedElementPtr(
@@ -123,7 +123,7 @@ void ConvertToLLVMPattern::getMemRefDescriptorSizes(
runningStride = sizes[i];
else if (stride == ShapedType::kDynamic)
runningStride =
- rewriter.create<LLVM::MulOp>(loc, runningStride, sizes[i]);
+ LLVM::MulOp::create(rewriter, loc, runningStride, sizes[i]);
else
runningStride = createIndexAttrConstant(rewriter, loc, indexType, stride);
}
@@ -131,10 +131,10 @@ void ConvertToLLVMPattern::getMemRefDescriptorSizes(
// Buffer size in bytes.
Type elementType = typeConverter->convertType(memRefType.getElementType());
auto elementPtrType = LLVM::LLVMPointerType::get(rewriter.getContext());
- Value nullPtr = rewriter.create<LLVM::ZeroOp>(loc, elementPtrType);
- Value gepPtr = rewriter.create<LLVM::GEPOp>(
- loc, elementPtrType, elementType, nullPtr, runningStride);
- size = rewriter.create<LLVM::PtrToIntOp>(loc, getIndexType(), gepPtr);
+ Value nullPtr = LLVM::ZeroOp::create(rewriter, loc, elementPtrType);
+ Value gepPtr = LLVM::GEPOp::create(rewriter, loc, elementPtrType,
+ elementType, nullPtr, runningStride);
+ size = LLVM::PtrToIntOp::create(rewriter, loc, getIndexType(), gepPtr);
} else {
size = runningStride;
}
@@ -149,10 +149,10 @@ Value ConvertToLLVMPattern::getSizeInBytes(
// which is a common pattern of getting the size of a type in bytes.
Type llvmType = typeConverter->convertType(type);
auto convertedPtrType = LLVM::LLVMPointerType::get(rewriter.getContext());
- auto nullPtr = rewriter.create<LLVM::ZeroOp>(loc, convertedPtrType);
- auto gep = rewriter.create<LLVM::GEPOp>(loc, convertedPtrType, llvmType,
- nullPtr, ArrayRef<LLVM::GEPArg>{1});
- return rewriter.create<LLVM::PtrToIntOp>(loc, getIndexType(), gep);
+ auto nullPtr = LLVM::ZeroOp::create(rewriter, loc, convertedPtrType);
+ auto gep = LLVM::GEPOp::create(rewriter, loc, convertedPtrType, llvmType,
+ nullPtr, ArrayRef<LLVM::GEPArg>{1});
+ return LLVM::PtrToIntOp::create(rewriter, loc, getIndexType(), gep);
}
Value ConvertToLLVMPattern::getNumElements(
@@ -175,7 +175,7 @@ Value ConvertToLLVMPattern::getNumElements(
staticSize == ShapedType::kDynamic
? dynamicSizes[dynamicIndex++]
: createIndexAttrConstant(rewriter, loc, indexType, staticSize);
- numElements = rewriter.create<LLVM::MulOp>(loc, numElements, size);
+ numElements = LLVM::MulOp::create(rewriter, loc, numElements, size);
} else {
numElements =
staticSize == ShapedType::kDynamic
@@ -276,14 +276,14 @@ LogicalResult ConvertToLLVMPattern::copyUnrankedDescriptors(
? builder
.create<LLVM::CallOp>(loc, mallocFunc.value(), allocationSize)
.getResult()
- : builder.create<LLVM::AllocaOp>(loc, getPtrType(),
- IntegerType::get(getContext(), 8),
- allocationSize,
- /*alignment=*/0);
+ : LLVM::AllocaOp::create(builder, loc, getPtrType(),
+ IntegerType::get(getContext(), 8),
+ allocationSize,
+ /*alignment=*/0);
Value source = desc.memRefDescPtr(builder, loc);
- builder.create<LLVM::MemcpyOp>(loc, memory, source, allocationSize, false);
+ LLVM::MemcpyOp::create(builder, loc, memory, source, allocationSize, false);
if (!toDynamic)
- builder.create<LLVM::CallOp>(loc, freeFunc.value(), source);
+ LLVM::CallOp::create(builder, loc, freeFunc.value(), source);
// Create a new descriptor. The same descriptor can be returned multiple
// times, attempting to modify its pointer can lead to memory leaks
@@ -349,8 +349,8 @@ LogicalResult LLVM::detail::oneToOneRewrite(
SmallVector<Value, 4> results;
results.reserve(numResults);
for (unsigned i = 0; i < numResults; ++i) {
- results.push_back(rewriter.create<LLVM::ExtractValueOp>(
- op->getLoc(), newOp->getResult(0), i));
+ results.push_back(LLVM::ExtractValueOp::create(rewriter, op->getLoc(),
+ newOp->getResult(0), i));
}
rewriter.replaceOp(op, results);
return success();
@@ -371,8 +371,8 @@ LogicalResult LLVM::detail::intrinsicRewrite(
if (numResults != 0)
resType = typeConverter.packOperationResults(op->getResultTypes());
- auto callIntrOp = rewriter.create<LLVM::CallIntrinsicOp>(
- loc, resType, rewriter.getStringAttr(intrinsic), operands);
+ auto callIntrOp = LLVM::CallIntrinsicOp::create(
+ rewriter, loc, resType, rewriter.getStringAttr(intrinsic), operands);
// Propagate attributes.
callIntrOp->setAttrs(op->getAttrDictionary());
@@ -388,7 +388,7 @@ LogicalResult LLVM::detail::intrinsicRewrite(
results.reserve(numResults);
Value intrRes = callIntrOp.getResults();
for (unsigned i = 0; i < numResults; ++i)
- results.push_back(rewriter.create<LLVM::ExtractValueOp>(loc, intrRes, i));
+ results.push_back(LLVM::ExtractValueOp::create(rewriter, loc, intrRes, i));
rewriter.replaceOp(op, results);
return success();
@@ -406,7 +406,7 @@ static unsigned getBitWidth(Type type) {
static Value createI32Constant(OpBuilder &builder, Location loc,
int32_t value) {
Type i32 = builder.getI32Type();
- return builder.create<LLVM::ConstantOp>(loc, i32, value);
+ return LLVM::ConstantOp::create(builder, loc, i32, value);
}
SmallVector<Value> mlir::LLVM::decomposeValue(OpBuilder &builder, Location loc,
@@ -418,17 +418,17 @@ SmallVector<Value> mlir::LLVM::decomposeValue(OpBuilder &builder, Location loc,
unsigned srcBitWidth = getBitWidth(srcType);
unsigned dstBitWidth = getBitWidth(dstType);
if (srcBitWidth == dstBitWidth) {
- Value cast = builder.create<LLVM::BitcastOp>(loc, dstType, src);
+ Value cast = LLVM::BitcastOp::create(builder, loc, dstType, src);
return {cast};
}
if (dstBitWidth > srcBitWidth) {
auto smallerInt = builder.getIntegerType(srcBitWidth);
if (srcType != smallerInt)
- src = builder.create<LLVM::BitcastOp>(loc, smallerInt, src);
+ src = LLVM::BitcastOp::create(builder, loc, smallerInt, src);
auto largerInt = builder.getIntegerType(dstBitWidth);
- Value res = builder.create<LLVM::ZExtOp>(loc, largerInt, src);
+ Value res = LLVM::ZExtOp::create(builder, loc, largerInt, src);
return {res};
}
assert(srcBitWidth % dstBitWidth == 0 &&
@@ -436,12 +436,12 @@ SmallVector<Value> mlir::LLVM::decomposeValue(OpBuilder &builder, Location loc,
int64_t numElements = srcBitWidth / dstBitWidth;
auto vecType = VectorType::get(numElements, dstType);
- src = builder.create<LLVM::BitcastOp>(loc, vecType, src);
+ src = LLVM::BitcastOp::create(builder, loc, vecType, src);
SmallVector<Value> res;
for (auto i : llvm::seq(numElements)) {
Value idx = createI32Constant(builder, loc, i);
- Value elem = builder.create<LLVM::ExtractElementOp>(loc, src, idx);
+ Value elem = LLVM::ExtractElementOp::create(builder, loc, src, idx);
res.emplace_back(elem);
}
@@ -461,28 +461,28 @@ Value mlir::LLVM::composeValue(OpBuilder &builder, Location loc, ValueRange src,
if (dstBitWidth < srcBitWidth) {
auto largerInt = builder.getIntegerType(srcBitWidth);
if (res.getType() != largerInt)
- res = builder.create<LLVM::BitcastOp>(loc, largerInt, res);
+ res = LLVM::BitcastOp::create(builder, loc, largerInt, res);
auto smallerInt = builder.getIntegerType(dstBitWidth);
- res = builder.create<LLVM::TruncOp>(loc, smallerInt, res);
+ res = LLVM::TruncOp::create(builder, loc, smallerInt, res);
}
if (res.getType() != dstType)
- res = builder.create<LLVM::BitcastOp>(loc, dstType, res);
+ res = LLVM::BitcastOp::create(builder, loc, dstType, res);
return res;
}
int64_t numElements = src.size();
auto srcType = VectorType::get(numElements, src.front().getType());
- Value res = builder.create<LLVM::PoisonOp>(loc, srcType);
+ Value res = LLVM::PoisonOp::create(builder, loc, srcType);
for (auto &&[i, elem] : llvm::enumerate(src)) {
Value idx = createI32Constant(builder, loc, i);
- res = builder.create<LLVM::InsertElementOp>(loc, srcType, res, elem, idx);
+ res = LLVM::InsertElementOp::create(builder, loc, srcType, res, elem, idx);
}
if (res.getType() != dstType)
- res = builder.create<LLVM::BitcastOp>(loc, dstType, res);
+ res = LLVM::BitcastOp::create(builder, loc, dstType, res);
return res;
}
@@ -518,20 +518,20 @@ Value mlir::LLVM::getStridedElementPtr(OpBuilder &builder, Location loc,
Value stride =
ShapedType::isDynamic(strides[i])
? memRefDescriptor.stride(builder, loc, i)
- : builder.create<LLVM::ConstantOp>(
- loc, indexType, builder.getIndexAttr(strides[i]));
- increment =
- builder.create<LLVM::MulOp>(loc, increment, stride, intOverflowFlags);
+ : LLVM::ConstantOp::create(builder, loc, indexType,
+ builder.getIndexAttr(strides[i]));
+ increment = LLVM::MulOp::create(builder, loc, increment, stride,
+ intOverflowFlags);
}
- index = index ? builder.create<LLVM::AddOp>(loc, index, increment,
- intOverflowFlags)
+ index = index ? LLVM::AddOp::create(builder, loc, index, increment,
+ intOverflowFlags)
: increment;
}
Type elementPtrType = memRefDescriptor.getElementPtrType();
- return index ? builder.create<LLVM::GEPOp>(
- loc, elementPtrType,
- converter.convertType(type.getElementType()), base, index,
- noWrapFlags)
- : base;
+ return index
+ ? LLVM::GEPOp::create(builder, loc, elementPtrType,
+ converter.convertType(type.getElementType()),
+ base, index, noWrapFlags)
+ : base;
}
diff --git a/mlir/lib/Conversion/LLVMCommon/PrintCallHelper.cpp b/mlir/lib/Conversion/LLVMCommon/PrintCallHelper.cpp
index 49c73fbc9dd79..d95aeba8a4488 100644
--- a/mlir/lib/Conversion/LLVMCommon/PrintCallHelper.cpp
+++ b/mlir/lib/Conversion/LLVMCommon/PrintCallHelper.cpp
@@ -66,23 +66,23 @@ LogicalResult mlir::LLVM::createPrintStrCall(
DenseElementsAttr::get(dataAttrType, llvm::ArrayRef(elementVals));
auto arrayTy =
LLVM::LLVMArrayType::get(IntegerType::get(ctx, 8), elementVals.size());
- auto globalOp = builder.create<LLVM::GlobalOp>(
- loc, arrayTy, /*constant=*/true, LLVM::Linkage::Private,
+ auto globalOp = LLVM::GlobalOp::create(
+ builder, loc, arrayTy, /*constant=*/true, LLVM::Linkage::Private,
ensureSymbolNameIsUnique(moduleOp, symbolName, symbolTables), dataAttr);
auto ptrTy = LLVM::LLVMPointerType::get(builder.getContext());
// Emit call to `printStr` in runtime library.
builder.restoreInsertionPoint(ip);
auto msgAddr =
- builder.create<LLVM::AddressOfOp>(loc, ptrTy, globalOp.getName());
+ LLVM::AddressOfOp::create(builder, loc, ptrTy, globalOp.getName());
SmallVector<LLVM::GEPArg> indices(1, 0);
Value gep =
- builder.create<LLVM::GEPOp>(loc, ptrTy, arrayTy, msgAddr, indices);
+ LLVM::GEPOp::create(builder, loc, ptrTy, arrayTy, msgAddr, indices);
FailureOr<LLVM::LLVMFuncOp> printer =
LLVM::lookupOrCreatePrintStringFn(builder, moduleOp, runtimeFunctionName);
if (failed(printer))
return failure();
- builder.create<LLVM::CallOp>(loc, TypeRange(),
- SymbolRefAttr::get(printer.value()), gep);
+ LLVM::CallOp::create(builder, loc, TypeRange(),
+ SymbolRefAttr::get(printer.value()), gep);
return success();
}
diff --git a/mlir/lib/Conversion/LLVMCommon/StructBuilder.cpp b/mlir/lib/Conversion/LLVMCommon/StructBuilder.cpp
index 1cd0bd85f9894..13ed4628c3c9e 100644
--- a/mlir/lib/Conversion/LLVMCommon/StructBuilder.cpp
+++ b/mlir/lib/Conversion/LLVMCommon/StructBuilder.cpp
@@ -24,10 +24,10 @@ StructBuilder::StructBuilder(Value v) : value(v), structType(v.getType()) {
Value StructBuilder::extractPtr(OpBuilder &builder, Location loc,
unsigned pos) const {
- return builder.create<LLVM::ExtractValueOp>(loc, value, pos);
+ return LLVM::ExtractValueOp::create(builder, loc, value, pos);
}
void StructBuilder::setPtr(OpBuilder &builder, Location loc, unsigned pos,
Value ptr) {
- value = builder.create<LLVM::InsertValueOp>(loc, value, ptr, pos);
+ value = LLVM::InsertValueOp::create(builder, loc, value, ptr, pos);
}
diff --git a/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp b/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp
index 7312594c761f7..1a9bf569086da 100644
--- a/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp
+++ b/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp
@@ -91,7 +91,7 @@ static Value unrankedMemRefMaterialization(OpBuilder &builder,
packUnrankedMemRefDesc(builder, resultType, inputs, loc, converter);
if (!packed)
return Value();
- return builder.create<UnrealizedConversionCastOp>(loc, resultType, packed)
+ return UnrealizedConversionCastOp::create(builder, loc, resultType, packed)
.getResult(0);
}
@@ -107,7 +107,7 @@ static Value rankedMemRefMaterialization(OpBuilder &builder,
packRankedMemRefDesc(builder, resultType, inputs, loc, converter);
if (!packed)
return Value();
- return builder.create<UnrealizedConversionCastOp>(loc, resultType, packed)
+ return UnrealizedConversionCastOp::create(builder, loc, resultType, packed)
.getResult(0);
}
@@ -224,12 +224,12 @@ LLVMTypeConverter::LLVMTypeConverter(MLIRContext *ctx,
// non-LLVM types persist after an LLVM conversion.
addSourceMaterialization([&](OpBuilder &builder, Type resultType,
ValueRange inputs, Location loc) {
- return builder.create<UnrealizedConversionCastOp>(loc, resultType, inputs)
+ return UnrealizedConversionCastOp::create(builder, loc, resultType, inputs)
.getResult(0);
});
addTargetMaterialization([&](OpBuilder &builder, Type resultType,
ValueRange inputs, Location loc) {
- return builder.create<UnrealizedConversionCastOp>(loc, resultType, inputs)
+ return UnrealizedConversionCastOp::create(builder, loc, resultType, inputs)
.getResult(0);
});
@@ -731,12 +731,12 @@ Value LLVMTypeConverter::promoteOneMemRefDescriptor(Location loc, Value operand,
// Alloca with proper alignment. We do not expect optimizations of this
// alloca op and so we omit allocating at the entry block.
auto ptrType = LLVM::LLVMPointerType::get(builder.getContext());
- Value one = builder.create<LLVM::ConstantOp>(loc, builder.getI64Type(),
- builder.getIndexAttr(1));
+ Value one = LLVM::ConstantOp::create(builder, loc, builder.getI64Type(),
+ builder.getIndexAttr(1));
Value allocated =
- builder.create<LLVM::AllocaOp>(loc, ptrType, operand.getType(), one);
+ LLVM::AllocaOp::create(builder, loc, ptrType, operand.getType(), one);
// Store into the alloca'ed descriptor.
- builder.create<LLVM::StoreOp>(loc, operand, allocated);
+ LLVM::StoreOp::create(builder, loc, operand, allocated);
return allocated;
}
diff --git a/mlir/lib/Conversion/LLVMCommon/VectorPattern.cpp b/mlir/lib/Conversion/LLVMCommon/VectorPattern.cpp
index bf3f31729c3da..e7dd0b506e12d 100644
--- a/mlir/lib/Conversion/LLVMCommon/VectorPattern.cpp
+++ b/mlir/lib/Conversion/LLVMCommon/VectorPattern.cpp
@@ -87,17 +87,17 @@ LogicalResult LLVM::detail::handleMultidimensionalVectors(
auto result1DVectorTy = resultTypeInfo.llvm1DVectorTy;
auto resultNDVectoryTy = resultTypeInfo.llvmNDVectorTy;
auto loc = op->getLoc();
- Value desc = rewriter.create<LLVM::PoisonOp>(loc, resultNDVectoryTy);
+ Value desc = LLVM::PoisonOp::create(rewriter, loc, resultNDVectoryTy);
nDVectorIterate(resultTypeInfo, rewriter, [&](ArrayRef<int64_t> position) {
// For this unrolled `position` corresponding to the `linearIndex`^th
// element, extract operand vectors
SmallVector<Value, 4> extractedOperands;
for (const auto &operand : llvm::enumerate(operands)) {
- extractedOperands.push_back(rewriter.create<LLVM::ExtractValueOp>(
- loc, operand.value(), position));
+ extractedOperands.push_back(LLVM::ExtractValueOp::create(
+ rewriter, loc, operand.value(), position));
}
Value newVal = createOperand(result1DVectorTy, extractedOperands);
- desc = rewriter.create<LLVM::InsertValueOp>(loc, desc, newVal, position);
+ desc = LLVM::InsertValueOp::create(rewriter, loc, desc, newVal, position);
});
rewriter.replaceOp(op, desc);
return success();
diff --git a/mlir/lib/Conversion/LinalgToStandard/LinalgToStandard.cpp b/mlir/lib/Conversion/LinalgToStandard/LinalgToStandard.cpp
index c3f213147b7a7..3f4b4d6cbc8ab 100644
--- a/mlir/lib/Conversion/LinalgToStandard/LinalgToStandard.cpp
+++ b/mlir/lib/Conversion/LinalgToStandard/LinalgToStandard.cpp
@@ -78,8 +78,8 @@ getLibraryCallSymbolRef(Operation *op, PatternRewriter &rewriter) {
// Insert before module terminator.
rewriter.setInsertionPoint(module.getBody(),
std::prev(module.getBody()->end()));
- func::FuncOp funcOp = rewriter.create<func::FuncOp>(
- op->getLoc(), fnNameAttr.getValue(), libFnType);
+ func::FuncOp funcOp = func::FuncOp::create(rewriter, op->getLoc(),
+ fnNameAttr.getValue(), libFnType);
// Insert a function attribute that will trigger the emission of the
// corresponding `_mlir_ciface_xxx` interface so that external libraries see
// a normalized ABI. This interface is added during std to llvm conversion.
@@ -100,8 +100,8 @@ createTypeCanonicalizedMemRefOperands(OpBuilder &b, Location loc,
res.push_back(op);
continue;
}
- Value cast =
- b.create<memref::CastOp>(loc, makeStridedLayoutDynamic(memrefType), op);
+ Value cast = memref::CastOp::create(
+ b, loc, makeStridedLayoutDynamic(memrefType), op);
res.push_back(cast);
}
return res;
diff --git a/mlir/lib/Conversion/MPIToLLVM/MPIToLLVM.cpp b/mlir/lib/Conversion/MPIToLLVM/MPIToLLVM.cpp
index d4deff5b88070..5b68eb8188996 100644
--- a/mlir/lib/Conversion/MPIToLLVM/MPIToLLVM.cpp
+++ b/mlir/lib/Conversion/MPIToLLVM/MPIToLLVM.cpp
@@ -54,18 +54,18 @@ std::pair<Value, Value> getRawPtrAndSize(const Location loc,
Value memRef, Type elType) {
Type ptrType = LLVM::LLVMPointerType::get(rewriter.getContext());
Value dataPtr =
- rewriter.create<LLVM::ExtractValueOp>(loc, ptrType, memRef, 1);
- Value offset = rewriter.create<LLVM::ExtractValueOp>(
- loc, rewriter.getI64Type(), memRef, 2);
+ LLVM::ExtractValueOp::create(rewriter, loc, ptrType, memRef, 1);
+ Value offset = LLVM::ExtractValueOp::create(rewriter, loc,
+ rewriter.getI64Type(), memRef, 2);
Value resPtr =
- rewriter.create<LLVM::GEPOp>(loc, ptrType, elType, dataPtr, offset);
+ LLVM::GEPOp::create(rewriter, loc, ptrType, elType, dataPtr, offset);
Value size;
if (cast<LLVM::LLVMStructType>(memRef.getType()).getBody().size() > 3) {
- size = rewriter.create<LLVM::ExtractValueOp>(loc, memRef,
- ArrayRef<int64_t>{3, 0});
- size = rewriter.create<LLVM::TruncOp>(loc, rewriter.getI32Type(), size);
+ size = LLVM::ExtractValueOp::create(rewriter, loc, memRef,
+ ArrayRef<int64_t>{3, 0});
+ size = LLVM::TruncOp::create(rewriter, loc, rewriter.getI32Type(), size);
} else {
- size = rewriter.create<arith::ConstantIntOp>(loc, 1, 32);
+ size = arith::ConstantIntOp::create(rewriter, loc, 1, 32);
}
return {resPtr, size};
}
@@ -157,13 +157,13 @@ class MPICHImplTraits : public MPIImplTraits {
Value getCommWorld(const Location loc,
ConversionPatternRewriter &rewriter) override {
static constexpr int MPI_COMM_WORLD = 0x44000000;
- return rewriter.create<LLVM::ConstantOp>(loc, rewriter.getI64Type(),
- MPI_COMM_WORLD);
+ return LLVM::ConstantOp::create(rewriter, loc, rewriter.getI64Type(),
+ MPI_COMM_WORLD);
}
Value castComm(const Location loc, ConversionPatternRewriter &rewriter,
Value comm) override {
- return rewriter.create<LLVM::TruncOp>(loc, rewriter.getI32Type(), comm);
+ return LLVM::TruncOp::create(rewriter, loc, rewriter.getI32Type(), comm);
}
intptr_t getStatusIgnore() override { return 1; }
@@ -195,7 +195,8 @@ class MPICHImplTraits : public MPIImplTraits {
mtype = MPI_UINT8_T;
else
assert(false && "unsupported type");
- return rewriter.create<LLVM::ConstantOp>(loc, rewriter.getI32Type(), mtype);
+ return LLVM::ConstantOp::create(rewriter, loc, rewriter.getI32Type(),
+ mtype);
}
Value getMPIOp(const Location loc, ConversionPatternRewriter &rewriter,
@@ -245,7 +246,7 @@ class MPICHImplTraits : public MPIImplTraits {
op = MPI_REPLACE;
break;
}
- return rewriter.create<LLVM::ConstantOp>(loc, rewriter.getI32Type(), op);
+ return LLVM::ConstantOp::create(rewriter, loc, rewriter.getI32Type(), op);
}
};
@@ -281,16 +282,16 @@ class OMPIImplTraits : public MPIImplTraits {
getOrDefineExternalStruct(loc, rewriter, name, commStructT);
// get address of symbol
- auto comm = rewriter.create<LLVM::AddressOfOp>(
- loc, LLVM::LLVMPointerType::get(context),
- SymbolRefAttr::get(context, name));
- return rewriter.create<LLVM::PtrToIntOp>(loc, rewriter.getI64Type(), comm);
+ auto comm = LLVM::AddressOfOp::create(rewriter, loc,
+ LLVM::LLVMPointerType::get(context),
+ SymbolRefAttr::get(context, name));
+ return LLVM::PtrToIntOp::create(rewriter, loc, rewriter.getI64Type(), comm);
}
Value castComm(const Location loc, ConversionPatternRewriter &rewriter,
Value comm) override {
- return rewriter.create<LLVM::IntToPtrOp>(
- loc, LLVM::LLVMPointerType::get(rewriter.getContext()), comm);
+ return LLVM::IntToPtrOp::create(
+ rewriter, loc, LLVM::LLVMPointerType::get(rewriter.getContext()), comm);
}
intptr_t getStatusIgnore() override { return 0; }
@@ -330,9 +331,9 @@ class OMPIImplTraits : public MPIImplTraits {
// make sure global op definition exists
getOrDefineExternalStruct(loc, rewriter, mtype, typeStructT);
// get address of symbol
- return rewriter.create<LLVM::AddressOfOp>(
- loc, LLVM::LLVMPointerType::get(context),
- SymbolRefAttr::get(context, mtype));
+ return LLVM::AddressOfOp::create(rewriter, loc,
+ LLVM::LLVMPointerType::get(context),
+ SymbolRefAttr::get(context, mtype));
}
Value getMPIOp(const Location loc, ConversionPatternRewriter &rewriter,
@@ -389,9 +390,9 @@ class OMPIImplTraits : public MPIImplTraits {
// make sure global op definition exists
getOrDefineExternalStruct(loc, rewriter, op, opStructT);
// get address of symbol
- return rewriter.create<LLVM::AddressOfOp>(
- loc, LLVM::LLVMPointerType::get(context),
- SymbolRefAttr::get(context, op));
+ return LLVM::AddressOfOp::create(rewriter, loc,
+ LLVM::LLVMPointerType::get(context),
+ SymbolRefAttr::get(context, op));
}
};
@@ -424,7 +425,7 @@ struct InitOpLowering : public ConvertOpToLLVMPattern<mpi::InitOp> {
Type ptrType = LLVM::LLVMPointerType::get(rewriter.getContext());
// instantiate nullptr `%nullptr = llvm.mlir.zero : !llvm.ptr`
- auto nullPtrOp = rewriter.create<LLVM::ZeroOp>(loc, ptrType);
+ auto nullPtrOp = LLVM::ZeroOp::create(rewriter, loc, ptrType);
Value llvmnull = nullPtrOp.getRes();
// grab a reference to the global module op:
@@ -513,9 +514,9 @@ struct CommSplitOpLowering : public ConvertOpToLLVMPattern<mpi::CommSplitOp> {
// get communicator
Value comm = mpiTraits->castComm(loc, rewriter, adaptor.getComm());
- auto one = rewriter.create<LLVM::ConstantOp>(loc, i32, 1);
+ auto one = LLVM::ConstantOp::create(rewriter, loc, i32, 1);
auto outPtr =
- rewriter.create<LLVM::AllocaOp>(loc, ptrType, comm.getType(), one);
+ LLVM::AllocaOp::create(rewriter, loc, ptrType, comm.getType(), one);
// int MPI_Comm_split(MPI_Comm comm, int color, int key, MPI_Comm * newcomm)
auto funcType =
@@ -524,14 +525,14 @@ struct CommSplitOpLowering : public ConvertOpToLLVMPattern<mpi::CommSplitOp> {
LLVM::LLVMFuncOp funcDecl = getOrDefineFunction(moduleOp, loc, rewriter,
"MPI_Comm_split", funcType);
- auto callOp = rewriter.create<LLVM::CallOp>(
- loc, funcDecl,
- ValueRange{comm, adaptor.getColor(), adaptor.getKey(),
- outPtr.getRes()});
+ auto callOp =
+ LLVM::CallOp::create(rewriter, loc, funcDecl,
+ ValueRange{comm, adaptor.getColor(),
+ adaptor.getKey(), outPtr.getRes()});
// load the communicator into a register
- Value res = rewriter.create<LLVM::LoadOp>(loc, i32, outPtr.getResult());
- res = rewriter.create<LLVM::SExtOp>(loc, rewriter.getI64Type(), res);
+ Value res = LLVM::LoadOp::create(rewriter, loc, i32, outPtr.getResult());
+ res = LLVM::SExtOp::create(rewriter, loc, rewriter.getI64Type(), res);
// if retval is checked, replace uses of retval with the results from the
// call op
@@ -580,14 +581,14 @@ struct CommRankOpLowering : public ConvertOpToLLVMPattern<mpi::CommRankOp> {
moduleOp, loc, rewriter, "MPI_Comm_rank", rankFuncType);
// replace with function call
- auto one = rewriter.create<LLVM::ConstantOp>(loc, i32, 1);
- auto rankptr = rewriter.create<LLVM::AllocaOp>(loc, ptrType, i32, one);
- auto callOp = rewriter.create<LLVM::CallOp>(
- loc, initDecl, ValueRange{comm, rankptr.getRes()});
+ auto one = LLVM::ConstantOp::create(rewriter, loc, i32, 1);
+ auto rankptr = LLVM::AllocaOp::create(rewriter, loc, ptrType, i32, one);
+ auto callOp = LLVM::CallOp::create(rewriter, loc, initDecl,
+ ValueRange{comm, rankptr.getRes()});
// load the rank into a register
auto loadedRank =
- rewriter.create<LLVM::LoadOp>(loc, i32, rankptr.getResult());
+ LLVM::LoadOp::create(rewriter, loc, i32, rankptr.getResult());
// if retval is checked, replace uses of retval with the results from the
// call op
@@ -641,10 +642,10 @@ struct SendOpLowering : public ConvertOpToLLVMPattern<mpi::SendOp> {
getOrDefineFunction(moduleOp, loc, rewriter, "MPI_Send", funcType);
// replace op with function call
- auto funcCall = rewriter.create<LLVM::CallOp>(
- loc, funcDecl,
- ValueRange{dataPtr, size, dataType, adaptor.getDest(), adaptor.getTag(),
- comm});
+ auto funcCall = LLVM::CallOp::create(rewriter, loc, funcDecl,
+ ValueRange{dataPtr, size, dataType,
+ adaptor.getDest(),
+ adaptor.getTag(), comm});
if (op.getRetval())
rewriter.replaceOp(op, funcCall.getResult());
else
@@ -683,10 +684,10 @@ struct RecvOpLowering : public ConvertOpToLLVMPattern<mpi::RecvOp> {
auto mpiTraits = MPIImplTraits::get(moduleOp);
Value dataType = mpiTraits->getDataType(loc, rewriter, elemType);
Value comm = mpiTraits->castComm(loc, rewriter, adaptor.getComm());
- Value statusIgnore = rewriter.create<LLVM::ConstantOp>(
- loc, i64, mpiTraits->getStatusIgnore());
+ Value statusIgnore = LLVM::ConstantOp::create(rewriter, loc, i64,
+ mpiTraits->getStatusIgnore());
statusIgnore =
- rewriter.create<LLVM::IntToPtrOp>(loc, ptrType, statusIgnore);
+ LLVM::IntToPtrOp::create(rewriter, loc, ptrType, statusIgnore);
// LLVM Function type representing `i32 MPI_Recv(data, count, datatype, dst,
// tag, comm)`
@@ -698,8 +699,8 @@ struct RecvOpLowering : public ConvertOpToLLVMPattern<mpi::RecvOp> {
getOrDefineFunction(moduleOp, loc, rewriter, "MPI_Recv", funcType);
// replace op with function call
- auto funcCall = rewriter.create<LLVM::CallOp>(
- loc, funcDecl,
+ auto funcCall = LLVM::CallOp::create(
+ rewriter, loc, funcDecl,
ValueRange{dataPtr, size, dataType, adaptor.getSource(),
adaptor.getTag(), comm, statusIgnore});
if (op.getRetval())
@@ -738,9 +739,10 @@ struct AllReduceOpLowering : public ConvertOpToLLVMPattern<mpi::AllReduceOp> {
// If input and output are the same, request in-place operation.
if (adaptor.getSendbuf() == adaptor.getRecvbuf()) {
- sendPtr = rewriter.create<LLVM::ConstantOp>(
- loc, i64, reinterpret_cast<int64_t>(mpiTraits->getInPlace()));
- sendPtr = rewriter.create<LLVM::IntToPtrOp>(loc, ptrType, sendPtr);
+ sendPtr = LLVM::ConstantOp::create(
+ rewriter, loc, i64,
+ reinterpret_cast<int64_t>(mpiTraits->getInPlace()));
+ sendPtr = LLVM::IntToPtrOp::create(rewriter, loc, ptrType, sendPtr);
}
Value dataType = mpiTraits->getDataType(loc, rewriter, elemType);
@@ -757,8 +759,8 @@ struct AllReduceOpLowering : public ConvertOpToLLVMPattern<mpi::AllReduceOp> {
getOrDefineFunction(moduleOp, loc, rewriter, "MPI_Allreduce", funcType);
// replace op with function call
- auto funcCall = rewriter.create<LLVM::CallOp>(
- loc, funcDecl,
+ auto funcCall = LLVM::CallOp::create(
+ rewriter, loc, funcDecl,
ValueRange{sendPtr, recvPtr, sendSize, dataType, mpiOp, commWorld});
if (op.getRetval())
diff --git a/mlir/lib/Conversion/MathToFuncs/MathToFuncs.cpp b/mlir/lib/Conversion/MathToFuncs/MathToFuncs.cpp
index 7f4655e53609e..08a456691880c 100644
--- a/mlir/lib/Conversion/MathToFuncs/MathToFuncs.cpp
+++ b/mlir/lib/Conversion/MathToFuncs/MathToFuncs.cpp
@@ -121,19 +121,19 @@ VecOpToScalarOp<Op>::matchAndRewrite(Op op, PatternRewriter &rewriter) const {
initValueAttr = FloatAttr::get(resultElementType, 0.0);
else
initValueAttr = IntegerAttr::get(resultElementType, 0);
- Value result = rewriter.create<arith::ConstantOp>(
- loc, DenseElementsAttr::get(vecType, initValueAttr));
+ Value result = arith::ConstantOp::create(
+ rewriter, loc, DenseElementsAttr::get(vecType, initValueAttr));
SmallVector<int64_t> strides = computeStrides(shape);
for (int64_t linearIndex = 0; linearIndex < numElements; ++linearIndex) {
SmallVector<int64_t> positions = delinearize(linearIndex, strides);
SmallVector<Value> operands;
for (Value input : op->getOperands())
operands.push_back(
- rewriter.create<vector::ExtractOp>(loc, input, positions));
+ vector::ExtractOp::create(rewriter, loc, input, positions));
Value scalarOp =
- rewriter.create<Op>(loc, vecType.getElementType(), operands);
+ Op::create(rewriter, loc, vecType.getElementType(), operands);
result =
- rewriter.create<vector::InsertOp>(loc, scalarOp, result, positions);
+ vector::InsertOp::create(rewriter, loc, scalarOp, result, positions);
}
rewriter.replaceOp(op, result);
return success();
@@ -195,7 +195,7 @@ static func::FuncOp createElementIPowIFunc(ModuleOp *module, Type elementType) {
FunctionType funcType = FunctionType::get(
builder.getContext(), {elementType, elementType}, elementType);
- auto funcOp = builder.create<func::FuncOp>(funcName, funcType);
+ auto funcOp = func::FuncOp::create(builder, funcName, funcType);
LLVM::linkage::Linkage inlineLinkage = LLVM::linkage::Linkage::LinkonceODR;
Attribute linkage =
LLVM::LinkageAttr::get(builder.getContext(), inlineLinkage);
@@ -208,12 +208,12 @@ static func::FuncOp createElementIPowIFunc(ModuleOp *module, Type elementType) {
Value bArg = funcOp.getArgument(0);
Value pArg = funcOp.getArgument(1);
builder.setInsertionPointToEnd(entryBlock);
- Value zeroValue = builder.create<arith::ConstantOp>(
- elementType, builder.getIntegerAttr(elementType, 0));
- Value oneValue = builder.create<arith::ConstantOp>(
- elementType, builder.getIntegerAttr(elementType, 1));
- Value minusOneValue = builder.create<arith::ConstantOp>(
- elementType,
+ Value zeroValue = arith::ConstantOp::create(
+ builder, elementType, builder.getIntegerAttr(elementType, 0));
+ Value oneValue = arith::ConstantOp::create(
+ builder, elementType, builder.getIntegerAttr(elementType, 1));
+ Value minusOneValue = arith::ConstantOp::create(
+ builder, elementType,
builder.getIntegerAttr(elementType,
APInt(elementType.getIntOrFloatBitWidth(), -1ULL,
/*isSigned=*/true)));
@@ -221,82 +221,83 @@ static func::FuncOp createElementIPowIFunc(ModuleOp *module, Type elementType) {
// if (p == T(0))
// return T(1);
auto pIsZero =
- builder.create<arith::CmpIOp>(arith::CmpIPredicate::eq, pArg, zeroValue);
+ arith::CmpIOp::create(builder, arith::CmpIPredicate::eq, pArg, zeroValue);
Block *thenBlock = builder.createBlock(funcBody);
- builder.create<func::ReturnOp>(oneValue);
+ func::ReturnOp::create(builder, oneValue);
Block *fallthroughBlock = builder.createBlock(funcBody);
// Set up conditional branch for (p == T(0)).
builder.setInsertionPointToEnd(pIsZero->getBlock());
- builder.create<cf::CondBranchOp>(pIsZero, thenBlock, fallthroughBlock);
+ cf::CondBranchOp::create(builder, pIsZero, thenBlock, fallthroughBlock);
// if (p < T(0)) {
builder.setInsertionPointToEnd(fallthroughBlock);
- auto pIsNeg =
- builder.create<arith::CmpIOp>(arith::CmpIPredicate::sle, pArg, zeroValue);
+ auto pIsNeg = arith::CmpIOp::create(builder, arith::CmpIPredicate::sle, pArg,
+ zeroValue);
// if (b == T(0))
builder.createBlock(funcBody);
auto bIsZero =
- builder.create<arith::CmpIOp>(arith::CmpIPredicate::eq, bArg, zeroValue);
+ arith::CmpIOp::create(builder, arith::CmpIPredicate::eq, bArg, zeroValue);
// return T(1) / T(0);
thenBlock = builder.createBlock(funcBody);
- builder.create<func::ReturnOp>(
- builder.create<arith::DivSIOp>(oneValue, zeroValue).getResult());
+ func::ReturnOp::create(
+ builder,
+ arith::DivSIOp::create(builder, oneValue, zeroValue).getResult());
fallthroughBlock = builder.createBlock(funcBody);
// Set up conditional branch for (b == T(0)).
builder.setInsertionPointToEnd(bIsZero->getBlock());
- builder.create<cf::CondBranchOp>(bIsZero, thenBlock, fallthroughBlock);
+ cf::CondBranchOp::create(builder, bIsZero, thenBlock, fallthroughBlock);
// if (b == T(1))
builder.setInsertionPointToEnd(fallthroughBlock);
auto bIsOne =
- builder.create<arith::CmpIOp>(arith::CmpIPredicate::eq, bArg, oneValue);
+ arith::CmpIOp::create(builder, arith::CmpIPredicate::eq, bArg, oneValue);
// return T(1);
thenBlock = builder.createBlock(funcBody);
- builder.create<func::ReturnOp>(oneValue);
+ func::ReturnOp::create(builder, oneValue);
fallthroughBlock = builder.createBlock(funcBody);
// Set up conditional branch for (b == T(1)).
builder.setInsertionPointToEnd(bIsOne->getBlock());
- builder.create<cf::CondBranchOp>(bIsOne, thenBlock, fallthroughBlock);
+ cf::CondBranchOp::create(builder, bIsOne, thenBlock, fallthroughBlock);
// if (b == T(-1)) {
builder.setInsertionPointToEnd(fallthroughBlock);
- auto bIsMinusOne = builder.create<arith::CmpIOp>(arith::CmpIPredicate::eq,
- bArg, minusOneValue);
+ auto bIsMinusOne = arith::CmpIOp::create(builder, arith::CmpIPredicate::eq,
+ bArg, minusOneValue);
// if (p & T(1))
builder.createBlock(funcBody);
- auto pIsOdd = builder.create<arith::CmpIOp>(
- arith::CmpIPredicate::ne, builder.create<arith::AndIOp>(pArg, oneValue),
- zeroValue);
+ auto pIsOdd = arith::CmpIOp::create(
+ builder, arith::CmpIPredicate::ne,
+ arith::AndIOp::create(builder, pArg, oneValue), zeroValue);
// return T(-1);
thenBlock = builder.createBlock(funcBody);
- builder.create<func::ReturnOp>(minusOneValue);
+ func::ReturnOp::create(builder, minusOneValue);
fallthroughBlock = builder.createBlock(funcBody);
// Set up conditional branch for (p & T(1)).
builder.setInsertionPointToEnd(pIsOdd->getBlock());
- builder.create<cf::CondBranchOp>(pIsOdd, thenBlock, fallthroughBlock);
+ cf::CondBranchOp::create(builder, pIsOdd, thenBlock, fallthroughBlock);
// return T(1);
// } // b == T(-1)
builder.setInsertionPointToEnd(fallthroughBlock);
- builder.create<func::ReturnOp>(oneValue);
+ func::ReturnOp::create(builder, oneValue);
fallthroughBlock = builder.createBlock(funcBody);
// Set up conditional branch for (b == T(-1)).
builder.setInsertionPointToEnd(bIsMinusOne->getBlock());
- builder.create<cf::CondBranchOp>(bIsMinusOne, pIsOdd->getBlock(),
- fallthroughBlock);
+ cf::CondBranchOp::create(builder, bIsMinusOne, pIsOdd->getBlock(),
+ fallthroughBlock);
// return T(0);
// } // (p < T(0))
builder.setInsertionPointToEnd(fallthroughBlock);
- builder.create<func::ReturnOp>(zeroValue);
+ func::ReturnOp::create(builder, zeroValue);
Block *loopHeader = builder.createBlock(
funcBody, funcBody->end(), {elementType, elementType, elementType},
{builder.getLoc(), builder.getLoc(), builder.getLoc()});
// Set up conditional branch for (p < T(0)).
builder.setInsertionPointToEnd(pIsNeg->getBlock());
// Set initial values of 'result', 'b' and 'p' for the loop.
- builder.create<cf::CondBranchOp>(pIsNeg, bIsZero->getBlock(), loopHeader,
- ValueRange{oneValue, bArg, pArg});
+ cf::CondBranchOp::create(builder, pIsNeg, bIsZero->getBlock(), loopHeader,
+ ValueRange{oneValue, bArg, pArg});
// T result = T(1);
// while (true) {
@@ -313,45 +314,46 @@ static func::FuncOp createElementIPowIFunc(ModuleOp *module, Type elementType) {
builder.setInsertionPointToEnd(loopHeader);
// if (p & T(1))
- auto powerTmpIsOdd = builder.create<arith::CmpIOp>(
- arith::CmpIPredicate::ne,
- builder.create<arith::AndIOp>(powerTmp, oneValue), zeroValue);
+ auto powerTmpIsOdd = arith::CmpIOp::create(
+ builder, arith::CmpIPredicate::ne,
+ arith::AndIOp::create(builder, powerTmp, oneValue), zeroValue);
thenBlock = builder.createBlock(funcBody);
// result *= b;
- Value newResultTmp = builder.create<arith::MulIOp>(resultTmp, baseTmp);
+ Value newResultTmp = arith::MulIOp::create(builder, resultTmp, baseTmp);
fallthroughBlock = builder.createBlock(funcBody, funcBody->end(), elementType,
builder.getLoc());
builder.setInsertionPointToEnd(thenBlock);
- builder.create<cf::BranchOp>(newResultTmp, fallthroughBlock);
+ cf::BranchOp::create(builder, newResultTmp, fallthroughBlock);
// Set up conditional branch for (p & T(1)).
builder.setInsertionPointToEnd(powerTmpIsOdd->getBlock());
- builder.create<cf::CondBranchOp>(powerTmpIsOdd, thenBlock, fallthroughBlock,
- resultTmp);
+ cf::CondBranchOp::create(builder, powerTmpIsOdd, thenBlock, fallthroughBlock,
+ resultTmp);
// Merged 'result'.
newResultTmp = fallthroughBlock->getArgument(0);
// p >>= T(1);
builder.setInsertionPointToEnd(fallthroughBlock);
- Value newPowerTmp = builder.create<arith::ShRUIOp>(powerTmp, oneValue);
+ Value newPowerTmp = arith::ShRUIOp::create(builder, powerTmp, oneValue);
// if (p == T(0))
- auto newPowerIsZero = builder.create<arith::CmpIOp>(arith::CmpIPredicate::eq,
- newPowerTmp, zeroValue);
+ auto newPowerIsZero = arith::CmpIOp::create(builder, arith::CmpIPredicate::eq,
+ newPowerTmp, zeroValue);
// return result;
thenBlock = builder.createBlock(funcBody);
- builder.create<func::ReturnOp>(newResultTmp);
+ func::ReturnOp::create(builder, newResultTmp);
fallthroughBlock = builder.createBlock(funcBody);
// Set up conditional branch for (p == T(0)).
builder.setInsertionPointToEnd(newPowerIsZero->getBlock());
- builder.create<cf::CondBranchOp>(newPowerIsZero, thenBlock, fallthroughBlock);
+ cf::CondBranchOp::create(builder, newPowerIsZero, thenBlock,
+ fallthroughBlock);
// b *= b;
// }
builder.setInsertionPointToEnd(fallthroughBlock);
- Value newBaseTmp = builder.create<arith::MulIOp>(baseTmp, baseTmp);
+ Value newBaseTmp = arith::MulIOp::create(builder, baseTmp, baseTmp);
// Pass new values for 'result', 'b' and 'p' to the loop header.
- builder.create<cf::BranchOp>(
- ValueRange{newResultTmp, newBaseTmp, newPowerTmp}, loopHeader);
+ cf::BranchOp::create(
+ builder, ValueRange{newResultTmp, newBaseTmp, newPowerTmp}, loopHeader);
return funcOp;
}
@@ -420,7 +422,7 @@ static func::FuncOp createElementFPowIFunc(ModuleOp *module,
llvm::raw_string_ostream nameOS(funcName);
nameOS << '_' << baseType;
nameOS << '_' << powType;
- auto funcOp = builder.create<func::FuncOp>(funcName, funcType);
+ auto funcOp = func::FuncOp::create(builder, funcName, funcType);
LLVM::linkage::Linkage inlineLinkage = LLVM::linkage::Linkage::LinkonceODR;
Attribute linkage =
LLVM::LinkageAttr::get(builder.getContext(), inlineLinkage);
@@ -433,46 +435,48 @@ static func::FuncOp createElementFPowIFunc(ModuleOp *module,
Value bArg = funcOp.getArgument(0);
Value pArg = funcOp.getArgument(1);
builder.setInsertionPointToEnd(entryBlock);
- Value oneBValue = builder.create<arith::ConstantOp>(
- baseType, builder.getFloatAttr(baseType, 1.0));
- Value zeroPValue = builder.create<arith::ConstantOp>(
- powType, builder.getIntegerAttr(powType, 0));
- Value onePValue = builder.create<arith::ConstantOp>(
- powType, builder.getIntegerAttr(powType, 1));
- Value minPValue = builder.create<arith::ConstantOp>(
- powType, builder.getIntegerAttr(powType, llvm::APInt::getSignedMinValue(
- powType.getWidth())));
- Value maxPValue = builder.create<arith::ConstantOp>(
- powType, builder.getIntegerAttr(powType, llvm::APInt::getSignedMaxValue(
- powType.getWidth())));
+ Value oneBValue = arith::ConstantOp::create(
+ builder, baseType, builder.getFloatAttr(baseType, 1.0));
+ Value zeroPValue = arith::ConstantOp::create(
+ builder, powType, builder.getIntegerAttr(powType, 0));
+ Value onePValue = arith::ConstantOp::create(
+ builder, powType, builder.getIntegerAttr(powType, 1));
+ Value minPValue = arith::ConstantOp::create(
+ builder, powType,
+ builder.getIntegerAttr(
+ powType, llvm::APInt::getSignedMinValue(powType.getWidth())));
+ Value maxPValue = arith::ConstantOp::create(
+ builder, powType,
+ builder.getIntegerAttr(
+ powType, llvm::APInt::getSignedMaxValue(powType.getWidth())));
// if (p == Tp{0})
// return Tb{1};
- auto pIsZero =
- builder.create<arith::CmpIOp>(arith::CmpIPredicate::eq, pArg, zeroPValue);
+ auto pIsZero = arith::CmpIOp::create(builder, arith::CmpIPredicate::eq, pArg,
+ zeroPValue);
Block *thenBlock = builder.createBlock(funcBody);
- builder.create<func::ReturnOp>(oneBValue);
+ func::ReturnOp::create(builder, oneBValue);
Block *fallthroughBlock = builder.createBlock(funcBody);
// Set up conditional branch for (p == Tp{0}).
builder.setInsertionPointToEnd(pIsZero->getBlock());
- builder.create<cf::CondBranchOp>(pIsZero, thenBlock, fallthroughBlock);
+ cf::CondBranchOp::create(builder, pIsZero, thenBlock, fallthroughBlock);
builder.setInsertionPointToEnd(fallthroughBlock);
// bool isNegativePower{p < Tp{0}}
- auto pIsNeg = builder.create<arith::CmpIOp>(arith::CmpIPredicate::sle, pArg,
- zeroPValue);
+ auto pIsNeg = arith::CmpIOp::create(builder, arith::CmpIPredicate::sle, pArg,
+ zeroPValue);
// bool isMin{p == std::numeric_limits<Tp>::min()};
auto pIsMin =
- builder.create<arith::CmpIOp>(arith::CmpIPredicate::eq, pArg, minPValue);
+ arith::CmpIOp::create(builder, arith::CmpIPredicate::eq, pArg, minPValue);
// if (isMin) {
// p = std::numeric_limits<Tp>::max();
// } else if (isNegativePower) {
// p = -p;
// }
- Value negP = builder.create<arith::SubIOp>(zeroPValue, pArg);
- auto pInit = builder.create<arith::SelectOp>(pIsNeg, negP, pArg);
- pInit = builder.create<arith::SelectOp>(pIsMin, maxPValue, pInit);
+ Value negP = arith::SubIOp::create(builder, zeroPValue, pArg);
+ auto pInit = arith::SelectOp::create(builder, pIsNeg, negP, pArg);
+ pInit = arith::SelectOp::create(builder, pIsMin, maxPValue, pInit);
// Tb result = Tb{1};
// Tb origBase = Tb{b};
@@ -489,7 +493,7 @@ static func::FuncOp createElementFPowIFunc(ModuleOp *module,
{builder.getLoc(), builder.getLoc(), builder.getLoc()});
// Set initial values of 'result', 'b' and 'p' for the loop.
builder.setInsertionPointToEnd(pInit->getBlock());
- builder.create<cf::BranchOp>(loopHeader, ValueRange{oneBValue, bArg, pInit});
+ cf::BranchOp::create(builder, loopHeader, ValueRange{oneBValue, bArg, pInit});
// Create loop body.
Value resultTmp = loopHeader->getArgument(0);
@@ -498,30 +502,30 @@ static func::FuncOp createElementFPowIFunc(ModuleOp *module,
builder.setInsertionPointToEnd(loopHeader);
// if (p & Tp{1})
- auto powerTmpIsOdd = builder.create<arith::CmpIOp>(
- arith::CmpIPredicate::ne,
- builder.create<arith::AndIOp>(powerTmp, onePValue), zeroPValue);
+ auto powerTmpIsOdd = arith::CmpIOp::create(
+ builder, arith::CmpIPredicate::ne,
+ arith::AndIOp::create(builder, powerTmp, onePValue), zeroPValue);
thenBlock = builder.createBlock(funcBody);
// result *= b;
- Value newResultTmp = builder.create<arith::MulFOp>(resultTmp, baseTmp);
+ Value newResultTmp = arith::MulFOp::create(builder, resultTmp, baseTmp);
fallthroughBlock = builder.createBlock(funcBody, funcBody->end(), baseType,
builder.getLoc());
builder.setInsertionPointToEnd(thenBlock);
- builder.create<cf::BranchOp>(newResultTmp, fallthroughBlock);
+ cf::BranchOp::create(builder, newResultTmp, fallthroughBlock);
// Set up conditional branch for (p & Tp{1}).
builder.setInsertionPointToEnd(powerTmpIsOdd->getBlock());
- builder.create<cf::CondBranchOp>(powerTmpIsOdd, thenBlock, fallthroughBlock,
- resultTmp);
+ cf::CondBranchOp::create(builder, powerTmpIsOdd, thenBlock, fallthroughBlock,
+ resultTmp);
// Merged 'result'.
newResultTmp = fallthroughBlock->getArgument(0);
// p >>= Tp{1};
builder.setInsertionPointToEnd(fallthroughBlock);
- Value newPowerTmp = builder.create<arith::ShRUIOp>(powerTmp, onePValue);
+ Value newPowerTmp = arith::ShRUIOp::create(builder, powerTmp, onePValue);
// if (p == Tp{0})
- auto newPowerIsZero = builder.create<arith::CmpIOp>(arith::CmpIPredicate::eq,
- newPowerTmp, zeroPValue);
+ auto newPowerIsZero = arith::CmpIOp::create(builder, arith::CmpIPredicate::eq,
+ newPowerTmp, zeroPValue);
// break;
//
// The conditional branch is finalized below with a jump to
@@ -531,10 +535,10 @@ static func::FuncOp createElementFPowIFunc(ModuleOp *module,
// b *= b;
// }
builder.setInsertionPointToEnd(fallthroughBlock);
- Value newBaseTmp = builder.create<arith::MulFOp>(baseTmp, baseTmp);
+ Value newBaseTmp = arith::MulFOp::create(builder, baseTmp, baseTmp);
// Pass new values for 'result', 'b' and 'p' to the loop header.
- builder.create<cf::BranchOp>(
- ValueRange{newResultTmp, newBaseTmp, newPowerTmp}, loopHeader);
+ cf::BranchOp::create(
+ builder, ValueRange{newResultTmp, newBaseTmp, newPowerTmp}, loopHeader);
// Set up conditional branch for early loop exit:
// if (p == Tp{0})
@@ -542,8 +546,8 @@ static func::FuncOp createElementFPowIFunc(ModuleOp *module,
Block *loopExit = builder.createBlock(funcBody, funcBody->end(), baseType,
builder.getLoc());
builder.setInsertionPointToEnd(newPowerIsZero->getBlock());
- builder.create<cf::CondBranchOp>(newPowerIsZero, loopExit, newResultTmp,
- fallthroughBlock, ValueRange{});
+ cf::CondBranchOp::create(builder, newPowerIsZero, loopExit, newResultTmp,
+ fallthroughBlock, ValueRange{});
// if (isMin) {
// result *= origBase;
@@ -553,11 +557,11 @@ static func::FuncOp createElementFPowIFunc(ModuleOp *module,
fallthroughBlock = builder.createBlock(funcBody, funcBody->end(), baseType,
builder.getLoc());
builder.setInsertionPointToEnd(loopExit);
- builder.create<cf::CondBranchOp>(pIsMin, thenBlock, fallthroughBlock,
- newResultTmp);
+ cf::CondBranchOp::create(builder, pIsMin, thenBlock, fallthroughBlock,
+ newResultTmp);
builder.setInsertionPointToEnd(thenBlock);
- newResultTmp = builder.create<arith::MulFOp>(newResultTmp, bArg);
- builder.create<cf::BranchOp>(newResultTmp, fallthroughBlock);
+ newResultTmp = arith::MulFOp::create(builder, newResultTmp, bArg);
+ cf::BranchOp::create(builder, newResultTmp, fallthroughBlock);
/// if (isNegativePower) {
/// result = Tb{1} / result;
@@ -567,15 +571,15 @@ static func::FuncOp createElementFPowIFunc(ModuleOp *module,
Block *returnBlock = builder.createBlock(funcBody, funcBody->end(), baseType,
builder.getLoc());
builder.setInsertionPointToEnd(fallthroughBlock);
- builder.create<cf::CondBranchOp>(pIsNeg, thenBlock, returnBlock,
- newResultTmp);
+ cf::CondBranchOp::create(builder, pIsNeg, thenBlock, returnBlock,
+ newResultTmp);
builder.setInsertionPointToEnd(thenBlock);
- newResultTmp = builder.create<arith::DivFOp>(oneBValue, newResultTmp);
- builder.create<cf::BranchOp>(newResultTmp, returnBlock);
+ newResultTmp = arith::DivFOp::create(builder, oneBValue, newResultTmp);
+ cf::BranchOp::create(builder, newResultTmp, returnBlock);
// return result;
builder.setInsertionPointToEnd(returnBlock);
- builder.create<func::ReturnOp>(returnBlock->getArgument(0));
+ func::ReturnOp::create(builder, returnBlock->getArgument(0));
return funcOp;
}
@@ -667,7 +671,7 @@ static func::FuncOp createCtlzFunc(ModuleOp *module, Type elementType) {
nameOS << '_' << elementType;
FunctionType funcType =
FunctionType::get(builder.getContext(), {elementType}, elementType);
- auto funcOp = builder.create<func::FuncOp>(funcName, funcType);
+ auto funcOp = func::FuncOp::create(builder, funcName, funcType);
// LinkonceODR ensures that there is only one implementation of this function
// across all math.ctlz functions that are lowered in this way.
@@ -683,33 +687,34 @@ static func::FuncOp createCtlzFunc(ModuleOp *module, Type elementType) {
Value arg = funcOp.getArgument(0);
Type indexType = builder.getIndexType();
- Value bitWidthValue = builder.create<arith::ConstantOp>(
- elementType, builder.getIntegerAttr(elementType, bitWidth));
- Value zeroValue = builder.create<arith::ConstantOp>(
- elementType, builder.getIntegerAttr(elementType, 0));
+ Value bitWidthValue = arith::ConstantOp::create(
+ builder, elementType, builder.getIntegerAttr(elementType, bitWidth));
+ Value zeroValue = arith::ConstantOp::create(
+ builder, elementType, builder.getIntegerAttr(elementType, 0));
Value inputEqZero =
- builder.create<arith::CmpIOp>(arith::CmpIPredicate::eq, arg, zeroValue);
+ arith::CmpIOp::create(builder, arith::CmpIPredicate::eq, arg, zeroValue);
// if input == 0, return bit width, else enter loop.
- scf::IfOp ifOp = builder.create<scf::IfOp>(
- elementType, inputEqZero, /*addThenBlock=*/true, /*addElseBlock=*/true);
+ scf::IfOp ifOp =
+ scf::IfOp::create(builder, elementType, inputEqZero,
+ /*addThenBlock=*/true, /*addElseBlock=*/true);
ifOp.getThenBodyBuilder().create<scf::YieldOp>(loc, bitWidthValue);
auto elseBuilder =
ImplicitLocOpBuilder::atBlockEnd(loc, &ifOp.getElseRegion().front());
- Value oneIndex = elseBuilder.create<arith::ConstantOp>(
- indexType, elseBuilder.getIndexAttr(1));
- Value oneValue = elseBuilder.create<arith::ConstantOp>(
- elementType, elseBuilder.getIntegerAttr(elementType, 1));
- Value bitWidthIndex = elseBuilder.create<arith::ConstantOp>(
- indexType, elseBuilder.getIndexAttr(bitWidth));
- Value nValue = elseBuilder.create<arith::ConstantOp>(
- elementType, elseBuilder.getIntegerAttr(elementType, 0));
-
- auto loop = elseBuilder.create<scf::ForOp>(
- oneIndex, bitWidthIndex, oneIndex,
+ Value oneIndex = arith::ConstantOp::create(elseBuilder, indexType,
+ elseBuilder.getIndexAttr(1));
+ Value oneValue = arith::ConstantOp::create(
+ elseBuilder, elementType, elseBuilder.getIntegerAttr(elementType, 1));
+ Value bitWidthIndex = arith::ConstantOp::create(
+ elseBuilder, indexType, elseBuilder.getIndexAttr(bitWidth));
+ Value nValue = arith::ConstantOp::create(
+ elseBuilder, elementType, elseBuilder.getIntegerAttr(elementType, 0));
+
+ auto loop = scf::ForOp::create(
+ elseBuilder, oneIndex, bitWidthIndex, oneIndex,
// Initial values for two loop induction variables, the arg which is being
// shifted left in each iteration, and the n value which tracks the count
// of leading zeros.
@@ -725,25 +730,25 @@ static func::FuncOp createCtlzFunc(ModuleOp *module, Type elementType) {
Value argIter = args[0];
Value nIter = args[1];
- Value argIsNonNegative = b.create<arith::CmpIOp>(
- loc, arith::CmpIPredicate::slt, argIter, zeroValue);
- scf::IfOp ifOp = b.create<scf::IfOp>(
- loc, argIsNonNegative,
+ Value argIsNonNegative = arith::CmpIOp::create(
+ b, loc, arith::CmpIPredicate::slt, argIter, zeroValue);
+ scf::IfOp ifOp = scf::IfOp::create(
+ b, loc, argIsNonNegative,
[&](OpBuilder &b, Location loc) {
// If arg is negative, continue (effectively, break)
- b.create<scf::YieldOp>(loc, ValueRange{argIter, nIter});
+ scf::YieldOp::create(b, loc, ValueRange{argIter, nIter});
},
[&](OpBuilder &b, Location loc) {
// Otherwise, increment n and shift arg left.
- Value nNext = b.create<arith::AddIOp>(loc, nIter, oneValue);
- Value argNext = b.create<arith::ShLIOp>(loc, argIter, oneValue);
- b.create<scf::YieldOp>(loc, ValueRange{argNext, nNext});
+ Value nNext = arith::AddIOp::create(b, loc, nIter, oneValue);
+ Value argNext = arith::ShLIOp::create(b, loc, argIter, oneValue);
+ scf::YieldOp::create(b, loc, ValueRange{argNext, nNext});
});
- b.create<scf::YieldOp>(loc, ifOp.getResults());
+ scf::YieldOp::create(b, loc, ifOp.getResults());
});
- elseBuilder.create<scf::YieldOp>(loop.getResult(1));
+ scf::YieldOp::create(elseBuilder, loop.getResult(1));
- builder.create<func::ReturnOp>(ifOp.getResult(0));
+ func::ReturnOp::create(builder, ifOp.getResult(0));
return funcOp;
}
diff --git a/mlir/lib/Conversion/MathToLLVM/MathToLLVM.cpp b/mlir/lib/Conversion/MathToLLVM/MathToLLVM.cpp
index f4d69ce8235bb..853f45498ac52 100644
--- a/mlir/lib/Conversion/MathToLLVM/MathToLLVM.cpp
+++ b/mlir/lib/Conversion/MathToLLVM/MathToLLVM.cpp
@@ -107,8 +107,8 @@ struct IntOpWithFlagLowering : public ConvertOpToLLVMPattern<MathOp> {
return LLVM::detail::handleMultidimensionalVectors(
op.getOperation(), adaptor.getOperands(), typeConverter,
[&](Type llvm1DVectorTy, ValueRange operands) {
- return rewriter.create<LLVMOp>(loc, llvm1DVectorTy, operands[0],
- false);
+ return LLVMOp::create(rewriter, loc, llvm1DVectorTy, operands[0],
+ false);
},
rewriter);
}
@@ -145,15 +145,16 @@ struct ExpM1OpLowering : public ConvertOpToLLVMPattern<math::ExpM1Op> {
if (!isa<LLVM::LLVMArrayType>(llvmOperandType)) {
LLVM::ConstantOp one;
if (LLVM::isCompatibleVectorType(llvmOperandType)) {
- one = rewriter.create<LLVM::ConstantOp>(
- loc, llvmOperandType,
+ one = LLVM::ConstantOp::create(
+ rewriter, loc, llvmOperandType,
SplatElementsAttr::get(cast<ShapedType>(llvmOperandType),
floatOne));
} else {
- one = rewriter.create<LLVM::ConstantOp>(loc, llvmOperandType, floatOne);
+ one =
+ LLVM::ConstantOp::create(rewriter, loc, llvmOperandType, floatOne);
}
- auto exp = rewriter.create<LLVM::ExpOp>(loc, adaptor.getOperand(),
- expAttrs.getAttrs());
+ auto exp = LLVM::ExpOp::create(rewriter, loc, adaptor.getOperand(),
+ expAttrs.getAttrs());
rewriter.replaceOpWithNewOp<LLVM::FSubOp>(
op, llvmOperandType, ValueRange{exp, one}, subAttrs.getAttrs());
return success();
@@ -170,12 +171,13 @@ struct ExpM1OpLowering : public ConvertOpToLLVMPattern<math::ExpM1Op> {
mlir::VectorType::get({numElements.getKnownMinValue()}, floatType,
{numElements.isScalable()}),
floatOne);
- auto one =
- rewriter.create<LLVM::ConstantOp>(loc, llvm1DVectorTy, splatAttr);
- auto exp = rewriter.create<LLVM::ExpOp>(
- loc, llvm1DVectorTy, operands[0], expAttrs.getAttrs());
- return rewriter.create<LLVM::FSubOp>(
- loc, llvm1DVectorTy, ValueRange{exp, one}, subAttrs.getAttrs());
+ auto one = LLVM::ConstantOp::create(rewriter, loc, llvm1DVectorTy,
+ splatAttr);
+ auto exp = LLVM::ExpOp::create(rewriter, loc, llvm1DVectorTy,
+ operands[0], expAttrs.getAttrs());
+ return LLVM::FSubOp::create(rewriter, loc, llvm1DVectorTy,
+ ValueRange{exp, one},
+ subAttrs.getAttrs());
},
rewriter);
}
@@ -205,16 +207,16 @@ struct Log1pOpLowering : public ConvertOpToLLVMPattern<math::Log1pOp> {
if (!isa<LLVM::LLVMArrayType>(llvmOperandType)) {
LLVM::ConstantOp one =
isa<VectorType>(llvmOperandType)
- ? rewriter.create<LLVM::ConstantOp>(
- loc, llvmOperandType,
+ ? LLVM::ConstantOp::create(
+ rewriter, loc, llvmOperandType,
SplatElementsAttr::get(cast<ShapedType>(llvmOperandType),
floatOne))
- : rewriter.create<LLVM::ConstantOp>(loc, llvmOperandType,
- floatOne);
+ : LLVM::ConstantOp::create(rewriter, loc, llvmOperandType,
+ floatOne);
- auto add = rewriter.create<LLVM::FAddOp>(
- loc, llvmOperandType, ValueRange{one, adaptor.getOperand()},
- addAttrs.getAttrs());
+ auto add = LLVM::FAddOp::create(rewriter, loc, llvmOperandType,
+ ValueRange{one, adaptor.getOperand()},
+ addAttrs.getAttrs());
rewriter.replaceOpWithNewOp<LLVM::LogOp>(
op, llvmOperandType, ValueRange{add}, logAttrs.getAttrs());
return success();
@@ -231,13 +233,13 @@ struct Log1pOpLowering : public ConvertOpToLLVMPattern<math::Log1pOp> {
mlir::VectorType::get({numElements.getKnownMinValue()}, floatType,
{numElements.isScalable()}),
floatOne);
- auto one =
- rewriter.create<LLVM::ConstantOp>(loc, llvm1DVectorTy, splatAttr);
- auto add = rewriter.create<LLVM::FAddOp>(loc, llvm1DVectorTy,
- ValueRange{one, operands[0]},
- addAttrs.getAttrs());
- return rewriter.create<LLVM::LogOp>(
- loc, llvm1DVectorTy, ValueRange{add}, logAttrs.getAttrs());
+ auto one = LLVM::ConstantOp::create(rewriter, loc, llvm1DVectorTy,
+ splatAttr);
+ auto add = LLVM::FAddOp::create(rewriter, loc, llvm1DVectorTy,
+ ValueRange{one, operands[0]},
+ addAttrs.getAttrs());
+ return LLVM::LogOp::create(rewriter, loc, llvm1DVectorTy,
+ ValueRange{add}, logAttrs.getAttrs());
},
rewriter);
}
@@ -267,15 +269,16 @@ struct RsqrtOpLowering : public ConvertOpToLLVMPattern<math::RsqrtOp> {
if (!isa<LLVM::LLVMArrayType>(llvmOperandType)) {
LLVM::ConstantOp one;
if (isa<VectorType>(llvmOperandType)) {
- one = rewriter.create<LLVM::ConstantOp>(
- loc, llvmOperandType,
+ one = LLVM::ConstantOp::create(
+ rewriter, loc, llvmOperandType,
SplatElementsAttr::get(cast<ShapedType>(llvmOperandType),
floatOne));
} else {
- one = rewriter.create<LLVM::ConstantOp>(loc, llvmOperandType, floatOne);
+ one =
+ LLVM::ConstantOp::create(rewriter, loc, llvmOperandType, floatOne);
}
- auto sqrt = rewriter.create<LLVM::SqrtOp>(loc, adaptor.getOperand(),
- sqrtAttrs.getAttrs());
+ auto sqrt = LLVM::SqrtOp::create(rewriter, loc, adaptor.getOperand(),
+ sqrtAttrs.getAttrs());
rewriter.replaceOpWithNewOp<LLVM::FDivOp>(
op, llvmOperandType, ValueRange{one, sqrt}, divAttrs.getAttrs());
return success();
@@ -292,12 +295,13 @@ struct RsqrtOpLowering : public ConvertOpToLLVMPattern<math::RsqrtOp> {
mlir::VectorType::get({numElements.getKnownMinValue()}, floatType,
{numElements.isScalable()}),
floatOne);
- auto one =
- rewriter.create<LLVM::ConstantOp>(loc, llvm1DVectorTy, splatAttr);
- auto sqrt = rewriter.create<LLVM::SqrtOp>(
- loc, llvm1DVectorTy, operands[0], sqrtAttrs.getAttrs());
- return rewriter.create<LLVM::FDivOp>(
- loc, llvm1DVectorTy, ValueRange{one, sqrt}, divAttrs.getAttrs());
+ auto one = LLVM::ConstantOp::create(rewriter, loc, llvm1DVectorTy,
+ splatAttr);
+ auto sqrt = LLVM::SqrtOp::create(rewriter, loc, llvm1DVectorTy,
+ operands[0], sqrtAttrs.getAttrs());
+ return LLVM::FDivOp::create(rewriter, loc, llvm1DVectorTy,
+ ValueRange{one, sqrt},
+ divAttrs.getAttrs());
},
rewriter);
}
diff --git a/mlir/lib/Conversion/MathToLibm/MathToLibm.cpp b/mlir/lib/Conversion/MathToLibm/MathToLibm.cpp
index a0ce7d3b75fc2..f7c0d4fe3a799 100644
--- a/mlir/lib/Conversion/MathToLibm/MathToLibm.cpp
+++ b/mlir/lib/Conversion/MathToLibm/MathToLibm.cpp
@@ -84,20 +84,21 @@ VecOpToScalarOp<Op>::matchAndRewrite(Op op, PatternRewriter &rewriter) const {
auto shape = vecType.getShape();
int64_t numElements = vecType.getNumElements();
- Value result = rewriter.create<arith::ConstantOp>(
- loc, DenseElementsAttr::get(
- vecType, FloatAttr::get(vecType.getElementType(), 0.0)));
+ Value result = arith::ConstantOp::create(
+ rewriter, loc,
+ DenseElementsAttr::get(vecType,
+ FloatAttr::get(vecType.getElementType(), 0.0)));
SmallVector<int64_t> strides = computeStrides(shape);
for (auto linearIndex = 0; linearIndex < numElements; ++linearIndex) {
SmallVector<int64_t> positions = delinearize(linearIndex, strides);
SmallVector<Value> operands;
for (auto input : op->getOperands())
operands.push_back(
- rewriter.create<vector::ExtractOp>(loc, input, positions));
+ vector::ExtractOp::create(rewriter, loc, input, positions));
Value scalarOp =
- rewriter.create<Op>(loc, vecType.getElementType(), operands);
+ Op::create(rewriter, loc, vecType.getElementType(), operands);
result =
- rewriter.create<vector::InsertOp>(loc, scalarOp, result, positions);
+ vector::InsertOp::create(rewriter, loc, scalarOp, result, positions);
}
rewriter.replaceOp(op, {result});
return success();
@@ -114,9 +115,9 @@ PromoteOpToF32<Op>::matchAndRewrite(Op op, PatternRewriter &rewriter) const {
auto f32 = rewriter.getF32Type();
auto extendedOperands = llvm::to_vector(
llvm::map_range(op->getOperands(), [&](Value operand) -> Value {
- return rewriter.create<arith::ExtFOp>(loc, f32, operand);
+ return arith::ExtFOp::create(rewriter, loc, f32, operand);
}));
- auto newOp = rewriter.create<Op>(loc, f32, extendedOperands);
+ auto newOp = Op::create(rewriter, loc, f32, extendedOperands);
rewriter.replaceOpWithNewOp<arith::TruncFOp>(op, opType, newOp);
return success();
}
@@ -139,8 +140,8 @@ ScalarOpToLibmCall<Op>::matchAndRewrite(Op op,
rewriter.setInsertionPointToStart(&module->getRegion(0).front());
auto opFunctionTy = FunctionType::get(
rewriter.getContext(), op->getOperandTypes(), op->getResultTypes());
- opFunc = rewriter.create<func::FuncOp>(rewriter.getUnknownLoc(), name,
- opFunctionTy);
+ opFunc = func::FuncOp::create(rewriter, rewriter.getUnknownLoc(), name,
+ opFunctionTy);
opFunc.setPrivate();
// By definition Math dialect operations imply LLVM's "readnone"
diff --git a/mlir/lib/Conversion/MathToSPIRV/MathToSPIRV.cpp b/mlir/lib/Conversion/MathToSPIRV/MathToSPIRV.cpp
index 59db14ed816be..a877ad21734a2 100644
--- a/mlir/lib/Conversion/MathToSPIRV/MathToSPIRV.cpp
+++ b/mlir/lib/Conversion/MathToSPIRV/MathToSPIRV.cpp
@@ -36,12 +36,12 @@ static Value getScalarOrVectorI32Constant(Type type, int value,
if (!vectorType.getElementType().isInteger(32))
return nullptr;
SmallVector<int> values(vectorType.getNumElements(), value);
- return builder.create<spirv::ConstantOp>(loc, type,
- builder.getI32VectorAttr(values));
+ return spirv::ConstantOp::create(builder, loc, type,
+ builder.getI32VectorAttr(values));
}
if (type.isInteger(32))
- return builder.create<spirv::ConstantOp>(loc, type,
- builder.getI32IntegerAttr(value));
+ return spirv::ConstantOp::create(builder, loc, type,
+ builder.getI32IntegerAttr(value));
return nullptr;
}
@@ -144,10 +144,11 @@ struct CopySignPattern final : public OpConversionPattern<math::CopySignOp> {
Type intType = rewriter.getIntegerType(bitwidth);
uint64_t intValue = uint64_t(1) << (bitwidth - 1);
- Value signMask = rewriter.create<spirv::ConstantOp>(
- loc, intType, rewriter.getIntegerAttr(intType, intValue));
- Value valueMask = rewriter.create<spirv::ConstantOp>(
- loc, intType, rewriter.getIntegerAttr(intType, intValue - 1u));
+ Value signMask = spirv::ConstantOp::create(
+ rewriter, loc, intType, rewriter.getIntegerAttr(intType, intValue));
+ Value valueMask = spirv::ConstantOp::create(
+ rewriter, loc, intType,
+ rewriter.getIntegerAttr(intType, intValue - 1u));
if (auto vectorType = dyn_cast<VectorType>(type)) {
assert(vectorType.getRank() == 1);
@@ -155,26 +156,26 @@ struct CopySignPattern final : public OpConversionPattern<math::CopySignOp> {
intType = VectorType::get(count, intType);
SmallVector<Value> signSplat(count, signMask);
- signMask =
- rewriter.create<spirv::CompositeConstructOp>(loc, intType, signSplat);
+ signMask = spirv::CompositeConstructOp::create(rewriter, loc, intType,
+ signSplat);
SmallVector<Value> valueSplat(count, valueMask);
- valueMask = rewriter.create<spirv::CompositeConstructOp>(loc, intType,
- valueSplat);
+ valueMask = spirv::CompositeConstructOp::create(rewriter, loc, intType,
+ valueSplat);
}
Value lhsCast =
- rewriter.create<spirv::BitcastOp>(loc, intType, adaptor.getLhs());
+ spirv::BitcastOp::create(rewriter, loc, intType, adaptor.getLhs());
Value rhsCast =
- rewriter.create<spirv::BitcastOp>(loc, intType, adaptor.getRhs());
+ spirv::BitcastOp::create(rewriter, loc, intType, adaptor.getRhs());
- Value value = rewriter.create<spirv::BitwiseAndOp>(
- loc, intType, ValueRange{lhsCast, valueMask});
- Value sign = rewriter.create<spirv::BitwiseAndOp>(
- loc, intType, ValueRange{rhsCast, signMask});
+ Value value = spirv::BitwiseAndOp::create(rewriter, loc, intType,
+ ValueRange{lhsCast, valueMask});
+ Value sign = spirv::BitwiseAndOp::create(rewriter, loc, intType,
+ ValueRange{rhsCast, signMask});
- Value result = rewriter.create<spirv::BitwiseOrOp>(loc, intType,
- ValueRange{value, sign});
+ Value result = spirv::BitwiseOrOp::create(rewriter, loc, intType,
+ ValueRange{value, sign});
rewriter.replaceOpWithNewOp<spirv::BitcastOp>(copySignOp, type, result);
return success();
}
@@ -214,18 +215,18 @@ struct CountLeadingZerosPattern final
Value val31 = getScalarOrVectorI32Constant(type, 31, rewriter, loc);
Value val32 = getScalarOrVectorI32Constant(type, 32, rewriter, loc);
- Value msb = rewriter.create<spirv::GLFindUMsbOp>(loc, input);
+ Value msb = spirv::GLFindUMsbOp::create(rewriter, loc, input);
// We need to subtract from 31 given that the index returned by GLSL
// FindUMsb is counted from the least significant bit. Theoretically this
// also gives the correct result even if the integer has all zero bits, in
// which case GL FindUMsb would return -1.
- Value subMsb = rewriter.create<spirv::ISubOp>(loc, val31, msb);
+ Value subMsb = spirv::ISubOp::create(rewriter, loc, val31, msb);
// However, certain Vulkan implementations have driver bugs for the corner
// case where the input is zero. And.. it can be smart to optimize a select
// only involving the corner case. So separately compute the result when the
// input is either zero or one.
- Value subInput = rewriter.create<spirv::ISubOp>(loc, val32, input);
- Value cmp = rewriter.create<spirv::ULessThanEqualOp>(loc, input, val1);
+ Value subInput = spirv::ISubOp::create(rewriter, loc, val32, input);
+ Value cmp = spirv::ULessThanEqualOp::create(rewriter, loc, input, val1);
rewriter.replaceOpWithNewOp<spirv::SelectOp>(countOp, cmp, subInput,
subMsb);
return success();
@@ -253,7 +254,7 @@ struct ExpM1OpPattern final : public OpConversionPattern<math::ExpM1Op> {
if (!type)
return failure();
- Value exp = rewriter.create<ExpOp>(loc, type, adaptor.getOperand());
+ Value exp = ExpOp::create(rewriter, loc, type, adaptor.getOperand());
auto one = spirv::ConstantOp::getOne(type, loc, rewriter);
rewriter.replaceOpWithNewOp<spirv::FSubOp>(operation, exp, one);
return success();
@@ -283,7 +284,7 @@ struct Log1pOpPattern final : public OpConversionPattern<math::Log1pOp> {
auto one = spirv::ConstantOp::getOne(type, operation.getLoc(), rewriter);
Value onePlus =
- rewriter.create<spirv::FAddOp>(loc, one, adaptor.getOperand());
+ spirv::FAddOp::create(rewriter, loc, one, adaptor.getOperand());
rewriter.replaceOpWithNewOp<LogOp>(operation, type, onePlus);
return success();
}
@@ -321,15 +322,15 @@ struct Log2Log10OpPattern final : public OpConversionPattern<MathLogOp> {
auto getConstantValue = [&](double value) {
if (auto floatType = dyn_cast<FloatType>(type)) {
- return rewriter.create<spirv::ConstantOp>(
- loc, type, rewriter.getFloatAttr(floatType, value));
+ return spirv::ConstantOp::create(
+ rewriter, loc, type, rewriter.getFloatAttr(floatType, value));
}
if (auto vectorType = dyn_cast<VectorType>(type)) {
Type elemType = vectorType.getElementType();
if (isa<FloatType>(elemType)) {
- return rewriter.create<spirv::ConstantOp>(
- loc, type,
+ return spirv::ConstantOp::create(
+ rewriter, loc, type,
DenseFPElementsAttr::get(
vectorType, FloatAttr::get(elemType, value).getValue()));
}
@@ -341,7 +342,7 @@ struct Log2Log10OpPattern final : public OpConversionPattern<MathLogOp> {
Value constantValue = getConstantValue(
std::is_same<MathLogOp, math::Log2Op>() ? log2Reciprocal
: log10Reciprocal);
- Value log = rewriter.create<SpirvLogOp>(loc, adaptor.getOperand());
+ Value log = SpirvLogOp::create(rewriter, loc, adaptor.getOperand());
rewriter.replaceOpWithNewOp<spirv::FMulOp>(operation, type, log,
constantValue);
return success();
@@ -386,7 +387,7 @@ struct PowFOpPattern final : public OpConversionPattern<math::PowFOp> {
Location loc = powfOp.getLoc();
Value zero = spirv::ConstantOp::getZero(operandType, loc, rewriter);
Value lessThan =
- rewriter.create<spirv::FOrdLessThanOp>(loc, adaptor.getLhs(), zero);
+ spirv::FOrdLessThanOp::create(rewriter, loc, adaptor.getLhs(), zero);
// Per C/C++ spec:
// > pow(base, exponent) returns NaN (and raises FE_INVALID) if base is
@@ -394,11 +395,11 @@ struct PowFOpPattern final : public OpConversionPattern<math::PowFOp> {
// Calculate the reminder from the exponent and check whether it is zero.
Value floatOne = spirv::ConstantOp::getOne(operandType, loc, rewriter);
Value expRem =
- rewriter.create<spirv::FRemOp>(loc, adaptor.getRhs(), floatOne);
+ spirv::FRemOp::create(rewriter, loc, adaptor.getRhs(), floatOne);
Value expRemNonZero =
- rewriter.create<spirv::FOrdNotEqualOp>(loc, expRem, zero);
+ spirv::FOrdNotEqualOp::create(rewriter, loc, expRem, zero);
Value cmpNegativeWithFractionalExp =
- rewriter.create<spirv::LogicalAndOp>(loc, expRemNonZero, lessThan);
+ spirv::LogicalAndOp::create(rewriter, loc, expRemNonZero, lessThan);
// Create NaN result and replace base value if conditions are met.
const auto &floatSemantics = scalarFloatType.getFloatSemantics();
const auto nan = APFloat::getNaN(floatSemantics);
@@ -407,10 +408,11 @@ struct PowFOpPattern final : public OpConversionPattern<math::PowFOp> {
nanAttr = DenseElementsAttr::get(vectorType, nan);
Value NanValue =
- rewriter.create<spirv::ConstantOp>(loc, operandType, nanAttr);
- Value lhs = rewriter.create<spirv::SelectOp>(
- loc, cmpNegativeWithFractionalExp, NanValue, adaptor.getLhs());
- Value abs = rewriter.create<spirv::GLFAbsOp>(loc, lhs);
+ spirv::ConstantOp::create(rewriter, loc, operandType, nanAttr);
+ Value lhs =
+ spirv::SelectOp::create(rewriter, loc, cmpNegativeWithFractionalExp,
+ NanValue, adaptor.getLhs());
+ Value abs = spirv::GLFAbsOp::create(rewriter, loc, lhs);
// TODO: The following just forcefully casts y into an integer value in
// order to properly propagate the sign, assuming integer y cases. It
@@ -418,18 +420,18 @@ struct PowFOpPattern final : public OpConversionPattern<math::PowFOp> {
// Cast exponent to integer and calculate exponent % 2 != 0.
Value intRhs =
- rewriter.create<spirv::ConvertFToSOp>(loc, intType, adaptor.getRhs());
+ spirv::ConvertFToSOp::create(rewriter, loc, intType, adaptor.getRhs());
Value intOne = spirv::ConstantOp::getOne(intType, loc, rewriter);
Value bitwiseAndOne =
- rewriter.create<spirv::BitwiseAndOp>(loc, intRhs, intOne);
- Value isOdd = rewriter.create<spirv::IEqualOp>(loc, bitwiseAndOne, intOne);
+ spirv::BitwiseAndOp::create(rewriter, loc, intRhs, intOne);
+ Value isOdd = spirv::IEqualOp::create(rewriter, loc, bitwiseAndOne, intOne);
// calculate pow based on abs(lhs)^rhs.
- Value pow = rewriter.create<spirv::GLPowOp>(loc, abs, adaptor.getRhs());
- Value negate = rewriter.create<spirv::FNegateOp>(loc, pow);
+ Value pow = spirv::GLPowOp::create(rewriter, loc, abs, adaptor.getRhs());
+ Value negate = spirv::FNegateOp::create(rewriter, loc, pow);
// if the exponent is odd and lhs < 0, negate the result.
Value shouldNegate =
- rewriter.create<spirv::LogicalAndOp>(loc, lessThan, isOdd);
+ spirv::LogicalAndOp::create(rewriter, loc, lessThan, isOdd);
rewriter.replaceOpWithNewOp<spirv::SelectOp>(powfOp, shouldNegate, negate,
pow);
return success();
@@ -455,22 +457,22 @@ struct RoundOpPattern final : public OpConversionPattern<math::RoundOp> {
auto one = spirv::ConstantOp::getOne(ty, loc, rewriter);
Value half;
if (VectorType vty = dyn_cast<VectorType>(ty)) {
- half = rewriter.create<spirv::ConstantOp>(
- loc, vty,
+ half = spirv::ConstantOp::create(
+ rewriter, loc, vty,
DenseElementsAttr::get(vty,
rewriter.getFloatAttr(ety, 0.5).getValue()));
} else {
- half = rewriter.create<spirv::ConstantOp>(
- loc, ty, rewriter.getFloatAttr(ety, 0.5));
+ half = spirv::ConstantOp::create(rewriter, loc, ty,
+ rewriter.getFloatAttr(ety, 0.5));
}
- auto abs = rewriter.create<spirv::GLFAbsOp>(loc, operand);
- auto floor = rewriter.create<spirv::GLFloorOp>(loc, abs);
- auto sub = rewriter.create<spirv::FSubOp>(loc, abs, floor);
+ auto abs = spirv::GLFAbsOp::create(rewriter, loc, operand);
+ auto floor = spirv::GLFloorOp::create(rewriter, loc, abs);
+ auto sub = spirv::FSubOp::create(rewriter, loc, abs, floor);
auto greater =
- rewriter.create<spirv::FOrdGreaterThanEqualOp>(loc, sub, half);
- auto select = rewriter.create<spirv::SelectOp>(loc, greater, one, zero);
- auto add = rewriter.create<spirv::FAddOp>(loc, floor, select);
+ spirv::FOrdGreaterThanEqualOp::create(rewriter, loc, sub, half);
+ auto select = spirv::SelectOp::create(rewriter, loc, greater, one, zero);
+ auto add = spirv::FAddOp::create(rewriter, loc, floor, select);
rewriter.replaceOpWithNewOp<math::CopySignOp>(roundOp, add, operand);
return success();
}
diff --git a/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitC.cpp b/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitC.cpp
index 0b7ffa40ec09d..e882845d9d99a 100644
--- a/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitC.cpp
+++ b/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitC.cpp
@@ -160,8 +160,8 @@ struct ConvertGetGlobal final
if (opTy.getRank() == 0) {
emitc::LValueType lvalueType = emitc::LValueType::get(resultTy);
- emitc::GetGlobalOp globalLValue = rewriter.create<emitc::GetGlobalOp>(
- op.getLoc(), lvalueType, operands.getNameAttr());
+ emitc::GetGlobalOp globalLValue = emitc::GetGlobalOp::create(
+ rewriter, op.getLoc(), lvalueType, operands.getNameAttr());
emitc::PointerType pointerType = emitc::PointerType::get(resultTy);
rewriter.replaceOpWithNewOp<emitc::ApplyOp>(
op, pointerType, rewriter.getStringAttr("&"), globalLValue);
@@ -191,8 +191,8 @@ struct ConvertLoad final : public OpConversionPattern<memref::LoadOp> {
return rewriter.notifyMatchFailure(op.getLoc(), "expected array type");
}
- auto subscript = rewriter.create<emitc::SubscriptOp>(
- op.getLoc(), arrayValue, operands.getIndices());
+ auto subscript = emitc::SubscriptOp::create(
+ rewriter, op.getLoc(), arrayValue, operands.getIndices());
rewriter.replaceOpWithNewOp<emitc::LoadOp>(op, resultTy, subscript);
return success();
@@ -211,8 +211,8 @@ struct ConvertStore final : public OpConversionPattern<memref::StoreOp> {
return rewriter.notifyMatchFailure(op.getLoc(), "expected array type");
}
- auto subscript = rewriter.create<emitc::SubscriptOp>(
- op.getLoc(), arrayValue, operands.getIndices());
+ auto subscript = emitc::SubscriptOp::create(
+ rewriter, op.getLoc(), arrayValue, operands.getIndices());
rewriter.replaceOpWithNewOp<emitc::AssignOp>(op, subscript,
operands.getValue());
return success();
@@ -242,7 +242,7 @@ void mlir::populateMemRefToEmitCTypeConversion(TypeConverter &typeConverter) {
if (inputs.size() != 1)
return Value();
- return builder.create<UnrealizedConversionCastOp>(loc, resultType, inputs)
+ return UnrealizedConversionCastOp::create(builder, loc, resultType, inputs)
.getResult(0);
};
diff --git a/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp b/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp
index 83681b2d5fd87..53a19129103a3 100644
--- a/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp
+++ b/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp
@@ -87,12 +87,12 @@ getAlignedAllocFn(OpBuilder &b, const LLVMTypeConverter *typeConverter,
/// aligned = bumped - bumped % alignment
static Value createAligned(ConversionPatternRewriter &rewriter, Location loc,
Value input, Value alignment) {
- Value one = rewriter.create<LLVM::ConstantOp>(loc, alignment.getType(),
- rewriter.getIndexAttr(1));
- Value bump = rewriter.create<LLVM::SubOp>(loc, alignment, one);
- Value bumped = rewriter.create<LLVM::AddOp>(loc, input, bump);
- Value mod = rewriter.create<LLVM::URemOp>(loc, bumped, alignment);
- return rewriter.create<LLVM::SubOp>(loc, bumped, mod);
+ Value one = LLVM::ConstantOp::create(rewriter, loc, alignment.getType(),
+ rewriter.getIndexAttr(1));
+ Value bump = LLVM::SubOp::create(rewriter, loc, alignment, one);
+ Value bumped = LLVM::AddOp::create(rewriter, loc, input, bump);
+ Value mod = LLVM::URemOp::create(rewriter, loc, bumped, alignment);
+ return LLVM::SubOp::create(rewriter, loc, bumped, mod);
}
/// Computes the byte size for the MemRef element type.
@@ -123,8 +123,9 @@ static Value castAllocFuncResult(ConversionPatternRewriter &rewriter,
assert(succeeded(maybeMemrefAddrSpace) && "unsupported address space");
unsigned memrefAddrSpace = *maybeMemrefAddrSpace;
if (allocatedPtrTy.getAddressSpace() != memrefAddrSpace)
- allocatedPtr = rewriter.create<LLVM::AddrSpaceCastOp>(
- loc, LLVM::LLVMPointerType::get(rewriter.getContext(), memrefAddrSpace),
+ allocatedPtr = LLVM::AddrSpaceCastOp::create(
+ rewriter, loc,
+ LLVM::LLVMPointerType::get(rewriter.getContext(), memrefAddrSpace),
allocatedPtr);
return allocatedPtr;
}
@@ -168,14 +169,14 @@ class AllocOpLowering : public ConvertOpToLLVMPattern<memref::AllocOp> {
Value alignment = getAlignment(rewriter, loc, op);
if (alignment) {
// Adjust the allocation size to consider alignment.
- sizeBytes = rewriter.create<LLVM::AddOp>(loc, sizeBytes, alignment);
+ sizeBytes = LLVM::AddOp::create(rewriter, loc, sizeBytes, alignment);
}
// Allocate the underlying buffer.
Type elementPtrType = this->getElementPtrType(memRefType);
assert(elementPtrType && "could not compute element ptr type");
auto results =
- rewriter.create<LLVM::CallOp>(loc, allocFuncOp.value(), sizeBytes);
+ LLVM::CallOp::create(rewriter, loc, allocFuncOp.value(), sizeBytes);
Value allocatedPtr =
castAllocFuncResult(rewriter, loc, results.getResult(), memRefType,
@@ -184,11 +185,11 @@ class AllocOpLowering : public ConvertOpToLLVMPattern<memref::AllocOp> {
if (alignment) {
// Compute the aligned pointer.
Value allocatedInt =
- rewriter.create<LLVM::PtrToIntOp>(loc, getIndexType(), allocatedPtr);
+ LLVM::PtrToIntOp::create(rewriter, loc, getIndexType(), allocatedPtr);
Value alignmentInt =
createAligned(rewriter, loc, allocatedInt, alignment);
alignedPtr =
- rewriter.create<LLVM::IntToPtrOp>(loc, elementPtrType, alignmentInt);
+ LLVM::IntToPtrOp::create(rewriter, loc, elementPtrType, alignmentInt);
}
// Create the MemRef descriptor.
@@ -268,8 +269,9 @@ class AlignedAllocOpLowering : public ConvertOpToLLVMPattern<memref::AllocOp> {
sizeBytes = createAligned(rewriter, loc, sizeBytes, allocAlignment);
Type elementPtrType = this->getElementPtrType(memRefType);
- auto results = rewriter.create<LLVM::CallOp>(
- loc, allocFuncOp.value(), ValueRange({allocAlignment, sizeBytes}));
+ auto results =
+ LLVM::CallOp::create(rewriter, loc, allocFuncOp.value(),
+ ValueRange({allocAlignment, sizeBytes}));
Value ptr =
castAllocFuncResult(rewriter, loc, results.getResult(), memRefType,
@@ -360,8 +362,9 @@ struct AllocaOpLowering : public ConvertOpToLLVMPattern<memref::AllocaOp> {
auto elementPtrType =
LLVM::LLVMPointerType::get(rewriter.getContext(), addrSpace);
- auto allocatedElementPtr = rewriter.create<LLVM::AllocaOp>(
- loc, elementPtrType, elementType, size, op.getAlignment().value_or(0));
+ auto allocatedElementPtr =
+ LLVM::AllocaOp::create(rewriter, loc, elementPtrType, elementType, size,
+ op.getAlignment().value_or(0));
// Create the MemRef descriptor.
auto memRefDescriptor = this->createMemRefDescriptor(
@@ -397,7 +400,7 @@ struct AllocaScopeOpLowering
remainingOpsBlock, allocaScopeOp.getResultTypes(),
SmallVector<Location>(allocaScopeOp->getNumResults(),
allocaScopeOp.getLoc()));
- rewriter.create<LLVM::BrOp>(loc, ValueRange(), remainingOpsBlock);
+ LLVM::BrOp::create(rewriter, loc, ValueRange(), remainingOpsBlock);
}
// Inline body region.
@@ -407,8 +410,8 @@ struct AllocaScopeOpLowering
// Save stack and then branch into the body of the region.
rewriter.setInsertionPointToEnd(currentBlock);
- auto stackSaveOp = rewriter.create<LLVM::StackSaveOp>(loc, getPtrType());
- rewriter.create<LLVM::BrOp>(loc, ValueRange(), beforeBody);
+ auto stackSaveOp = LLVM::StackSaveOp::create(rewriter, loc, getPtrType());
+ LLVM::BrOp::create(rewriter, loc, ValueRange(), beforeBody);
// Replace the alloca_scope return with a branch that jumps out of the body.
// Stack restore before leaving the body region.
@@ -420,7 +423,7 @@ struct AllocaScopeOpLowering
// Insert stack restore before jumping out the body of the region.
rewriter.setInsertionPoint(branchOp);
- rewriter.create<LLVM::StackRestoreOp>(loc, stackSaveOp);
+ LLVM::StackRestoreOp::create(rewriter, loc, stackSaveOp);
// Replace the op with values return from the body region.
rewriter.replaceOp(allocaScopeOp, continueBlock->getArguments());
@@ -451,11 +454,11 @@ struct AssumeAlignmentOpLowering
// This is more direct than ptrtoint-based checks, is explicitly supported,
// and works with non-integral address spaces.
Value trueCond =
- rewriter.create<LLVM::ConstantOp>(loc, rewriter.getBoolAttr(true));
+ LLVM::ConstantOp::create(rewriter, loc, rewriter.getBoolAttr(true));
Value alignmentConst =
createIndexAttrConstant(rewriter, loc, getIndexType(), alignment);
- rewriter.create<LLVM::AssumeOp>(loc, trueCond, LLVM::AssumeAlignTag(), ptr,
- alignmentConst);
+ LLVM::AssumeOp::create(rewriter, loc, trueCond, LLVM::AssumeAlignTag(), ptr,
+ alignmentConst);
rewriter.replaceOp(op, memref);
return success();
}
@@ -559,18 +562,19 @@ struct DimOpLowering : public ConvertOpToLLVMPattern<memref::DimOp> {
// Get pointer to offset field of memref<element_type> descriptor.
auto indexPtrTy =
LLVM::LLVMPointerType::get(rewriter.getContext(), addressSpace);
- Value offsetPtr = rewriter.create<LLVM::GEPOp>(
- loc, indexPtrTy, elementType, underlyingRankedDesc,
- ArrayRef<LLVM::GEPArg>{0, 2});
+ Value offsetPtr =
+ LLVM::GEPOp::create(rewriter, loc, indexPtrTy, elementType,
+ underlyingRankedDesc, ArrayRef<LLVM::GEPArg>{0, 2});
// The size value that we have to extract can be obtained using GEPop with
// `dimOp.index() + 1` index argument.
- Value idxPlusOne = rewriter.create<LLVM::AddOp>(
- loc, createIndexAttrConstant(rewriter, loc, getIndexType(), 1),
+ Value idxPlusOne = LLVM::AddOp::create(
+ rewriter, loc,
+ createIndexAttrConstant(rewriter, loc, getIndexType(), 1),
adaptor.getIndex());
- Value sizePtr = rewriter.create<LLVM::GEPOp>(
- loc, indexPtrTy, getTypeConverter()->getIndexType(), offsetPtr,
- idxPlusOne);
+ Value sizePtr = LLVM::GEPOp::create(rewriter, loc, indexPtrTy,
+ getTypeConverter()->getIndexType(),
+ offsetPtr, idxPlusOne);
return rewriter
.create<LLVM::LoadOp>(loc, getTypeConverter()->getIndexType(), sizePtr)
.getResult();
@@ -674,9 +678,10 @@ struct GenericAtomicRMWOpLowering
auto memRefType = cast<MemRefType>(atomicOp.getMemref().getType());
auto dataPtr = getStridedElementPtr(
rewriter, loc, memRefType, adaptor.getMemref(), adaptor.getIndices());
- Value init = rewriter.create<LLVM::LoadOp>(
- loc, typeConverter->convertType(memRefType.getElementType()), dataPtr);
- rewriter.create<LLVM::BrOp>(loc, init, loopBlock);
+ Value init = LLVM::LoadOp::create(
+ rewriter, loc, typeConverter->convertType(memRefType.getElementType()),
+ dataPtr);
+ LLVM::BrOp::create(rewriter, loc, init, loopBlock);
// Prepare the body of the loop block.
rewriter.setInsertionPointToStart(loopBlock);
@@ -696,15 +701,16 @@ struct GenericAtomicRMWOpLowering
// Append the cmpxchg op to the end of the loop block.
auto successOrdering = LLVM::AtomicOrdering::acq_rel;
auto failureOrdering = LLVM::AtomicOrdering::monotonic;
- auto cmpxchg = rewriter.create<LLVM::AtomicCmpXchgOp>(
- loc, dataPtr, loopArgument, result, successOrdering, failureOrdering);
+ auto cmpxchg =
+ LLVM::AtomicCmpXchgOp::create(rewriter, loc, dataPtr, loopArgument,
+ result, successOrdering, failureOrdering);
// Extract the %new_loaded and %ok values from the pair.
- Value newLoaded = rewriter.create<LLVM::ExtractValueOp>(loc, cmpxchg, 0);
- Value ok = rewriter.create<LLVM::ExtractValueOp>(loc, cmpxchg, 1);
+ Value newLoaded = LLVM::ExtractValueOp::create(rewriter, loc, cmpxchg, 0);
+ Value ok = LLVM::ExtractValueOp::create(rewriter, loc, cmpxchg, 1);
// Conditionally branch to the end or back to the loop depending on %ok.
- rewriter.create<LLVM::CondBrOp>(loc, ok, endBlock, ArrayRef<Value>(),
- loopBlock, newLoaded);
+ LLVM::CondBrOp::create(rewriter, loc, ok, endBlock, ArrayRef<Value>(),
+ loopBlock, newLoaded);
rewriter.setInsertionPointToEnd(endBlock);
@@ -796,8 +802,8 @@ class GlobalMemrefOpLowering : public ConvertOpToLLVMPattern<memref::GlobalOp> {
if (!isExternal && isUninitialized) {
rewriter.createBlock(&newGlobal.getInitializerRegion());
Value undef[] = {
- rewriter.create<LLVM::UndefOp>(newGlobal.getLoc(), arrayTy)};
- rewriter.create<LLVM::ReturnOp>(newGlobal.getLoc(), undef);
+ LLVM::UndefOp::create(rewriter, newGlobal.getLoc(), arrayTy)};
+ LLVM::ReturnOp::create(rewriter, newGlobal.getLoc(), undef);
}
return success();
}
@@ -842,13 +848,13 @@ struct GetGlobalMemrefOpLowering
Type arrayTy = convertGlobalMemrefTypeToLLVM(type, *getTypeConverter());
auto ptrTy = LLVM::LLVMPointerType::get(rewriter.getContext(), memSpace);
auto addressOf =
- rewriter.create<LLVM::AddressOfOp>(loc, ptrTy, op.getName());
+ LLVM::AddressOfOp::create(rewriter, loc, ptrTy, op.getName());
// Get the address of the first element in the array by creating a GEP with
// the address of the GV as the base, and (rank + 1) number of 0 indices.
- auto gep = rewriter.create<LLVM::GEPOp>(
- loc, ptrTy, arrayTy, addressOf,
- SmallVector<LLVM::GEPArg>(type.getRank() + 1, 0));
+ auto gep =
+ LLVM::GEPOp::create(rewriter, loc, ptrTy, arrayTy, addressOf,
+ SmallVector<LLVM::GEPArg>(type.getRank() + 1, 0));
// We do not expect the memref obtained using `memref.get_global` to be
// ever deallocated. Set the allocated pointer to be known bad value to
@@ -857,7 +863,7 @@ struct GetGlobalMemrefOpLowering
Value deadBeefConst =
createIndexAttrConstant(rewriter, op->getLoc(), intPtrType, 0xdeadbeef);
auto deadBeefPtr =
- rewriter.create<LLVM::IntToPtrOp>(loc, ptrTy, deadBeefConst);
+ LLVM::IntToPtrOp::create(rewriter, loc, ptrTy, deadBeefConst);
// Both allocated and aligned pointers are same. We could potentially stash
// a nullptr for the allocated pointer since we do not expect any dealloc.
@@ -1009,8 +1015,8 @@ struct MemRefCastOpLowering : public ConvertOpToLLVMPattern<memref::CastOp> {
loc, adaptor.getSource(), rewriter);
// rank = ConstantOp srcRank
- auto rankVal = rewriter.create<LLVM::ConstantOp>(
- loc, getIndexType(), rewriter.getIndexAttr(rank));
+ auto rankVal = LLVM::ConstantOp::create(rewriter, loc, getIndexType(),
+ rewriter.getIndexAttr(rank));
// poison = PoisonOp
UnrankedMemRefDescriptor memRefDesc =
UnrankedMemRefDescriptor::poison(rewriter, loc, targetStructType);
@@ -1029,7 +1035,7 @@ struct MemRefCastOpLowering : public ConvertOpToLLVMPattern<memref::CastOp> {
auto ptr = memRefDesc.memRefDescPtr(rewriter, loc);
// struct = LoadOp ptr
- auto loadOp = rewriter.create<LLVM::LoadOp>(loc, targetStructType, ptr);
+ auto loadOp = LLVM::LoadOp::create(rewriter, loc, targetStructType, ptr);
rewriter.replaceOp(memRefCastOp, loadOp.getResult());
} else {
llvm_unreachable("Unsupported unranked memref to unranked memref cast");
@@ -1063,32 +1069,33 @@ class MemRefCopyOpLowering : public ConvertOpToLLVMPattern<memref::CopyOp> {
MemRefDescriptor srcDesc(adaptor.getSource());
// Compute number of elements.
- Value numElements = rewriter.create<LLVM::ConstantOp>(
- loc, getIndexType(), rewriter.getIndexAttr(1));
+ Value numElements = LLVM::ConstantOp::create(rewriter, loc, getIndexType(),
+ rewriter.getIndexAttr(1));
for (int pos = 0; pos < srcType.getRank(); ++pos) {
auto size = srcDesc.size(rewriter, loc, pos);
- numElements = rewriter.create<LLVM::MulOp>(loc, numElements, size);
+ numElements = LLVM::MulOp::create(rewriter, loc, numElements, size);
}
// Get element size.
auto sizeInBytes = getSizeInBytes(loc, srcType.getElementType(), rewriter);
// Compute total.
Value totalSize =
- rewriter.create<LLVM::MulOp>(loc, numElements, sizeInBytes);
+ LLVM::MulOp::create(rewriter, loc, numElements, sizeInBytes);
Type elementType = typeConverter->convertType(srcType.getElementType());
Value srcBasePtr = srcDesc.alignedPtr(rewriter, loc);
Value srcOffset = srcDesc.offset(rewriter, loc);
- Value srcPtr = rewriter.create<LLVM::GEPOp>(
- loc, srcBasePtr.getType(), elementType, srcBasePtr, srcOffset);
+ Value srcPtr = LLVM::GEPOp::create(rewriter, loc, srcBasePtr.getType(),
+ elementType, srcBasePtr, srcOffset);
MemRefDescriptor targetDesc(adaptor.getTarget());
Value targetBasePtr = targetDesc.alignedPtr(rewriter, loc);
Value targetOffset = targetDesc.offset(rewriter, loc);
- Value targetPtr = rewriter.create<LLVM::GEPOp>(
- loc, targetBasePtr.getType(), elementType, targetBasePtr, targetOffset);
- rewriter.create<LLVM::MemcpyOp>(loc, targetPtr, srcPtr, totalSize,
- /*isVolatile=*/false);
+ Value targetPtr =
+ LLVM::GEPOp::create(rewriter, loc, targetBasePtr.getType(), elementType,
+ targetBasePtr, targetOffset);
+ LLVM::MemcpyOp::create(rewriter, loc, targetPtr, srcPtr, totalSize,
+ /*isVolatile=*/false);
rewriter.eraseOp(op);
return success();
@@ -1103,8 +1110,8 @@ class MemRefCopyOpLowering : public ConvertOpToLLVMPattern<memref::CopyOp> {
// First make sure we have an unranked memref descriptor representation.
auto makeUnranked = [&, this](Value ranked, MemRefType type) {
- auto rank = rewriter.create<LLVM::ConstantOp>(loc, getIndexType(),
- type.getRank());
+ auto rank = LLVM::ConstantOp::create(rewriter, loc, getIndexType(),
+ type.getRank());
auto *typeConverter = getTypeConverter();
auto ptr =
typeConverter->promoteOneMemRefDescriptor(loc, ranked, rewriter);
@@ -1116,7 +1123,7 @@ class MemRefCopyOpLowering : public ConvertOpToLLVMPattern<memref::CopyOp> {
};
// Save stack position before promoting descriptors
- auto stackSaveOp = rewriter.create<LLVM::StackSaveOp>(loc, getPtrType());
+ auto stackSaveOp = LLVM::StackSaveOp::create(rewriter, loc, getPtrType());
auto srcMemRefType = dyn_cast<MemRefType>(srcType);
Value unrankedSource =
@@ -1128,13 +1135,13 @@ class MemRefCopyOpLowering : public ConvertOpToLLVMPattern<memref::CopyOp> {
: adaptor.getTarget();
// Now promote the unranked descriptors to the stack.
- auto one = rewriter.create<LLVM::ConstantOp>(loc, getIndexType(),
- rewriter.getIndexAttr(1));
+ auto one = LLVM::ConstantOp::create(rewriter, loc, getIndexType(),
+ rewriter.getIndexAttr(1));
auto promote = [&](Value desc) {
auto ptrType = LLVM::LLVMPointerType::get(rewriter.getContext());
auto allocated =
- rewriter.create<LLVM::AllocaOp>(loc, ptrType, desc.getType(), one);
- rewriter.create<LLVM::StoreOp>(loc, desc, allocated);
+ LLVM::AllocaOp::create(rewriter, loc, ptrType, desc.getType(), one);
+ LLVM::StoreOp::create(rewriter, loc, desc, allocated);
return allocated;
};
@@ -1149,11 +1156,11 @@ class MemRefCopyOpLowering : public ConvertOpToLLVMPattern<memref::CopyOp> {
sourcePtr.getType(), symbolTables);
if (failed(copyFn))
return failure();
- rewriter.create<LLVM::CallOp>(loc, copyFn.value(),
- ValueRange{elemSize, sourcePtr, targetPtr});
+ LLVM::CallOp::create(rewriter, loc, copyFn.value(),
+ ValueRange{elemSize, sourcePtr, targetPtr});
// Restore stack used for descriptors
- rewriter.create<LLVM::StackRestoreOp>(loc, stackSaveOp);
+ LLVM::StackRestoreOp::create(rewriter, loc, stackSaveOp);
rewriter.eraseOp(op);
@@ -1204,9 +1211,9 @@ struct MemorySpaceCastOpLowering
MemRefDescriptor::unpack(rewriter, loc, adaptor.getSource(), resultTypeR,
descVals);
descVals[0] =
- rewriter.create<LLVM::AddrSpaceCastOp>(loc, newPtrType, descVals[0]);
+ LLVM::AddrSpaceCastOp::create(rewriter, loc, newPtrType, descVals[0]);
descVals[1] =
- rewriter.create<LLVM::AddrSpaceCastOp>(loc, newPtrType, descVals[1]);
+ LLVM::AddrSpaceCastOp::create(rewriter, loc, newPtrType, descVals[1]);
Value result = MemRefDescriptor::pack(rewriter, loc, *getTypeConverter(),
resultTypeR, descVals);
rewriter.replaceOp(op, result);
@@ -1241,8 +1248,9 @@ struct MemorySpaceCastOpLowering
UnrankedMemRefDescriptor::computeSizes(rewriter, loc, *getTypeConverter(),
result, resultAddrSpace, sizes);
Value resultUnderlyingSize = sizes.front();
- Value resultUnderlyingDesc = rewriter.create<LLVM::AllocaOp>(
- loc, getPtrType(), rewriter.getI8Type(), resultUnderlyingSize);
+ Value resultUnderlyingDesc =
+ LLVM::AllocaOp::create(rewriter, loc, getPtrType(),
+ rewriter.getI8Type(), resultUnderlyingSize);
result.setMemRefDescPtr(rewriter, loc, resultUnderlyingDesc);
// Copy pointers, performing address space casts.
@@ -1256,10 +1264,10 @@ struct MemorySpaceCastOpLowering
Value alignedPtr =
sourceDesc.alignedPtr(rewriter, loc, *getTypeConverter(),
sourceUnderlyingDesc, sourceElemPtrType);
- allocatedPtr = rewriter.create<LLVM::AddrSpaceCastOp>(
- loc, resultElemPtrType, allocatedPtr);
- alignedPtr = rewriter.create<LLVM::AddrSpaceCastOp>(
- loc, resultElemPtrType, alignedPtr);
+ allocatedPtr = LLVM::AddrSpaceCastOp::create(
+ rewriter, loc, resultElemPtrType, allocatedPtr);
+ alignedPtr = LLVM::AddrSpaceCastOp::create(rewriter, loc,
+ resultElemPtrType, alignedPtr);
result.setAllocatedPtr(rewriter, loc, resultUnderlyingDesc,
resultElemPtrType, allocatedPtr);
@@ -1277,12 +1285,13 @@ struct MemorySpaceCastOpLowering
int64_t bytesToSkip =
2 * llvm::divideCeil(
getTypeConverter()->getPointerBitwidth(resultAddrSpace), 8);
- Value bytesToSkipConst = rewriter.create<LLVM::ConstantOp>(
- loc, getIndexType(), rewriter.getIndexAttr(bytesToSkip));
- Value copySize = rewriter.create<LLVM::SubOp>(
- loc, getIndexType(), resultUnderlyingSize, bytesToSkipConst);
- rewriter.create<LLVM::MemcpyOp>(loc, resultIndexVals, sourceIndexVals,
- copySize, /*isVolatile=*/false);
+ Value bytesToSkipConst = LLVM::ConstantOp::create(
+ rewriter, loc, getIndexType(), rewriter.getIndexAttr(bytesToSkip));
+ Value copySize =
+ LLVM::SubOp::create(rewriter, loc, getIndexType(),
+ resultUnderlyingSize, bytesToSkipConst);
+ LLVM::MemcpyOp::create(rewriter, loc, resultIndexVals, sourceIndexVals,
+ copySize, /*isVolatile=*/false);
rewriter.replaceOp(op, ValueRange{result});
return success();
@@ -1485,7 +1494,7 @@ struct MemRefReshapeOpLowering
} else {
Value shapeOp = reshapeOp.getShape();
Value index = createIndexAttrConstant(rewriter, loc, indexType, i);
- dimSize = rewriter.create<memref::LoadOp>(loc, shapeOp, index);
+ dimSize = memref::LoadOp::create(rewriter, loc, shapeOp, index);
Type indexType = getIndexType();
if (dimSize.getType() != indexType)
dimSize = typeConverter->materializeTargetConversion(
@@ -1497,7 +1506,7 @@ struct MemRefReshapeOpLowering
desc.setStride(rewriter, loc, i, stride);
// Prepare the stride value for the next dimension.
- stride = rewriter.create<LLVM::MulOp>(loc, stride, dimSize);
+ stride = LLVM::MulOp::create(rewriter, loc, stride, dimSize);
}
*descriptor = desc;
@@ -1522,8 +1531,9 @@ struct MemRefReshapeOpLowering
SmallVector<Value, 4> sizes;
UnrankedMemRefDescriptor::computeSizes(rewriter, loc, *getTypeConverter(),
targetDesc, addressSpace, sizes);
- Value underlyingDescPtr = rewriter.create<LLVM::AllocaOp>(
- loc, getPtrType(), IntegerType::get(getContext(), 8), sizes.front());
+ Value underlyingDescPtr = LLVM::AllocaOp::create(
+ rewriter, loc, getPtrType(), IntegerType::get(getContext(), 8),
+ sizes.front());
targetDesc.setMemRefDescPtr(rewriter, loc, underlyingDescPtr);
// Extract pointers and offset from the source memref.
@@ -1554,7 +1564,7 @@ struct MemRefReshapeOpLowering
Value shapeOperandPtr = shapeDesc.alignedPtr(rewriter, loc);
Value oneIndex = createIndexAttrConstant(rewriter, loc, getIndexType(), 1);
Value resultRankMinusOne =
- rewriter.create<LLVM::SubOp>(loc, resultRank, oneIndex);
+ LLVM::SubOp::create(rewriter, loc, resultRank, oneIndex);
Block *initBlock = rewriter.getInsertionBlock();
Type indexType = getTypeConverter()->getIndexType();
@@ -1568,15 +1578,15 @@ struct MemRefReshapeOpLowering
rewriter.mergeBlocks(remainingBlock, condBlock, ValueRange());
rewriter.setInsertionPointToEnd(initBlock);
- rewriter.create<LLVM::BrOp>(loc, ValueRange({resultRankMinusOne, oneIndex}),
- condBlock);
+ LLVM::BrOp::create(rewriter, loc,
+ ValueRange({resultRankMinusOne, oneIndex}), condBlock);
rewriter.setInsertionPointToStart(condBlock);
Value indexArg = condBlock->getArgument(0);
Value strideArg = condBlock->getArgument(1);
Value zeroIndex = createIndexAttrConstant(rewriter, loc, indexType, 0);
- Value pred = rewriter.create<LLVM::ICmpOp>(
- loc, IntegerType::get(rewriter.getContext(), 1),
+ Value pred = LLVM::ICmpOp::create(
+ rewriter, loc, IntegerType::get(rewriter.getContext(), 1),
LLVM::ICmpPredicate::sge, indexArg, zeroIndex);
Block *bodyBlock =
@@ -1585,31 +1595,31 @@ struct MemRefReshapeOpLowering
// Copy size from shape to descriptor.
auto llvmIndexPtrType = LLVM::LLVMPointerType::get(rewriter.getContext());
- Value sizeLoadGep = rewriter.create<LLVM::GEPOp>(
- loc, llvmIndexPtrType,
+ Value sizeLoadGep = LLVM::GEPOp::create(
+ rewriter, loc, llvmIndexPtrType,
typeConverter->convertType(shapeMemRefType.getElementType()),
shapeOperandPtr, indexArg);
- Value size = rewriter.create<LLVM::LoadOp>(loc, indexType, sizeLoadGep);
+ Value size = LLVM::LoadOp::create(rewriter, loc, indexType, sizeLoadGep);
UnrankedMemRefDescriptor::setSize(rewriter, loc, *getTypeConverter(),
targetSizesBase, indexArg, size);
// Write stride value and compute next one.
UnrankedMemRefDescriptor::setStride(rewriter, loc, *getTypeConverter(),
targetStridesBase, indexArg, strideArg);
- Value nextStride = rewriter.create<LLVM::MulOp>(loc, strideArg, size);
+ Value nextStride = LLVM::MulOp::create(rewriter, loc, strideArg, size);
// Decrement loop counter and branch back.
- Value decrement = rewriter.create<LLVM::SubOp>(loc, indexArg, oneIndex);
- rewriter.create<LLVM::BrOp>(loc, ValueRange({decrement, nextStride}),
- condBlock);
+ Value decrement = LLVM::SubOp::create(rewriter, loc, indexArg, oneIndex);
+ LLVM::BrOp::create(rewriter, loc, ValueRange({decrement, nextStride}),
+ condBlock);
Block *remainder =
rewriter.splitBlock(bodyBlock, rewriter.getInsertionPoint());
// Hook up the cond exit to the remainder.
rewriter.setInsertionPointToEnd(condBlock);
- rewriter.create<LLVM::CondBrOp>(loc, pred, bodyBlock, ValueRange(),
- remainder, ValueRange());
+ LLVM::CondBrOp::create(rewriter, loc, pred, bodyBlock, ValueRange(),
+ remainder, ValueRange());
// Reset position to beginning of new remainder block.
rewriter.setInsertionPointToStart(remainder);
@@ -1738,7 +1748,7 @@ struct ViewOpLowering : public ConvertOpToLLVMPattern<memref::ViewOp> {
return createIndexAttrConstant(rewriter, loc, indexType, strides[idx]);
if (nextSize)
return runningStride
- ? rewriter.create<LLVM::MulOp>(loc, runningStride, nextSize)
+ ? LLVM::MulOp::create(rewriter, loc, runningStride, nextSize)
: nextSize;
assert(!runningStride);
return createIndexAttrConstant(rewriter, loc, indexType, 1);
@@ -1783,8 +1793,8 @@ struct ViewOpLowering : public ConvertOpToLLVMPattern<memref::ViewOp> {
// Field 2: Copy the actual aligned pointer to payload.
Value alignedPtr = sourceMemRef.alignedPtr(rewriter, loc);
- alignedPtr = rewriter.create<LLVM::GEPOp>(
- loc, alignedPtr.getType(),
+ alignedPtr = LLVM::GEPOp::create(
+ rewriter, loc, alignedPtr.getType(),
typeConverter->convertType(srcMemRefType.getElementType()), alignedPtr,
adaptor.getByteShift());
diff --git a/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp b/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp
index b866afbce98b0..7a705336bf11c 100644
--- a/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp
+++ b/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp
@@ -79,7 +79,8 @@ adjustAccessChainForBitwidth(const SPIRVTypeConverter &typeConverter,
assert(indices.size() == 2);
indices.back() = builder.createOrFold<spirv::SDivOp>(loc, lastDim, idx);
Type t = typeConverter.convertType(op.getComponentPtr().getType());
- return builder.create<spirv::AccessChainOp>(loc, t, op.getBasePtr(), indices);
+ return spirv::AccessChainOp::create(builder, loc, t, op.getBasePtr(),
+ indices);
}
/// Casts the given `srcBool` into an integer of `dstType`.
@@ -107,8 +108,8 @@ static Value shiftValue(Location loc, Value value, Value offset, Value mask,
value = castBoolToIntN(loc, value, dstType, builder);
} else {
if (valueBits < targetBits) {
- value = builder.create<spirv::UConvertOp>(
- loc, builder.getIntegerType(targetBits), value);
+ value = spirv::UConvertOp::create(
+ builder, loc, builder.getIntegerType(targetBits), value);
}
value = builder.createOrFold<spirv::BitwiseAndOp>(loc, value, mask);
@@ -372,8 +373,8 @@ AllocOpPattern::matchAndRewrite(memref::AllocOp operation, OpAdaptor adaptor,
std::string varName =
std::string("__workgroup_mem__") +
std::to_string(std::distance(varOps.begin(), varOps.end()));
- varOp = rewriter.create<spirv::GlobalVariableOp>(loc, spirvType, varName,
- /*initializer=*/nullptr);
+ varOp = spirv::GlobalVariableOp::create(rewriter, loc, spirvType, varName,
+ /*initializer=*/nullptr);
}
// Get pointer to global variable at the current scope.
@@ -572,8 +573,8 @@ IntLoadOpPattern::matchAndRewrite(memref::LoadOp loadOp, OpAdaptor adaptor,
loadOp, "failed to determine memory requirements");
auto [memoryAccess, alignment] = *memoryRequirements;
- Value loadVal = rewriter.create<spirv::LoadOp>(loc, accessChain,
- memoryAccess, alignment);
+ Value loadVal = spirv::LoadOp::create(rewriter, loc, accessChain,
+ memoryAccess, alignment);
if (isBool)
loadVal = castIntNToBool(loc, loadVal, rewriter);
rewriter.replaceOp(loadOp, loadVal);
@@ -601,8 +602,8 @@ IntLoadOpPattern::matchAndRewrite(memref::LoadOp loadOp, OpAdaptor adaptor,
loadOp, "failed to determine memory requirements");
auto [memoryAccess, alignment] = *memoryRequirements;
- Value spvLoadOp = rewriter.create<spirv::LoadOp>(loc, dstType, adjustedPtr,
- memoryAccess, alignment);
+ Value spvLoadOp = spirv::LoadOp::create(rewriter, loc, dstType, adjustedPtr,
+ memoryAccess, alignment);
// Shift the bits to the rightmost.
// ____XXXX________ -> ____________XXXX
@@ -770,12 +771,12 @@ IntStoreOpPattern::matchAndRewrite(memref::StoreOp storeOp, OpAdaptor adaptor,
if (!scope)
return rewriter.notifyMatchFailure(storeOp, "atomic scope not available");
- Value result = rewriter.create<spirv::AtomicAndOp>(
- loc, dstType, adjustedPtr, *scope, spirv::MemorySemantics::AcquireRelease,
- clearBitsMask);
- result = rewriter.create<spirv::AtomicOrOp>(
- loc, dstType, adjustedPtr, *scope, spirv::MemorySemantics::AcquireRelease,
- storeVal);
+ Value result = spirv::AtomicAndOp::create(
+ rewriter, loc, dstType, adjustedPtr, *scope,
+ spirv::MemorySemantics::AcquireRelease, clearBitsMask);
+ result = spirv::AtomicOrOp::create(
+ rewriter, loc, dstType, adjustedPtr, *scope,
+ spirv::MemorySemantics::AcquireRelease, storeVal);
// The AtomicOrOp has no side effect. Since it is already inserted, we can
// just remove the original StoreOp. Note that rewriter.replaceOp()
@@ -850,12 +851,12 @@ LogicalResult MemorySpaceCastOpPattern::matchAndRewrite(
genericPtrType = typeConverter.convertType(intermediateType);
}
if (sourceSc != spirv::StorageClass::Generic) {
- result =
- rewriter.create<spirv::PtrCastToGenericOp>(loc, genericPtrType, result);
+ result = spirv::PtrCastToGenericOp::create(rewriter, loc, genericPtrType,
+ result);
}
if (resultSc != spirv::StorageClass::Generic) {
result =
- rewriter.create<spirv::GenericCastToPtrOp>(loc, resultPtrType, result);
+ spirv::GenericCastToPtrOp::create(rewriter, loc, resultPtrType, result);
}
rewriter.replaceOp(addrCastOp, result);
return success();
diff --git a/mlir/lib/Conversion/MeshToMPI/MeshToMPI.cpp b/mlir/lib/Conversion/MeshToMPI/MeshToMPI.cpp
index b93128441f2b5..63b1fdabaf407 100644
--- a/mlir/lib/Conversion/MeshToMPI/MeshToMPI.cpp
+++ b/mlir/lib/Conversion/MeshToMPI/MeshToMPI.cpp
@@ -65,7 +65,7 @@ static SmallVector<Value> getMixedAsValues(OpBuilder b, const Location &loc,
values.emplace_back(*(dyn++));
} else {
TypedAttr val = type == i64 ? b.getI64IntegerAttr(s) : b.getIndexAttr(s);
- values.emplace_back(b.create<arith::ConstantOp>(loc, type, val));
+ values.emplace_back(arith::ConstantOp::create(b, loc, type, val));
}
}
return values;
@@ -79,9 +79,9 @@ static SmallVector<Value> linearToMultiIndex(Location loc, OpBuilder b,
SmallVector<Value> multiIndex(n);
for (int i = n - 1; i >= 0; --i) {
- multiIndex[i] = b.create<arith::RemSIOp>(loc, linearIndex, dimensions[i]);
+ multiIndex[i] = arith::RemSIOp::create(b, loc, linearIndex, dimensions[i]);
if (i > 0)
- linearIndex = b.create<arith::DivSIOp>(loc, linearIndex, dimensions[i]);
+ linearIndex = arith::DivSIOp::create(b, loc, linearIndex, dimensions[i]);
}
return multiIndex;
@@ -91,13 +91,13 @@ static SmallVector<Value> linearToMultiIndex(Location loc, OpBuilder b,
Value multiToLinearIndex(Location loc, OpBuilder b, ValueRange multiIndex,
ValueRange dimensions) {
- Value linearIndex = b.create<arith::ConstantIndexOp>(loc, 0);
- Value stride = b.create<arith::ConstantIndexOp>(loc, 1);
+ Value linearIndex = arith::ConstantIndexOp::create(b, loc, 0);
+ Value stride = arith::ConstantIndexOp::create(b, loc, 1);
for (int i = multiIndex.size() - 1; i >= 0; --i) {
- Value off = b.create<arith::MulIOp>(loc, multiIndex[i], stride);
- linearIndex = b.create<arith::AddIOp>(loc, linearIndex, off);
- stride = b.create<arith::MulIOp>(loc, stride, dimensions[i]);
+ Value off = arith::MulIOp::create(b, loc, multiIndex[i], stride);
+ linearIndex = arith::AddIOp::create(b, loc, linearIndex, off);
+ stride = arith::MulIOp::create(b, loc, stride, dimensions[i]);
}
return linearIndex;
@@ -144,11 +144,12 @@ struct ConvertShardingOp : public OpConversionPattern<ShardingOp> {
auto i64 = rewriter.getI64Type();
std::array<int64_t, 2> shape = {static_cast<int64_t>(splitAxes.size()),
maxNAxes};
- Value resSplitAxes = rewriter.create<tensor::EmptyOp>(loc, shape, i16);
+ Value resSplitAxes = tensor::EmptyOp::create(rewriter, loc, shape, i16);
auto attr = IntegerAttr::get(i16, -1);
- Value fillValue = rewriter.create<arith::ConstantOp>(loc, i16, attr);
- resSplitAxes = rewriter.create<linalg::FillOp>(loc, fillValue, resSplitAxes)
- .getResult(0);
+ Value fillValue = arith::ConstantOp::create(rewriter, loc, i16, attr);
+ resSplitAxes =
+ linalg::FillOp::create(rewriter, loc, fillValue, resSplitAxes)
+ .getResult(0);
// explicitly write values into tensor row by row
std::array<int64_t, 2> strides = {1, 1};
@@ -162,9 +163,10 @@ struct ConvertShardingOp : public OpConversionPattern<ShardingOp> {
std::array<int64_t, 2> sizes = {1, size};
auto tensorType = RankedTensorType::get({size}, i16);
auto attrs = DenseIntElementsAttr::get(tensorType, axes.asArrayRef());
- auto vals = rewriter.create<arith::ConstantOp>(loc, tensorType, attrs);
- resSplitAxes = rewriter.create<tensor::InsertSliceOp>(
- loc, vals, resSplitAxes, empty, empty, empty, offs, sizes, strides);
+ auto vals = arith::ConstantOp::create(rewriter, loc, tensorType, attrs);
+ resSplitAxes = tensor::InsertSliceOp::create(rewriter, loc, vals,
+ resSplitAxes, empty, empty,
+ empty, offs, sizes, strides);
}
// To hold halos sizes, create 2d Tensor with shape {nSplits, 2}.
@@ -179,7 +181,7 @@ struct ConvertShardingOp : public OpConversionPattern<ShardingOp> {
.create<tensor::EmptyOp>(loc, std::array<int64_t, 2>{0, 0},
i64)
.getResult()
- : rewriter.create<tensor::FromElementsOp>(loc, type, haloSizes)
+ : tensor::FromElementsOp::create(rewriter, loc, type, haloSizes)
.getResult();
// To hold sharded dims offsets, create Tensor with shape {nSplits,
@@ -189,8 +191,8 @@ struct ConvertShardingOp : public OpConversionPattern<ShardingOp> {
// MeshOp)
Value resOffsets;
if (adaptor.getStaticShardedDimsOffsets().empty()) {
- resOffsets = rewriter.create<tensor::EmptyOp>(
- loc, std::array<int64_t, 2>{0, 0}, i64);
+ resOffsets = tensor::EmptyOp::create(rewriter, loc,
+ std::array<int64_t, 2>{0, 0}, i64);
} else {
SymbolTableCollection symbolTableCollection;
auto meshOp = getMesh(op, symbolTableCollection);
@@ -204,12 +206,12 @@ struct ConvertShardingOp : public OpConversionPattern<ShardingOp> {
assert(maxSplitSize);
++maxSplitSize; // add one for the total size
- resOffsets = rewriter.create<tensor::EmptyOp>(
- loc, std::array<int64_t, 2>{nSplits, maxSplitSize}, i64);
- Value zero = rewriter.create<arith::ConstantOp>(
- loc, i64, rewriter.getI64IntegerAttr(ShapedType::kDynamic));
+ resOffsets = tensor::EmptyOp::create(
+ rewriter, loc, std::array<int64_t, 2>{nSplits, maxSplitSize}, i64);
+ Value zero = arith::ConstantOp::create(
+ rewriter, loc, i64, rewriter.getI64IntegerAttr(ShapedType::kDynamic));
resOffsets =
- rewriter.create<linalg::FillOp>(loc, zero, resOffsets).getResult(0);
+ linalg::FillOp::create(rewriter, loc, zero, resOffsets).getResult(0);
SmallVector<Value> offsets =
getMixedAsValues(rewriter, loc, adaptor.getStaticShardedDimsOffsets(),
adaptor.getDynamicShardedDimsOffsets());
@@ -220,11 +222,12 @@ struct ConvertShardingOp : public OpConversionPattern<ShardingOp> {
assert(splitSize != ShapedType::kDynamic && splitSize < maxSplitSize);
++splitSize; // add one for the total size
ArrayRef<Value> values(&offsets[curr], splitSize);
- Value vals = rewriter.create<tensor::FromElementsOp>(loc, values);
+ Value vals = tensor::FromElementsOp::create(rewriter, loc, values);
std::array<int64_t, 2> offs = {static_cast<int64_t>(i), 0};
std::array<int64_t, 2> sizes = {1, splitSize};
- resOffsets = rewriter.create<tensor::InsertSliceOp>(
- loc, vals, resOffsets, empty, empty, empty, offs, sizes, strides);
+ resOffsets = tensor::InsertSliceOp::create(rewriter, loc, vals,
+ resOffsets, empty, empty,
+ empty, offs, sizes, strides);
curr += splitSize;
}
}
@@ -236,10 +239,10 @@ struct ConvertShardingOp : public OpConversionPattern<ShardingOp> {
return failure();
resSplitAxes =
- rewriter.create<tensor::CastOp>(loc, resTypes[0], resSplitAxes);
+ tensor::CastOp::create(rewriter, loc, resTypes[0], resSplitAxes);
resHaloSizes =
- rewriter.create<tensor::CastOp>(loc, resTypes[1], resHaloSizes);
- resOffsets = rewriter.create<tensor::CastOp>(loc, resTypes[2], resOffsets);
+ tensor::CastOp::create(rewriter, loc, resTypes[1], resHaloSizes);
+ resOffsets = tensor::CastOp::create(rewriter, loc, resTypes[2], resOffsets);
rewriter.replaceOpWithNewOp<UnrealizedConversionCastOp>(
op, TupleType::get(op.getContext(), resTypes),
@@ -269,9 +272,9 @@ struct ConvertProcessMultiIndexOp
SmallVector<Value> dims;
llvm::transform(
meshOp.getShape(), std::back_inserter(dims), [&](int64_t i) {
- return rewriter.create<arith::ConstantIndexOp>(loc, i).getResult();
+ return arith::ConstantIndexOp::create(rewriter, loc, i).getResult();
});
- Value rank = rewriter.create<ProcessLinearIndexOp>(op.getLoc(), meshOp);
+ Value rank = ProcessLinearIndexOp::create(rewriter, op.getLoc(), meshOp);
auto mIdx = linearToMultiIndex(loc, rewriter, rank, dims);
// optionally extract subset of mesh axes
@@ -302,7 +305,7 @@ class ConvertProcessLinearIndexOp
Location loc = op.getLoc();
auto ctx = op.getContext();
Value commWorld =
- rewriter.create<mpi::CommWorldOp>(loc, mpi::CommType::get(ctx));
+ mpi::CommWorldOp::create(rewriter, loc, mpi::CommType::get(ctx));
auto rank =
rewriter
.create<mpi::CommRankOp>(
@@ -341,41 +344,41 @@ struct ConvertNeighborsLinearIndicesOp
SmallVector<Value> dims;
llvm::transform(
meshOp.getShape(), std::back_inserter(dims), [&](int64_t i) {
- return rewriter.create<arith::ConstantIndexOp>(loc, i).getResult();
+ return arith::ConstantIndexOp::create(rewriter, loc, i).getResult();
});
Value dimSz = dims[axes[0]];
- Value one = rewriter.create<arith::ConstantIndexOp>(loc, 1);
- Value minus1 = rewriter.create<arith::ConstantIndexOp>(loc, -1);
- Value atBorder = rewriter.create<arith::CmpIOp>(
- loc, arith::CmpIPredicate::sle, orgIdx,
- rewriter.create<arith::ConstantIndexOp>(loc, 0));
- auto down = rewriter.create<scf::IfOp>(
- loc, atBorder,
+ Value one = arith::ConstantIndexOp::create(rewriter, loc, 1);
+ Value minus1 = arith::ConstantIndexOp::create(rewriter, loc, -1);
+ Value atBorder =
+ arith::CmpIOp::create(rewriter, loc, arith::CmpIPredicate::sle, orgIdx,
+ arith::ConstantIndexOp::create(rewriter, loc, 0));
+ auto down = scf::IfOp::create(
+ rewriter, loc, atBorder,
[&](OpBuilder &builder, Location loc) {
- builder.create<scf::YieldOp>(loc, minus1);
+ scf::YieldOp::create(builder, loc, minus1);
},
[&](OpBuilder &builder, Location loc) {
SmallVector<Value> tmp = mIdx;
tmp[axes[0]] =
- rewriter.create<arith::SubIOp>(op.getLoc(), orgIdx, one)
+ arith::SubIOp::create(rewriter, op.getLoc(), orgIdx, one)
.getResult();
- builder.create<scf::YieldOp>(
- loc, multiToLinearIndex(loc, rewriter, tmp, dims));
+ scf::YieldOp::create(builder, loc,
+ multiToLinearIndex(loc, rewriter, tmp, dims));
});
- atBorder = rewriter.create<arith::CmpIOp>(
- loc, arith::CmpIPredicate::sge, orgIdx,
- rewriter.create<arith::SubIOp>(loc, dimSz, one).getResult());
- auto up = rewriter.create<scf::IfOp>(
- loc, atBorder,
+ atBorder = arith::CmpIOp::create(
+ rewriter, loc, arith::CmpIPredicate::sge, orgIdx,
+ arith::SubIOp::create(rewriter, loc, dimSz, one).getResult());
+ auto up = scf::IfOp::create(
+ rewriter, loc, atBorder,
[&](OpBuilder &builder, Location loc) {
- builder.create<scf::YieldOp>(loc, minus1);
+ scf::YieldOp::create(builder, loc, minus1);
},
[&](OpBuilder &builder, Location loc) {
SmallVector<Value> tmp = mIdx;
tmp[axes[0]] =
- rewriter.create<arith::AddIOp>(op.getLoc(), orgIdx, one);
- builder.create<scf::YieldOp>(
- loc, multiToLinearIndex(loc, rewriter, tmp, dims));
+ arith::AddIOp::create(rewriter, op.getLoc(), orgIdx, one);
+ scf::YieldOp::create(builder, loc,
+ multiToLinearIndex(loc, rewriter, tmp, dims));
});
rewriter.replaceOp(op, ValueRange{down.getResult(0), up.getResult(0)});
return success();
@@ -447,8 +450,9 @@ struct ConvertShardShapeOp : public OpConversionPattern<ShardShapeOp> {
rewriter, loc, sharding.getStaticShardedDimsOffsets(),
sharding.getDynamicShardedDimsOffsets(), index);
if (!tmp.empty())
- shardedDimsOffs = rewriter.create<tensor::FromElementsOp>(
- loc, RankedTensorType::get({(int64_t)tmp.size()}, index), tmp);
+ shardedDimsOffs = tensor::FromElementsOp::create(
+ rewriter, loc, RankedTensorType::get({(int64_t)tmp.size()}, index),
+ tmp);
}
// With static mesh shape the sizes of the split axes are known.
@@ -457,9 +461,9 @@ struct ConvertShardShapeOp : public OpConversionPattern<ShardShapeOp> {
int64_t pos = 0;
SmallVector<Value> shardShape;
Value zero =
- rewriter.create<arith::ConstantOp>(loc, rewriter.getZeroAttr(index));
+ arith::ConstantOp::create(rewriter, loc, rewriter.getZeroAttr(index));
Value one =
- rewriter.create<arith::ConstantOp>(loc, rewriter.getOneAttr(index));
+ arith::ConstantOp::create(rewriter, loc, rewriter.getOneAttr(index));
// Iterate over the dimensions of the tensor shape, get their split Axes,
// and compute the sharded shape.
@@ -469,8 +473,8 @@ struct ConvertShardShapeOp : public OpConversionPattern<ShardShapeOp> {
auto axes = splitAxes[i];
// The current dimension might not be sharded.
// Create a value from the static position in shardDimsOffsets.
- Value posVal =
- rewriter.create<arith::ConstantOp>(loc, rewriter.getIndexAttr(pos));
+ Value posVal = arith::ConstantOp::create(rewriter, loc,
+ rewriter.getIndexAttr(pos));
// Get the index of the local shard in the mesh axis.
Value idx = multiIdx[axes[0]];
auto numShards =
@@ -482,29 +486,29 @@ struct ConvertShardShapeOp : public OpConversionPattern<ShardShapeOp> {
return op->emitError() << "Only single axis sharding is "
<< "supported for each dimension.";
}
- idx = rewriter.create<arith::AddIOp>(loc, posVal, idx);
+ idx = arith::AddIOp::create(rewriter, loc, posVal, idx);
// Compute size = shardedDimsOffs[idx+1] - shardedDimsOffs[idx].
Value off =
- rewriter.create<tensor::ExtractOp>(loc, shardedDimsOffs, idx);
- idx = rewriter.create<arith::AddIOp>(loc, idx, one);
+ tensor::ExtractOp::create(rewriter, loc, shardedDimsOffs, idx);
+ idx = arith::AddIOp::create(rewriter, loc, idx, one);
Value nextOff =
- rewriter.create<tensor::ExtractOp>(loc, shardedDimsOffs, idx);
- Value sz = rewriter.create<arith::SubIOp>(loc, nextOff, off);
+ tensor::ExtractOp::create(rewriter, loc, shardedDimsOffs, idx);
+ Value sz = arith::SubIOp::create(rewriter, loc, nextOff, off);
shardShape.emplace_back(sz);
} else {
- Value numShardsVal = rewriter.create<arith::ConstantOp>(
- loc, rewriter.getIndexAttr(numShards));
+ Value numShardsVal = arith::ConstantOp::create(
+ rewriter, loc, rewriter.getIndexAttr(numShards));
// Compute shard dim size by distributing odd elements to trailing
// shards:
// sz = dim / numShards
// + (idx >= (numShards - (dim % numShards)) ? 1 : 0)
- Value sz = rewriter.create<arith::DivSIOp>(loc, dim, numShardsVal);
- Value sz1 = rewriter.create<arith::RemSIOp>(loc, dim, numShardsVal);
- sz1 = rewriter.create<arith::SubIOp>(loc, numShardsVal, sz1);
- auto cond = rewriter.create<arith::CmpIOp>(
- loc, arith::CmpIPredicate::sge, idx, sz1);
- Value odd = rewriter.create<arith::SelectOp>(loc, cond, one, zero);
- sz = rewriter.create<arith::AddIOp>(loc, sz, odd);
+ Value sz = arith::DivSIOp::create(rewriter, loc, dim, numShardsVal);
+ Value sz1 = arith::RemSIOp::create(rewriter, loc, dim, numShardsVal);
+ sz1 = arith::SubIOp::create(rewriter, loc, numShardsVal, sz1);
+ auto cond = arith::CmpIOp::create(
+ rewriter, loc, arith::CmpIPredicate::sge, idx, sz1);
+ Value odd = arith::SelectOp::create(rewriter, loc, cond, one, zero);
+ sz = arith::AddIOp::create(rewriter, loc, sz, odd);
shardShape.emplace_back(sz);
}
pos += numShards + 1; // add one for the total size.
@@ -568,7 +572,7 @@ struct ConvertAllReduceOp : public OpConversionPattern<AllReduceOp> {
if (isa<RankedTensorType>(input.getType())) {
auto memrefType = MemRefType::get(
inputShape, cast<ShapedType>(input.getType()).getElementType());
- input = iBuilder.create<bufferization::ToBufferOp>(memrefType, input);
+ input = bufferization::ToBufferOp::create(iBuilder, memrefType, input);
}
MemRefType inType = cast<MemRefType>(input.getType());
@@ -577,15 +581,15 @@ struct ConvertAllReduceOp : public OpConversionPattern<AllReduceOp> {
for (auto i = 0; i < inType.getRank(); ++i) {
auto s = inputShape[i];
if (ShapedType::isDynamic(s))
- shape[i] = iBuilder.create<memref::DimOp>(input, s).getResult();
+ shape[i] = memref::DimOp::create(iBuilder, input, s).getResult();
else
shape[i] = iBuilder.getIndexAttr(s);
}
// Allocate buffer and copy input to buffer.
- Value buffer = iBuilder.create<memref::AllocOp>(
- shape, cast<ShapedType>(op.getType()).getElementType());
- iBuilder.create<linalg::CopyOp>(input, buffer);
+ Value buffer = memref::AllocOp::create(
+ iBuilder, shape, cast<ShapedType>(op.getType()).getElementType());
+ linalg::CopyOp::create(iBuilder, input, buffer);
// Get an MPI_Comm_split for the AllReduce operation.
// The color is the linear index of the process in the mesh along the
@@ -594,9 +598,9 @@ struct ConvertAllReduceOp : public OpConversionPattern<AllReduceOp> {
SmallVector<Type> indexResultTypes(meshOp.getShape().size(),
iBuilder.getIndexType());
SmallVector<Value> myMultiIndex =
- iBuilder.create<ProcessMultiIndexOp>(indexResultTypes, mesh)
+ ProcessMultiIndexOp::create(iBuilder, indexResultTypes, mesh)
.getResult();
- Value zero = iBuilder.create<arith::ConstantIndexOp>(0);
+ Value zero = arith::ConstantIndexOp::create(iBuilder, 0);
SmallVector<Value> multiKey(myMultiIndex.size(), zero);
auto redAxes = adaptor.getMeshAxes();
@@ -607,15 +611,15 @@ struct ConvertAllReduceOp : public OpConversionPattern<AllReduceOp> {
Value color =
createProcessLinearIndex(mesh, myMultiIndex, redAxes, iBuilder);
- color = iBuilder.create<arith::IndexCastOp>(iBuilder.getI32Type(), color);
+ color = arith::IndexCastOp::create(iBuilder, iBuilder.getI32Type(), color);
Value key = createProcessLinearIndex(mesh, multiKey, redAxes, iBuilder);
- key = iBuilder.create<arith::IndexCastOp>(iBuilder.getI32Type(), key);
+ key = arith::IndexCastOp::create(iBuilder, iBuilder.getI32Type(), key);
// Finally split the communicator
auto commType = mpi::CommType::get(op->getContext());
- Value commWorld = iBuilder.create<mpi::CommWorldOp>(commType);
+ Value commWorld = mpi::CommWorldOp::create(iBuilder, commType);
auto comm =
- iBuilder.create<mpi::CommSplitOp>(commType, commWorld, color, key)
+ mpi::CommSplitOp::create(iBuilder, commType, commWorld, color, key)
.getNewcomm();
Value buffer1d = buffer;
@@ -623,19 +627,19 @@ struct ConvertAllReduceOp : public OpConversionPattern<AllReduceOp> {
if (inType.getRank() > 1) {
ReassociationIndices reassociation(inType.getRank());
std::iota(reassociation.begin(), reassociation.end(), 0);
- buffer1d = iBuilder.create<memref::CollapseShapeOp>(
- buffer, ArrayRef<ReassociationIndices>(reassociation));
+ buffer1d = memref::CollapseShapeOp::create(
+ iBuilder, buffer, ArrayRef<ReassociationIndices>(reassociation));
}
// Create the MPI AllReduce operation.
- iBuilder.create<mpi::AllReduceOp>(
- TypeRange(), buffer1d, buffer1d,
- getMPIReductionOp(adaptor.getReductionAttr()), comm);
+ mpi::AllReduceOp::create(iBuilder, TypeRange(), buffer1d, buffer1d,
+ getMPIReductionOp(adaptor.getReductionAttr()),
+ comm);
// If the destination is a memref, cast it to a tensor
if (isa<RankedTensorType>(op.getType()))
- buffer = iBuilder.create<bufferization::ToTensorOp>(op.getType(), buffer,
- true);
+ buffer = bufferization::ToTensorOp::create(iBuilder, op.getType(), buffer,
+ true);
rewriter.replaceOp(op, buffer);
return success();
@@ -676,9 +680,10 @@ struct ConvertUpdateHaloOp : public OpConversionPattern<UpdateHaloOp> {
auto toValue = [&rewriter, &loc](OpFoldResult &v) -> Value {
if (auto value = dyn_cast<Value>(v))
return value;
- return rewriter.create<arith::ConstantOp>(
- loc, rewriter.getIndexAttr(
- cast<IntegerAttr>(cast<Attribute>(v)).getInt()));
+ return arith::ConstantOp::create(
+ rewriter, loc,
+ rewriter.getIndexAttr(
+ cast<IntegerAttr>(cast<Attribute>(v)).getInt()));
};
auto dest = adaptor.getDestination();
@@ -689,7 +694,7 @@ struct ConvertUpdateHaloOp : public OpConversionPattern<UpdateHaloOp> {
auto mmemrefType = MemRefType::get(
dstShape, cast<ShapedType>(array.getType()).getElementType());
array =
- rewriter.create<bufferization::ToBufferOp>(loc, mmemrefType, array);
+ bufferization::ToBufferOp::create(rewriter, loc, mmemrefType, array);
}
auto rank = cast<ShapedType>(array.getType()).getRank();
auto opSplitAxes = adaptor.getSplitAxes().getAxes();
@@ -713,7 +718,7 @@ struct ConvertUpdateHaloOp : public OpConversionPattern<UpdateHaloOp> {
for (auto i = 0; i < rank; ++i) {
auto s = dstShape[i];
if (ShapedType::isDynamic(s))
- shape[i] = rewriter.create<memref::DimOp>(loc, array, s).getResult();
+ shape[i] = memref::DimOp::create(rewriter, loc, array, s).getResult();
else
shape[i] = rewriter.getIndexAttr(s);
@@ -723,12 +728,12 @@ struct ConvertUpdateHaloOp : public OpConversionPattern<UpdateHaloOp> {
offsets[i] = haloSizes[currHaloDim * 2];
// prepare shape and offsets of highest dim's halo exchange
- Value _haloSz = rewriter.create<arith::AddIOp>(
- loc, toValue(haloSizes[currHaloDim * 2]),
+ Value _haloSz = arith::AddIOp::create(
+ rewriter, loc, toValue(haloSizes[currHaloDim * 2]),
toValue(haloSizes[currHaloDim * 2 + 1]));
// the halo shape of lower dims exlude the halos
dimSizes[i] =
- rewriter.create<arith::SubIOp>(loc, toValue(shape[i]), _haloSz)
+ arith::SubIOp::create(rewriter, loc, toValue(shape[i]), _haloSz)
.getResult();
} else {
dimSizes[i] = shape[i];
@@ -736,14 +741,14 @@ struct ConvertUpdateHaloOp : public OpConversionPattern<UpdateHaloOp> {
}
auto tagAttr = rewriter.getI32IntegerAttr(91); // we just pick something
- auto tag = rewriter.create<arith::ConstantOp>(loc, tagAttr);
+ auto tag = arith::ConstantOp::create(rewriter, loc, tagAttr);
auto zeroAttr = rewriter.getI32IntegerAttr(0); // for detecting v<0
- auto zero = rewriter.create<arith::ConstantOp>(loc, zeroAttr);
+ auto zero = arith::ConstantOp::create(rewriter, loc, zeroAttr);
SmallVector<Type> indexResultTypes(meshOp.getShape().size(),
rewriter.getIndexType());
auto myMultiIndex =
- rewriter.create<ProcessMultiIndexOp>(loc, indexResultTypes, mesh)
+ ProcessMultiIndexOp::create(rewriter, loc, indexResultTypes, mesh)
.getResult();
// traverse all split axes from high to low dim
for (ssize_t dim = opSplitAxes.size() - 1; dim >= 0; --dim) {
@@ -758,20 +763,22 @@ struct ConvertUpdateHaloOp : public OpConversionPattern<UpdateHaloOp> {
splitAxes)
.getResults();
// MPI operates on i32...
- Value neighbourIDs[2] = {rewriter.create<arith::IndexCastOp>(
- loc, rewriter.getI32Type(), tmp[0]),
- rewriter.create<arith::IndexCastOp>(
- loc, rewriter.getI32Type(), tmp[1])};
+ Value neighbourIDs[2] = {
+ arith::IndexCastOp::create(rewriter, loc, rewriter.getI32Type(),
+ tmp[0]),
+ arith::IndexCastOp::create(rewriter, loc, rewriter.getI32Type(),
+ tmp[1])};
auto lowerRecvOffset = rewriter.getIndexAttr(0);
auto lowerSendOffset = toValue(haloSizes[currHaloDim * 2]);
- auto upperRecvOffset = rewriter.create<arith::SubIOp>(
- loc, toValue(shape[dim]), toValue(haloSizes[currHaloDim * 2 + 1]));
- auto upperSendOffset = rewriter.create<arith::SubIOp>(
- loc, upperRecvOffset, toValue(haloSizes[currHaloDim * 2]));
+ auto upperRecvOffset =
+ arith::SubIOp::create(rewriter, loc, toValue(shape[dim]),
+ toValue(haloSizes[currHaloDim * 2 + 1]));
+ auto upperSendOffset = arith::SubIOp::create(
+ rewriter, loc, upperRecvOffset, toValue(haloSizes[currHaloDim * 2]));
- Value commWorld = rewriter.create<mpi::CommWorldOp>(
- loc, mpi::CommType::get(op->getContext()));
+ Value commWorld = mpi::CommWorldOp::create(
+ rewriter, loc, mpi::CommType::get(op->getContext()));
// Make sure we send/recv in a way that does not lead to a dead-lock.
// The current approach is by far not optimal, this should be at least
@@ -787,37 +794,38 @@ struct ConvertUpdateHaloOp : public OpConversionPattern<UpdateHaloOp> {
// Processes on the mesh borders have only one neighbor
auto to = upperHalo ? neighbourIDs[0] : neighbourIDs[1];
auto from = upperHalo ? neighbourIDs[1] : neighbourIDs[0];
- auto hasFrom = rewriter.create<arith::CmpIOp>(
- loc, arith::CmpIPredicate::sge, from, zero);
- auto hasTo = rewriter.create<arith::CmpIOp>(
- loc, arith::CmpIPredicate::sge, to, zero);
- auto buffer = rewriter.create<memref::AllocOp>(
- loc, dimSizes, cast<ShapedType>(array.getType()).getElementType());
+ auto hasFrom = arith::CmpIOp::create(
+ rewriter, loc, arith::CmpIPredicate::sge, from, zero);
+ auto hasTo = arith::CmpIOp::create(rewriter, loc,
+ arith::CmpIPredicate::sge, to, zero);
+ auto buffer = memref::AllocOp::create(
+ rewriter, loc, dimSizes,
+ cast<ShapedType>(array.getType()).getElementType());
// if has neighbor: copy halo data from array to buffer and send
- rewriter.create<scf::IfOp>(
- loc, hasTo, [&](OpBuilder &builder, Location loc) {
+ scf::IfOp::create(
+ rewriter, loc, hasTo, [&](OpBuilder &builder, Location loc) {
offsets[dim] = upperHalo ? OpFoldResult(lowerSendOffset)
: OpFoldResult(upperSendOffset);
- auto subview = builder.create<memref::SubViewOp>(
- loc, array, offsets, dimSizes, strides);
- builder.create<memref::CopyOp>(loc, subview, buffer);
- builder.create<mpi::SendOp>(loc, TypeRange{}, buffer, tag, to,
- commWorld);
- builder.create<scf::YieldOp>(loc);
+ auto subview = memref::SubViewOp::create(
+ builder, loc, array, offsets, dimSizes, strides);
+ memref::CopyOp::create(builder, loc, subview, buffer);
+ mpi::SendOp::create(builder, loc, TypeRange{}, buffer, tag, to,
+ commWorld);
+ scf::YieldOp::create(builder, loc);
});
// if has neighbor: receive halo data into buffer and copy to array
- rewriter.create<scf::IfOp>(
- loc, hasFrom, [&](OpBuilder &builder, Location loc) {
+ scf::IfOp::create(
+ rewriter, loc, hasFrom, [&](OpBuilder &builder, Location loc) {
offsets[dim] = upperHalo ? OpFoldResult(upperRecvOffset)
: OpFoldResult(lowerRecvOffset);
- builder.create<mpi::RecvOp>(loc, TypeRange{}, buffer, tag, from,
- commWorld);
- auto subview = builder.create<memref::SubViewOp>(
- loc, array, offsets, dimSizes, strides);
- builder.create<memref::CopyOp>(loc, buffer, subview);
- builder.create<scf::YieldOp>(loc);
+ mpi::RecvOp::create(builder, loc, TypeRange{}, buffer, tag, from,
+ commWorld);
+ auto subview = memref::SubViewOp::create(
+ builder, loc, array, offsets, dimSizes, strides);
+ memref::CopyOp::create(builder, loc, buffer, subview);
+ scf::YieldOp::create(builder, loc);
});
- rewriter.create<memref::DeallocOp>(loc, buffer);
+ memref::DeallocOp::create(rewriter, loc, buffer);
offsets[dim] = orgOffset;
};
@@ -825,16 +833,17 @@ struct ConvertUpdateHaloOp : public OpConversionPattern<UpdateHaloOp> {
OpFoldResult &v = haloSizes[currHaloDim * 2 + upOrDown];
Value haloSz = dyn_cast<Value>(v);
if (!haloSz)
- haloSz = rewriter.create<arith::ConstantOp>(
- loc, rewriter.getI32IntegerAttr(
- cast<IntegerAttr>(cast<Attribute>(v)).getInt()));
- auto hasSize = rewriter.create<arith::CmpIOp>(
- loc, arith::CmpIPredicate::sgt, haloSz, zero);
- rewriter.create<scf::IfOp>(loc, hasSize,
- [&](OpBuilder &builder, Location loc) {
- genSendRecv(upOrDown > 0);
- builder.create<scf::YieldOp>(loc);
- });
+ haloSz = arith::ConstantOp::create(
+ rewriter, loc,
+ rewriter.getI32IntegerAttr(
+ cast<IntegerAttr>(cast<Attribute>(v)).getInt()));
+ auto hasSize = arith::CmpIOp::create(
+ rewriter, loc, arith::CmpIPredicate::sgt, haloSz, zero);
+ scf::IfOp::create(rewriter, loc, hasSize,
+ [&](OpBuilder &builder, Location loc) {
+ genSendRecv(upOrDown > 0);
+ scf::YieldOp::create(builder, loc);
+ });
};
doSendRecv(0);
@@ -852,8 +861,8 @@ struct ConvertUpdateHaloOp : public OpConversionPattern<UpdateHaloOp> {
rewriter.replaceOp(op, array);
} else {
assert(isa<RankedTensorType>(op.getResult().getType()));
- rewriter.replaceOp(op, rewriter.create<bufferization::ToTensorOp>(
- loc, op.getResult().getType(), array,
+ rewriter.replaceOp(op, bufferization::ToTensorOp::create(
+ rewriter, loc, op.getResult().getType(), array,
/*restrict=*/true, /*writable=*/true));
}
return success();
diff --git a/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp b/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp
index 80b3d85488495..905287e107b0b 100644
--- a/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp
+++ b/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp
@@ -53,7 +53,7 @@ static Value truncToI32(ImplicitLocOpBuilder &b, Value value) {
assert(llvm::isa<IntegerType>(type) && "expected an integer Value");
if (type.getIntOrFloatBitWidth() <= 32)
return value;
- return b.create<LLVM::TruncOp>(b.getI32Type(), value);
+ return LLVM::TruncOp::create(b, b.getI32Type(), value);
}
/// Returns the type for the intrinsic given the vectorResultType of the
@@ -113,8 +113,8 @@ static Value convertIntrinsicResult(Location loc, Type intrinsicResultType,
Type f32x1Ty = VectorType::get(1, f32Ty);
auto makeConst = [&](int32_t index) -> Value {
- return rewriter.create<LLVM::ConstantOp>(loc, IntegerType::get(ctx, 32),
- rewriter.getI32IntegerAttr(index));
+ return LLVM::ConstantOp::create(rewriter, loc, IntegerType::get(ctx, 32),
+ rewriter.getI32IntegerAttr(index));
};
if (arrayType) {
@@ -126,7 +126,7 @@ static Value convertIntrinsicResult(Location loc, Type intrinsicResultType,
arrayType.getElementType() == f32x1Ty) {
for (unsigned i = 0; i < structType.getBody().size(); i++) {
Value el =
- rewriter.create<LLVM::ExtractValueOp>(loc, intrinsicResult, i);
+ LLVM::ExtractValueOp::create(rewriter, loc, intrinsicResult, i);
el = rewriter.createOrFold<LLVM::BitcastOp>(
loc, arrayType.getElementType(), el);
elements.push_back(el);
@@ -143,24 +143,24 @@ static Value convertIntrinsicResult(Location loc, Type intrinsicResultType,
for (unsigned i = 0, e = structType.getBody().size() / 2; i < e; i++) {
Value vec =
- rewriter.create<LLVM::PoisonOp>(loc, arrayType.getElementType());
+ LLVM::PoisonOp::create(rewriter, loc, arrayType.getElementType());
Value x1 =
- rewriter.create<LLVM::ExtractValueOp>(loc, intrinsicResult, i * 2);
- Value x2 = rewriter.create<LLVM::ExtractValueOp>(loc, intrinsicResult,
- i * 2 + 1);
- vec = rewriter.create<LLVM::InsertElementOp>(loc, vec.getType(), vec,
- x1, makeConst(0));
- vec = rewriter.create<LLVM::InsertElementOp>(loc, vec.getType(), vec,
- x2, makeConst(1));
+ LLVM::ExtractValueOp::create(rewriter, loc, intrinsicResult, i * 2);
+ Value x2 = LLVM::ExtractValueOp::create(rewriter, loc, intrinsicResult,
+ i * 2 + 1);
+ vec = LLVM::InsertElementOp::create(rewriter, loc, vec.getType(), vec,
+ x1, makeConst(0));
+ vec = LLVM::InsertElementOp::create(rewriter, loc, vec.getType(), vec,
+ x2, makeConst(1));
elements.push_back(vec);
}
}
// Create the final vectorized result.
- Value result = rewriter.create<LLVM::PoisonOp>(loc, arrayType);
+ Value result = LLVM::PoisonOp::create(rewriter, loc, arrayType);
for (const auto &el : llvm::enumerate(elements)) {
- result = rewriter.create<LLVM::InsertValueOp>(loc, result, el.value(),
- el.index());
+ result = LLVM::InsertValueOp::create(rewriter, loc, result, el.value(),
+ el.index());
}
return result;
}
@@ -187,7 +187,7 @@ static SmallVector<Value> unpackOperandVector(ImplicitLocOpBuilder &b,
auto arrayTy = cast<LLVM::LLVMArrayType>(operand.getType());
for (unsigned i = 0, e = arrayTy.getNumElements(); i < e; ++i) {
- Value toUse = b.create<LLVM::ExtractValueOp>(operand, i);
+ Value toUse = LLVM::ExtractValueOp::create(b, operand, i);
// For 4xi8 vectors, the intrinsic expects these to be provided as i32
// scalar types.
@@ -195,7 +195,7 @@ static SmallVector<Value> unpackOperandVector(ImplicitLocOpBuilder &b,
arrayTy.getElementType() == i4x8Ty ||
(arrayTy.getElementType() == f32x1Ty &&
operandPtxType == NVVM::MMATypes::tf32)) {
- result.push_back(b.create<LLVM::BitcastOp>(i32Ty, toUse));
+ result.push_back(LLVM::BitcastOp::create(b, i32Ty, toUse));
continue;
}
@@ -208,9 +208,9 @@ static SmallVector<Value> unpackOperandVector(ImplicitLocOpBuilder &b,
innerArrayTy.getElementType() == f32Ty)) {
for (unsigned idx = 0, innerSize = innerArrayTy.getNumElements();
idx < innerSize; idx++) {
- result.push_back(b.create<LLVM::ExtractElementOp>(
- toUse,
- b.create<LLVM::ConstantOp>(i64Ty, b.getI64IntegerAttr(idx))));
+ result.push_back(LLVM::ExtractElementOp::create(
+ b, toUse,
+ LLVM::ConstantOp::create(b, i64Ty, b.getI64IntegerAttr(idx))));
}
continue;
}
@@ -285,8 +285,8 @@ struct MmaLdMatrixOpToNVVM : public ConvertOpToLLVMPattern<nvgpu::LdMatrixOp> {
Value srcPtr =
getStridedElementPtr(rewriter, b.getLoc(), srcMemrefType,
adaptor.getSrcMemref(), adaptor.getIndices());
- Value ldMatrixResult = b.create<NVVM::LdMatrixOp>(
- ldMatrixResultType, srcPtr,
+ Value ldMatrixResult = NVVM::LdMatrixOp::create(
+ b, ldMatrixResultType, srcPtr,
/*num=*/op.getNumTiles(),
/*layout=*/op.getTranspose() ? NVVM::MMALayout::col
: NVVM::MMALayout::row);
@@ -296,13 +296,13 @@ struct MmaLdMatrixOpToNVVM : public ConvertOpToLLVMPattern<nvgpu::LdMatrixOp> {
// actual vector type (still of width 32b) and repack them into a result
// struct.
Type finalResultType = typeConverter->convertType(vectorResultType);
- Value result = b.create<LLVM::PoisonOp>(finalResultType);
+ Value result = LLVM::PoisonOp::create(b, finalResultType);
for (int64_t i = 0, e = vectorResultType.getDimSize(0); i < e; i++) {
Value i32Register =
- num32BitRegs > 1 ? b.create<LLVM::ExtractValueOp>(ldMatrixResult, i)
+ num32BitRegs > 1 ? LLVM::ExtractValueOp::create(b, ldMatrixResult, i)
: ldMatrixResult;
- Value casted = b.create<LLVM::BitcastOp>(innerVectorType, i32Register);
- result = b.create<LLVM::InsertValueOp>(result, casted, i);
+ Value casted = LLVM::BitcastOp::create(b, innerVectorType, i32Register);
+ result = LLVM::InsertValueOp::create(b, result, casted, i);
}
rewriter.replaceOp(op, result);
@@ -375,16 +375,16 @@ struct MmaSyncOptoNVVM : public ConvertOpToLLVMPattern<nvgpu::MmaSyncOp> {
Type desiredRetTy = typeConverter->convertType(op->getResultTypes()[0]);
Type intrinsicResTy = inferIntrinsicResultType(
typeConverter->convertType(op->getResultTypes()[0]));
- Value intrinsicResult = b.create<NVVM::MmaOp>(
- intrinsicResTy, matA, matB, matC,
- /*shape=*/gemmShape,
- /*b1Op=*/std::nullopt,
- /*intOverflow=*/overflow,
- /*multiplicandPtxTypes=*/
- std::array<NVVM::MMATypes, 2>{*ptxTypeA, *ptxTypeB},
- /*multiplicandLayouts=*/
- std::array<NVVM::MMALayout, 2>{NVVM::MMALayout::row,
- NVVM::MMALayout::col});
+ Value intrinsicResult =
+ NVVM::MmaOp::create(b, intrinsicResTy, matA, matB, matC,
+ /*shape=*/gemmShape,
+ /*b1Op=*/std::nullopt,
+ /*intOverflow=*/overflow,
+ /*multiplicandPtxTypes=*/
+ std::array<NVVM::MMATypes, 2>{*ptxTypeA, *ptxTypeB},
+ /*multiplicandLayouts=*/
+ std::array<NVVM::MMALayout, 2>{
+ NVVM::MMALayout::row, NVVM::MMALayout::col});
rewriter.replaceOp(op, convertIntrinsicResult(op.getLoc(), intrinsicResTy,
desiredRetTy, intrinsicResult,
rewriter));
@@ -565,15 +565,16 @@ static FailureOr<LLVM::InlineAsmOp> emitMmaSparseSyncOpAsm(
llvm::append_range(asmVals, args);
asmVals.push_back(indexData);
- return b.create<LLVM::InlineAsmOp>(
- /*resultTypes=*/intrinsicResultType,
- /*operands=*/asmVals,
- /*asm_string=*/asmStr,
- /*constraints=*/constraintStr,
- /*has_side_effects=*/true,
- /*is_align_stack=*/false, LLVM::TailCallKind::None,
- /*asm_dialect=*/asmDialectAttr,
- /*operand_attrs=*/ArrayAttr());
+ return LLVM::InlineAsmOp::create(b,
+ /*resultTypes=*/intrinsicResultType,
+ /*operands=*/asmVals,
+ /*asm_string=*/asmStr,
+ /*constraints=*/constraintStr,
+ /*has_side_effects=*/true,
+ /*is_align_stack=*/false,
+ LLVM::TailCallKind::None,
+ /*asm_dialect=*/asmDialectAttr,
+ /*operand_attrs=*/ArrayAttr());
}
/// Lowers `nvgpu.mma.sp.sync` to inline assembly.
@@ -631,7 +632,7 @@ struct NVGPUMmaSparseSyncLowering
return op->emitOpError() << "Expected metadata type to be LLVM "
"VectorType of 2 i16 elements";
sparseMetadata =
- b.create<LLVM::BitcastOp>(rewriter.getI32Type(), sparseMetadata);
+ LLVM::BitcastOp::create(b, rewriter.getI32Type(), sparseMetadata);
FailureOr<LLVM::InlineAsmOp> intrinsicResult = emitMmaSparseSyncOpAsm(
b, *ptxTypeA, *ptxTypeB, *ptxTypeC, *ptxTypeC, overflow, matA, matB,
@@ -682,7 +683,7 @@ struct NVGPUAsyncCopyLowering
// Intrinsics takes a global pointer so we need an address space cast.
auto srcPointerGlobalType = LLVM::LLVMPointerType::get(
op->getContext(), NVVM::NVVMMemorySpace::kGlobalMemorySpace);
- scrPtr = b.create<LLVM::AddrSpaceCastOp>(srcPointerGlobalType, scrPtr);
+ scrPtr = LLVM::AddrSpaceCastOp::create(b, srcPointerGlobalType, scrPtr);
int64_t dstElements = adaptor.getDstElements().getZExtValue();
int64_t sizeInBytes =
(dstMemrefType.getElementTypeBitWidth() * dstElements) / 8;
@@ -697,13 +698,13 @@ struct NVGPUAsyncCopyLowering
// The rest of the DstElements in the destination (shared memory) are
// filled with zeros.
Value c3I32 =
- b.create<LLVM::ConstantOp>(b.getI32Type(), b.getI32IntegerAttr(3));
- Value bitwidth = b.create<LLVM::ConstantOp>(
- b.getI32Type(),
+ LLVM::ConstantOp::create(b, b.getI32Type(), b.getI32IntegerAttr(3));
+ Value bitwidth = LLVM::ConstantOp::create(
+ b, b.getI32Type(),
b.getI32IntegerAttr(srcMemrefType.getElementTypeBitWidth()));
- Value srcElementsI32 = b.create<LLVM::TruncOp>(b.getI32Type(), srcBytes);
- srcBytes = b.create<LLVM::LShrOp>(
- b.create<LLVM::MulOp>(bitwidth, srcElementsI32), c3I32);
+ Value srcElementsI32 = LLVM::TruncOp::create(b, b.getI32Type(), srcBytes);
+ srcBytes = LLVM::LShrOp::create(
+ b, LLVM::MulOp::create(b, bitwidth, srcElementsI32), c3I32);
}
// Cache global (.cg) for 16 dst bytes, Cache all (.ca) for sizes other than
// 16 dst bytes.
@@ -712,14 +713,15 @@ struct NVGPUAsyncCopyLowering
? NVVM::LoadCacheModifierKind::CG
: NVVM::LoadCacheModifierKind::CA;
- b.create<NVVM::CpAsyncOp>(
- dstPtr, scrPtr, rewriter.getI32IntegerAttr(sizeInBytes),
+ NVVM::CpAsyncOp::create(
+ b, dstPtr, scrPtr, rewriter.getI32IntegerAttr(sizeInBytes),
NVVM::LoadCacheModifierKindAttr::get(op->getContext(), cacheModifier),
srcBytes);
// Drop the result token.
- Value zero = b.create<LLVM::ConstantOp>(
- IntegerType::get(op.getContext(), 32), rewriter.getI32IntegerAttr(0));
+ Value zero =
+ LLVM::ConstantOp::create(b, IntegerType::get(op.getContext(), 32),
+ rewriter.getI32IntegerAttr(0));
rewriter.replaceOp(op, zero);
return success();
}
@@ -733,11 +735,11 @@ struct NVGPUAsyncCreateGroupLowering
LogicalResult
matchAndRewrite(nvgpu::DeviceAsyncCreateGroupOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
- rewriter.create<NVVM::CpAsyncCommitGroupOp>(op.getLoc());
+ NVVM::CpAsyncCommitGroupOp::create(rewriter, op.getLoc());
// Drop the result token.
- Value zero = rewriter.create<LLVM::ConstantOp>(
- op->getLoc(), IntegerType::get(op.getContext(), 32),
- rewriter.getI32IntegerAttr(0));
+ Value zero = LLVM::ConstantOp::create(rewriter, op->getLoc(),
+ IntegerType::get(op.getContext(), 32),
+ rewriter.getI32IntegerAttr(0));
rewriter.replaceOp(op, zero);
return success();
}
@@ -753,7 +755,7 @@ struct NVGPUAsyncWaitLowering
ConversionPatternRewriter &rewriter) const override {
// If numGroup is not present pick 0 as a conservative correct value.
int32_t numGroups = adaptor.getNumGroups().value_or(0);
- rewriter.create<NVVM::CpAsyncWaitGroupOp>(op.getLoc(), numGroups);
+ NVVM::CpAsyncWaitGroupOp::create(rewriter, op.getLoc(), numGroups);
rewriter.eraseOp(op);
return success();
}
@@ -771,8 +773,8 @@ struct NVGPUMBarrierCreateLowering
SymbolTable symbolTable(moduleOp);
OpBuilder::InsertionGuard guard(rewriter);
rewriter.setInsertionPoint(&moduleOp.front());
- auto global = rewriter.create<memref::GlobalOp>(
- funcOp->getLoc(), "__mbarrier",
+ auto global = memref::GlobalOp::create(
+ rewriter, funcOp->getLoc(), "__mbarrier",
/*sym_visibility=*/rewriter.getStringAttr("private"),
/*type=*/barrierType,
/*initial_value=*/ElementsAttr(),
@@ -974,7 +976,7 @@ struct NVGPUMBarrierTryWaitParityLowering
adaptor.getMbarId(), rewriter);
Value ticks = truncToI32(b, adaptor.getTicks());
Value phase =
- b.create<LLVM::ZExtOp>(b.getI32Type(), adaptor.getPhaseParity());
+ LLVM::ZExtOp::create(b, b.getI32Type(), adaptor.getPhaseParity());
if (isMbarrierShared(op.getBarriers().getType())) {
rewriter.replaceOpWithNewOp<NVVM::MBarrierTryWaitParitySharedOp>(
@@ -1063,16 +1065,16 @@ struct NVGPUGenerateWarpgroupDescriptorLowering
auto ti64 = b.getIntegerType(64);
auto makeConst = [&](uint64_t index) -> Value {
- return b.create<LLVM::ConstantOp>(ti64, b.getI64IntegerAttr(index));
+ return LLVM::ConstantOp::create(b, ti64, b.getI64IntegerAttr(index));
};
auto shiftLeft = [&](Value value, unsigned shift) -> Value {
- return b.create<LLVM::ShlOp>(ti64, value, makeConst(shift));
+ return LLVM::ShlOp::create(b, ti64, value, makeConst(shift));
};
auto shiftRight = [&](Value value, unsigned shift) -> Value {
- return b.create<LLVM::LShrOp>(ti64, value, makeConst(shift));
+ return LLVM::LShrOp::create(b, ti64, value, makeConst(shift));
};
auto insertBit = [&](Value desc, Value val, int startBit) {
- return b.create<LLVM::OrOp>(ti64, desc, shiftLeft(val, startBit));
+ return LLVM::OrOp::create(b, ti64, desc, shiftLeft(val, startBit));
};
int64_t sizeN = op.getTensorMap().getType().getTensor().getDimSize(0);
@@ -1086,7 +1088,7 @@ struct NVGPUGenerateWarpgroupDescriptorLowering
Value baseAddr = getStridedElementPtr(
rewriter, op->getLoc(), cast<MemRefType>(op.getTensor().getType()),
adaptor.getTensor(), {});
- Value basePtr = b.create<LLVM::PtrToIntOp>(ti64, baseAddr);
+ Value basePtr = LLVM::PtrToIntOp::create(b, ti64, baseAddr);
// Just use 14 bits for base address
Value basePtr14bit = shiftRight(shiftLeft(basePtr, 46), 50);
@@ -1118,8 +1120,8 @@ struct NVGPUGenerateWarpgroupDescriptorLowering
};
static Value makeI64Const(ImplicitLocOpBuilder &b, int32_t index) {
- return b.create<LLVM::ConstantOp>(b.getIntegerType(64),
- b.getI32IntegerAttr(index));
+ return LLVM::ConstantOp::create(b, b.getIntegerType(64),
+ b.getI32IntegerAttr(index));
}
/// Returns a Value that holds data type enum that is expected by CUDA driver.
@@ -1182,12 +1184,12 @@ struct NVGPUTmaCreateDescriptorOpLowering
auto promotedOperands = getTypeConverter()->promoteOperands(
b.getLoc(), op->getOperands(), adaptor.getOperands(), b);
- Value boxArrayPtr = b.create<LLVM::AllocaOp>(llvmPointerType, llvmInt64Type,
- makeI64Const(b, 5));
+ Value boxArrayPtr = LLVM::AllocaOp::create(
+ b, llvmPointerType, llvmInt64Type, makeI64Const(b, 5));
for (auto [index, value] : llvm::enumerate(adaptor.getBoxDimensions())) {
- Value gep = b.create<LLVM::GEPOp>(llvmPointerType, llvmPointerType,
- boxArrayPtr, makeI64Const(b, index));
- b.create<LLVM::StoreOp>(value, gep);
+ Value gep = LLVM::GEPOp::create(b, llvmPointerType, llvmPointerType,
+ boxArrayPtr, makeI64Const(b, index));
+ LLVM::StoreOp::create(b, value, gep);
}
nvgpu::TensorMapDescriptorType desc = op.getTensorMap().getType();
@@ -1337,7 +1339,7 @@ struct NVGPUWarpgroupMmaOpLowering
/// Basic function to generate Add
Value makeAdd(Value lhs, Value rhs) {
- return b.create<LLVM::AddOp>(lhs.getType(), lhs, rhs);
+ return LLVM::AddOp::create(b, lhs.getType(), lhs, rhs);
};
/// Moves the descriptor pointer of matrix-A for the next wgmma instruction.
@@ -1430,29 +1432,30 @@ struct NVGPUWarpgroupMmaOpLowering
auto overflow = NVVM::MMAIntOverflowAttr::get(
op->getContext(), NVVM::MMAIntOverflow::wrapped);
- return b.create<NVVM::WgmmaMmaAsyncOp>(
- matrixC.getType(), matrixC, descriptorA, descriptorB, shape, itypeA,
- itypeB, itypeD, scaleOut, scaleIn, scaleIn, layoutA, layoutB,
+ return NVVM::WgmmaMmaAsyncOp::create(
+ b, matrixC.getType(), matrixC, descriptorA, descriptorB, shape,
+ itypeA, itypeB, itypeD, scaleOut, scaleIn, scaleIn, layoutA, layoutB,
overflow);
}
/// Generates multiple wgmma instructions to complete the given GEMM shape
Value generateWgmmaGroup() {
Value wgmmaResult =
- b.create<LLVM::PoisonOp>(adaptor.getMatrixC().getType());
+ LLVM::PoisonOp::create(b, adaptor.getMatrixC().getType());
// Perform GEMM
SmallVector<Value> wgmmaResults;
for (int i = 0; i < iterationM; ++i) {
- Value matrixC = b.create<LLVM::ExtractValueOp>(adaptor.getMatrixC(), i);
+ Value matrixC =
+ LLVM::ExtractValueOp::create(b, adaptor.getMatrixC(), i);
for (int j = 0; j < iterationN; ++j)
for (int k = 0; k < iterationK; ++k)
matrixC = generateWgmma(i, j, k, matrixC);
wgmmaResults.push_back(matrixC);
}
for (auto [idx, matrix] : llvm::enumerate(wgmmaResults)) {
- wgmmaResult = b.create<LLVM::InsertValueOp>(wgmmaResult.getType(),
- wgmmaResult, matrix, idx);
+ wgmmaResult = LLVM::InsertValueOp::create(b, wgmmaResult.getType(),
+ wgmmaResult, matrix, idx);
}
return wgmmaResult;
}
@@ -1486,10 +1489,10 @@ struct NVGPUWarpgroupMmaOpLowering
/// (WgmmaGroupSyncAlignedOp) for group synchronization
/// (WgmmaWaitGroupSyncOp) after the instructions.
Value generateWarpgroupMma() {
- b.create<NVVM::WgmmaFenceAlignedOp>();
+ NVVM::WgmmaFenceAlignedOp::create(b);
Value wgmmaResult = generateWgmmaGroup();
- b.create<NVVM::WgmmaGroupSyncAlignedOp>();
- b.create<NVVM::WgmmaWaitGroupSyncOp>(op.getWaitGroup());
+ NVVM::WgmmaGroupSyncAlignedOp::create(b);
+ NVVM::WgmmaWaitGroupSyncOp::create(b, op.getWaitGroup());
return wgmmaResult;
}
};
@@ -1557,7 +1560,7 @@ struct NVGPUWarpgroupMmaStoreOpLowering
Type i32 = b.getI32Type();
auto makeConst = [&](int32_t index) -> Value {
- return b.create<LLVM::ConstantOp>(i32, b.getI32IntegerAttr(index));
+ return LLVM::ConstantOp::create(b, i32, b.getI32IntegerAttr(index));
};
Value c1 = makeConst(1);
Value c2 = makeConst(2);
@@ -1567,29 +1570,29 @@ struct NVGPUWarpgroupMmaStoreOpLowering
Value warpSize = makeConst(kWarpSize);
auto makeMul = [&](Value lhs, Value rhs) -> Value {
- return b.create<LLVM::MulOp>(lhs.getType(), lhs, rhs);
+ return LLVM::MulOp::create(b, lhs.getType(), lhs, rhs);
};
auto makeAdd = [&](Value lhs, Value rhs) -> Value {
- return b.create<LLVM::AddOp>(lhs.getType(), lhs, rhs);
+ return LLVM::AddOp::create(b, lhs.getType(), lhs, rhs);
};
auto makeExtractAndStore = [&](int i, Value wgmmaResult, Value x, Value y,
TypedValue<::mlir::MemRefType> memref) {
Type it = b.getIndexType();
- Value idx = b.create<arith::IndexCastOp>(it, x);
- Value idy0 = b.create<arith::IndexCastOp>(it, y);
- Value idy1 = b.create<arith::IndexCastOp>(it, makeAdd(y, c1));
- Value d0 = b.create<LLVM::ExtractValueOp>(wgmmaResult, i);
- Value d1 = b.create<LLVM::ExtractValueOp>(wgmmaResult, i + 1);
- b.create<memref::StoreOp>(d0, memref, ValueRange{idx, idy0});
- b.create<memref::StoreOp>(d1, memref, ValueRange{idx, idy1});
+ Value idx = arith::IndexCastOp::create(b, it, x);
+ Value idy0 = arith::IndexCastOp::create(b, it, y);
+ Value idy1 = arith::IndexCastOp::create(b, it, makeAdd(y, c1));
+ Value d0 = LLVM::ExtractValueOp::create(b, wgmmaResult, i);
+ Value d1 = LLVM::ExtractValueOp::create(b, wgmmaResult, i + 1);
+ memref::StoreOp::create(b, d0, memref, ValueRange{idx, idy0});
+ memref::StoreOp::create(b, d1, memref, ValueRange{idx, idy1});
};
- Value tidx = b.create<NVVM::ThreadIdXOp>(i32);
- Value laneId = b.create<LLVM::URemOp>(i32, tidx, warpSize);
- Value warpId = b.create<LLVM::UDivOp>(i32, tidx, warpSize);
- Value lane4Id = b.create<LLVM::UDivOp>(i32, laneId, c4);
- Value lane4modId = b.create<LLVM::URemOp>(i32, laneId, c4);
+ Value tidx = NVVM::ThreadIdXOp::create(b, i32);
+ Value laneId = LLVM::URemOp::create(b, i32, tidx, warpSize);
+ Value warpId = LLVM::UDivOp::create(b, i32, tidx, warpSize);
+ Value lane4Id = LLVM::UDivOp::create(b, i32, laneId, c4);
+ Value lane4modId = LLVM::URemOp::create(b, i32, laneId, c4);
Value tj = makeMul(lane4modId, c2);
Value ti = makeAdd(lane4Id, makeMul(warpId, c16));
@@ -1626,7 +1629,8 @@ struct NVGPUWarpgroupMmaStoreOpLowering
auto stype = cast<LLVM::LLVMStructType>(matriDValue.getType());
for (auto [idx, matrixD] : llvm::enumerate(stype.getBody())) {
auto structType = cast<LLVM::LLVMStructType>(matrixD);
- Value innerStructValue = b.create<LLVM::ExtractValueOp>(matriDValue, idx);
+ Value innerStructValue =
+ LLVM::ExtractValueOp::create(b, matriDValue, idx);
storeFragmentedMatrix(b, innerStructValue, op.getDstMemref(), offset);
offset += structType.getBody().size();
}
@@ -1648,23 +1652,23 @@ struct NVGPUWarpgroupMmaInitAccumulatorOpLowering
Type elemType = cast<LLVM::LLVMStructType>(packStructType.getBody().front())
.getBody()
.front();
- Value zero = b.create<LLVM::ConstantOp>(elemType, b.getZeroAttr(elemType));
- Value packStruct = b.create<LLVM::PoisonOp>(packStructType);
+ Value zero = LLVM::ConstantOp::create(b, elemType, b.getZeroAttr(elemType));
+ Value packStruct = LLVM::PoisonOp::create(b, packStructType);
SmallVector<Value> innerStructs;
// Unpack the structs and set all values to zero
for (auto [idx, s] : llvm::enumerate(packStructType.getBody())) {
auto structType = cast<LLVM::LLVMStructType>(s);
- Value structValue = b.create<LLVM::ExtractValueOp>(packStruct, idx);
+ Value structValue = LLVM::ExtractValueOp::create(b, packStruct, idx);
for (unsigned i = 0; i < structType.getBody().size(); ++i) {
- structValue = b.create<LLVM::InsertValueOp>(
- structType, structValue, zero, ArrayRef<int64_t>({i}));
+ structValue = LLVM::InsertValueOp::create(b, structType, structValue,
+ zero, ArrayRef<int64_t>({i}));
}
innerStructs.push_back(structValue);
}
// Pack the inner structs into a single struct
for (auto [idx, matrix] : llvm::enumerate(innerStructs)) {
- packStruct = b.create<LLVM::InsertValueOp>(packStruct.getType(),
- packStruct, matrix, idx);
+ packStruct = LLVM::InsertValueOp::create(b, packStruct.getType(),
+ packStruct, matrix, idx);
}
rewriter.replaceOp(op, packStruct);
return success();
@@ -1681,7 +1685,7 @@ struct NVGPUTmaFenceOpLowering
ImplicitLocOpBuilder b(op->getLoc(), rewriter);
auto i32Ty = b.getI32Type();
Value tensormapSize =
- b.create<LLVM::ConstantOp>(i32Ty, rewriter.getI32IntegerAttr(128));
+ LLVM::ConstantOp::create(b, i32Ty, rewriter.getI32IntegerAttr(128));
auto memscope =
NVVM::MemScopeKindAttr::get(ctx, ::mlir::NVVM::MemScopeKind::SYS);
@@ -1716,13 +1720,13 @@ struct NVGPURcpOpLowering : public ConvertOpToLLVMPattern<nvgpu::RcpOp> {
VectorType inTy = op.getIn().getType();
// apply rcp.approx.ftz.f on each element in vector.
auto convert1DVec = [&](Type llvm1DVectorTy, Value inVec) {
- Value ret1DVec = b.create<LLVM::PoisonOp>(llvm1DVectorTy);
+ Value ret1DVec = LLVM::PoisonOp::create(b, llvm1DVectorTy);
int numElems = llvm::cast<VectorType>(llvm1DVectorTy).getNumElements();
for (int i = 0; i < numElems; i++) {
- Value idx = b.create<LLVM::ConstantOp>(i64Ty, b.getI64IntegerAttr(i));
- Value elem = b.create<LLVM::ExtractElementOp>(inVec, idx);
- Value dst = b.create<NVVM::RcpApproxFtzF32Op>(f32Ty, elem);
- ret1DVec = b.create<LLVM::InsertElementOp>(ret1DVec, dst, idx);
+ Value idx = LLVM::ConstantOp::create(b, i64Ty, b.getI64IntegerAttr(i));
+ Value elem = LLVM::ExtractElementOp::create(b, inVec, idx);
+ Value dst = NVVM::RcpApproxFtzF32Op::create(b, f32Ty, elem);
+ ret1DVec = LLVM::InsertElementOp::create(b, ret1DVec, dst, idx);
}
return ret1DVec;
};
diff --git a/mlir/lib/Conversion/OpenACCToSCF/OpenACCToSCF.cpp b/mlir/lib/Conversion/OpenACCToSCF/OpenACCToSCF.cpp
index 479725aae8afd..f5b3689c88d26 100644
--- a/mlir/lib/Conversion/OpenACCToSCF/OpenACCToSCF.cpp
+++ b/mlir/lib/Conversion/OpenACCToSCF/OpenACCToSCF.cpp
@@ -39,8 +39,8 @@ class ExpandIfCondition : public OpRewritePattern<OpTy> {
IntegerAttr constAttr;
if (!matchPattern(op.getIfCond(), m_Constant(&constAttr))) {
- auto ifOp = rewriter.create<scf::IfOp>(op.getLoc(), TypeRange(),
- op.getIfCond(), false);
+ auto ifOp = scf::IfOp::create(rewriter, op.getLoc(), TypeRange(),
+ op.getIfCond(), false);
rewriter.modifyOpInPlace(op, [&]() { op.getIfCondMutable().erase(0); });
auto thenBodyBuilder = ifOp.getThenBodyBuilder(rewriter.getListener());
thenBodyBuilder.clone(*op.getOperation());
diff --git a/mlir/lib/Conversion/OpenMPToLLVM/OpenMPToLLVM.cpp b/mlir/lib/Conversion/OpenMPToLLVM/OpenMPToLLVM.cpp
index 7ac9687c4eeda..021e31a8ecd97 100644
--- a/mlir/lib/Conversion/OpenMPToLLVM/OpenMPToLLVM.cpp
+++ b/mlir/lib/Conversion/OpenMPToLLVM/OpenMPToLLVM.cpp
@@ -95,8 +95,8 @@ struct OpenMPOpConversion : public ConvertOpToLLVMPattern<T> {
}
// Create new operation.
- auto newOp = rewriter.create<T>(op.getLoc(), resTypes, convertedOperands,
- convertedAttrs);
+ auto newOp = T::create(rewriter, op.getLoc(), resTypes, convertedOperands,
+ convertedAttrs);
// Translate regions.
for (auto [originalRegion, convertedRegion] :
diff --git a/mlir/lib/Conversion/PDLToPDLInterp/PDLToPDLInterp.cpp b/mlir/lib/Conversion/PDLToPDLInterp/PDLToPDLInterp.cpp
index 7d20109b3db59..b711e33cfc0d6 100644
--- a/mlir/lib/Conversion/PDLToPDLInterp/PDLToPDLInterp.cpp
+++ b/mlir/lib/Conversion/PDLToPDLInterp/PDLToPDLInterp.cpp
@@ -196,7 +196,7 @@ Block *PatternLowering::generateMatcher(MatcherNode &node, Region ®ion,
// finalize.
if (isa<ExitNode>(node)) {
builder.setInsertionPointToEnd(block);
- builder.create<pdl_interp::FinalizeOp>(matcherFunc.getLoc());
+ pdl_interp::FinalizeOp::create(builder, matcherFunc.getLoc());
return block;
}
@@ -272,8 +272,8 @@ Value PatternLowering::getValueAt(Block *¤tBlock, Position *pos) {
auto *operationPos = cast<OperationPosition>(pos);
if (operationPos->isOperandDefiningOp())
// Standard (downward) traversal which directly follows the defining op.
- value = builder.create<pdl_interp::GetDefiningOpOp>(
- loc, builder.getType<pdl::OperationType>(), parentVal);
+ value = pdl_interp::GetDefiningOpOp::create(
+ builder, loc, builder.getType<pdl::OperationType>(), parentVal);
else
// A passthrough operation position.
value = parentVal;
@@ -287,23 +287,23 @@ Value PatternLowering::getValueAt(Block *¤tBlock, Position *pos) {
// requested to use a representative value (e.g., upward traversal).
if (isa<pdl::RangeType>(parentVal.getType()) &&
usersPos->useRepresentative())
- value = builder.create<pdl_interp::ExtractOp>(loc, parentVal, 0);
+ value = pdl_interp::ExtractOp::create(builder, loc, parentVal, 0);
else
value = parentVal;
// The second operation retrieves the users.
- value = builder.create<pdl_interp::GetUsersOp>(loc, value);
+ value = pdl_interp::GetUsersOp::create(builder, loc, value);
break;
}
case Predicates::ForEachPos: {
assert(!failureBlockStack.empty() && "expected valid failure block");
- auto foreach = builder.create<pdl_interp::ForEachOp>(
- loc, parentVal, failureBlockStack.back(), /*initLoop=*/true);
+ auto foreach = pdl_interp::ForEachOp::create(
+ builder, loc, parentVal, failureBlockStack.back(), /*initLoop=*/true);
value = foreach.getLoopVariable();
// Create the continuation block.
Block *continueBlock = builder.createBlock(&foreach.getRegion());
- builder.create<pdl_interp::ContinueOp>(loc);
+ pdl_interp::ContinueOp::create(builder, loc);
failureBlockStack.push_back(continueBlock);
currentBlock = &foreach.getRegion().front();
@@ -311,62 +311,64 @@ Value PatternLowering::getValueAt(Block *¤tBlock, Position *pos) {
}
case Predicates::OperandPos: {
auto *operandPos = cast<OperandPosition>(pos);
- value = builder.create<pdl_interp::GetOperandOp>(
- loc, builder.getType<pdl::ValueType>(), parentVal,
+ value = pdl_interp::GetOperandOp::create(
+ builder, loc, builder.getType<pdl::ValueType>(), parentVal,
operandPos->getOperandNumber());
break;
}
case Predicates::OperandGroupPos: {
auto *operandPos = cast<OperandGroupPosition>(pos);
Type valueTy = builder.getType<pdl::ValueType>();
- value = builder.create<pdl_interp::GetOperandsOp>(
- loc, operandPos->isVariadic() ? pdl::RangeType::get(valueTy) : valueTy,
+ value = pdl_interp::GetOperandsOp::create(
+ builder, loc,
+ operandPos->isVariadic() ? pdl::RangeType::get(valueTy) : valueTy,
parentVal, operandPos->getOperandGroupNumber());
break;
}
case Predicates::AttributePos: {
auto *attrPos = cast<AttributePosition>(pos);
- value = builder.create<pdl_interp::GetAttributeOp>(
- loc, builder.getType<pdl::AttributeType>(), parentVal,
+ value = pdl_interp::GetAttributeOp::create(
+ builder, loc, builder.getType<pdl::AttributeType>(), parentVal,
attrPos->getName().strref());
break;
}
case Predicates::TypePos: {
if (isa<pdl::AttributeType>(parentVal.getType()))
- value = builder.create<pdl_interp::GetAttributeTypeOp>(loc, parentVal);
+ value = pdl_interp::GetAttributeTypeOp::create(builder, loc, parentVal);
else
- value = builder.create<pdl_interp::GetValueTypeOp>(loc, parentVal);
+ value = pdl_interp::GetValueTypeOp::create(builder, loc, parentVal);
break;
}
case Predicates::ResultPos: {
auto *resPos = cast<ResultPosition>(pos);
- value = builder.create<pdl_interp::GetResultOp>(
- loc, builder.getType<pdl::ValueType>(), parentVal,
+ value = pdl_interp::GetResultOp::create(
+ builder, loc, builder.getType<pdl::ValueType>(), parentVal,
resPos->getResultNumber());
break;
}
case Predicates::ResultGroupPos: {
auto *resPos = cast<ResultGroupPosition>(pos);
Type valueTy = builder.getType<pdl::ValueType>();
- value = builder.create<pdl_interp::GetResultsOp>(
- loc, resPos->isVariadic() ? pdl::RangeType::get(valueTy) : valueTy,
+ value = pdl_interp::GetResultsOp::create(
+ builder, loc,
+ resPos->isVariadic() ? pdl::RangeType::get(valueTy) : valueTy,
parentVal, resPos->getResultGroupNumber());
break;
}
case Predicates::AttributeLiteralPos: {
auto *attrPos = cast<AttributeLiteralPosition>(pos);
- value =
- builder.create<pdl_interp::CreateAttributeOp>(loc, attrPos->getValue());
+ value = pdl_interp::CreateAttributeOp::create(builder, loc,
+ attrPos->getValue());
break;
}
case Predicates::TypeLiteralPos: {
auto *typePos = cast<TypeLiteralPosition>(pos);
Attribute rawTypeAttr = typePos->getValue();
if (TypeAttr typeAttr = dyn_cast<TypeAttr>(rawTypeAttr))
- value = builder.create<pdl_interp::CreateTypeOp>(loc, typeAttr);
+ value = pdl_interp::CreateTypeOp::create(builder, loc, typeAttr);
else
- value = builder.create<pdl_interp::CreateTypesOp>(
- loc, cast<ArrayAttr>(rawTypeAttr));
+ value = pdl_interp::CreateTypesOp::create(builder, loc,
+ cast<ArrayAttr>(rawTypeAttr));
break;
}
case Predicates::ConstraintResultPos: {
@@ -413,56 +415,59 @@ void PatternLowering::generate(BoolNode *boolNode, Block *¤tBlock,
Predicates::Kind kind = question->getKind();
switch (kind) {
case Predicates::IsNotNullQuestion:
- builder.create<pdl_interp::IsNotNullOp>(loc, val, success, failure);
+ pdl_interp::IsNotNullOp::create(builder, loc, val, success, failure);
break;
case Predicates::OperationNameQuestion: {
auto *opNameAnswer = cast<OperationNameAnswer>(answer);
- builder.create<pdl_interp::CheckOperationNameOp>(
- loc, val, opNameAnswer->getValue().getStringRef(), success, failure);
+ pdl_interp::CheckOperationNameOp::create(
+ builder, loc, val, opNameAnswer->getValue().getStringRef(), success,
+ failure);
break;
}
case Predicates::TypeQuestion: {
auto *ans = cast<TypeAnswer>(answer);
if (isa<pdl::RangeType>(val.getType()))
- builder.create<pdl_interp::CheckTypesOp>(
- loc, val, llvm::cast<ArrayAttr>(ans->getValue()), success, failure);
+ pdl_interp::CheckTypesOp::create(builder, loc, val,
+ llvm::cast<ArrayAttr>(ans->getValue()),
+ success, failure);
else
- builder.create<pdl_interp::CheckTypeOp>(
- loc, val, llvm::cast<TypeAttr>(ans->getValue()), success, failure);
+ pdl_interp::CheckTypeOp::create(builder, loc, val,
+ llvm::cast<TypeAttr>(ans->getValue()),
+ success, failure);
break;
}
case Predicates::AttributeQuestion: {
auto *ans = cast<AttributeAnswer>(answer);
- builder.create<pdl_interp::CheckAttributeOp>(loc, val, ans->getValue(),
- success, failure);
+ pdl_interp::CheckAttributeOp::create(builder, loc, val, ans->getValue(),
+ success, failure);
break;
}
case Predicates::OperandCountAtLeastQuestion:
case Predicates::OperandCountQuestion:
- builder.create<pdl_interp::CheckOperandCountOp>(
- loc, val, cast<UnsignedAnswer>(answer)->getValue(),
+ pdl_interp::CheckOperandCountOp::create(
+ builder, loc, val, cast<UnsignedAnswer>(answer)->getValue(),
/*compareAtLeast=*/kind == Predicates::OperandCountAtLeastQuestion,
success, failure);
break;
case Predicates::ResultCountAtLeastQuestion:
case Predicates::ResultCountQuestion:
- builder.create<pdl_interp::CheckResultCountOp>(
- loc, val, cast<UnsignedAnswer>(answer)->getValue(),
+ pdl_interp::CheckResultCountOp::create(
+ builder, loc, val, cast<UnsignedAnswer>(answer)->getValue(),
/*compareAtLeast=*/kind == Predicates::ResultCountAtLeastQuestion,
success, failure);
break;
case Predicates::EqualToQuestion: {
bool trueAnswer = isa<TrueAnswer>(answer);
- builder.create<pdl_interp::AreEqualOp>(loc, val, args.front(),
- trueAnswer ? success : failure,
- trueAnswer ? failure : success);
+ pdl_interp::AreEqualOp::create(builder, loc, val, args.front(),
+ trueAnswer ? success : failure,
+ trueAnswer ? failure : success);
break;
}
case Predicates::ConstraintQuestion: {
auto *cstQuestion = cast<ConstraintQuestion>(question);
- auto applyConstraintOp = builder.create<pdl_interp::ApplyConstraintOp>(
- loc, cstQuestion->getResultTypes(), cstQuestion->getName(), args,
- cstQuestion->getIsNegated(), success, failure);
+ auto applyConstraintOp = pdl_interp::ApplyConstraintOp::create(
+ builder, loc, cstQuestion->getResultTypes(), cstQuestion->getName(),
+ args, cstQuestion->getIsNegated(), success, failure);
constraintOpMap.insert({cstQuestion, applyConstraintOp});
break;
@@ -487,7 +492,7 @@ static void createSwitchOp(Value val, Block *defaultDest, OpBuilder &builder,
blocks.push_back(it.second);
values.push_back(cast<PredT>(it.first)->getValue());
}
- builder.create<OpT>(val.getLoc(), val, values, defaultDest, blocks);
+ OpT::create(builder, val.getLoc(), val, values, defaultDest, blocks);
}
void PatternLowering::generate(SwitchNode *switchNode, Block *currentBlock,
@@ -536,12 +541,14 @@ void PatternLowering::generate(SwitchNode *switchNode, Block *currentBlock,
unsigned ans = cast<UnsignedAnswer>(child.first)->getValue();
switch (kind) {
case Predicates::OperandCountAtLeastQuestion:
- builder.create<pdl_interp::CheckOperandCountOp>(
- loc, val, ans, /*compareAtLeast=*/true, childBlock, defaultDest);
+ pdl_interp::CheckOperandCountOp::create(builder, loc, val, ans,
+ /*compareAtLeast=*/true,
+ childBlock, defaultDest);
break;
case Predicates::ResultCountAtLeastQuestion:
- builder.create<pdl_interp::CheckResultCountOp>(
- loc, val, ans, /*compareAtLeast=*/true, childBlock, defaultDest);
+ pdl_interp::CheckResultCountOp::create(builder, loc, val, ans,
+ /*compareAtLeast=*/true,
+ childBlock, defaultDest);
break;
default:
llvm_unreachable("Generating invalid AtLeast operation");
@@ -619,8 +626,8 @@ void PatternLowering::generate(SuccessNode *successNode, Block *¤tBlock) {
rootKindAttr = builder.getStringAttr(*rootKind);
builder.setInsertionPointToEnd(currentBlock);
- auto matchOp = builder.create<pdl_interp::RecordMatchOp>(
- pattern.getLoc(), mappedMatchValues, locOps.getArrayRef(),
+ auto matchOp = pdl_interp::RecordMatchOp::create(
+ builder, pattern.getLoc(), mappedMatchValues, locOps.getArrayRef(),
rewriterFuncRef, rootKindAttr, generatedOpsAttr, pattern.getBenefitAttr(),
failureBlockStack.back());
@@ -632,8 +639,8 @@ void PatternLowering::generate(SuccessNode *successNode, Block *¤tBlock) {
SymbolRefAttr PatternLowering::generateRewriter(
pdl::PatternOp pattern, SmallVectorImpl<Position *> &usedMatchValues) {
builder.setInsertionPointToEnd(rewriterModule.getBody());
- auto rewriterFunc = builder.create<pdl_interp::FuncOp>(
- pattern.getLoc(), "pdl_generated_rewriter",
+ auto rewriterFunc = pdl_interp::FuncOp::create(
+ builder, pattern.getLoc(), "pdl_generated_rewriter",
builder.getFunctionType({}, {}));
rewriterSymbolTable.insert(rewriterFunc);
@@ -651,18 +658,18 @@ SymbolRefAttr PatternLowering::generateRewriter(
Operation *oldOp = oldValue.getDefiningOp();
if (pdl::AttributeOp attrOp = dyn_cast<pdl::AttributeOp>(oldOp)) {
if (Attribute value = attrOp.getValueAttr()) {
- return newValue = builder.create<pdl_interp::CreateAttributeOp>(
- attrOp.getLoc(), value);
+ return newValue = pdl_interp::CreateAttributeOp::create(
+ builder, attrOp.getLoc(), value);
}
} else if (pdl::TypeOp typeOp = dyn_cast<pdl::TypeOp>(oldOp)) {
if (TypeAttr type = typeOp.getConstantTypeAttr()) {
- return newValue = builder.create<pdl_interp::CreateTypeOp>(
- typeOp.getLoc(), type);
+ return newValue = pdl_interp::CreateTypeOp::create(
+ builder, typeOp.getLoc(), type);
}
} else if (pdl::TypesOp typeOp = dyn_cast<pdl::TypesOp>(oldOp)) {
if (ArrayAttr type = typeOp.getConstantTypesAttr()) {
- return newValue = builder.create<pdl_interp::CreateTypesOp>(
- typeOp.getLoc(), typeOp.getType(), type);
+ return newValue = pdl_interp::CreateTypesOp::create(
+ builder, typeOp.getLoc(), typeOp.getType(), type);
}
}
@@ -684,8 +691,9 @@ SymbolRefAttr PatternLowering::generateRewriter(
auto mappedArgs =
llvm::map_range(rewriter.getExternalArgs(), mapRewriteValue);
args.append(mappedArgs.begin(), mappedArgs.end());
- builder.create<pdl_interp::ApplyRewriteOp>(
- rewriter.getLoc(), /*resultTypes=*/TypeRange(), rewriteName, args);
+ pdl_interp::ApplyRewriteOp::create(builder, rewriter.getLoc(),
+ /*resultTypes=*/TypeRange(), rewriteName,
+ args);
} else {
// Otherwise this is a dag rewriter defined using PDL operations.
for (Operation &rewriteOp : *rewriter.getBody()) {
@@ -703,7 +711,7 @@ SymbolRefAttr PatternLowering::generateRewriter(
llvm::to_vector<8>(rewriterFunc.front().getArgumentTypes()),
/*results=*/{}));
- builder.create<pdl_interp::FinalizeOp>(rewriter.getLoc());
+ pdl_interp::FinalizeOp::create(builder, rewriter.getLoc());
return SymbolRefAttr::get(
builder.getContext(),
pdl_interp::PDLInterpDialect::getRewriterModuleName(),
@@ -716,9 +724,9 @@ void PatternLowering::generateRewriter(
SmallVector<Value, 2> arguments;
for (Value argument : rewriteOp.getArgs())
arguments.push_back(mapRewriteValue(argument));
- auto interpOp = builder.create<pdl_interp::ApplyRewriteOp>(
- rewriteOp.getLoc(), rewriteOp.getResultTypes(), rewriteOp.getNameAttr(),
- arguments);
+ auto interpOp = pdl_interp::ApplyRewriteOp::create(
+ builder, rewriteOp.getLoc(), rewriteOp.getResultTypes(),
+ rewriteOp.getNameAttr(), arguments);
for (auto it : llvm::zip(rewriteOp.getResults(), interpOp.getResults()))
rewriteValues[std::get<0>(it)] = std::get<1>(it);
}
@@ -726,16 +734,16 @@ void PatternLowering::generateRewriter(
void PatternLowering::generateRewriter(
pdl::AttributeOp attrOp, DenseMap<Value, Value> &rewriteValues,
function_ref<Value(Value)> mapRewriteValue) {
- Value newAttr = builder.create<pdl_interp::CreateAttributeOp>(
- attrOp.getLoc(), attrOp.getValueAttr());
+ Value newAttr = pdl_interp::CreateAttributeOp::create(
+ builder, attrOp.getLoc(), attrOp.getValueAttr());
rewriteValues[attrOp] = newAttr;
}
void PatternLowering::generateRewriter(
pdl::EraseOp eraseOp, DenseMap<Value, Value> &rewriteValues,
function_ref<Value(Value)> mapRewriteValue) {
- builder.create<pdl_interp::EraseOp>(eraseOp.getLoc(),
- mapRewriteValue(eraseOp.getOpValue()));
+ pdl_interp::EraseOp::create(builder, eraseOp.getLoc(),
+ mapRewriteValue(eraseOp.getOpValue()));
}
void PatternLowering::generateRewriter(
@@ -756,9 +764,9 @@ void PatternLowering::generateRewriter(
// Create the new operation.
Location loc = operationOp.getLoc();
- Value createdOp = builder.create<pdl_interp::CreateOperationOp>(
- loc, *operationOp.getOpName(), types, hasInferredResultTypes, operands,
- attributes, operationOp.getAttributeValueNames());
+ Value createdOp = pdl_interp::CreateOperationOp::create(
+ builder, loc, *operationOp.getOpName(), types, hasInferredResultTypes,
+ operands, attributes, operationOp.getAttributeValueNames());
rewriteValues[operationOp.getOp()] = createdOp;
// Generate accesses for any results that have their types constrained.
@@ -768,8 +776,8 @@ void PatternLowering::generateRewriter(
if (resultTys.size() == 1 && isa<pdl::RangeType>(resultTys[0].getType())) {
Value &type = rewriteValues[resultTys[0]];
if (!type) {
- auto results = builder.create<pdl_interp::GetResultsOp>(loc, createdOp);
- type = builder.create<pdl_interp::GetValueTypeOp>(loc, results);
+ auto results = pdl_interp::GetResultsOp::create(builder, loc, createdOp);
+ type = pdl_interp::GetValueTypeOp::create(builder, loc, results);
}
return;
}
@@ -789,12 +797,13 @@ void PatternLowering::generateRewriter(
// groups because the exact index of the result is not statically known.
Value resultVal;
if (seenVariableLength)
- resultVal = builder.create<pdl_interp::GetResultsOp>(
- loc, isVariadic ? valueRangeTy : valueTy, createdOp, it.index());
+ resultVal = pdl_interp::GetResultsOp::create(
+ builder, loc, isVariadic ? valueRangeTy : valueTy, createdOp,
+ it.index());
else
- resultVal = builder.create<pdl_interp::GetResultOp>(
- loc, valueTy, createdOp, it.index());
- type = builder.create<pdl_interp::GetValueTypeOp>(loc, resultVal);
+ resultVal = pdl_interp::GetResultOp::create(builder, loc, valueTy,
+ createdOp, it.index());
+ type = pdl_interp::GetValueTypeOp::create(builder, loc, resultVal);
}
}
@@ -804,8 +813,8 @@ void PatternLowering::generateRewriter(
SmallVector<Value, 4> replOperands;
for (Value operand : rangeOp.getArguments())
replOperands.push_back(mapRewriteValue(operand));
- rewriteValues[rangeOp] = builder.create<pdl_interp::CreateRangeOp>(
- rangeOp.getLoc(), rangeOp.getType(), replOperands);
+ rewriteValues[rangeOp] = pdl_interp::CreateRangeOp::create(
+ builder, rangeOp.getLoc(), rangeOp.getType(), replOperands);
}
void PatternLowering::generateRewriter(
@@ -820,8 +829,8 @@ void PatternLowering::generateRewriter(
// Don't use replace if we know the replaced operation has no results.
auto opOp = replaceOp.getOpValue().getDefiningOp<pdl::OperationOp>();
if (!opOp || !opOp.getTypeValues().empty()) {
- replOperands.push_back(builder.create<pdl_interp::GetResultsOp>(
- replOp.getLoc(), mapRewriteValue(replOp)));
+ replOperands.push_back(pdl_interp::GetResultsOp::create(
+ builder, replOp.getLoc(), mapRewriteValue(replOp)));
}
} else {
for (Value operand : replaceOp.getReplValues())
@@ -830,29 +839,29 @@ void PatternLowering::generateRewriter(
// If there are no replacement values, just create an erase instead.
if (replOperands.empty()) {
- builder.create<pdl_interp::EraseOp>(
- replaceOp.getLoc(), mapRewriteValue(replaceOp.getOpValue()));
+ pdl_interp::EraseOp::create(builder, replaceOp.getLoc(),
+ mapRewriteValue(replaceOp.getOpValue()));
return;
}
- builder.create<pdl_interp::ReplaceOp>(replaceOp.getLoc(),
- mapRewriteValue(replaceOp.getOpValue()),
- replOperands);
+ pdl_interp::ReplaceOp::create(builder, replaceOp.getLoc(),
+ mapRewriteValue(replaceOp.getOpValue()),
+ replOperands);
}
void PatternLowering::generateRewriter(
pdl::ResultOp resultOp, DenseMap<Value, Value> &rewriteValues,
function_ref<Value(Value)> mapRewriteValue) {
- rewriteValues[resultOp] = builder.create<pdl_interp::GetResultOp>(
- resultOp.getLoc(), builder.getType<pdl::ValueType>(),
+ rewriteValues[resultOp] = pdl_interp::GetResultOp::create(
+ builder, resultOp.getLoc(), builder.getType<pdl::ValueType>(),
mapRewriteValue(resultOp.getParent()), resultOp.getIndex());
}
void PatternLowering::generateRewriter(
pdl::ResultsOp resultOp, DenseMap<Value, Value> &rewriteValues,
function_ref<Value(Value)> mapRewriteValue) {
- rewriteValues[resultOp] = builder.create<pdl_interp::GetResultsOp>(
- resultOp.getLoc(), resultOp.getType(),
+ rewriteValues[resultOp] = pdl_interp::GetResultsOp::create(
+ builder, resultOp.getLoc(), resultOp.getType(),
mapRewriteValue(resultOp.getParent()), resultOp.getIndex());
}
@@ -863,7 +872,7 @@ void PatternLowering::generateRewriter(
// type.
if (TypeAttr typeAttr = typeOp.getConstantTypeAttr()) {
rewriteValues[typeOp] =
- builder.create<pdl_interp::CreateTypeOp>(typeOp.getLoc(), typeAttr);
+ pdl_interp::CreateTypeOp::create(builder, typeOp.getLoc(), typeAttr);
}
}
@@ -873,8 +882,8 @@ void PatternLowering::generateRewriter(
// If the type isn't constant, the users (e.g. OperationOp) will resolve this
// type.
if (ArrayAttr typeAttr = typeOp.getConstantTypesAttr()) {
- rewriteValues[typeOp] = builder.create<pdl_interp::CreateTypesOp>(
- typeOp.getLoc(), typeOp.getType(), typeAttr);
+ rewriteValues[typeOp] = pdl_interp::CreateTypesOp::create(
+ builder, typeOp.getLoc(), typeOp.getType(), typeAttr);
}
}
@@ -939,10 +948,10 @@ void PatternLowering::generateOperationResultTypeRewriter(
!replacedOp->isBeforeInBlock(op))
continue;
- Value replacedOpResults = builder.create<pdl_interp::GetResultsOp>(
- replacedOp->getLoc(), mapRewriteValue(replOpVal));
- types.push_back(builder.create<pdl_interp::GetValueTypeOp>(
- replacedOp->getLoc(), replacedOpResults));
+ Value replacedOpResults = pdl_interp::GetResultsOp::create(
+ builder, replacedOp->getLoc(), mapRewriteValue(replOpVal));
+ types.push_back(pdl_interp::GetValueTypeOp::create(
+ builder, replacedOp->getLoc(), replacedOpResults));
return;
}
@@ -985,16 +994,18 @@ void PDLToPDLInterpPass::runOnOperation() {
// Create the main matcher function This function contains all of the match
// related functionality from patterns in the module.
OpBuilder builder = OpBuilder::atBlockBegin(module.getBody());
- auto matcherFunc = builder.create<pdl_interp::FuncOp>(
- module.getLoc(), pdl_interp::PDLInterpDialect::getMatcherFunctionName(),
+ auto matcherFunc = pdl_interp::FuncOp::create(
+ builder, module.getLoc(),
+ pdl_interp::PDLInterpDialect::getMatcherFunctionName(),
builder.getFunctionType(builder.getType<pdl::OperationType>(),
/*results=*/{}),
/*attrs=*/ArrayRef<NamedAttribute>());
// Create a nested module to hold the functions invoked for rewriting the IR
// after a successful match.
- ModuleOp rewriterModule = builder.create<ModuleOp>(
- module.getLoc(), pdl_interp::PDLInterpDialect::getRewriterModuleName());
+ ModuleOp rewriterModule =
+ ModuleOp::create(builder, module.getLoc(),
+ pdl_interp::PDLInterpDialect::getRewriterModuleName());
// Generate the code for the patterns within the module.
PatternLowering generator(matcherFunc, rewriterModule, configMap);
diff --git a/mlir/lib/Conversion/SCFToControlFlow/SCFToControlFlow.cpp b/mlir/lib/Conversion/SCFToControlFlow/SCFToControlFlow.cpp
index 0df91a243d07a..240491a51d2b9 100644
--- a/mlir/lib/Conversion/SCFToControlFlow/SCFToControlFlow.cpp
+++ b/mlir/lib/Conversion/SCFToControlFlow/SCFToControlFlow.cpp
@@ -340,7 +340,7 @@ LogicalResult ForLowering::matchAndRewrite(ForOp forOp,
Operation *terminator = lastBodyBlock->getTerminator();
rewriter.setInsertionPointToEnd(lastBodyBlock);
auto step = forOp.getStep();
- auto stepped = rewriter.create<arith::AddIOp>(loc, iv, step).getResult();
+ auto stepped = arith::AddIOp::create(rewriter, loc, iv, step).getResult();
if (!stepped)
return failure();
@@ -348,7 +348,7 @@ LogicalResult ForLowering::matchAndRewrite(ForOp forOp,
loopCarried.push_back(stepped);
loopCarried.append(terminator->operand_begin(), terminator->operand_end());
auto branchOp =
- rewriter.create<cf::BranchOp>(loc, conditionBlock, loopCarried);
+ cf::BranchOp::create(rewriter, loc, conditionBlock, loopCarried);
// Let the CondBranchOp carry the LLVM attributes from the ForOp, such as the
// llvm.loop_annotation attribute.
@@ -375,16 +375,15 @@ LogicalResult ForLowering::matchAndRewrite(ForOp forOp,
SmallVector<Value, 8> destOperands;
destOperands.push_back(lowerBound);
llvm::append_range(destOperands, forOp.getInitArgs());
- rewriter.create<cf::BranchOp>(loc, conditionBlock, destOperands);
+ cf::BranchOp::create(rewriter, loc, conditionBlock, destOperands);
// With the body block done, we can fill in the condition block.
rewriter.setInsertionPointToEnd(conditionBlock);
- auto comparison = rewriter.create<arith::CmpIOp>(
- loc, arith::CmpIPredicate::slt, iv, upperBound);
+ auto comparison = arith::CmpIOp::create(
+ rewriter, loc, arith::CmpIPredicate::slt, iv, upperBound);
- rewriter.create<cf::CondBranchOp>(loc, comparison, firstBodyBlock,
- ArrayRef<Value>(), endBlock,
- ArrayRef<Value>());
+ cf::CondBranchOp::create(rewriter, loc, comparison, firstBodyBlock,
+ ArrayRef<Value>(), endBlock, ArrayRef<Value>());
// The result of the loop operation is the values of the condition block
// arguments except the induction variable on the last iteration.
@@ -409,7 +408,7 @@ LogicalResult IfLowering::matchAndRewrite(IfOp ifOp,
continueBlock =
rewriter.createBlock(remainingOpsBlock, ifOp.getResultTypes(),
SmallVector<Location>(ifOp.getNumResults(), loc));
- rewriter.create<cf::BranchOp>(loc, remainingOpsBlock);
+ cf::BranchOp::create(rewriter, loc, remainingOpsBlock);
}
// Move blocks from the "then" region to the region containing 'scf.if',
@@ -419,7 +418,7 @@ LogicalResult IfLowering::matchAndRewrite(IfOp ifOp,
Operation *thenTerminator = thenRegion.back().getTerminator();
ValueRange thenTerminatorOperands = thenTerminator->getOperands();
rewriter.setInsertionPointToEnd(&thenRegion.back());
- rewriter.create<cf::BranchOp>(loc, continueBlock, thenTerminatorOperands);
+ cf::BranchOp::create(rewriter, loc, continueBlock, thenTerminatorOperands);
rewriter.eraseOp(thenTerminator);
rewriter.inlineRegionBefore(thenRegion, continueBlock);
@@ -433,15 +432,15 @@ LogicalResult IfLowering::matchAndRewrite(IfOp ifOp,
Operation *elseTerminator = elseRegion.back().getTerminator();
ValueRange elseTerminatorOperands = elseTerminator->getOperands();
rewriter.setInsertionPointToEnd(&elseRegion.back());
- rewriter.create<cf::BranchOp>(loc, continueBlock, elseTerminatorOperands);
+ cf::BranchOp::create(rewriter, loc, continueBlock, elseTerminatorOperands);
rewriter.eraseOp(elseTerminator);
rewriter.inlineRegionBefore(elseRegion, continueBlock);
}
rewriter.setInsertionPointToEnd(condBlock);
- rewriter.create<cf::CondBranchOp>(loc, ifOp.getCondition(), thenBlock,
- /*trueArgs=*/ArrayRef<Value>(), elseBlock,
- /*falseArgs=*/ArrayRef<Value>());
+ cf::CondBranchOp::create(rewriter, loc, ifOp.getCondition(), thenBlock,
+ /*trueArgs=*/ArrayRef<Value>(), elseBlock,
+ /*falseArgs=*/ArrayRef<Value>());
// Ok, we're done!
rewriter.replaceOp(ifOp, continueBlock->getArguments());
@@ -459,13 +458,14 @@ ExecuteRegionLowering::matchAndRewrite(ExecuteRegionOp op,
auto ®ion = op.getRegion();
rewriter.setInsertionPointToEnd(condBlock);
- rewriter.create<cf::BranchOp>(loc, ®ion.front());
+ cf::BranchOp::create(rewriter, loc, ®ion.front());
for (Block &block : region) {
if (auto terminator = dyn_cast<scf::YieldOp>(block.getTerminator())) {
ValueRange terminatorOperands = terminator->getOperands();
rewriter.setInsertionPointToEnd(&block);
- rewriter.create<cf::BranchOp>(loc, remainingOpsBlock, terminatorOperands);
+ cf::BranchOp::create(rewriter, loc, remainingOpsBlock,
+ terminatorOperands);
rewriter.eraseOp(terminator);
}
}
@@ -503,7 +503,7 @@ ParallelLowering::matchAndRewrite(ParallelOp parallelOp,
for (auto [iv, lower, upper, step] :
llvm::zip(parallelOp.getInductionVars(), parallelOp.getLowerBound(),
parallelOp.getUpperBound(), parallelOp.getStep())) {
- ForOp forOp = rewriter.create<ForOp>(loc, lower, upper, step, iterArgs);
+ ForOp forOp = ForOp::create(rewriter, loc, lower, upper, step, iterArgs);
ivs.push_back(forOp.getInductionVar());
auto iterRange = forOp.getRegionIterArgs();
iterArgs.assign(iterRange.begin(), iterRange.end());
@@ -517,7 +517,7 @@ ParallelLowering::matchAndRewrite(ParallelOp parallelOp,
// A loop is constructed with an empty "yield" terminator if there are
// no results.
rewriter.setInsertionPointToEnd(rewriter.getInsertionBlock());
- rewriter.create<scf::YieldOp>(loc, forOp.getResults());
+ scf::YieldOp::create(rewriter, loc, forOp.getResults());
}
rewriter.setInsertionPointToStart(forOp.getBody());
@@ -549,7 +549,7 @@ ParallelLowering::matchAndRewrite(ParallelOp parallelOp,
// has been already created in loop construction).
if (!yieldOperands.empty()) {
rewriter.setInsertionPointToEnd(rewriter.getInsertionBlock());
- rewriter.create<scf::YieldOp>(loc, yieldOperands);
+ scf::YieldOp::create(rewriter, loc, yieldOperands);
}
rewriter.replaceOp(parallelOp, loopResults);
@@ -575,7 +575,7 @@ LogicalResult WhileLowering::matchAndRewrite(WhileOp whileOp,
// Branch to the "before" region.
rewriter.setInsertionPointToEnd(currentBlock);
- rewriter.create<cf::BranchOp>(loc, before, whileOp.getInits());
+ cf::BranchOp::create(rewriter, loc, before, whileOp.getInits());
// Replace terminators with branches. Assuming bodies are SESE, which holds
// given only the patterns from this file, we only need to look at the last
@@ -625,14 +625,14 @@ DoWhileLowering::matchAndRewrite(WhileOp whileOp,
// Branch to the "before" region.
rewriter.setInsertionPointToEnd(currentBlock);
- rewriter.create<cf::BranchOp>(whileOp.getLoc(), before, whileOp.getInits());
+ cf::BranchOp::create(rewriter, whileOp.getLoc(), before, whileOp.getInits());
// Loop around the "before" region based on condition.
rewriter.setInsertionPointToEnd(before);
auto condOp = cast<ConditionOp>(before->getTerminator());
- rewriter.create<cf::CondBranchOp>(condOp.getLoc(), condOp.getCondition(),
- before, condOp.getArgs(), continuation,
- ValueRange());
+ cf::CondBranchOp::create(rewriter, condOp.getLoc(), condOp.getCondition(),
+ before, condOp.getArgs(), continuation,
+ ValueRange());
// Replace the op with values "yielded" from the "before" region, which are
// visible by dominance.
@@ -695,12 +695,12 @@ IndexSwitchLowering::matchAndRewrite(IndexSwitchOp op,
SmallVector<ValueRange> caseOperands(caseSuccessors.size(), {});
// Cast switch index to integer case value.
- Value caseValue = rewriter.create<arith::IndexCastOp>(
- op.getLoc(), rewriter.getI32Type(), op.getArg());
+ Value caseValue = arith::IndexCastOp::create(
+ rewriter, op.getLoc(), rewriter.getI32Type(), op.getArg());
- rewriter.create<cf::SwitchOp>(
- op.getLoc(), caseValue, *defaultBlock, ValueRange(),
- rewriter.getDenseI32ArrayAttr(caseValues), caseSuccessors, caseOperands);
+ cf::SwitchOp::create(rewriter, op.getLoc(), caseValue, *defaultBlock,
+ ValueRange(), rewriter.getDenseI32ArrayAttr(caseValues),
+ caseSuccessors, caseOperands);
rewriter.replaceOp(op, continueBlock->getArguments());
return success();
}
diff --git a/mlir/lib/Conversion/SCFToEmitC/SCFToEmitC.cpp b/mlir/lib/Conversion/SCFToEmitC/SCFToEmitC.cpp
index dcb48529a74e6..84cbd869c78ef 100644
--- a/mlir/lib/Conversion/SCFToEmitC/SCFToEmitC.cpp
+++ b/mlir/lib/Conversion/SCFToEmitC/SCFToEmitC.cpp
@@ -91,7 +91,7 @@ createVariablesForResults(T op, const TypeConverter *typeConverter,
Type varType = emitc::LValueType::get(resultType);
emitc::OpaqueAttr noInit = emitc::OpaqueAttr::get(context, "");
emitc::VariableOp var =
- rewriter.create<emitc::VariableOp>(loc, varType, noInit);
+ emitc::VariableOp::create(rewriter, loc, varType, noInit);
resultVariables.push_back(var);
}
@@ -103,14 +103,14 @@ createVariablesForResults(T op, const TypeConverter *typeConverter,
static void assignValues(ValueRange values, ValueRange variables,
ConversionPatternRewriter &rewriter, Location loc) {
for (auto [value, var] : llvm::zip(values, variables))
- rewriter.create<emitc::AssignOp>(loc, var, value);
+ emitc::AssignOp::create(rewriter, loc, var, value);
}
SmallVector<Value> loadValues(const SmallVector<Value> &variables,
PatternRewriter &rewriter, Location loc) {
return llvm::map_to_vector<>(variables, [&](Value var) {
Type type = cast<emitc::LValueType>(var.getType()).getValueType();
- return rewriter.create<emitc::LoadOp>(loc, type, var).getResult();
+ return emitc::LoadOp::create(rewriter, loc, type, var).getResult();
});
}
@@ -129,7 +129,7 @@ static LogicalResult lowerYield(Operation *op, ValueRange resultVariables,
assignValues(yieldOperands, resultVariables, rewriter, loc);
- rewriter.create<emitc::YieldOp>(loc);
+ emitc::YieldOp::create(rewriter, loc);
rewriter.eraseOp(yield);
return success();
@@ -164,8 +164,9 @@ ForLowering::matchAndRewrite(ForOp forOp, OpAdaptor adaptor,
assignValues(adaptor.getInitArgs(), resultVariables, rewriter, loc);
- emitc::ForOp loweredFor = rewriter.create<emitc::ForOp>(
- loc, adaptor.getLowerBound(), adaptor.getUpperBound(), adaptor.getStep());
+ emitc::ForOp loweredFor =
+ emitc::ForOp::create(rewriter, loc, adaptor.getLowerBound(),
+ adaptor.getUpperBound(), adaptor.getStep());
Block *loweredBody = loweredFor.getBody();
@@ -257,7 +258,7 @@ IfLowering::matchAndRewrite(IfOp ifOp, OpAdaptor adaptor,
bool hasElseBlock = !elseRegion.empty();
auto loweredIf =
- rewriter.create<emitc::IfOp>(loc, adaptor.getCondition(), false, false);
+ emitc::IfOp::create(rewriter, loc, adaptor.getCondition(), false, false);
Region &loweredThenRegion = loweredIf.getThenRegion();
auto result = lowerRegion(thenRegion, loweredThenRegion);
@@ -304,8 +305,9 @@ LogicalResult IndexSwitchOpLowering::matchAndRewrite(
"create variables for results failed");
}
- auto loweredSwitch = rewriter.create<emitc::SwitchOp>(
- loc, adaptor.getArg(), adaptor.getCases(), indexSwitchOp.getNumCases());
+ auto loweredSwitch =
+ emitc::SwitchOp::create(rewriter, loc, adaptor.getArg(),
+ adaptor.getCases(), indexSwitchOp.getNumCases());
// Lowering all case regions.
for (auto pair :
diff --git a/mlir/lib/Conversion/SCFToGPU/SCFToGPU.cpp b/mlir/lib/Conversion/SCFToGPU/SCFToGPU.cpp
index 844e66e927c4d..f191f3502cf5a 100644
--- a/mlir/lib/Conversion/SCFToGPU/SCFToGPU.cpp
+++ b/mlir/lib/Conversion/SCFToGPU/SCFToGPU.cpp
@@ -84,8 +84,8 @@ static Operation::operand_range getUpperBoundOperands(AffineForOp forOp) {
// Get a Value that corresponds to the loop step. If the step is an attribute,
// materialize a corresponding constant using builder.
static Value getOrCreateStep(AffineForOp forOp, OpBuilder &builder) {
- return builder.create<arith::ConstantIndexOp>(forOp.getLoc(),
- forOp.getStepAsInt());
+ return arith::ConstantIndexOp::create(builder, forOp.getLoc(),
+ forOp.getStepAsInt());
}
// Get a Value for the loop lower bound. If the value requires computation,
@@ -190,12 +190,12 @@ AffineLoopToGpuConverter::collectBounds(AffineForOp forOp, unsigned numLoops) {
return std::nullopt;
}
- Value range = builder.create<arith::SubIOp>(currentLoop.getLoc(),
- upperBound, lowerBound);
+ Value range = arith::SubIOp::create(builder, currentLoop.getLoc(),
+ upperBound, lowerBound);
Value step = getOrCreateStep(currentLoop, builder);
if (getConstantIntValue(step) != static_cast<int64_t>(1))
- range =
- builder.create<arith::CeilDivSIOp>(currentLoop.getLoc(), range, step);
+ range = arith::CeilDivSIOp::create(builder, currentLoop.getLoc(), range,
+ step);
dims.push_back(range);
lbs.push_back(lowerBound);
@@ -221,7 +221,7 @@ void AffineLoopToGpuConverter::createLaunch(AffineForOp rootForOp,
// no loop mapped to a specific dimension, use constant "1" as its size.
Value constOne =
(numBlockDims < 3 || numThreadDims < 3)
- ? builder.create<arith::ConstantIndexOp>(rootForOp.getLoc(), 1)
+ ? arith::ConstantIndexOp::create(builder, rootForOp.getLoc(), 1)
: nullptr;
Value gridSizeX = numBlockDims > 0 ? dims[0] : constOne;
Value gridSizeY = numBlockDims > 1 ? dims[1] : constOne;
@@ -232,9 +232,9 @@ void AffineLoopToGpuConverter::createLaunch(AffineForOp rootForOp,
// Create a launch op and move the body region of the innermost loop to the
// launch op.
- auto launchOp = builder.create<gpu::LaunchOp>(
- rootForOp.getLoc(), gridSizeX, gridSizeY, gridSizeZ, blockSizeX,
- blockSizeY, blockSizeZ);
+ auto launchOp =
+ gpu::LaunchOp::create(builder, rootForOp.getLoc(), gridSizeX, gridSizeY,
+ gridSizeZ, blockSizeX, blockSizeY, blockSizeZ);
// Replace the loop terminator (loops contain only a single block) with the
// gpu terminator and move the operations from the loop body block to the gpu
@@ -244,7 +244,7 @@ void AffineLoopToGpuConverter::createLaunch(AffineForOp rootForOp,
Location terminatorLoc = terminator.getLoc();
terminator.erase();
builder.setInsertionPointToEnd(innermostForOp.getBody());
- builder.create<gpu::TerminatorOp>(terminatorLoc, TypeRange());
+ gpu::TerminatorOp::create(builder, terminatorLoc, TypeRange());
launchOp.getBody().front().getOperations().splice(
launchOp.getBody().front().begin(),
innermostForOp.getBody()->getOperations());
@@ -263,10 +263,10 @@ void AffineLoopToGpuConverter::createLaunch(AffineForOp rootForOp,
: getDim3Value(launchOp.getThreadIds(), en.index() - numBlockDims);
Value step = steps[en.index()];
if (getConstantIntValue(step) != static_cast<int64_t>(1))
- id = builder.create<arith::MulIOp>(rootForOp.getLoc(), step, id);
+ id = arith::MulIOp::create(builder, rootForOp.getLoc(), step, id);
Value ivReplacement =
- builder.create<arith::AddIOp>(rootForOp.getLoc(), *lbArgumentIt, id);
+ arith::AddIOp::create(builder, rootForOp.getLoc(), *lbArgumentIt, id);
en.value().replaceAllUsesWith(ivReplacement);
std::advance(lbArgumentIt, 1);
std::advance(stepArgumentIt, 1);
@@ -319,8 +319,8 @@ static Value deriveStaticUpperBound(Value upperBound,
if (auto minOp = upperBound.getDefiningOp<AffineMinOp>()) {
for (const AffineExpr &result : minOp.getMap().getResults()) {
if (auto constExpr = dyn_cast<AffineConstantExpr>(result)) {
- return rewriter.create<arith::ConstantIndexOp>(minOp.getLoc(),
- constExpr.getValue());
+ return arith::ConstantIndexOp::create(rewriter, minOp.getLoc(),
+ constExpr.getValue());
}
}
}
@@ -344,8 +344,8 @@ static Value deriveStaticUpperBound(Value upperBound,
if ((lhs.value() < 0) != (rhs.value() < 0))
return {};
- return rewriter.create<arith::ConstantIndexOp>(
- multiplyOp.getLoc(), lhs.value() * rhs.value());
+ return arith::ConstantIndexOp::create(rewriter, multiplyOp.getLoc(),
+ lhs.value() * rhs.value());
}
}
@@ -422,8 +422,8 @@ static LogicalResult processParallelLoop(
if (launchIndependent(val))
return val;
if (auto constOp = val.getDefiningOp<arith::ConstantOp>())
- return rewriter.create<arith::ConstantOp>(constOp.getLoc(),
- constOp.getValue());
+ return arith::ConstantOp::create(rewriter, constOp.getLoc(),
+ constOp.getValue());
return {};
};
@@ -453,8 +453,8 @@ static LogicalResult processParallelLoop(
1, 2,
rewriter.getAffineDimExpr(0) * rewriter.getAffineSymbolExpr(0) +
rewriter.getAffineSymbolExpr(1));
- newIndex = rewriter.create<AffineApplyOp>(
- loc, annotation.getMap().compose(lowerAndStep),
+ newIndex = AffineApplyOp::create(
+ rewriter, loc, annotation.getMap().compose(lowerAndStep),
ValueRange{operand, ensureLaunchIndependent(step),
ensureLaunchIndependent(lowerBound)});
// If there was also a bound, insert that, too.
@@ -498,8 +498,8 @@ static LogicalResult processParallelLoop(
1, 2,
((rewriter.getAffineDimExpr(0) - rewriter.getAffineSymbolExpr(0))
.ceilDiv(rewriter.getAffineSymbolExpr(1))));
- Value launchBound = rewriter.create<AffineApplyOp>(
- loc, annotation.getBound().compose(stepMap),
+ Value launchBound = AffineApplyOp::create(
+ rewriter, loc, annotation.getBound().compose(stepMap),
ValueRange{
ensureLaunchIndependent(
cloningMap.lookupOrDefault(upperBound)),
@@ -517,10 +517,10 @@ static LogicalResult processParallelLoop(
if (!boundIsPrecise) {
// We are using an approximation, create a surrounding conditional.
Value originalBound = std::get<3>(config);
- arith::CmpIOp pred = rewriter.create<arith::CmpIOp>(
- loc, arith::CmpIPredicate::slt, newIndex,
+ arith::CmpIOp pred = arith::CmpIOp::create(
+ rewriter, loc, arith::CmpIPredicate::slt, newIndex,
cloningMap.lookupOrDefault(originalBound));
- scf::IfOp ifOp = rewriter.create<scf::IfOp>(loc, pred, false);
+ scf::IfOp ifOp = scf::IfOp::create(rewriter, loc, pred, false);
rewriter.setInsertionPointToStart(&ifOp.getThenRegion().front());
// Put a sentinel into the worklist so we know when to pop out of the
// if body again. We use the launchOp here, as that cannot be part of
@@ -530,10 +530,10 @@ static LogicalResult processParallelLoop(
}
} else {
// Create a sequential for loop.
- auto loopOp = rewriter.create<scf::ForOp>(
- loc, cloningMap.lookupOrDefault(lowerBound),
- cloningMap.lookupOrDefault(upperBound),
- cloningMap.lookupOrDefault(step));
+ auto loopOp = scf::ForOp::create(rewriter, loc,
+ cloningMap.lookupOrDefault(lowerBound),
+ cloningMap.lookupOrDefault(upperBound),
+ cloningMap.lookupOrDefault(step));
newIndex = loopOp.getInductionVar();
rewriter.setInsertionPointToStart(loopOp.getBody());
// Put a sentinel into the worklist so we know when to pop out of the loop
@@ -608,12 +608,12 @@ ParallelToGpuLaunchLowering::matchAndRewrite(ParallelOp parallelOp,
// sizes. Those will be refined later as we discover them from mappings.
Location loc = parallelOp.getLoc();
Value constantOne =
- rewriter.create<arith::ConstantIndexOp>(parallelOp.getLoc(), 1);
- gpu::LaunchOp launchOp = rewriter.create<gpu::LaunchOp>(
- parallelOp.getLoc(), constantOne, constantOne, constantOne, constantOne,
- constantOne, constantOne);
+ arith::ConstantIndexOp::create(rewriter, parallelOp.getLoc(), 1);
+ gpu::LaunchOp launchOp = gpu::LaunchOp::create(
+ rewriter, parallelOp.getLoc(), constantOne, constantOne, constantOne,
+ constantOne, constantOne, constantOne);
rewriter.setInsertionPointToEnd(&launchOp.getBody().front());
- rewriter.create<gpu::TerminatorOp>(loc);
+ gpu::TerminatorOp::create(rewriter, loc);
rewriter.setInsertionPointToStart(&launchOp.getBody().front());
IRMapping cloningMap;
@@ -667,7 +667,7 @@ ParallelToGpuLaunchLowering::matchAndRewrite(ParallelOp parallelOp,
if (externalValues.size())
return failure();
// Replace by gpu.all_reduce.
- auto gpuRedOp = rewriter.create<gpu::AllReduceOp>(loc, newValue);
+ auto gpuRedOp = gpu::AllReduceOp::create(rewriter, loc, newValue);
cloningMap.map(parentLoop->getResult(0), gpuRedOp.getResult());
// Copy region.
rewriter.inlineRegionBefore(reduceOp.getRegion(0), gpuRedOp.getRegion(),
diff --git a/mlir/lib/Conversion/SCFToOpenMP/SCFToOpenMP.cpp b/mlir/lib/Conversion/SCFToOpenMP/SCFToOpenMP.cpp
index 584ac2f11b670..34f372af1e4b5 100644
--- a/mlir/lib/Conversion/SCFToOpenMP/SCFToOpenMP.cpp
+++ b/mlir/lib/Conversion/SCFToOpenMP/SCFToOpenMP.cpp
@@ -187,8 +187,8 @@ createDecl(PatternRewriter &builder, SymbolTable &symbolTable,
scf::ReduceOp reduce, int64_t reductionIndex, Attribute initValue) {
OpBuilder::InsertionGuard guard(builder);
Type type = reduce.getOperands()[reductionIndex].getType();
- auto decl = builder.create<omp::DeclareReductionOp>(reduce.getLoc(),
- "__scf_reduction", type);
+ auto decl = omp::DeclareReductionOp::create(builder, reduce.getLoc(),
+ "__scf_reduction", type);
symbolTable.insert(decl);
builder.createBlock(&decl.getInitializerRegion(),
@@ -196,8 +196,8 @@ createDecl(PatternRewriter &builder, SymbolTable &symbolTable,
{reduce.getOperands()[reductionIndex].getLoc()});
builder.setInsertionPointToEnd(&decl.getInitializerRegion().back());
Value init =
- builder.create<LLVM::ConstantOp>(reduce.getLoc(), type, initValue);
- builder.create<omp::YieldOp>(reduce.getLoc(), init);
+ LLVM::ConstantOp::create(builder, reduce.getLoc(), type, initValue);
+ omp::YieldOp::create(builder, reduce.getLoc(), init);
Operation *terminator =
&reduce.getReductions()[reductionIndex].front().back();
@@ -227,12 +227,12 @@ static omp::DeclareReductionOp addAtomicRMW(OpBuilder &builder,
{reduceOperandLoc, reduceOperandLoc});
Block *atomicBlock = &decl.getAtomicReductionRegion().back();
builder.setInsertionPointToEnd(atomicBlock);
- Value loaded = builder.create<LLVM::LoadOp>(reduce.getLoc(), decl.getType(),
- atomicBlock->getArgument(1));
- builder.create<LLVM::AtomicRMWOp>(reduce.getLoc(), atomicKind,
- atomicBlock->getArgument(0), loaded,
- LLVM::AtomicOrdering::monotonic);
- builder.create<omp::YieldOp>(reduce.getLoc(), ArrayRef<Value>());
+ Value loaded = LLVM::LoadOp::create(builder, reduce.getLoc(), decl.getType(),
+ atomicBlock->getArgument(1));
+ LLVM::AtomicRMWOp::create(builder, reduce.getLoc(), atomicKind,
+ atomicBlock->getArgument(0), loaded,
+ LLVM::AtomicOrdering::monotonic);
+ omp::YieldOp::create(builder, reduce.getLoc(), ArrayRef<Value>());
return decl;
}
@@ -380,8 +380,9 @@ struct ParallelOpLowering : public OpRewritePattern<scf::ParallelOp> {
// Allocate reduction variables. Make sure the we don't overflow the stack
// with local `alloca`s by saving and restoring the stack pointer.
Location loc = parallelOp.getLoc();
- Value one = rewriter.create<LLVM::ConstantOp>(
- loc, rewriter.getIntegerType(64), rewriter.getI64IntegerAttr(1));
+ Value one =
+ LLVM::ConstantOp::create(rewriter, loc, rewriter.getIntegerType(64),
+ rewriter.getI64IntegerAttr(1));
SmallVector<Value> reductionVariables;
reductionVariables.reserve(parallelOp.getNumReductions());
auto ptrType = LLVM::LLVMPointerType::get(parallelOp.getContext());
@@ -390,9 +391,9 @@ struct ParallelOpLowering : public OpRewritePattern<scf::ParallelOp> {
isa<LLVM::PointerElementTypeInterface>(init.getType())) &&
"cannot create a reduction variable if the type is not an LLVM "
"pointer element");
- Value storage =
- rewriter.create<LLVM::AllocaOp>(loc, ptrType, init.getType(), one, 0);
- rewriter.create<LLVM::StoreOp>(loc, init, storage);
+ Value storage = LLVM::AllocaOp::create(rewriter, loc, ptrType,
+ init.getType(), one, 0);
+ LLVM::StoreOp::create(rewriter, loc, init, storage);
reductionVariables.push_back(storage);
}
@@ -411,8 +412,8 @@ struct ParallelOpLowering : public OpRewritePattern<scf::ParallelOp> {
assert(redRegion.hasOneBlock() &&
"expect reduction region to have one block");
Value pvtRedVar = parallelOp.getRegion().addArgument(x.getType(), loc);
- Value pvtRedVal = rewriter.create<LLVM::LoadOp>(reduce.getLoc(),
- rD.getType(), pvtRedVar);
+ Value pvtRedVal = LLVM::LoadOp::create(rewriter, reduce.getLoc(),
+ rD.getType(), pvtRedVar);
// Make a copy of the reduction combiner region in the body
mlir::OpBuilder builder(rewriter.getContext());
builder.setInsertionPoint(reduce);
@@ -427,7 +428,7 @@ struct ParallelOpLowering : public OpRewritePattern<scf::ParallelOp> {
assert(yieldOp && yieldOp.getResults().size() == 1 &&
"expect YieldOp in reduction region to return one result");
Value redVal = yieldOp.getResults()[0];
- rewriter.create<LLVM::StoreOp>(loc, redVal, pvtRedVar);
+ LLVM::StoreOp::create(rewriter, loc, redVal, pvtRedVar);
rewriter.eraseOp(yieldOp);
break;
}
@@ -437,12 +438,12 @@ struct ParallelOpLowering : public OpRewritePattern<scf::ParallelOp> {
Value numThreadsVar;
if (numThreads > 0) {
- numThreadsVar = rewriter.create<LLVM::ConstantOp>(
- loc, rewriter.getI32IntegerAttr(numThreads));
+ numThreadsVar = LLVM::ConstantOp::create(
+ rewriter, loc, rewriter.getI32IntegerAttr(numThreads));
}
// Create the parallel wrapper.
- auto ompParallel = rewriter.create<omp::ParallelOp>(
- loc,
+ auto ompParallel = omp::ParallelOp::create(
+ rewriter, loc,
/* allocate_vars = */ llvm::SmallVector<Value>{},
/* allocator_vars = */ llvm::SmallVector<Value>{},
/* if_expr = */ Value{},
@@ -464,7 +465,7 @@ struct ParallelOpLowering : public OpRewritePattern<scf::ParallelOp> {
{
OpBuilder::InsertionGuard allocaGuard(rewriter);
// Create worksharing loop wrapper.
- auto wsloopOp = rewriter.create<omp::WsloopOp>(parallelOp.getLoc());
+ auto wsloopOp = omp::WsloopOp::create(rewriter, parallelOp.getLoc());
if (!reductionVariables.empty()) {
wsloopOp.setReductionSymsAttr(
ArrayAttr::get(rewriter.getContext(), reductionSyms));
@@ -476,7 +477,7 @@ struct ParallelOpLowering : public OpRewritePattern<scf::ParallelOp> {
wsloopOp.setReductionByref(
DenseBoolArrayAttr::get(rewriter.getContext(), reductionByRef));
}
- rewriter.create<omp::TerminatorOp>(loc); // omp.parallel terminator.
+ omp::TerminatorOp::create(rewriter, loc); // omp.parallel terminator.
// The wrapper's entry block arguments will define the reduction
// variables.
@@ -490,8 +491,8 @@ struct ParallelOpLowering : public OpRewritePattern<scf::ParallelOp> {
parallelOp.getLoc()));
// Create loop nest and populate region with contents of scf.parallel.
- auto loopOp = rewriter.create<omp::LoopNestOp>(
- parallelOp.getLoc(), parallelOp.getLowerBound(),
+ auto loopOp = omp::LoopNestOp::create(
+ rewriter, parallelOp.getLoc(), parallelOp.getLowerBound(),
parallelOp.getUpperBound(), parallelOp.getStep());
rewriter.inlineRegionBefore(parallelOp.getRegion(), loopOp.getRegion(),
@@ -511,13 +512,13 @@ struct ParallelOpLowering : public OpRewritePattern<scf::ParallelOp> {
rewriter.splitBlock(&loopOpEntryBlock, loopOpEntryBlock.begin());
rewriter.setInsertionPointToStart(&loopOpEntryBlock);
- auto scope = rewriter.create<memref::AllocaScopeOp>(parallelOp.getLoc(),
- TypeRange());
- rewriter.create<omp::YieldOp>(loc, ValueRange());
+ auto scope = memref::AllocaScopeOp::create(
+ rewriter, parallelOp.getLoc(), TypeRange());
+ omp::YieldOp::create(rewriter, loc, ValueRange());
Block *scopeBlock = rewriter.createBlock(&scope.getBodyRegion());
rewriter.mergeBlocks(ops, scopeBlock);
rewriter.setInsertionPointToEnd(&*scope.getBodyRegion().begin());
- rewriter.create<memref::AllocaScopeReturnOp>(loc, ValueRange());
+ memref::AllocaScopeReturnOp::create(rewriter, loc, ValueRange());
}
}
@@ -526,7 +527,7 @@ struct ParallelOpLowering : public OpRewritePattern<scf::ParallelOp> {
results.reserve(reductionVariables.size());
for (auto [variable, type] :
llvm::zip(reductionVariables, parallelOp.getResultTypes())) {
- Value res = rewriter.create<LLVM::LoadOp>(loc, type, variable);
+ Value res = LLVM::LoadOp::create(rewriter, loc, type, variable);
results.push_back(res);
}
rewriter.replaceOp(parallelOp, results);
diff --git a/mlir/lib/Conversion/SCFToSPIRV/SCFToSPIRV.cpp b/mlir/lib/Conversion/SCFToSPIRV/SCFToSPIRV.cpp
index 78d13278fef53..dc92367fc58cd 100644
--- a/mlir/lib/Conversion/SCFToSPIRV/SCFToSPIRV.cpp
+++ b/mlir/lib/Conversion/SCFToSPIRV/SCFToSPIRV.cpp
@@ -71,12 +71,12 @@ void replaceSCFOutputValue(ScfOp scfOp, OpTy newOp,
auto pointerType =
spirv::PointerType::get(convertedType, spirv::StorageClass::Function);
rewriter.setInsertionPoint(newOp);
- auto alloc = rewriter.create<spirv::VariableOp>(
- loc, pointerType, spirv::StorageClass::Function,
- /*initializer=*/nullptr);
+ auto alloc = spirv::VariableOp::create(rewriter, loc, pointerType,
+ spirv::StorageClass::Function,
+ /*initializer=*/nullptr);
allocas.push_back(alloc);
rewriter.setInsertionPointAfter(newOp);
- Value loadResult = rewriter.create<spirv::LoadOp>(loc, alloc);
+ Value loadResult = spirv::LoadOp::create(rewriter, loc, alloc);
resultValue.push_back(loadResult);
}
rewriter.replaceOp(scfOp, resultValue);
@@ -135,7 +135,8 @@ struct ForOpConversion final : SCFToSPIRVPattern<scf::ForOp> {
// a single back edge from the continue to header block, and a single exit
// from header to merge.
auto loc = forOp.getLoc();
- auto loopOp = rewriter.create<spirv::LoopOp>(loc, spirv::LoopControl::None);
+ auto loopOp =
+ spirv::LoopOp::create(rewriter, loc, spirv::LoopControl::None);
loopOp.addEntryAndMergeBlock(rewriter);
OpBuilder::InsertionGuard guard(rewriter);
@@ -172,16 +173,17 @@ struct ForOpConversion final : SCFToSPIRVPattern<scf::ForOp> {
args.append(adaptor.getInitArgs().begin(), adaptor.getInitArgs().end());
// Branch into it from the entry.
rewriter.setInsertionPointToEnd(&(loopOp.getBody().front()));
- rewriter.create<spirv::BranchOp>(loc, header, args);
+ spirv::BranchOp::create(rewriter, loc, header, args);
// Generate the rest of the loop header.
rewriter.setInsertionPointToEnd(header);
auto *mergeBlock = loopOp.getMergeBlock();
- auto cmpOp = rewriter.create<spirv::SLessThanOp>(
- loc, rewriter.getI1Type(), newIndVar, adaptor.getUpperBound());
+ auto cmpOp = spirv::SLessThanOp::create(rewriter, loc, rewriter.getI1Type(),
+ newIndVar, adaptor.getUpperBound());
- rewriter.create<spirv::BranchConditionalOp>(
- loc, cmpOp, body, ArrayRef<Value>(), mergeBlock, ArrayRef<Value>());
+ spirv::BranchConditionalOp::create(rewriter, loc, cmpOp, body,
+ ArrayRef<Value>(), mergeBlock,
+ ArrayRef<Value>());
// Generate instructions to increment the step of the induction variable and
// branch to the header.
@@ -189,9 +191,9 @@ struct ForOpConversion final : SCFToSPIRVPattern<scf::ForOp> {
rewriter.setInsertionPointToEnd(continueBlock);
// Add the step to the induction variable and branch to the header.
- Value updatedIndVar = rewriter.create<spirv::IAddOp>(
- loc, newIndVar.getType(), newIndVar, adaptor.getStep());
- rewriter.create<spirv::BranchOp>(loc, header, updatedIndVar);
+ Value updatedIndVar = spirv::IAddOp::create(
+ rewriter, loc, newIndVar.getType(), newIndVar, adaptor.getStep());
+ spirv::BranchOp::create(rewriter, loc, header, updatedIndVar);
// Infer the return types from the init operands. Vector type may get
// converted to CooperativeMatrix or to Vector type, to avoid having complex
@@ -237,11 +239,11 @@ struct IfOpConversion : SCFToSPIRVPattern<scf::IfOp> {
// Create `spirv.selection` operation, selection header block and merge
// block.
- auto selectionOp =
- rewriter.create<spirv::SelectionOp>(loc, spirv::SelectionControl::None);
+ auto selectionOp = spirv::SelectionOp::create(
+ rewriter, loc, spirv::SelectionControl::None);
auto *mergeBlock = rewriter.createBlock(&selectionOp.getBody(),
selectionOp.getBody().end());
- rewriter.create<spirv::MergeOp>(loc);
+ spirv::MergeOp::create(rewriter, loc);
OpBuilder::InsertionGuard guard(rewriter);
auto *selectionHeaderBlock =
@@ -251,7 +253,7 @@ struct IfOpConversion : SCFToSPIRVPattern<scf::IfOp> {
auto &thenRegion = ifOp.getThenRegion();
auto *thenBlock = &thenRegion.front();
rewriter.setInsertionPointToEnd(&thenRegion.back());
- rewriter.create<spirv::BranchOp>(loc, mergeBlock);
+ spirv::BranchOp::create(rewriter, loc, mergeBlock);
rewriter.inlineRegionBefore(thenRegion, mergeBlock);
auto *elseBlock = mergeBlock;
@@ -261,15 +263,15 @@ struct IfOpConversion : SCFToSPIRVPattern<scf::IfOp> {
auto &elseRegion = ifOp.getElseRegion();
elseBlock = &elseRegion.front();
rewriter.setInsertionPointToEnd(&elseRegion.back());
- rewriter.create<spirv::BranchOp>(loc, mergeBlock);
+ spirv::BranchOp::create(rewriter, loc, mergeBlock);
rewriter.inlineRegionBefore(elseRegion, mergeBlock);
}
// Create a `spirv.BranchConditional` operation for selection header block.
rewriter.setInsertionPointToEnd(selectionHeaderBlock);
- rewriter.create<spirv::BranchConditionalOp>(loc, adaptor.getCondition(),
- thenBlock, ArrayRef<Value>(),
- elseBlock, ArrayRef<Value>());
+ spirv::BranchConditionalOp::create(rewriter, loc, adaptor.getCondition(),
+ thenBlock, ArrayRef<Value>(), elseBlock,
+ ArrayRef<Value>());
replaceSCFOutputValue(ifOp, selectionOp, rewriter, scfToSPIRVContext,
returnTypes);
@@ -310,7 +312,7 @@ struct TerminatorOpConversion final : SCFToSPIRVPattern<scf::YieldOp> {
auto loc = terminatorOp.getLoc();
for (unsigned i = 0, e = operands.size(); i < e; i++)
- rewriter.create<spirv::StoreOp>(loc, allocas[i], operands[i]);
+ spirv::StoreOp::create(rewriter, loc, allocas[i], operands[i]);
if (isa<spirv::LoopOp>(parent)) {
// For loops we also need to update the branch jumping back to the
// header.
@@ -319,8 +321,8 @@ struct TerminatorOpConversion final : SCFToSPIRVPattern<scf::YieldOp> {
SmallVector<Value, 8> args(br.getBlockArguments());
args.append(operands.begin(), operands.end());
rewriter.setInsertionPoint(br);
- rewriter.create<spirv::BranchOp>(terminatorOp.getLoc(), br.getTarget(),
- args);
+ spirv::BranchOp::create(rewriter, terminatorOp.getLoc(), br.getTarget(),
+ args);
rewriter.eraseOp(br);
}
}
@@ -340,7 +342,8 @@ struct WhileOpConversion final : SCFToSPIRVPattern<scf::WhileOp> {
matchAndRewrite(scf::WhileOp whileOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto loc = whileOp.getLoc();
- auto loopOp = rewriter.create<spirv::LoopOp>(loc, spirv::LoopControl::None);
+ auto loopOp =
+ spirv::LoopOp::create(rewriter, loc, spirv::LoopControl::None);
loopOp.addEntryAndMergeBlock(rewriter);
Region &beforeRegion = whileOp.getBefore();
@@ -382,7 +385,7 @@ struct WhileOpConversion final : SCFToSPIRVPattern<scf::WhileOp> {
// Jump from the loop entry block to the loop header block.
rewriter.setInsertionPointToEnd(&entryBlock);
- rewriter.create<spirv::BranchOp>(loc, &beforeBlock, adaptor.getInits());
+ spirv::BranchOp::create(rewriter, loc, &beforeBlock, adaptor.getInits());
auto condLoc = cond.getLoc();
@@ -403,18 +406,18 @@ struct WhileOpConversion final : SCFToSPIRVPattern<scf::WhileOp> {
// Create local variables before the scf.while op.
rewriter.setInsertionPoint(loopOp);
- auto alloc = rewriter.create<spirv::VariableOp>(
- condLoc, pointerType, spirv::StorageClass::Function,
- /*initializer=*/nullptr);
+ auto alloc = spirv::VariableOp::create(rewriter, condLoc, pointerType,
+ spirv::StorageClass::Function,
+ /*initializer=*/nullptr);
// Load the final result values after the scf.while op.
rewriter.setInsertionPointAfter(loopOp);
- auto loadResult = rewriter.create<spirv::LoadOp>(condLoc, alloc);
+ auto loadResult = spirv::LoadOp::create(rewriter, condLoc, alloc);
resultValues[i] = loadResult;
// Store the current iteration's result value.
rewriter.setInsertionPointToEnd(&beforeBlock);
- rewriter.create<spirv::StoreOp>(condLoc, alloc, res);
+ spirv::StoreOp::create(rewriter, condLoc, alloc, res);
}
rewriter.setInsertionPointToEnd(&beforeBlock);
diff --git a/mlir/lib/Conversion/SPIRVToLLVM/ConvertLaunchFuncToLLVMCalls.cpp b/mlir/lib/Conversion/SPIRVToLLVM/ConvertLaunchFuncToLLVMCalls.cpp
index d7ae9f0e94fe8..035f197b1eac2 100644
--- a/mlir/lib/Conversion/SPIRVToLLVM/ConvertLaunchFuncToLLVMCalls.cpp
+++ b/mlir/lib/Conversion/SPIRVToLLVM/ConvertLaunchFuncToLLVMCalls.cpp
@@ -68,7 +68,7 @@ static unsigned calculateGlobalIndex(spirv::GlobalVariableOp op) {
/// Copies the given number of bytes from src to dst pointers.
static void copy(Location loc, Value dst, Value src, Value size,
OpBuilder &builder) {
- builder.create<LLVM::MemcpyOp>(loc, dst, src, size, /*isVolatile=*/false);
+ LLVM::MemcpyOp::create(builder, loc, dst, src, size, /*isVolatile=*/false);
}
/// Encodes the binding and descriptor set numbers into a new symbolic name.
@@ -194,8 +194,8 @@ class GPULaunchLowering : public ConvertOpToLLVMPattern<gpu::LaunchFuncOp> {
if (!kernelFunc) {
OpBuilder::InsertionGuard guard(rewriter);
rewriter.setInsertionPointToStart(module.getBody());
- kernelFunc = rewriter.create<LLVM::LLVMFuncOp>(
- rewriter.getUnknownLoc(), newKernelFuncName,
+ kernelFunc = LLVM::LLVMFuncOp::create(
+ rewriter, rewriter.getUnknownLoc(), newKernelFuncName,
LLVM::LLVMFunctionType::get(LLVM::LLVMVoidType::get(context),
ArrayRef<Type>()));
rewriter.setInsertionPoint(launchOp);
@@ -245,8 +245,8 @@ class GPULaunchLowering : public ConvertOpToLLVMPattern<gpu::LaunchFuncOp> {
if (!dstGlobal) {
OpBuilder::InsertionGuard guard(rewriter);
rewriter.setInsertionPointToStart(module.getBody());
- dstGlobal = rewriter.create<LLVM::GlobalOp>(
- loc, dstGlobalType,
+ dstGlobal = LLVM::GlobalOp::create(
+ rewriter, loc, dstGlobalType,
/*isConstant=*/false, LLVM::Linkage::Linkonce, name, Attribute(),
/*alignment=*/0);
rewriter.setInsertionPoint(launchOp);
@@ -255,8 +255,8 @@ class GPULaunchLowering : public ConvertOpToLLVMPattern<gpu::LaunchFuncOp> {
// Copy the data from src operand pointer to dst global variable. Save
// src, dst and size so that we can copy data back after emulating the
// kernel call.
- Value dst = rewriter.create<LLVM::AddressOfOp>(
- loc, typeConverter->convertType(spirvGlobal.getType()),
+ Value dst = LLVM::AddressOfOp::create(
+ rewriter, loc, typeConverter->convertType(spirvGlobal.getType()),
dstGlobal.getSymName());
copy(loc, dst, src, sizeBytes, rewriter);
diff --git a/mlir/lib/Conversion/SPIRVToLLVM/SPIRVToLLVM.cpp b/mlir/lib/Conversion/SPIRVToLLVM/SPIRVToLLVM.cpp
index 1d92b5d5562b5..aae3271371c1f 100644
--- a/mlir/lib/Conversion/SPIRVToLLVM/SPIRVToLLVM.cpp
+++ b/mlir/lib/Conversion/SPIRVToLLVM/SPIRVToLLVM.cpp
@@ -94,13 +94,13 @@ static IntegerAttr minusOneIntegerAttribute(Type type, Builder builder) {
static Value createConstantAllBitsSet(Location loc, Type srcType, Type dstType,
PatternRewriter &rewriter) {
if (isa<VectorType>(srcType)) {
- return rewriter.create<LLVM::ConstantOp>(
- loc, dstType,
+ return LLVM::ConstantOp::create(
+ rewriter, loc, dstType,
SplatElementsAttr::get(cast<ShapedType>(srcType),
minusOneIntegerAttribute(srcType, rewriter)));
}
- return rewriter.create<LLVM::ConstantOp>(
- loc, dstType, minusOneIntegerAttribute(srcType, rewriter));
+ return LLVM::ConstantOp::create(rewriter, loc, dstType,
+ minusOneIntegerAttribute(srcType, rewriter));
}
/// Creates `llvm.mlir.constant` with a floating-point scalar or vector value.
@@ -108,14 +108,14 @@ static Value createFPConstant(Location loc, Type srcType, Type dstType,
PatternRewriter &rewriter, double value) {
if (auto vecType = dyn_cast<VectorType>(srcType)) {
auto floatType = cast<FloatType>(vecType.getElementType());
- return rewriter.create<LLVM::ConstantOp>(
- loc, dstType,
+ return LLVM::ConstantOp::create(
+ rewriter, loc, dstType,
SplatElementsAttr::get(vecType,
rewriter.getFloatAttr(floatType, value)));
}
auto floatType = cast<FloatType>(srcType);
- return rewriter.create<LLVM::ConstantOp>(
- loc, dstType, rewriter.getFloatAttr(floatType, value));
+ return LLVM::ConstantOp::create(rewriter, loc, dstType,
+ rewriter.getFloatAttr(floatType, value));
}
/// Utility function for bitfield ops:
@@ -134,13 +134,13 @@ static Value optionallyTruncateOrExtend(Location loc, Value value,
: getBitWidth(srcType);
if (valueBitWidth < targetBitWidth)
- return rewriter.create<LLVM::ZExtOp>(loc, llvmType, value);
+ return LLVM::ZExtOp::create(rewriter, loc, llvmType, value);
// If the bit widths of `Count` and `Offset` are greater than the bit width
// of the target type, they are truncated. Truncation is safe since `Count`
// and `Offset` must be no more than 64 for op behaviour to be defined. Hence,
// both values can be expressed in 8 bits.
if (valueBitWidth > targetBitWidth)
- return rewriter.create<LLVM::TruncOp>(loc, llvmType, value);
+ return LLVM::TruncOp::create(rewriter, loc, llvmType, value);
return value;
}
@@ -151,12 +151,12 @@ static Value broadcast(Location loc, Value toBroadcast, unsigned numElements,
auto vectorType = VectorType::get(numElements, toBroadcast.getType());
auto llvmVectorType = typeConverter.convertType(vectorType);
auto llvmI32Type = typeConverter.convertType(rewriter.getIntegerType(32));
- Value broadcasted = rewriter.create<LLVM::PoisonOp>(loc, llvmVectorType);
+ Value broadcasted = LLVM::PoisonOp::create(rewriter, loc, llvmVectorType);
for (unsigned i = 0; i < numElements; ++i) {
- auto index = rewriter.create<LLVM::ConstantOp>(
- loc, llvmI32Type, rewriter.getI32IntegerAttr(i));
- broadcasted = rewriter.create<LLVM::InsertElementOp>(
- loc, llvmVectorType, broadcasted, toBroadcast, index);
+ auto index = LLVM::ConstantOp::create(rewriter, loc, llvmI32Type,
+ rewriter.getI32IntegerAttr(i));
+ broadcasted = LLVM::InsertElementOp::create(
+ rewriter, loc, llvmVectorType, broadcasted, toBroadcast, index);
}
return broadcasted;
}
@@ -217,8 +217,8 @@ static Type convertStructTypePacked(spirv::StructType type,
/// Creates LLVM dialect constant with the given value.
static Value createI32ConstantOf(Location loc, PatternRewriter &rewriter,
unsigned value) {
- return rewriter.create<LLVM::ConstantOp>(
- loc, IntegerType::get(rewriter.getContext(), 32),
+ return LLVM::ConstantOp::create(
+ rewriter, loc, IntegerType::get(rewriter.getContext(), 32),
rewriter.getIntegerAttr(rewriter.getI32Type(), value));
}
@@ -322,8 +322,9 @@ class AccessChainPattern : public SPIRVToLLVMConversion<spirv::AccessChainOp> {
auto llvmIndexType = getTypeConverter()->convertType(indexType);
if (!llvmIndexType)
return rewriter.notifyMatchFailure(op, "type conversion failed");
- Value zero = rewriter.create<LLVM::ConstantOp>(
- op.getLoc(), llvmIndexType, rewriter.getIntegerAttr(indexType, 0));
+ Value zero =
+ LLVM::ConstantOp::create(rewriter, op.getLoc(), llvmIndexType,
+ rewriter.getIntegerAttr(indexType, 0));
indices.insert(indices.begin(), zero);
auto elementType = getTypeConverter()->convertType(
@@ -375,20 +376,20 @@ class BitFieldInsertPattern
// Create a mask with bits set outside [Offset, Offset + Count - 1].
Value minusOne = createConstantAllBitsSet(loc, srcType, dstType, rewriter);
Value maskShiftedByCount =
- rewriter.create<LLVM::ShlOp>(loc, dstType, minusOne, count);
- Value negated = rewriter.create<LLVM::XOrOp>(loc, dstType,
- maskShiftedByCount, minusOne);
+ LLVM::ShlOp::create(rewriter, loc, dstType, minusOne, count);
+ Value negated = LLVM::XOrOp::create(rewriter, loc, dstType,
+ maskShiftedByCount, minusOne);
Value maskShiftedByCountAndOffset =
- rewriter.create<LLVM::ShlOp>(loc, dstType, negated, offset);
- Value mask = rewriter.create<LLVM::XOrOp>(
- loc, dstType, maskShiftedByCountAndOffset, minusOne);
+ LLVM::ShlOp::create(rewriter, loc, dstType, negated, offset);
+ Value mask = LLVM::XOrOp::create(rewriter, loc, dstType,
+ maskShiftedByCountAndOffset, minusOne);
// Extract unchanged bits from the `Base` that are outside of
// [Offset, Offset + Count - 1]. Then `or` with shifted `Insert`.
Value baseAndMask =
- rewriter.create<LLVM::AndOp>(loc, dstType, op.getBase(), mask);
+ LLVM::AndOp::create(rewriter, loc, dstType, op.getBase(), mask);
Value insertShiftedByOffset =
- rewriter.create<LLVM::ShlOp>(loc, dstType, op.getInsert(), offset);
+ LLVM::ShlOp::create(rewriter, loc, dstType, op.getInsert(), offset);
rewriter.replaceOpWithNewOp<LLVM::OrOp>(op, dstType, baseAndMask,
insertShiftedByOffset);
return success();
@@ -470,23 +471,23 @@ class BitFieldSExtractPattern
auto baseSize = rewriter.getIntegerAttr(integerType, getBitWidth(srcType));
Value size =
isa<VectorType>(srcType)
- ? rewriter.create<LLVM::ConstantOp>(
- loc, dstType,
+ ? LLVM::ConstantOp::create(
+ rewriter, loc, dstType,
SplatElementsAttr::get(cast<ShapedType>(srcType), baseSize))
- : rewriter.create<LLVM::ConstantOp>(loc, dstType, baseSize);
+ : LLVM::ConstantOp::create(rewriter, loc, dstType, baseSize);
// Shift `Base` left by [sizeof(Base) - (Count + Offset)], so that the bit
// at Offset + Count - 1 is the most significant bit now.
Value countPlusOffset =
- rewriter.create<LLVM::AddOp>(loc, dstType, count, offset);
+ LLVM::AddOp::create(rewriter, loc, dstType, count, offset);
Value amountToShiftLeft =
- rewriter.create<LLVM::SubOp>(loc, dstType, size, countPlusOffset);
- Value baseShiftedLeft = rewriter.create<LLVM::ShlOp>(
- loc, dstType, op.getBase(), amountToShiftLeft);
+ LLVM::SubOp::create(rewriter, loc, dstType, size, countPlusOffset);
+ Value baseShiftedLeft = LLVM::ShlOp::create(
+ rewriter, loc, dstType, op.getBase(), amountToShiftLeft);
// Shift the result right, filling the bits with the sign bit.
Value amountToShiftRight =
- rewriter.create<LLVM::AddOp>(loc, dstType, offset, amountToShiftLeft);
+ LLVM::AddOp::create(rewriter, loc, dstType, offset, amountToShiftLeft);
rewriter.replaceOpWithNewOp<LLVM::AShrOp>(op, dstType, baseShiftedLeft,
amountToShiftRight);
return success();
@@ -516,13 +517,13 @@ class BitFieldUExtractPattern
// Create a mask with bits set at [0, Count - 1].
Value minusOne = createConstantAllBitsSet(loc, srcType, dstType, rewriter);
Value maskShiftedByCount =
- rewriter.create<LLVM::ShlOp>(loc, dstType, minusOne, count);
- Value mask = rewriter.create<LLVM::XOrOp>(loc, dstType, maskShiftedByCount,
- minusOne);
+ LLVM::ShlOp::create(rewriter, loc, dstType, minusOne, count);
+ Value mask = LLVM::XOrOp::create(rewriter, loc, dstType, maskShiftedByCount,
+ minusOne);
// Shift `Base` by `Offset` and apply the mask on it.
Value shiftedBase =
- rewriter.create<LLVM::LShrOp>(loc, dstType, op.getBase(), offset);
+ LLVM::LShrOp::create(rewriter, loc, dstType, op.getBase(), offset);
rewriter.replaceOpWithNewOp<LLVM::AndOp>(op, dstType, shiftedBase, mask);
return success();
}
@@ -694,8 +695,8 @@ class ExecutionModePattern
auto structType = LLVM::LLVMStructType::getLiteral(context, fields);
// Create `llvm.mlir.global` with initializer region containing one block.
- auto global = rewriter.create<LLVM::GlobalOp>(
- UnknownLoc::get(context), structType, /*isConstant=*/true,
+ auto global = LLVM::GlobalOp::create(
+ rewriter, UnknownLoc::get(context), structType, /*isConstant=*/true,
LLVM::Linkage::External, executionModeInfoName, Attribute(),
/*alignment=*/0);
Location loc = global.getLoc();
@@ -704,22 +705,23 @@ class ExecutionModePattern
// Initialize the struct and set the execution mode value.
rewriter.setInsertionPointToStart(block);
- Value structValue = rewriter.create<LLVM::PoisonOp>(loc, structType);
- Value executionMode = rewriter.create<LLVM::ConstantOp>(
- loc, llvmI32Type,
+ Value structValue = LLVM::PoisonOp::create(rewriter, loc, structType);
+ Value executionMode = LLVM::ConstantOp::create(
+ rewriter, loc, llvmI32Type,
rewriter.getI32IntegerAttr(
static_cast<uint32_t>(executionModeAttr.getValue())));
- structValue = rewriter.create<LLVM::InsertValueOp>(loc, structValue,
- executionMode, 0);
+ SmallVector<int64_t> position{0};
+ structValue = LLVM::InsertValueOp::create(rewriter, loc, structValue,
+ executionMode, position);
// Insert extra operands if they exist into execution mode info struct.
for (unsigned i = 0, e = values.size(); i < e; ++i) {
auto attr = values.getValue()[i];
- Value entry = rewriter.create<LLVM::ConstantOp>(loc, llvmI32Type, attr);
- structValue = rewriter.create<LLVM::InsertValueOp>(
- loc, structValue, entry, ArrayRef<int64_t>({1, i}));
+ Value entry = LLVM::ConstantOp::create(rewriter, loc, llvmI32Type, attr);
+ structValue = LLVM::InsertValueOp::create(
+ rewriter, loc, structValue, entry, ArrayRef<int64_t>({1, i}));
}
- rewriter.create<LLVM::ReturnOp>(loc, ArrayRef<Value>({structValue}));
+ LLVM::ReturnOp::create(rewriter, loc, ArrayRef<Value>({structValue}));
rewriter.eraseOp(op);
return success();
}
@@ -913,7 +915,7 @@ class InverseSqrtPattern
Location loc = op.getLoc();
Value one = createFPConstant(loc, srcType, dstType, rewriter, 1.0);
- Value sqrt = rewriter.create<LLVM::SqrtOp>(loc, dstType, op.getOperand());
+ Value sqrt = LLVM::SqrtOp::create(rewriter, loc, dstType, op.getOperand());
rewriter.replaceOpWithNewOp<LLVM::FDivOp>(op, dstType, one, sqrt);
return success();
}
@@ -973,10 +975,10 @@ class NotPattern : public SPIRVToLLVMConversion<SPIRVOp> {
IntegerAttr minusOne = minusOneIntegerAttribute(srcType, rewriter);
auto mask =
isa<VectorType>(srcType)
- ? rewriter.create<LLVM::ConstantOp>(
- loc, dstType,
+ ? LLVM::ConstantOp::create(
+ rewriter, loc, dstType,
SplatElementsAttr::get(cast<VectorType>(srcType), minusOne))
- : rewriter.create<LLVM::ConstantOp>(loc, dstType, minusOne);
+ : LLVM::ConstantOp::create(rewriter, loc, dstType, minusOne);
rewriter.template replaceOpWithNewOp<LLVM::XOrOp>(notOp, dstType,
notOp.getOperand(), mask);
return success();
@@ -1034,8 +1036,8 @@ static LLVM::LLVMFuncOp lookupOrCreateSPIRVFn(Operation *symbolTable,
return func;
OpBuilder b(symbolTable->getRegion(0));
- func = b.create<LLVM::LLVMFuncOp>(
- symbolTable->getLoc(), name,
+ func = LLVM::LLVMFuncOp::create(
+ b, symbolTable->getLoc(), name,
LLVM::LLVMFunctionType::get(resultType, paramTypes));
func.setCConv(LLVM::cconv::CConv::SPIR_FUNC);
func.setConvergent(convergent);
@@ -1047,7 +1049,7 @@ static LLVM::LLVMFuncOp lookupOrCreateSPIRVFn(Operation *symbolTable,
static LLVM::CallOp createSPIRVBuiltinCall(Location loc, OpBuilder &builder,
LLVM::LLVMFuncOp func,
ValueRange args) {
- auto call = builder.create<LLVM::CallOp>(loc, func, args);
+ auto call = LLVM::CallOp::create(builder, loc, func, args);
call.setCConv(func.getCConv());
call.setConvergentAttr(func.getConvergentAttr());
call.setNoUnwindAttr(func.getNoUnwindAttr());
@@ -1078,12 +1080,12 @@ class ControlBarrierPattern : public SPIRVToLLVMConversion<BarrierOpTy> {
lookupOrCreateSPIRVFn(symbolTable, funcName, {i32, i32, i32}, voidTy);
Location loc = controlBarrierOp->getLoc();
- Value execution = rewriter.create<LLVM::ConstantOp>(
- loc, i32, static_cast<int32_t>(adaptor.getExecutionScope()));
- Value memory = rewriter.create<LLVM::ConstantOp>(
- loc, i32, static_cast<int32_t>(adaptor.getMemoryScope()));
- Value semantics = rewriter.create<LLVM::ConstantOp>(
- loc, i32, static_cast<int32_t>(adaptor.getMemorySemantics()));
+ Value execution = LLVM::ConstantOp::create(
+ rewriter, loc, i32, static_cast<int32_t>(adaptor.getExecutionScope()));
+ Value memory = LLVM::ConstantOp::create(
+ rewriter, loc, i32, static_cast<int32_t>(adaptor.getMemoryScope()));
+ Value semantics = LLVM::ConstantOp::create(
+ rewriter, loc, i32, static_cast<int32_t>(adaptor.getMemorySemantics()));
auto call = createSPIRVBuiltinCall(loc, rewriter, func,
{execution, memory, semantics});
@@ -1255,10 +1257,12 @@ class GroupReducePattern : public SPIRVToLLVMConversion<ReduceOp> {
lookupOrCreateSPIRVFn(symbolTable, funcName, paramTypes, retTy);
Location loc = op.getLoc();
- Value scope = rewriter.create<LLVM::ConstantOp>(
- loc, i32Ty, static_cast<int32_t>(adaptor.getExecutionScope()));
- Value groupOp = rewriter.create<LLVM::ConstantOp>(
- loc, i32Ty, static_cast<int32_t>(adaptor.getGroupOperation()));
+ Value scope = LLVM::ConstantOp::create(
+ rewriter, loc, i32Ty,
+ static_cast<int32_t>(adaptor.getExecutionScope()));
+ Value groupOp = LLVM::ConstantOp::create(
+ rewriter, loc, i32Ty,
+ static_cast<int32_t>(adaptor.getGroupOperation()));
SmallVector<Value> operands{scope, groupOp};
operands.append(adaptor.getOperands().begin(), adaptor.getOperands().end());
@@ -1368,7 +1372,7 @@ class LoopPattern : public SPIRVToLLVMConversion<spirv::LoopOp> {
return failure();
Block *headerBlock = loopOp.getHeaderBlock();
rewriter.setInsertionPointToEnd(currentBlock);
- rewriter.create<LLVM::BrOp>(loc, brOp.getBlockArguments(), headerBlock);
+ LLVM::BrOp::create(rewriter, loc, brOp.getBlockArguments(), headerBlock);
rewriter.eraseBlock(entryBlock);
// Branch from merge block to end block.
@@ -1376,7 +1380,7 @@ class LoopPattern : public SPIRVToLLVMConversion<spirv::LoopOp> {
Operation *terminator = mergeBlock->getTerminator();
ValueRange terminatorOperands = terminator->getOperands();
rewriter.setInsertionPointToEnd(mergeBlock);
- rewriter.create<LLVM::BrOp>(loc, terminatorOperands, endBlock);
+ LLVM::BrOp::create(rewriter, loc, terminatorOperands, endBlock);
rewriter.inlineRegionBefore(loopOp.getBody(), endBlock);
rewriter.replaceOp(loopOp, endBlock->getArguments());
@@ -1434,16 +1438,15 @@ class SelectionPattern : public SPIRVToLLVMConversion<spirv::SelectionOp> {
Operation *terminator = mergeBlock->getTerminator();
ValueRange terminatorOperands = terminator->getOperands();
rewriter.setInsertionPointToEnd(mergeBlock);
- rewriter.create<LLVM::BrOp>(loc, terminatorOperands, continueBlock);
+ LLVM::BrOp::create(rewriter, loc, terminatorOperands, continueBlock);
// Link current block to `true` and `false` blocks within the selection.
Block *trueBlock = condBrOp.getTrueBlock();
Block *falseBlock = condBrOp.getFalseBlock();
rewriter.setInsertionPointToEnd(currentBlock);
- rewriter.create<LLVM::CondBrOp>(loc, condBrOp.getCondition(), trueBlock,
- condBrOp.getTrueTargetOperands(),
- falseBlock,
- condBrOp.getFalseTargetOperands());
+ LLVM::CondBrOp::create(rewriter, loc, condBrOp.getCondition(), trueBlock,
+ condBrOp.getTrueTargetOperands(), falseBlock,
+ condBrOp.getFalseTargetOperands());
rewriter.eraseBlock(headerBlock);
rewriter.inlineRegionBefore(op.getBody(), continueBlock);
@@ -1521,8 +1524,8 @@ class TanPattern : public SPIRVToLLVMConversion<spirv::GLTanOp> {
return rewriter.notifyMatchFailure(tanOp, "type conversion failed");
Location loc = tanOp.getLoc();
- Value sin = rewriter.create<LLVM::SinOp>(loc, dstType, tanOp.getOperand());
- Value cos = rewriter.create<LLVM::CosOp>(loc, dstType, tanOp.getOperand());
+ Value sin = LLVM::SinOp::create(rewriter, loc, dstType, tanOp.getOperand());
+ Value cos = LLVM::CosOp::create(rewriter, loc, dstType, tanOp.getOperand());
rewriter.replaceOpWithNewOp<LLVM::FDivOp>(tanOp, dstType, sin, cos);
return success();
}
@@ -1549,13 +1552,13 @@ class TanhPattern : public SPIRVToLLVMConversion<spirv::GLTanhOp> {
Location loc = tanhOp.getLoc();
Value two = createFPConstant(loc, srcType, dstType, rewriter, 2.0);
Value multiplied =
- rewriter.create<LLVM::FMulOp>(loc, dstType, two, tanhOp.getOperand());
- Value exponential = rewriter.create<LLVM::ExpOp>(loc, dstType, multiplied);
+ LLVM::FMulOp::create(rewriter, loc, dstType, two, tanhOp.getOperand());
+ Value exponential = LLVM::ExpOp::create(rewriter, loc, dstType, multiplied);
Value one = createFPConstant(loc, srcType, dstType, rewriter, 1.0);
Value numerator =
- rewriter.create<LLVM::FSubOp>(loc, dstType, exponential, one);
+ LLVM::FSubOp::create(rewriter, loc, dstType, exponential, one);
Value denominator =
- rewriter.create<LLVM::FAddOp>(loc, dstType, exponential, one);
+ LLVM::FAddOp::create(rewriter, loc, dstType, exponential, one);
rewriter.replaceOpWithNewOp<LLVM::FDivOp>(tanhOp, dstType, numerator,
denominator);
return success();
@@ -1594,8 +1597,8 @@ class VariablePattern : public SPIRVToLLVMConversion<spirv::VariableOp> {
if (!elementType)
return rewriter.notifyMatchFailure(varOp, "type conversion failed");
Value allocated =
- rewriter.create<LLVM::AllocaOp>(loc, dstType, elementType, size);
- rewriter.create<LLVM::StoreOp>(loc, adaptor.getInitializer(), allocated);
+ LLVM::AllocaOp::create(rewriter, loc, dstType, elementType, size);
+ LLVM::StoreOp::create(rewriter, loc, adaptor.getInitializer(), allocated);
rewriter.replaceOp(varOp, allocated);
return success();
}
@@ -1656,7 +1659,7 @@ class FuncConversionPattern : public SPIRVToLLVMConversion<spirv::FuncOp> {
// Create a new `LLVMFuncOp`
Location loc = funcOp.getLoc();
StringRef name = funcOp.getName();
- auto newFuncOp = rewriter.create<LLVM::LLVMFuncOp>(loc, name, llvmType);
+ auto newFuncOp = LLVM::LLVMFuncOp::create(rewriter, loc, name, llvmType);
// Convert SPIR-V Function Control to equivalent LLVM function attribute
MLIRContext *context = funcOp.getContext();
@@ -1710,7 +1713,7 @@ class ModuleConversionPattern : public SPIRVToLLVMConversion<spirv::ModuleOp> {
ConversionPatternRewriter &rewriter) const override {
auto newModuleOp =
- rewriter.create<ModuleOp>(spvModuleOp.getLoc(), spvModuleOp.getName());
+ ModuleOp::create(rewriter, spvModuleOp.getLoc(), spvModuleOp.getName());
rewriter.inlineRegionBefore(spvModuleOp.getRegion(), newModuleOp.getBody());
// Remove the terminator block that was automatically added by builder
@@ -1751,7 +1754,7 @@ class VectorShufflePattern
auto componentsArray = components.getValue();
auto *context = rewriter.getContext();
auto llvmI32Type = IntegerType::get(context, 32);
- Value targetOp = rewriter.create<LLVM::PoisonOp>(loc, dstType);
+ Value targetOp = LLVM::PoisonOp::create(rewriter, loc, dstType);
for (unsigned i = 0; i < componentsArray.size(); i++) {
if (!isa<IntegerAttr>(componentsArray[i]))
return op.emitError("unable to support non-constant component");
@@ -1767,16 +1770,17 @@ class VectorShufflePattern
baseVector = vector2;
}
- Value dstIndex = rewriter.create<LLVM::ConstantOp>(
- loc, llvmI32Type, rewriter.getIntegerAttr(rewriter.getI32Type(), i));
- Value index = rewriter.create<LLVM::ConstantOp>(
- loc, llvmI32Type,
+ Value dstIndex = LLVM::ConstantOp::create(
+ rewriter, loc, llvmI32Type,
+ rewriter.getIntegerAttr(rewriter.getI32Type(), i));
+ Value index = LLVM::ConstantOp::create(
+ rewriter, loc, llvmI32Type,
rewriter.getIntegerAttr(rewriter.getI32Type(), indexVal - offsetVal));
- auto extractOp = rewriter.create<LLVM::ExtractElementOp>(
- loc, scalarType, baseVector, index);
- targetOp = rewriter.create<LLVM::InsertElementOp>(loc, dstType, targetOp,
- extractOp, dstIndex);
+ auto extractOp = LLVM::ExtractElementOp::create(rewriter, loc, scalarType,
+ baseVector, index);
+ targetOp = LLVM::InsertElementOp::create(rewriter, loc, dstType, targetOp,
+ extractOp, dstIndex);
}
rewriter.replaceOp(op, targetOp);
return success();
diff --git a/mlir/lib/Conversion/ShapeToStandard/ConvertShapeConstraints.cpp b/mlir/lib/Conversion/ShapeToStandard/ConvertShapeConstraints.cpp
index da9ad3dd67328..245e60b04ec31 100644
--- a/mlir/lib/Conversion/ShapeToStandard/ConvertShapeConstraints.cpp
+++ b/mlir/lib/Conversion/ShapeToStandard/ConvertShapeConstraints.cpp
@@ -32,7 +32,7 @@ class ConvertCstrRequireOp : public OpRewritePattern<shape::CstrRequireOp> {
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(shape::CstrRequireOp op,
PatternRewriter &rewriter) const override {
- rewriter.create<cf::AssertOp>(op.getLoc(), op.getPred(), op.getMsgAttr());
+ cf::AssertOp::create(rewriter, op.getLoc(), op.getPred(), op.getMsgAttr());
rewriter.replaceOpWithNewOp<shape::ConstWitnessOp>(op, true);
return success();
}
diff --git a/mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.cpp b/mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.cpp
index bbe1490137bf8..7025c5a7daf93 100644
--- a/mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.cpp
+++ b/mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.cpp
@@ -82,40 +82,40 @@ struct BroadcastOpConverter : public OpConversionPattern<BroadcastOp> {
// number of extent tensors and shifted offsets into them.
Value getBroadcastedDim(ImplicitLocOpBuilder lb, ValueRange extentTensors,
ValueRange rankDiffs, Value outputDimension) {
- Value one = lb.create<arith::ConstantIndexOp>(1);
+ Value one = arith::ConstantIndexOp::create(lb, 1);
Value broadcastedDim = one;
for (auto tup : llvm::zip(extentTensors, rankDiffs)) {
Value shape = std::get<0>(tup);
Value rankDiff = std::get<1>(tup);
- Value outOfBounds = lb.create<arith::CmpIOp>(arith::CmpIPredicate::ult,
- outputDimension, rankDiff);
+ Value outOfBounds = arith::CmpIOp::create(lb, arith::CmpIPredicate::ult,
+ outputDimension, rankDiff);
Type indexTy = lb.getIndexType();
broadcastedDim =
- lb.create<IfOp>(
- outOfBounds,
- [&](OpBuilder &b, Location loc) {
- b.create<scf::YieldOp>(loc, broadcastedDim);
- },
- [&](OpBuilder &b, Location loc) {
- // The broadcasting logic is:
- // - if one extent (here we arbitrarily choose the
- // extent from the greater-rank operand) is equal to 1,
- // then take the extent from the other operand
- // - otherwise, take the extent as-is.
- // Note that this logic remains correct in the presence
- // of dimensions of zero extent.
- Value lesserRankOperandDimension = b.create<arith::SubIOp>(
- loc, indexTy, outputDimension, rankDiff);
- Value lesserRankOperandExtent = b.create<tensor::ExtractOp>(
- loc, shape, ValueRange{lesserRankOperandDimension});
-
- Value dimIsOne =
- b.create<arith::CmpIOp>(loc, arith::CmpIPredicate::eq,
- lesserRankOperandExtent, one);
- Value dim = b.create<arith::SelectOp>(
- loc, dimIsOne, broadcastedDim, lesserRankOperandExtent);
- b.create<scf::YieldOp>(loc, dim);
- })
+ IfOp::create(
+ lb, outOfBounds,
+ [&](OpBuilder &b, Location loc) {
+ scf::YieldOp::create(b, loc, broadcastedDim);
+ },
+ [&](OpBuilder &b, Location loc) {
+ // The broadcasting logic is:
+ // - if one extent (here we arbitrarily choose the
+ // extent from the greater-rank operand) is equal to 1,
+ // then take the extent from the other operand
+ // - otherwise, take the extent as-is.
+ // Note that this logic remains correct in the presence
+ // of dimensions of zero extent.
+ Value lesserRankOperandDimension = arith::SubIOp::create(
+ b, loc, indexTy, outputDimension, rankDiff);
+ Value lesserRankOperandExtent = tensor::ExtractOp::create(
+ b, loc, shape, ValueRange{lesserRankOperandDimension});
+
+ Value dimIsOne =
+ arith::CmpIOp::create(b, loc, arith::CmpIPredicate::eq,
+ lesserRankOperandExtent, one);
+ Value dim = arith::SelectOp::create(
+ b, loc, dimIsOne, broadcastedDim, lesserRankOperandExtent);
+ scf::YieldOp::create(b, loc, dim);
+ })
.getResult(0);
}
return broadcastedDim;
@@ -133,7 +133,7 @@ LogicalResult BroadcastOpConverter::matchAndRewrite(
auto loc = op.getLoc();
ImplicitLocOpBuilder lb(loc, rewriter);
- Value zero = lb.create<arith::ConstantIndexOp>(0);
+ Value zero = arith::ConstantIndexOp::create(lb, 0);
Type indexTy = lb.getIndexType();
// Save all the ranks for bounds checking. Because this is a tensor
@@ -141,31 +141,31 @@ LogicalResult BroadcastOpConverter::matchAndRewrite(
// dimension in the tensor.
SmallVector<Value> ranks, rankDiffs;
llvm::append_range(ranks, llvm::map_range(adaptor.getShapes(), [&](Value v) {
- return lb.create<tensor::DimOp>(v, zero);
+ return tensor::DimOp::create(lb, v, zero);
}));
// Find the maximum rank
Value maxRank = ranks.front();
for (Value v : llvm::drop_begin(ranks, 1)) {
- maxRank = lb.create<arith::MaxUIOp>(v, maxRank);
+ maxRank = arith::MaxUIOp::create(lb, v, maxRank);
}
// Calculate the difference of ranks and the maximum rank for later offsets.
llvm::append_range(rankDiffs, llvm::map_range(ranks, [&](Value v) {
- return lb.create<arith::SubIOp>(indexTy, maxRank, v);
+ return arith::SubIOp::create(lb, indexTy, maxRank, v);
}));
- Value replacement = lb.create<tensor::GenerateOp>(
- getExtentTensorType(lb.getContext()), ValueRange{maxRank},
+ Value replacement = tensor::GenerateOp::create(
+ lb, getExtentTensorType(lb.getContext()), ValueRange{maxRank},
[&](OpBuilder &b, Location loc, ValueRange args) {
Value broadcastedDim =
getBroadcastedDim(ImplicitLocOpBuilder(loc, b), adaptor.getShapes(),
rankDiffs, args[0]);
- b.create<tensor::YieldOp>(loc, broadcastedDim);
+ tensor::YieldOp::create(b, loc, broadcastedDim);
});
if (replacement.getType() != op.getType())
- replacement = lb.create<tensor::CastOp>(op.getType(), replacement);
+ replacement = tensor::CastOp::create(lb, op.getType(), replacement);
rewriter.replaceOp(op, replacement);
return success();
}
@@ -193,13 +193,13 @@ LogicalResult ConstShapeOpConverter::matchAndRewrite(
auto loc = op.getLoc();
SmallVector<Value, 4> extentOperands;
for (auto extent : op.getShape()) {
- extentOperands.push_back(
- rewriter.create<arith::ConstantIndexOp>(loc, extent.getLimitedValue()));
+ extentOperands.push_back(arith::ConstantIndexOp::create(
+ rewriter, loc, extent.getLimitedValue()));
}
Type resultTy =
RankedTensorType::get({op.getShape().size()}, rewriter.getIndexType());
Value tensor =
- rewriter.create<tensor::FromElementsOp>(loc, resultTy, extentOperands);
+ tensor::FromElementsOp::create(rewriter, loc, resultTy, extentOperands);
rewriter.replaceOpWithNewOp<tensor::CastOp>(op, resultTy, tensor);
return success();
}
@@ -245,8 +245,8 @@ LogicalResult IsBroadcastableOpConverter::matchAndRewrite(
auto loc = op.getLoc();
ImplicitLocOpBuilder lb(loc, rewriter);
- Value zero = lb.create<arith::ConstantIndexOp>(0);
- Value one = lb.create<arith::ConstantIndexOp>(1);
+ Value zero = arith::ConstantIndexOp::create(lb, 0);
+ Value one = arith::ConstantIndexOp::create(lb, 1);
Type indexTy = lb.getIndexType();
// Save all the ranks for bounds checking. Because this is a tensor
@@ -254,26 +254,26 @@ LogicalResult IsBroadcastableOpConverter::matchAndRewrite(
// dimension in the tensor.
SmallVector<Value> ranks, rankDiffs;
llvm::append_range(ranks, llvm::map_range(adaptor.getShapes(), [&](Value v) {
- return lb.create<tensor::DimOp>(v, zero);
+ return tensor::DimOp::create(lb, v, zero);
}));
// Find the maximum rank
Value maxRank = ranks.front();
for (Value v : llvm::drop_begin(ranks, 1)) {
- maxRank = lb.create<arith::MaxUIOp>(v, maxRank);
+ maxRank = arith::MaxUIOp::create(lb, v, maxRank);
}
// Calculate the difference of ranks and the maximum rank for later offsets.
llvm::append_range(rankDiffs, llvm::map_range(ranks, [&](Value v) {
- return lb.create<arith::SubIOp>(indexTy, maxRank, v);
+ return arith::SubIOp::create(lb, indexTy, maxRank, v);
}));
Type i1Ty = rewriter.getI1Type();
- Value trueVal =
- rewriter.create<arith::ConstantOp>(loc, i1Ty, rewriter.getBoolAttr(true));
+ Value trueVal = arith::ConstantOp::create(rewriter, loc, i1Ty,
+ rewriter.getBoolAttr(true));
- auto reduceResult = lb.create<ForOp>(
- loc, zero, maxRank, one, ValueRange{trueVal},
+ auto reduceResult = ForOp::create(
+ lb, loc, zero, maxRank, one, ValueRange{trueVal},
[&](OpBuilder &b, Location loc, Value iv, ValueRange iterArgs) {
// Find a non-1 dim, if it exists. Note that the first part of this
// could reuse the Broadcast lowering entirely, but we redo the work
@@ -285,38 +285,38 @@ LogicalResult IsBroadcastableOpConverter::matchAndRewrite(
for (auto tup : llvm::zip(adaptor.getShapes(), rankDiffs)) {
Value shape, rankDiff;
std::tie(shape, rankDiff) = tup;
- Value outOfBounds = b.create<arith::CmpIOp>(
- loc, arith::CmpIPredicate::ult, iv, rankDiff);
+ Value outOfBounds = arith::CmpIOp::create(
+ b, loc, arith::CmpIPredicate::ult, iv, rankDiff);
broadcastable =
- b.create<IfOp>(
- loc, outOfBounds,
- [&](OpBuilder &b, Location loc) {
- // Non existent dimensions are always broadcastable
- b.create<scf::YieldOp>(loc, broadcastable);
- },
- [&](OpBuilder &b, Location loc) {
- // Every value needs to be either 1, or the same non-1
- // value to be broadcastable in this dim.
- Value operandDimension =
- b.create<arith::SubIOp>(loc, indexTy, iv, rankDiff);
- Value dimensionExtent = b.create<tensor::ExtractOp>(
- loc, shape, ValueRange{operandDimension});
-
- Value equalOne = b.create<arith::CmpIOp>(
- loc, arith::CmpIPredicate::eq, dimensionExtent, one);
- Value equalBroadcasted = b.create<arith::CmpIOp>(
- loc, arith::CmpIPredicate::eq, dimensionExtent,
- broadcastedDim);
- Value result = b.create<arith::AndIOp>(
- loc, broadcastable,
- b.create<arith::OrIOp>(loc, equalOne,
- equalBroadcasted));
- b.create<scf::YieldOp>(loc, result);
- })
+ IfOp::create(
+ b, loc, outOfBounds,
+ [&](OpBuilder &b, Location loc) {
+ // Non existent dimensions are always broadcastable
+ scf::YieldOp::create(b, loc, broadcastable);
+ },
+ [&](OpBuilder &b, Location loc) {
+ // Every value needs to be either 1, or the same non-1
+ // value to be broadcastable in this dim.
+ Value operandDimension =
+ arith::SubIOp::create(b, loc, indexTy, iv, rankDiff);
+ Value dimensionExtent = tensor::ExtractOp::create(
+ b, loc, shape, ValueRange{operandDimension});
+
+ Value equalOne = arith::CmpIOp::create(
+ b, loc, arith::CmpIPredicate::eq, dimensionExtent, one);
+ Value equalBroadcasted =
+ arith::CmpIOp::create(b, loc, arith::CmpIPredicate::eq,
+ dimensionExtent, broadcastedDim);
+ Value result = arith::AndIOp::create(
+ b, loc, broadcastable,
+ arith::OrIOp::create(b, loc, equalOne,
+ equalBroadcasted));
+ scf::YieldOp::create(b, loc, result);
+ })
.getResult(0);
}
- b.create<scf::YieldOp>(loc, broadcastable);
+ scf::YieldOp::create(b, loc, broadcastable);
});
rewriter.replaceOp(op, reduceResult.getResults().front());
@@ -339,7 +339,7 @@ DimOpConverter::matchAndRewrite(DimOp op, OpAdaptor adaptor,
// Lower to dim(X, i) to get_extent(shape_of(X), i) and rely on further
// lowerings. This can be further optimized if needed to avoid intermediate
// steps.
- auto shapeOf = rewriter.create<shape::ShapeOfOp>(op.getLoc(), op.getValue());
+ auto shapeOf = shape::ShapeOfOp::create(rewriter, op.getLoc(), op.getValue());
rewriter.replaceOpWithNewOp<shape::GetExtentOp>(op, op.getType(), shapeOf,
op.getIndex());
return success();
@@ -421,16 +421,17 @@ ReduceOpConverter::matchAndRewrite(shape::ReduceOp op, OpAdaptor adaptor,
auto loc = op.getLoc();
- Value zero = rewriter.create<arith::ConstantIndexOp>(loc, 0);
- Value one = rewriter.create<arith::ConstantIndexOp>(loc, 1);
+ Value zero = arith::ConstantIndexOp::create(rewriter, loc, 0);
+ Value one = arith::ConstantIndexOp::create(rewriter, loc, 1);
Type indexTy = rewriter.getIndexType();
Value rank =
- rewriter.create<tensor::DimOp>(loc, indexTy, adaptor.getShape(), zero);
+ tensor::DimOp::create(rewriter, loc, indexTy, adaptor.getShape(), zero);
- auto loop = rewriter.create<scf::ForOp>(
- loc, zero, rank, one, op.getInitVals(),
+ auto loop = scf::ForOp::create(
+ rewriter, loc, zero, rank, one, op.getInitVals(),
[&](OpBuilder &b, Location loc, Value iv, ValueRange args) {
- Value extent = b.create<tensor::ExtractOp>(loc, adaptor.getShape(), iv);
+ Value extent =
+ tensor::ExtractOp::create(b, loc, adaptor.getShape(), iv);
SmallVector<Value, 2> mappedValues{iv, extent};
mappedValues.append(args.begin(), args.end());
@@ -444,7 +445,7 @@ ReduceOpConverter::matchAndRewrite(shape::ReduceOp op, OpAdaptor adaptor,
SmallVector<Value, 2> mappedResults;
for (auto result : reduceBody->getTerminator()->getOperands())
mappedResults.push_back(mapping.lookup(result));
- b.create<scf::YieldOp>(loc, mappedResults);
+ scf::YieldOp::create(b, loc, mappedResults);
});
rewriter.replaceOp(op, loop.getResults());
@@ -507,44 +508,44 @@ ShapeEqOpConverter::matchAndRewrite(ShapeEqOp op, OpAdaptor adaptor,
auto loc = op.getLoc();
Type indexTy = rewriter.getIndexType();
- Value zero = rewriter.create<arith::ConstantIndexOp>(loc, 0);
+ Value zero = arith::ConstantIndexOp::create(rewriter, loc, 0);
Value firstShape = adaptor.getShapes().front();
Value firstRank =
- rewriter.create<tensor::DimOp>(loc, indexTy, firstShape, zero);
+ tensor::DimOp::create(rewriter, loc, indexTy, firstShape, zero);
Value result = nullptr;
// Generate a linear sequence of compares, all with firstShape as lhs.
for (Value shape : adaptor.getShapes().drop_front(1)) {
- Value rank = rewriter.create<tensor::DimOp>(loc, indexTy, shape, zero);
- Value eqRank = rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::eq,
- firstRank, rank);
- auto same = rewriter.create<IfOp>(
- loc, eqRank,
+ Value rank = tensor::DimOp::create(rewriter, loc, indexTy, shape, zero);
+ Value eqRank = arith::CmpIOp::create(
+ rewriter, loc, arith::CmpIPredicate::eq, firstRank, rank);
+ auto same = IfOp::create(
+ rewriter, loc, eqRank,
[&](OpBuilder &b, Location loc) {
- Value one = b.create<arith::ConstantIndexOp>(loc, 1);
+ Value one = arith::ConstantIndexOp::create(b, loc, 1);
Value init =
- b.create<arith::ConstantOp>(loc, i1Ty, b.getBoolAttr(true));
- auto loop = b.create<scf::ForOp>(
- loc, zero, firstRank, one, ValueRange{init},
+ arith::ConstantOp::create(b, loc, i1Ty, b.getBoolAttr(true));
+ auto loop = scf::ForOp::create(
+ b, loc, zero, firstRank, one, ValueRange{init},
[&](OpBuilder &b, Location nestedLoc, Value iv, ValueRange args) {
Value conj = args[0];
Value lhsExtent =
- b.create<tensor::ExtractOp>(loc, firstShape, iv);
- Value rhsExtent = b.create<tensor::ExtractOp>(loc, shape, iv);
- Value eqExtent = b.create<arith::CmpIOp>(
- loc, arith::CmpIPredicate::eq, lhsExtent, rhsExtent);
- Value conjNext = b.create<arith::AndIOp>(loc, conj, eqExtent);
- b.create<scf::YieldOp>(loc, ValueRange({conjNext}));
+ tensor::ExtractOp::create(b, loc, firstShape, iv);
+ Value rhsExtent = tensor::ExtractOp::create(b, loc, shape, iv);
+ Value eqExtent = arith::CmpIOp::create(
+ b, loc, arith::CmpIPredicate::eq, lhsExtent, rhsExtent);
+ Value conjNext = arith::AndIOp::create(b, loc, conj, eqExtent);
+ scf::YieldOp::create(b, loc, ValueRange({conjNext}));
});
- b.create<scf::YieldOp>(loc, loop.getResults());
+ scf::YieldOp::create(b, loc, loop.getResults());
},
[&](OpBuilder &b, Location loc) {
Value result =
- b.create<arith::ConstantOp>(loc, i1Ty, b.getBoolAttr(false));
- b.create<scf::YieldOp>(loc, result);
+ arith::ConstantOp::create(b, loc, i1Ty, b.getBoolAttr(false));
+ scf::YieldOp::create(b, loc, result);
});
result = !result ? same.getResult(0)
- : rewriter.create<arith::AndIOp>(loc, result,
- same.getResult(0));
+ : arith::AndIOp::create(rewriter, loc, result,
+ same.getResult(0));
}
rewriter.replaceOp(op, result);
return success();
@@ -581,18 +582,18 @@ LogicalResult ShapeOfOpConversion::matchAndRewrite(
int64_t rank = rankedTensorTy.getRank();
for (int64_t i = 0; i < rank; i++) {
if (rankedTensorTy.isDynamicDim(i)) {
- Value extent = rewriter.create<tensor::DimOp>(loc, tensor, i);
+ Value extent = tensor::DimOp::create(rewriter, loc, tensor, i);
extentValues.push_back(extent);
} else {
- Value extent = rewriter.create<arith::ConstantIndexOp>(
- loc, rankedTensorTy.getDimSize(i));
+ Value extent = arith::ConstantIndexOp::create(
+ rewriter, loc, rankedTensorTy.getDimSize(i));
extentValues.push_back(extent);
}
}
// Materialize extent tensor.
- Value staticExtentTensor = rewriter.create<tensor::FromElementsOp>(
- loc, RankedTensorType::get({rank}, rewriter.getIndexType()),
+ Value staticExtentTensor = tensor::FromElementsOp::create(
+ rewriter, loc, RankedTensorType::get({rank}, rewriter.getIndexType()),
extentValues);
rewriter.replaceOpWithNewOp<tensor::CastOp>(op, op.getType(),
staticExtentTensor);
@@ -601,13 +602,13 @@ LogicalResult ShapeOfOpConversion::matchAndRewrite(
// Lower to `tensor.generate` otherwise.
auto *ctx = rewriter.getContext();
- Value rank = rewriter.create<tensor::RankOp>(loc, tensor);
+ Value rank = tensor::RankOp::create(rewriter, loc, tensor);
rewriter.replaceOpWithNewOp<tensor::GenerateOp>(
op, getExtentTensorType(ctx), ValueRange{rank},
[&](OpBuilder &b, Location loc, ValueRange args) {
Value dim = args.front();
- Value extent = b.create<tensor::DimOp>(loc, tensor, dim);
- b.create<tensor::YieldOp>(loc, extent);
+ Value extent = tensor::DimOp::create(b, loc, tensor, dim);
+ tensor::YieldOp::create(b, loc, extent);
});
return success();
@@ -634,22 +635,22 @@ LogicalResult SplitAtOpConversion::matchAndRewrite(
return failure();
ImplicitLocOpBuilder b(op.getLoc(), rewriter);
- Value zero = b.create<arith::ConstantIndexOp>(0);
- Value rank = b.create<tensor::DimOp>(adaptor.getOperand(), zero);
+ Value zero = arith::ConstantIndexOp::create(b, 0);
+ Value rank = tensor::DimOp::create(b, adaptor.getOperand(), zero);
// index < 0 ? index + rank : index
Value originalIndex = adaptor.getIndex();
- Value add = b.create<arith::AddIOp>(originalIndex, rank);
+ Value add = arith::AddIOp::create(b, originalIndex, rank);
Value indexIsNegative =
- b.create<arith::CmpIOp>(arith::CmpIPredicate::slt, originalIndex, zero);
- Value index = b.create<arith::SelectOp>(indexIsNegative, add, originalIndex);
+ arith::CmpIOp::create(b, arith::CmpIPredicate::slt, originalIndex, zero);
+ Value index = arith::SelectOp::create(b, indexIsNegative, add, originalIndex);
- Value one = b.create<arith::ConstantIndexOp>(1);
+ Value one = arith::ConstantIndexOp::create(b, 1);
Value head =
- b.create<tensor::ExtractSliceOp>(adaptor.getOperand(), zero, index, one);
- Value tailSize = b.create<arith::SubIOp>(rank, index);
- Value tail = b.create<tensor::ExtractSliceOp>(adaptor.getOperand(), index,
- tailSize, one);
+ tensor::ExtractSliceOp::create(b, adaptor.getOperand(), zero, index, one);
+ Value tailSize = arith::SubIOp::create(b, rank, index);
+ Value tail = tensor::ExtractSliceOp::create(b, adaptor.getOperand(), index,
+ tailSize, one);
rewriter.replaceOp(op, {head, tail});
return success();
}
diff --git a/mlir/lib/Conversion/TensorToSPIRV/TensorToSPIRV.cpp b/mlir/lib/Conversion/TensorToSPIRV/TensorToSPIRV.cpp
index 2c4d27502a521..f24972f6b6ee1 100644
--- a/mlir/lib/Conversion/TensorToSPIRV/TensorToSPIRV.cpp
+++ b/mlir/lib/Conversion/TensorToSPIRV/TensorToSPIRV.cpp
@@ -68,10 +68,10 @@ class TensorExtractPattern final
// We could use the initializer directly; but certain driver compilers
// have bugs dealing with that. So for now, use spirv.Store for
// initialization.
- varOp = rewriter.create<spirv::VariableOp>(loc, varType,
- spirv::StorageClass::Function,
- /*initializer=*/nullptr);
- rewriter.create<spirv::StoreOp>(loc, varOp, adaptor.getTensor());
+ varOp = spirv::VariableOp::create(rewriter, loc, varType,
+ spirv::StorageClass::Function,
+ /*initializer=*/nullptr);
+ spirv::StoreOp::create(rewriter, loc, varOp, adaptor.getTensor());
} else {
// Need to store the value to the local variable. It's questionable
// whether we want to support such case though.
@@ -83,7 +83,7 @@ class TensorExtractPattern final
Value index = spirv::linearizeIndex(adaptor.getIndices(), strides,
/*offset=*/0, indexType, loc, rewriter);
- auto acOp = rewriter.create<spirv::AccessChainOp>(loc, varOp, index);
+ auto acOp = spirv::AccessChainOp::create(rewriter, loc, varOp, index);
rewriter.replaceOpWithNewOp<spirv::LoadOp>(extractOp, acOp);
diff --git a/mlir/lib/Conversion/TosaToArith/TosaToArith.cpp b/mlir/lib/Conversion/TosaToArith/TosaToArith.cpp
index 40ad63610e23f..044b725c7d805 100644
--- a/mlir/lib/Conversion/TosaToArith/TosaToArith.cpp
+++ b/mlir/lib/Conversion/TosaToArith/TosaToArith.cpp
@@ -51,8 +51,8 @@ TypedAttr getConstantAttr(Type type, int64_t value, PatternRewriter &rewriter) {
Value getConstantValue(Location loc, Type type, int64_t value,
PatternRewriter &rewriter) {
- return rewriter.create<arith::ConstantOp>(
- loc, getConstantAttr(type, value, rewriter));
+ return arith::ConstantOp::create(rewriter, loc,
+ getConstantAttr(type, value, rewriter));
}
// This converts the TOSA ApplyScale operator to a set of arithmetic ops,
@@ -82,41 +82,41 @@ class ApplyScaleGenericOpConverter
Value one64 = getConstantValue(loc, i64Ty, 1, rewriter);
Value thirtyOne32 = getConstantValue(loc, i32Ty, 31, rewriter);
- Value shift32 = rewriter.create<arith::ExtUIOp>(loc, i32Ty, op.getShift());
+ Value shift32 = arith::ExtUIOp::create(rewriter, loc, i32Ty, op.getShift());
// Compute the multiplication in 64-bits then select the high / low parts.
Value value64 = value;
if (getElementTypeOrSelf(valueTy) != rewriter.getI64Type())
- value64 = rewriter.create<arith::ExtSIOp>(loc, i64Ty, value);
+ value64 = arith::ExtSIOp::create(rewriter, loc, i64Ty, value);
Value multiplier64 =
- rewriter.create<arith::ExtSIOp>(loc, i64Ty, multiplier32);
+ arith::ExtSIOp::create(rewriter, loc, i64Ty, multiplier32);
Value multiply64 =
- rewriter.create<arith::MulIOp>(loc, value64, multiplier64);
+ arith::MulIOp::create(rewriter, loc, value64, multiplier64);
// Apply normal rounding.
- Value shift64 = rewriter.create<arith::ExtUIOp>(loc, i64Ty, shift32);
- Value round = rewriter.create<arith::ShLIOp>(loc, one64, shift64);
- round = rewriter.create<arith::ShRUIOp>(loc, round, one64);
- multiply64 = rewriter.create<arith::AddIOp>(loc, multiply64, round);
+ Value shift64 = arith::ExtUIOp::create(rewriter, loc, i64Ty, shift32);
+ Value round = arith::ShLIOp::create(rewriter, loc, one64, shift64);
+ round = arith::ShRUIOp::create(rewriter, loc, round, one64);
+ multiply64 = arith::AddIOp::create(rewriter, loc, multiply64, round);
// Apply double rounding if necessary.
if (op.getRoundingMode() == "DOUBLE_ROUND") {
int64_t roundInt = 1 << 30;
Value roundUp = getConstantValue(loc, i64Ty, roundInt, rewriter);
Value roundDown = getConstantValue(loc, i64Ty, -roundInt, rewriter);
- Value positive = rewriter.create<arith::CmpIOp>(
- loc, arith::CmpIPredicate::sge, value, zero);
+ Value positive = arith::CmpIOp::create(
+ rewriter, loc, arith::CmpIPredicate::sge, value, zero);
Value dir =
- rewriter.create<arith::SelectOp>(loc, positive, roundUp, roundDown);
- Value val = rewriter.create<arith::AddIOp>(loc, dir, multiply64);
- Value valid = rewriter.create<arith::CmpIOp>(
- loc, arith::CmpIPredicate::sgt, shift32, thirtyOne32);
+ arith::SelectOp::create(rewriter, loc, positive, roundUp, roundDown);
+ Value val = arith::AddIOp::create(rewriter, loc, dir, multiply64);
+ Value valid = arith::CmpIOp::create(
+ rewriter, loc, arith::CmpIPredicate::sgt, shift32, thirtyOne32);
multiply64 =
- rewriter.create<arith::SelectOp>(loc, valid, val, multiply64);
+ arith::SelectOp::create(rewriter, loc, valid, val, multiply64);
}
- Value result64 = rewriter.create<arith::ShRSIOp>(loc, multiply64, shift64);
- Value result32 = rewriter.create<arith::TruncIOp>(loc, i32Ty, result64);
+ Value result64 = arith::ShRSIOp::create(rewriter, loc, multiply64, shift64);
+ Value result32 = arith::TruncIOp::create(rewriter, loc, i32Ty, result64);
rewriter.replaceOp(op, result32);
return success();
@@ -146,7 +146,7 @@ class ApplyScale32BitOpConverter : public OpRewritePattern<tosa::ApplyScaleOp> {
Value value32 = op.getValue();
Value multiplier32 = op.getMultiplier();
- Value shift32 = rewriter.create<arith::ExtUIOp>(loc, i32Ty, op.getShift());
+ Value shift32 = arith::ExtUIOp::create(rewriter, loc, i32Ty, op.getShift());
// Constants used during the scaling operation.
Value zero32 = getConstantValue(loc, i32Ty, 0, rewriter);
@@ -158,86 +158,87 @@ class ApplyScale32BitOpConverter : public OpRewritePattern<tosa::ApplyScaleOp> {
// Compute the multiplication in 64-bits then select the high / low parts.
// Grab out the high/low of the computation
auto value64 =
- rewriter.create<arith::MulSIExtendedOp>(loc, value32, multiplier32);
+ arith::MulSIExtendedOp::create(rewriter, loc, value32, multiplier32);
Value low32 = value64.getLow();
Value high32 = value64.getHigh();
// Determine the direction and amount to shift the high bits.
- Value shiftOver32 = rewriter.create<arith::CmpIOp>(
- loc, arith::CmpIPredicate::sge, shift32, thirtyTwo32);
- Value roundHighBits = rewriter.create<arith::CmpIOp>(
- loc, arith::CmpIPredicate::sgt, shift32, thirtyTwo32);
+ Value shiftOver32 = arith::CmpIOp::create(
+ rewriter, loc, arith::CmpIPredicate::sge, shift32, thirtyTwo32);
+ Value roundHighBits = arith::CmpIOp::create(
+ rewriter, loc, arith::CmpIPredicate::sgt, shift32, thirtyTwo32);
Value shiftHighL =
- rewriter.create<arith::SubIOp>(loc, thirtyTwo32, shift32);
+ arith::SubIOp::create(rewriter, loc, thirtyTwo32, shift32);
Value shiftHighR =
- rewriter.create<arith::SubIOp>(loc, shift32, thirtyTwo32);
+ arith::SubIOp::create(rewriter, loc, shift32, thirtyTwo32);
shiftHighL =
- rewriter.create<arith::SelectOp>(loc, shiftOver32, zero32, shiftHighL);
+ arith::SelectOp::create(rewriter, loc, shiftOver32, zero32, shiftHighL);
shiftHighR =
- rewriter.create<arith::SelectOp>(loc, shiftOver32, shiftHighR, zero32);
+ arith::SelectOp::create(rewriter, loc, shiftOver32, shiftHighR, zero32);
// Conditionally perform our double round.
if (op.getRoundingMode() == "DOUBLE_ROUND") {
Value negOne32 = getConstantValue(loc, i32Ty, -1, rewriter);
- Value valuePositive = rewriter.create<arith::CmpIOp>(
- loc, arith::CmpIPredicate::sge, value32, zero32);
+ Value valuePositive = arith::CmpIOp::create(
+ rewriter, loc, arith::CmpIPredicate::sge, value32, zero32);
- Value roundDir =
- rewriter.create<arith::SelectOp>(loc, valuePositive, one32, negOne32);
+ Value roundDir = arith::SelectOp::create(rewriter, loc, valuePositive,
+ one32, negOne32);
roundDir =
- rewriter.create<arith::SelectOp>(loc, shiftOver32, roundDir, zero32);
+ arith::SelectOp::create(rewriter, loc, shiftOver32, roundDir, zero32);
- Value shiftLow = rewriter.create<arith::ShRUIOp>(loc, low32, thirty32);
- Value rounded = rewriter.create<arith::AddIOp>(loc, shiftLow, roundDir);
- Value carry = rewriter.create<arith::ShRSIOp>(loc, rounded, two32);
+ Value shiftLow = arith::ShRUIOp::create(rewriter, loc, low32, thirty32);
+ Value rounded = arith::AddIOp::create(rewriter, loc, shiftLow, roundDir);
+ Value carry = arith::ShRSIOp::create(rewriter, loc, rounded, two32);
Value shiftRound =
- rewriter.create<arith::ShLIOp>(loc, roundDir, thirty32);
+ arith::ShLIOp::create(rewriter, loc, roundDir, thirty32);
- low32 = rewriter.create<arith::AddIOp>(loc, low32, shiftRound);
- high32 = rewriter.create<arith::AddIOp>(loc, high32, carry);
+ low32 = arith::AddIOp::create(rewriter, loc, low32, shiftRound);
+ high32 = arith::AddIOp::create(rewriter, loc, high32, carry);
}
// Conditionally apply rounding in the low bits.
{
- Value shiftSubOne = rewriter.create<arith::SubIOp>(loc, shift32, one32);
- Value roundBit = rewriter.create<arith::ShLIOp>(loc, one32, shiftSubOne);
- roundBit = rewriter.create<arith::SelectOp>(loc, roundHighBits, zero32,
- roundBit);
-
- Value newLow32 = rewriter.create<arith::AddIOp>(loc, low32, roundBit);
- Value wasRounded = rewriter.create<arith::CmpIOp>(
- loc, arith::CmpIPredicate::ugt, low32, newLow32);
+ Value shiftSubOne = arith::SubIOp::create(rewriter, loc, shift32, one32);
+ Value roundBit = arith::ShLIOp::create(rewriter, loc, one32, shiftSubOne);
+ roundBit = arith::SelectOp::create(rewriter, loc, roundHighBits, zero32,
+ roundBit);
+
+ Value newLow32 = arith::AddIOp::create(rewriter, loc, low32, roundBit);
+ Value wasRounded = arith::CmpIOp::create(
+ rewriter, loc, arith::CmpIPredicate::ugt, low32, newLow32);
low32 = newLow32;
- Value rounded32 = rewriter.create<arith::ExtUIOp>(loc, i32Ty, wasRounded);
- high32 = rewriter.create<arith::AddIOp>(loc, high32, rounded32);
+ Value rounded32 =
+ arith::ExtUIOp::create(rewriter, loc, i32Ty, wasRounded);
+ high32 = arith::AddIOp::create(rewriter, loc, high32, rounded32);
}
// Conditionally apply rounding in the high bits.
{
Value shiftSubOne =
- rewriter.create<arith::SubIOp>(loc, shiftHighR, one32);
- Value roundBit = rewriter.create<arith::ShLIOp>(loc, one32, shiftSubOne);
- roundBit = rewriter.create<arith::SelectOp>(loc, roundHighBits, roundBit,
- zero32);
- high32 = rewriter.create<arith::AddIOp>(loc, high32, roundBit);
+ arith::SubIOp::create(rewriter, loc, shiftHighR, one32);
+ Value roundBit = arith::ShLIOp::create(rewriter, loc, one32, shiftSubOne);
+ roundBit = arith::SelectOp::create(rewriter, loc, roundHighBits, roundBit,
+ zero32);
+ high32 = arith::AddIOp::create(rewriter, loc, high32, roundBit);
}
// Combine the correct high/low bits into the final rescale result.
- high32 = rewriter.create<arith::ShLIOp>(loc, high32, shiftHighL);
- high32 = rewriter.create<arith::ShRSIOp>(loc, high32, shiftHighR);
- low32 = rewriter.create<arith::ShRUIOp>(loc, low32, shift32);
- low32 = rewriter.create<arith::SelectOp>(loc, shiftOver32, zero32, low32);
+ high32 = arith::ShLIOp::create(rewriter, loc, high32, shiftHighL);
+ high32 = arith::ShRSIOp::create(rewriter, loc, high32, shiftHighR);
+ low32 = arith::ShRUIOp::create(rewriter, loc, low32, shift32);
+ low32 = arith::SelectOp::create(rewriter, loc, shiftOver32, zero32, low32);
// Apply the rounding behavior and shift to the final alignment.
- Value result = rewriter.create<arith::AddIOp>(loc, low32, high32);
+ Value result = arith::AddIOp::create(rewriter, loc, low32, high32);
// Truncate if necessary.
if (!getElementTypeOrSelf(resultTy).isInteger(32)) {
- result = rewriter.create<arith::TruncIOp>(loc, resultTy, result);
+ result = arith::TruncIOp::create(rewriter, loc, resultTy, result);
}
rewriter.replaceOp(op, result);
diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
index 2f608bbd637b4..ec55091cd7eb8 100644
--- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
+++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
@@ -70,14 +70,14 @@ materializeBinaryNanCheckIfRequired(OpTy op, PatternRewriter &rewriter,
return result;
// Unordered comparison of NaN against itself will always return true.
- Value lhsIsNaN = rewriter.create<arith::CmpFOp>(
- op.getLoc(), arith::CmpFPredicate::UNO, lhs, lhs);
- Value rhsIsNaN = rewriter.create<arith::CmpFOp>(
- op.getLoc(), arith::CmpFPredicate::UNO, rhs, rhs);
+ Value lhsIsNaN = arith::CmpFOp::create(rewriter, op.getLoc(),
+ arith::CmpFPredicate::UNO, lhs, lhs);
+ Value rhsIsNaN = arith::CmpFOp::create(rewriter, op.getLoc(),
+ arith::CmpFPredicate::UNO, rhs, rhs);
Value rhsOrResult =
- rewriter.create<arith::SelectOp>(op.getLoc(), lhsIsNaN, rhs, result);
- return rewriter.create<arith::SelectOp>(op.getLoc(), rhsIsNaN, lhs,
- rhsOrResult);
+ arith::SelectOp::create(rewriter, op.getLoc(), lhsIsNaN, rhs, result);
+ return arith::SelectOp::create(rewriter, op.getLoc(), rhsIsNaN, lhs,
+ rhsOrResult);
}
static Value createLinalgBodyCalculationForElementwiseOp(
@@ -89,38 +89,38 @@ static Value createLinalgBodyCalculationForElementwiseOp(
// tosa::AbsOp
if (isa<tosa::AbsOp>(op) && isa<FloatType>(elementTy))
- return rewriter.create<math::AbsFOp>(loc, resultTypes, args);
+ return math::AbsFOp::create(rewriter, loc, resultTypes, args);
if (isa<tosa::AbsOp>(op) && isa<IntegerType>(elementTy)) {
- auto zero = rewriter.create<arith::ConstantOp>(
- loc, rewriter.getZeroAttr(elementTy));
- auto neg = rewriter.create<arith::SubIOp>(loc, zero, args[0]);
- return rewriter.create<arith::MaxSIOp>(loc, args[0], neg);
+ auto zero = arith::ConstantOp::create(rewriter, loc,
+ rewriter.getZeroAttr(elementTy));
+ auto neg = arith::SubIOp::create(rewriter, loc, zero, args[0]);
+ return arith::MaxSIOp::create(rewriter, loc, args[0], neg);
}
// tosa::AddOp
if (isa<tosa::AddOp>(op) && isa<FloatType>(elementTy))
- return rewriter.create<arith::AddFOp>(loc, resultTypes, args);
+ return arith::AddFOp::create(rewriter, loc, resultTypes, args);
if (isa<tosa::AddOp>(op) && isa<IntegerType>(elementTy))
- return rewriter.create<arith::AddIOp>(loc, resultTypes, args);
+ return arith::AddIOp::create(rewriter, loc, resultTypes, args);
// tosa::SubOp
if (isa<tosa::SubOp>(op) && isa<FloatType>(elementTy))
- return rewriter.create<arith::SubFOp>(loc, resultTypes, args);
+ return arith::SubFOp::create(rewriter, loc, resultTypes, args);
if (isa<tosa::SubOp>(op) && isa<IntegerType>(elementTy))
- return rewriter.create<arith::SubIOp>(loc, resultTypes, args);
+ return arith::SubIOp::create(rewriter, loc, resultTypes, args);
// tosa::IntDivOp
if (isa<tosa::IntDivOp>(op) && isa<IntegerType>(elementTy))
- return rewriter.create<arith::DivSIOp>(loc, resultTypes, args);
+ return arith::DivSIOp::create(rewriter, loc, resultTypes, args);
// tosa::ReciprocalOp
if (isa<tosa::ReciprocalOp>(op) && isa<FloatType>(elementTy)) {
auto one =
- rewriter.create<arith::ConstantOp>(loc, FloatAttr::get(elementTy, 1));
- return rewriter.create<arith::DivFOp>(loc, resultTypes, one, args[0]);
+ arith::ConstantOp::create(rewriter, loc, FloatAttr::get(elementTy, 1));
+ return arith::DivFOp::create(rewriter, loc, resultTypes, one, args[0]);
}
// tosa::MulOp
@@ -140,7 +140,8 @@ static Value createLinalgBodyCalculationForElementwiseOp(
"Cannot have shift value for float");
return nullptr;
}
- return rewriter.create<arith::MulFOp>(loc, resultTypes, args[0], args[1]);
+ return arith::MulFOp::create(rewriter, loc, resultTypes, args[0],
+ args[1]);
}
if (isa<IntegerType>(elementTy)) {
@@ -149,21 +150,21 @@ static Value createLinalgBodyCalculationForElementwiseOp(
if (shift > 0) {
auto shiftConst =
- rewriter.create<arith::ConstantIntOp>(loc, shift, /*bitwidth=*/8);
+ arith::ConstantIntOp::create(rewriter, loc, shift, /*bitwidth=*/8);
if (!a.getType().isInteger(32))
- a = rewriter.create<arith::ExtSIOp>(loc, rewriter.getI32Type(), a);
+ a = arith::ExtSIOp::create(rewriter, loc, rewriter.getI32Type(), a);
if (!b.getType().isInteger(32))
- b = rewriter.create<arith::ExtSIOp>(loc, rewriter.getI32Type(), b);
+ b = arith::ExtSIOp::create(rewriter, loc, rewriter.getI32Type(), b);
- auto result = rewriter.create<tosa::ApplyScaleOp>(
- loc, rewriter.getI32Type(), a, b, shiftConst,
+ auto result = tosa::ApplyScaleOp::create(
+ rewriter, loc, rewriter.getI32Type(), a, b, shiftConst,
rewriter.getStringAttr("SINGLE_ROUND"));
if (elementTy.isInteger(32))
return result;
- return rewriter.create<arith::TruncIOp>(loc, elementTy, result);
+ return arith::TruncIOp::create(rewriter, loc, elementTy, result);
}
int aWidth = a.getType().getIntOrFloatBitWidth();
@@ -171,11 +172,11 @@ static Value createLinalgBodyCalculationForElementwiseOp(
int cWidth = resultTypes[0].getIntOrFloatBitWidth();
if (aWidth < cWidth)
- a = rewriter.create<arith::ExtSIOp>(loc, resultTypes[0], a);
+ a = arith::ExtSIOp::create(rewriter, loc, resultTypes[0], a);
if (bWidth < cWidth)
- b = rewriter.create<arith::ExtSIOp>(loc, resultTypes[0], b);
+ b = arith::ExtSIOp::create(rewriter, loc, resultTypes[0], b);
- return rewriter.create<arith::MulIOp>(loc, resultTypes, a, b);
+ return arith::MulIOp::create(rewriter, loc, resultTypes, a, b);
}
}
@@ -201,14 +202,14 @@ static Value createLinalgBodyCalculationForElementwiseOp(
int64_t outZp = *maybeOutZp;
if (isa<FloatType>(elementTy))
- return rewriter.create<arith::NegFOp>(loc, resultTypes, args[0]);
+ return arith::NegFOp::create(rewriter, loc, resultTypes, args[0]);
if (isa<IntegerType>(elementTy)) {
if (!inZp && !outZp) {
- auto constant = rewriter.create<arith::ConstantOp>(
- loc, IntegerAttr::get(elementTy, 0));
- return rewriter.create<arith::SubIOp>(loc, resultTypes, constant,
- args[0]);
+ auto constant = arith::ConstantOp::create(
+ rewriter, loc, IntegerAttr::get(elementTy, 0));
+ return arith::SubIOp::create(rewriter, loc, resultTypes, constant,
+ args[0]);
}
// Compute the maximum value that can occur in the intermediate buffer.
@@ -231,214 +232,214 @@ static Value createLinalgBodyCalculationForElementwiseOp(
}
Type intermediateType = rewriter.getIntegerType(intermediateBitWidth);
- Value zpAddValue = rewriter.create<arith::ConstantOp>(
- loc, rewriter.getIntegerAttr(intermediateType, zpAdd));
+ Value zpAddValue = arith::ConstantOp::create(
+ rewriter, loc, rewriter.getIntegerAttr(intermediateType, zpAdd));
// The negation can be applied by doing:
// outputValue = inZp + outZp - inputValue
auto ext =
- rewriter.create<arith::ExtSIOp>(loc, intermediateType, args[0]);
- auto sub = rewriter.create<arith::SubIOp>(loc, zpAddValue, ext);
+ arith::ExtSIOp::create(rewriter, loc, intermediateType, args[0]);
+ auto sub = arith::SubIOp::create(rewriter, loc, zpAddValue, ext);
// Clamp to the negation range.
- Value min = rewriter.create<arith::ConstantIntOp>(
- loc, intermediateType,
+ Value min = arith::ConstantIntOp::create(
+ rewriter, loc, intermediateType,
APInt::getSignedMinValue(inputBitWidth).getSExtValue());
- Value max = rewriter.create<arith::ConstantIntOp>(
- loc, intermediateType,
+ Value max = arith::ConstantIntOp::create(
+ rewriter, loc, intermediateType,
APInt::getSignedMaxValue(inputBitWidth).getSExtValue());
auto clamp = clampIntHelper(loc, sub, min, max, rewriter, false);
// Truncate to the final value.
- return rewriter.create<arith::TruncIOp>(loc, elementTy, clamp);
+ return arith::TruncIOp::create(rewriter, loc, elementTy, clamp);
}
}
// tosa::BitwiseAndOp
if (isa<tosa::BitwiseAndOp>(op) && isa<IntegerType>(elementTy))
- return rewriter.create<arith::AndIOp>(loc, resultTypes, args);
+ return arith::AndIOp::create(rewriter, loc, resultTypes, args);
// tosa::BitwiseOrOp
if (isa<tosa::BitwiseOrOp>(op) && isa<IntegerType>(elementTy))
- return rewriter.create<arith::OrIOp>(loc, resultTypes, args);
+ return arith::OrIOp::create(rewriter, loc, resultTypes, args);
// tosa::BitwiseNotOp
if (isa<tosa::BitwiseNotOp>(op) && isa<IntegerType>(elementTy)) {
auto allOnesAttr = rewriter.getIntegerAttr(
elementTy, APInt::getAllOnes(elementTy.getIntOrFloatBitWidth()));
- auto allOnes = rewriter.create<arith::ConstantOp>(loc, allOnesAttr);
- return rewriter.create<arith::XOrIOp>(loc, resultTypes, args[0], allOnes);
+ auto allOnes = arith::ConstantOp::create(rewriter, loc, allOnesAttr);
+ return arith::XOrIOp::create(rewriter, loc, resultTypes, args[0], allOnes);
}
// tosa::BitwiseXOrOp
if (isa<tosa::BitwiseXorOp>(op) && isa<IntegerType>(elementTy))
- return rewriter.create<arith::XOrIOp>(loc, resultTypes, args);
+ return arith::XOrIOp::create(rewriter, loc, resultTypes, args);
// tosa::LogicalLeftShiftOp
if (isa<tosa::LogicalLeftShiftOp>(op) && isa<IntegerType>(elementTy))
- return rewriter.create<arith::ShLIOp>(loc, resultTypes, args);
+ return arith::ShLIOp::create(rewriter, loc, resultTypes, args);
// tosa::LogicalRightShiftOp
if (isa<tosa::LogicalRightShiftOp>(op) && isa<IntegerType>(elementTy))
- return rewriter.create<arith::ShRUIOp>(loc, resultTypes, args);
+ return arith::ShRUIOp::create(rewriter, loc, resultTypes, args);
// tosa::ArithmeticRightShiftOp
if (isa<tosa::ArithmeticRightShiftOp>(op) && isa<IntegerType>(elementTy)) {
- auto result = rewriter.create<arith::ShRSIOp>(loc, resultTypes, args);
+ auto result = arith::ShRSIOp::create(rewriter, loc, resultTypes, args);
auto round = cast<BoolAttr>(op->getAttr("round")).getValue();
if (!round) {
return result;
}
Type i1Ty = IntegerType::get(rewriter.getContext(), /*width=*/1);
- auto one =
- rewriter.create<arith::ConstantOp>(loc, IntegerAttr::get(elementTy, 1));
- auto zero =
- rewriter.create<arith::ConstantOp>(loc, IntegerAttr::get(elementTy, 0));
+ auto one = arith::ConstantOp::create(rewriter, loc,
+ IntegerAttr::get(elementTy, 1));
+ auto zero = arith::ConstantOp::create(rewriter, loc,
+ IntegerAttr::get(elementTy, 0));
auto i1one =
- rewriter.create<arith::ConstantOp>(loc, IntegerAttr::get(i1Ty, 1));
+ arith::ConstantOp::create(rewriter, loc, IntegerAttr::get(i1Ty, 1));
// Checking that input2 != 0
- auto shiftValueGreaterThanZero = rewriter.create<arith::CmpIOp>(
- loc, arith::CmpIPredicate::sgt, args[1], zero);
+ auto shiftValueGreaterThanZero = arith::CmpIOp::create(
+ rewriter, loc, arith::CmpIPredicate::sgt, args[1], zero);
// Checking for the last bit of input1 to be 1
auto subtract =
- rewriter.create<arith::SubIOp>(loc, resultTypes, args[1], one);
+ arith::SubIOp::create(rewriter, loc, resultTypes, args[1], one);
auto shifted =
- rewriter.create<arith::ShRSIOp>(loc, resultTypes, args[0], subtract)
+ arith::ShRSIOp::create(rewriter, loc, resultTypes, args[0], subtract)
->getResults();
- auto truncated = rewriter.create<arith::TruncIOp>(
- loc, i1Ty, shifted, ArrayRef<NamedAttribute>());
+ auto truncated = arith::TruncIOp::create(rewriter, loc, i1Ty, shifted,
+ ArrayRef<NamedAttribute>());
auto isInputOdd =
- rewriter.create<arith::AndIOp>(loc, i1Ty, truncated, i1one);
+ arith::AndIOp::create(rewriter, loc, i1Ty, truncated, i1one);
- auto shouldRound = rewriter.create<arith::AndIOp>(
- loc, i1Ty, shiftValueGreaterThanZero, isInputOdd);
+ auto shouldRound = arith::AndIOp::create(
+ rewriter, loc, i1Ty, shiftValueGreaterThanZero, isInputOdd);
auto extended =
- rewriter.create<arith::ExtUIOp>(loc, resultTypes, shouldRound);
- return rewriter.create<arith::AddIOp>(loc, resultTypes, result, extended);
+ arith::ExtUIOp::create(rewriter, loc, resultTypes, shouldRound);
+ return arith::AddIOp::create(rewriter, loc, resultTypes, result, extended);
}
// tosa::ClzOp
if (isa<tosa::ClzOp>(op) && isa<IntegerType>(elementTy)) {
- return rewriter.create<math::CountLeadingZerosOp>(loc, elementTy, args[0]);
+ return math::CountLeadingZerosOp::create(rewriter, loc, elementTy, args[0]);
}
// tosa::LogicalAnd
if (isa<tosa::LogicalAndOp>(op) && elementTy.isInteger(1))
- return rewriter.create<arith::AndIOp>(loc, resultTypes, args);
+ return arith::AndIOp::create(rewriter, loc, resultTypes, args);
// tosa::LogicalNot
if (isa<tosa::LogicalNotOp>(op) && elementTy.isInteger(1)) {
- auto one = rewriter.create<arith::ConstantOp>(
- loc, rewriter.getIntegerAttr(elementTy, 1));
- return rewriter.create<arith::XOrIOp>(loc, resultTypes, args[0], one);
+ auto one = arith::ConstantOp::create(rewriter, loc,
+ rewriter.getIntegerAttr(elementTy, 1));
+ return arith::XOrIOp::create(rewriter, loc, resultTypes, args[0], one);
}
// tosa::LogicalOr
if (isa<tosa::LogicalOrOp>(op) && elementTy.isInteger(1))
- return rewriter.create<arith::OrIOp>(loc, resultTypes, args);
+ return arith::OrIOp::create(rewriter, loc, resultTypes, args);
// tosa::LogicalXor
if (isa<tosa::LogicalXorOp>(op) && elementTy.isInteger(1))
- return rewriter.create<arith::XOrIOp>(loc, resultTypes, args);
+ return arith::XOrIOp::create(rewriter, loc, resultTypes, args);
// tosa::PowOp
if (isa<tosa::PowOp>(op) && isa<FloatType>(elementTy))
- return rewriter.create<mlir::math::PowFOp>(loc, resultTypes, args);
+ return mlir::math::PowFOp::create(rewriter, loc, resultTypes, args);
// tosa::RsqrtOp
if (isa<tosa::RsqrtOp>(op) && isa<FloatType>(elementTy))
- return rewriter.create<mlir::math::RsqrtOp>(loc, resultTypes, args);
+ return mlir::math::RsqrtOp::create(rewriter, loc, resultTypes, args);
// tosa::LogOp
if (isa<tosa::LogOp>(op) && isa<FloatType>(elementTy))
- return rewriter.create<mlir::math::LogOp>(loc, resultTypes, args);
+ return mlir::math::LogOp::create(rewriter, loc, resultTypes, args);
// tosa::ExpOp
if (isa<tosa::ExpOp>(op) && isa<FloatType>(elementTy))
- return rewriter.create<mlir::math::ExpOp>(loc, resultTypes, args);
+ return mlir::math::ExpOp::create(rewriter, loc, resultTypes, args);
// tosa::SinOp
if (isa<tosa::SinOp>(op) && isa<FloatType>(elementTy))
- return rewriter.create<mlir::math::SinOp>(loc, resultTypes, args);
+ return mlir::math::SinOp::create(rewriter, loc, resultTypes, args);
// tosa::CosOp
if (isa<tosa::CosOp>(op) && isa<FloatType>(elementTy))
- return rewriter.create<mlir::math::CosOp>(loc, resultTypes, args);
+ return mlir::math::CosOp::create(rewriter, loc, resultTypes, args);
// tosa::TanhOp
if (isa<tosa::TanhOp>(op) && isa<FloatType>(elementTy))
- return rewriter.create<mlir::math::TanhOp>(loc, resultTypes, args);
+ return mlir::math::TanhOp::create(rewriter, loc, resultTypes, args);
// tosa::ErfOp
if (isa<tosa::ErfOp>(op) && llvm::isa<FloatType>(elementTy))
- return rewriter.create<mlir::math::ErfOp>(loc, resultTypes, args);
+ return mlir::math::ErfOp::create(rewriter, loc, resultTypes, args);
// tosa::GreaterOp
if (isa<tosa::GreaterOp>(op) && isa<FloatType>(elementTy))
- return rewriter.create<arith::CmpFOp>(loc, arith::CmpFPredicate::OGT,
- args[0], args[1]);
+ return arith::CmpFOp::create(rewriter, loc, arith::CmpFPredicate::OGT,
+ args[0], args[1]);
if (isa<tosa::GreaterOp>(op) && elementTy.isSignlessInteger())
- return rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::sgt,
- args[0], args[1]);
+ return arith::CmpIOp::create(rewriter, loc, arith::CmpIPredicate::sgt,
+ args[0], args[1]);
// tosa::GreaterEqualOp
if (isa<tosa::GreaterEqualOp>(op) && isa<FloatType>(elementTy))
- return rewriter.create<arith::CmpFOp>(loc, arith::CmpFPredicate::OGE,
- args[0], args[1]);
+ return arith::CmpFOp::create(rewriter, loc, arith::CmpFPredicate::OGE,
+ args[0], args[1]);
if (isa<tosa::GreaterEqualOp>(op) && elementTy.isSignlessInteger())
- return rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::sge,
- args[0], args[1]);
+ return arith::CmpIOp::create(rewriter, loc, arith::CmpIPredicate::sge,
+ args[0], args[1]);
// tosa::EqualOp
if (isa<tosa::EqualOp>(op) && isa<FloatType>(elementTy))
- return rewriter.create<arith::CmpFOp>(loc, arith::CmpFPredicate::OEQ,
- args[0], args[1]);
+ return arith::CmpFOp::create(rewriter, loc, arith::CmpFPredicate::OEQ,
+ args[0], args[1]);
if (isa<tosa::EqualOp>(op) && elementTy.isSignlessInteger())
- return rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::eq,
- args[0], args[1]);
+ return arith::CmpIOp::create(rewriter, loc, arith::CmpIPredicate::eq,
+ args[0], args[1]);
// tosa::SelectOp
if (isa<tosa::SelectOp>(op)) {
elementTy = cast<ShapedType>(op->getOperand(1).getType()).getElementType();
if (isa<FloatType>(elementTy) || isa<IntegerType>(elementTy))
- return rewriter.create<arith::SelectOp>(loc, args[0], args[1], args[2]);
+ return arith::SelectOp::create(rewriter, loc, args[0], args[1], args[2]);
}
// tosa::MaximumOp
if (isa<tosa::MaximumOp>(op) && isa<FloatType>(elementTy)) {
- auto max = rewriter.create<arith::MaximumFOp>(loc, args[0], args[1]);
+ auto max = arith::MaximumFOp::create(rewriter, loc, args[0], args[1]);
return materializeBinaryNanCheckIfRequired(llvm::cast<tosa::MaximumOp>(op),
rewriter, args[0], args[1], max);
}
if (isa<tosa::MaximumOp>(op) && elementTy.isSignlessInteger()) {
- return rewriter.create<arith::MaxSIOp>(loc, args[0], args[1]);
+ return arith::MaxSIOp::create(rewriter, loc, args[0], args[1]);
}
// tosa::MinimumOp
if (isa<tosa::MinimumOp>(op) && isa<FloatType>(elementTy)) {
- auto min = rewriter.create<arith::MinimumFOp>(loc, args[0], args[1]);
+ auto min = arith::MinimumFOp::create(rewriter, loc, args[0], args[1]);
return materializeBinaryNanCheckIfRequired(llvm::cast<tosa::MinimumOp>(op),
rewriter, args[0], args[1], min);
}
if (isa<tosa::MinimumOp>(op) && elementTy.isSignlessInteger()) {
- return rewriter.create<arith::MinSIOp>(loc, args[0], args[1]);
+ return arith::MinSIOp::create(rewriter, loc, args[0], args[1]);
}
// tosa::CeilOp
if (isa<tosa::CeilOp>(op) && isa<FloatType>(elementTy))
- return rewriter.create<math::CeilOp>(loc, resultTypes, args);
+ return math::CeilOp::create(rewriter, loc, resultTypes, args);
// tosa::FloorOp
if (isa<tosa::FloorOp>(op) && isa<FloatType>(elementTy))
- return rewriter.create<math::FloorOp>(loc, resultTypes, args);
+ return math::FloorOp::create(rewriter, loc, resultTypes, args);
// tosa::ClampOp
if (isa<tosa::ClampOp>(op) && isa<FloatType>(elementTy)) {
@@ -449,10 +450,10 @@ static Value createLinalgBodyCalculationForElementwiseOp(
APFloat::rmNearestTiesToEven, &losesInfo);
maxApf.convert(cast<FloatType>(elementTy).getFloatSemantics(),
APFloat::rmNearestTiesToEven, &losesInfo);
- auto min = rewriter.create<arith::ConstantOp>(
- loc, elementTy, rewriter.getFloatAttr(elementTy, minApf));
- auto max = rewriter.create<arith::ConstantOp>(
- loc, elementTy, rewriter.getFloatAttr(elementTy, maxApf));
+ auto min = arith::ConstantOp::create(
+ rewriter, loc, elementTy, rewriter.getFloatAttr(elementTy, minApf));
+ auto max = arith::ConstantOp::create(
+ rewriter, loc, elementTy, rewriter.getFloatAttr(elementTy, maxApf));
auto result = clampFloatHelper(loc, args[0], min, max, rewriter);
auto clampOp = llvm::cast<tosa::ClampOp>(op);
@@ -478,11 +479,11 @@ static Value createLinalgBodyCalculationForElementwiseOp(
// return init if x == NaN else result
// Unordered comparison of NaN against itself will always return true.
- Value isNaN = rewriter.create<arith::CmpFOp>(
- op->getLoc(), arith::CmpFPredicate::UNO, args[0], args[0]);
+ Value isNaN = arith::CmpFOp::create(
+ rewriter, op->getLoc(), arith::CmpFPredicate::UNO, args[0], args[0]);
// TOSA specifies that in "ignore" NaN mode the result is "min" if the input
// is NaN.
- return rewriter.create<arith::SelectOp>(op->getLoc(), isNaN, min, result);
+ return arith::SelectOp::create(rewriter, op->getLoc(), isNaN, min, result);
}
if (isa<tosa::ClampOp>(op) && isa<IntegerType>(elementTy)) {
@@ -515,10 +516,10 @@ static Value createLinalgBodyCalculationForElementwiseOp(
min = std::min(min, maxRepresentable);
max = std::min(max, maxRepresentable);
- auto minVal = rewriter.create<arith::ConstantIntOp>(
- loc, min, intTy.getIntOrFloatBitWidth());
- auto maxVal = rewriter.create<arith::ConstantIntOp>(
- loc, max, intTy.getIntOrFloatBitWidth());
+ auto minVal = arith::ConstantIntOp::create(rewriter, loc, min,
+ intTy.getIntOrFloatBitWidth());
+ auto maxVal = arith::ConstantIntOp::create(rewriter, loc, max,
+ intTy.getIntOrFloatBitWidth());
return clampIntHelper(loc, args[0], minVal, maxVal, rewriter,
intTy.isUnsignedInteger());
}
@@ -526,11 +527,11 @@ static Value createLinalgBodyCalculationForElementwiseOp(
// tosa::SigmoidOp
if (isa<tosa::SigmoidOp>(op) && isa<FloatType>(elementTy)) {
auto one =
- rewriter.create<arith::ConstantOp>(loc, FloatAttr::get(elementTy, 1));
- auto negate = rewriter.create<arith::NegFOp>(loc, resultTypes, args[0]);
- auto exp = rewriter.create<mlir::math::ExpOp>(loc, resultTypes, negate);
- auto added = rewriter.create<arith::AddFOp>(loc, resultTypes, exp, one);
- return rewriter.create<arith::DivFOp>(loc, resultTypes, one, added);
+ arith::ConstantOp::create(rewriter, loc, FloatAttr::get(elementTy, 1));
+ auto negate = arith::NegFOp::create(rewriter, loc, resultTypes, args[0]);
+ auto exp = mlir::math::ExpOp::create(rewriter, loc, resultTypes, negate);
+ auto added = arith::AddFOp::create(rewriter, loc, resultTypes, exp, one);
+ return arith::DivFOp::create(rewriter, loc, resultTypes, one, added);
}
// tosa::CastOp
@@ -549,21 +550,21 @@ static Value createLinalgBodyCalculationForElementwiseOp(
return args.front();
if (isa<FloatType>(srcTy) && isa<FloatType>(dstTy) && bitExtend)
- return rewriter.create<arith::ExtFOp>(loc, resultTypes, args,
- ArrayRef<NamedAttribute>());
+ return arith::ExtFOp::create(rewriter, loc, resultTypes, args,
+ ArrayRef<NamedAttribute>());
if (isa<FloatType>(srcTy) && isa<FloatType>(dstTy) && !bitExtend)
- return rewriter.create<arith::TruncFOp>(loc, resultTypes, args,
- ArrayRef<NamedAttribute>());
+ return arith::TruncFOp::create(rewriter, loc, resultTypes, args,
+ ArrayRef<NamedAttribute>());
// 1-bit integers need to be treated as signless.
if (srcTy.isInteger(1) && arith::UIToFPOp::areCastCompatible(srcTy, dstTy))
- return rewriter.create<arith::UIToFPOp>(loc, resultTypes, args,
- ArrayRef<NamedAttribute>());
+ return arith::UIToFPOp::create(rewriter, loc, resultTypes, args,
+ ArrayRef<NamedAttribute>());
if (srcTy.isInteger(1) && isa<IntegerType>(dstTy) && bitExtend)
- return rewriter.create<arith::ExtUIOp>(loc, resultTypes, args,
- ArrayRef<NamedAttribute>());
+ return arith::ExtUIOp::create(rewriter, loc, resultTypes, args,
+ ArrayRef<NamedAttribute>());
// Unsigned integers need an unrealized cast so that they can be passed
// to UIToFP.
@@ -574,25 +575,25 @@ static Value createLinalgBodyCalculationForElementwiseOp(
loc, rewriter.getIntegerType(srcTy.getIntOrFloatBitWidth()),
args[0])
.getResult(0);
- return rewriter.create<arith::UIToFPOp>(loc, resultTypes[0],
- unrealizedCast);
+ return arith::UIToFPOp::create(rewriter, loc, resultTypes[0],
+ unrealizedCast);
}
// All other si-to-fp conversions should be handled by SIToFP.
if (arith::SIToFPOp::areCastCompatible(srcTy, dstTy))
- return rewriter.create<arith::SIToFPOp>(loc, resultTypes, args,
- ArrayRef<NamedAttribute>());
+ return arith::SIToFPOp::create(rewriter, loc, resultTypes, args,
+ ArrayRef<NamedAttribute>());
// Casting to boolean, floats need to only be checked as not-equal to zero.
if (isa<FloatType>(srcTy) && dstTy.isInteger(1)) {
- Value zero = rewriter.create<arith::ConstantOp>(
- loc, rewriter.getFloatAttr(srcTy, 0.0));
- return rewriter.create<arith::CmpFOp>(loc, arith::CmpFPredicate::UNE,
- args.front(), zero);
+ Value zero = arith::ConstantOp::create(rewriter, loc,
+ rewriter.getFloatAttr(srcTy, 0.0));
+ return arith::CmpFOp::create(rewriter, loc, arith::CmpFPredicate::UNE,
+ args.front(), zero);
}
if (arith::FPToSIOp::areCastCompatible(srcTy, dstTy)) {
- auto rounded = rewriter.create<math::RoundEvenOp>(loc, args[0]);
+ auto rounded = math::RoundEvenOp::create(rewriter, loc, args[0]);
const auto &fltSemantics = cast<FloatType>(srcTy).getFloatSemantics();
// Check whether neither int min nor int max can be represented in the
@@ -601,37 +602,42 @@ static Value createLinalgBodyCalculationForElementwiseOp(
APFloat::semanticsMaxExponent(fltSemantics)) {
// Use cmp + select to replace infinites by int min / int max. Other
// integral values can be represented in the integer space.
- auto conv = rewriter.create<arith::FPToSIOp>(loc, dstTy, rounded);
- auto posInf = rewriter.create<arith::ConstantOp>(
- loc, rewriter.getFloatAttr(getElementTypeOrSelf(srcTy),
- APFloat::getInf(fltSemantics)));
- auto negInf = rewriter.create<arith::ConstantOp>(
- loc, rewriter.getFloatAttr(
- getElementTypeOrSelf(srcTy),
- APFloat::getInf(fltSemantics, /*Negative=*/true)));
- auto overflow = rewriter.create<arith::CmpFOp>(
- loc, arith::CmpFPredicate::UEQ, rounded, posInf);
- auto underflow = rewriter.create<arith::CmpFOp>(
- loc, arith::CmpFPredicate::UEQ, rounded, negInf);
- auto intMin = rewriter.create<arith::ConstantOp>(
- loc, rewriter.getIntegerAttr(
- getElementTypeOrSelf(dstTy),
- APInt::getSignedMinValue(dstTy.getIntOrFloatBitWidth())));
- auto intMax = rewriter.create<arith::ConstantOp>(
- loc, rewriter.getIntegerAttr(
- getElementTypeOrSelf(dstTy),
- APInt::getSignedMaxValue(dstTy.getIntOrFloatBitWidth())));
+ auto conv = arith::FPToSIOp::create(rewriter, loc, dstTy, rounded);
+ auto posInf = arith::ConstantOp::create(
+ rewriter, loc,
+ rewriter.getFloatAttr(getElementTypeOrSelf(srcTy),
+ APFloat::getInf(fltSemantics)));
+ auto negInf = arith::ConstantOp::create(
+ rewriter, loc,
+ rewriter.getFloatAttr(
+ getElementTypeOrSelf(srcTy),
+ APFloat::getInf(fltSemantics, /*Negative=*/true)));
+ auto overflow = arith::CmpFOp::create(
+ rewriter, loc, arith::CmpFPredicate::UEQ, rounded, posInf);
+ auto underflow = arith::CmpFOp::create(
+ rewriter, loc, arith::CmpFPredicate::UEQ, rounded, negInf);
+ auto intMin = arith::ConstantOp::create(
+ rewriter, loc,
+ rewriter.getIntegerAttr(
+ getElementTypeOrSelf(dstTy),
+ APInt::getSignedMinValue(dstTy.getIntOrFloatBitWidth())));
+ auto intMax = arith::ConstantOp::create(
+ rewriter, loc,
+ rewriter.getIntegerAttr(
+ getElementTypeOrSelf(dstTy),
+ APInt::getSignedMaxValue(dstTy.getIntOrFloatBitWidth())));
auto maxClamped =
- rewriter.create<arith::SelectOp>(loc, overflow, intMax, conv);
- return rewriter.create<arith::SelectOp>(loc, underflow, intMin,
- maxClamped);
+ arith::SelectOp::create(rewriter, loc, overflow, intMax, conv);
+ return arith::SelectOp::create(rewriter, loc, underflow, intMin,
+ maxClamped);
}
- auto intMinFP = rewriter.create<arith::ConstantOp>(
- loc, rewriter.getFloatAttr(
- getElementTypeOrSelf(srcTy),
- APInt::getSignedMinValue(dstTy.getIntOrFloatBitWidth())
- .getSExtValue()));
+ auto intMinFP = arith::ConstantOp::create(
+ rewriter, loc,
+ rewriter.getFloatAttr(
+ getElementTypeOrSelf(srcTy),
+ APInt::getSignedMinValue(dstTy.getIntOrFloatBitWidth())
+ .getSExtValue()));
// Check whether the mantissa has enough bits to represent int max.
if (cast<FloatType>(srcTy).getFPMantissaWidth() >=
@@ -640,58 +646,61 @@ static Value createLinalgBodyCalculationForElementwiseOp(
// consists of a single leading bit. Therefore we can clamp the input
// in the floating-point domain.
- auto intMaxFP = rewriter.create<arith::ConstantOp>(
- loc, rewriter.getFloatAttr(
- getElementTypeOrSelf(srcTy),
- APInt::getSignedMaxValue(dstTy.getIntOrFloatBitWidth())
- .getSExtValue()));
+ auto intMaxFP = arith::ConstantOp::create(
+ rewriter, loc,
+ rewriter.getFloatAttr(
+ getElementTypeOrSelf(srcTy),
+ APInt::getSignedMaxValue(dstTy.getIntOrFloatBitWidth())
+ .getSExtValue()));
Value clamped =
clampFloatHelper(loc, rounded, intMinFP, intMaxFP, rewriter);
- return rewriter.create<arith::FPToSIOp>(loc, dstTy, clamped);
+ return arith::FPToSIOp::create(rewriter, loc, dstTy, clamped);
}
// Due to earlier check we know exponant range is big enough to represent
// int min. We can therefore rely on int max + 1 being representable as
// well because it's just int min with a positive sign. So clamp the min
// value and compare against that to select the max int value if needed.
- auto intMaxPlusOneFP = rewriter.create<arith::ConstantOp>(
- loc, rewriter.getFloatAttr(
- getElementTypeOrSelf(srcTy),
- static_cast<double>(
- APInt::getSignedMaxValue(dstTy.getIntOrFloatBitWidth())
- .getSExtValue()) +
- 1.0f));
-
- auto intMax = rewriter.create<arith::ConstantOp>(
- loc, rewriter.getIntegerAttr(
- getElementTypeOrSelf(dstTy),
- APInt::getSignedMaxValue(dstTy.getIntOrFloatBitWidth())));
+ auto intMaxPlusOneFP = arith::ConstantOp::create(
+ rewriter, loc,
+ rewriter.getFloatAttr(
+ getElementTypeOrSelf(srcTy),
+ static_cast<double>(
+ APInt::getSignedMaxValue(dstTy.getIntOrFloatBitWidth())
+ .getSExtValue()) +
+ 1.0f));
+
+ auto intMax = arith::ConstantOp::create(
+ rewriter, loc,
+ rewriter.getIntegerAttr(
+ getElementTypeOrSelf(dstTy),
+ APInt::getSignedMaxValue(dstTy.getIntOrFloatBitWidth())));
auto minClampedFP =
- rewriter.create<arith::MaximumFOp>(loc, rounded, intMinFP);
+ arith::MaximumFOp::create(rewriter, loc, rounded, intMinFP);
auto minClamped =
- rewriter.create<arith::FPToSIOp>(loc, dstTy, minClampedFP);
- auto overflow = rewriter.create<arith::CmpFOp>(
- loc, arith::CmpFPredicate::UGE, rounded, intMaxPlusOneFP);
- return rewriter.create<arith::SelectOp>(loc, overflow, intMax,
- minClamped);
+ arith::FPToSIOp::create(rewriter, loc, dstTy, minClampedFP);
+ auto overflow = arith::CmpFOp::create(
+ rewriter, loc, arith::CmpFPredicate::UGE, rounded, intMaxPlusOneFP);
+ return arith::SelectOp::create(rewriter, loc, overflow, intMax,
+ minClamped);
}
// Casting to boolean, integers need to only be checked as not-equal to
// zero.
if (isa<IntegerType>(srcTy) && dstTy.isInteger(1)) {
- Value zero = rewriter.create<arith::ConstantIntOp>(
- loc, 0, srcTy.getIntOrFloatBitWidth());
- return rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::ne,
- args.front(), zero);
+ Value zero = arith::ConstantIntOp::create(rewriter, loc, 0,
+ srcTy.getIntOrFloatBitWidth());
+ return arith::CmpIOp::create(rewriter, loc, arith::CmpIPredicate::ne,
+ args.front(), zero);
}
if (isa<IntegerType>(srcTy) && isa<IntegerType>(dstTy) && bitExtend)
- return rewriter.create<arith::ExtSIOp>(loc, resultTypes, args,
- ArrayRef<NamedAttribute>());
+ return arith::ExtSIOp::create(rewriter, loc, resultTypes, args,
+ ArrayRef<NamedAttribute>());
if (isa<IntegerType>(srcTy) && isa<IntegerType>(dstTy) && !bitExtend) {
- return rewriter.create<arith::TruncIOp>(loc, dstTy, args[0]);
+ return arith::TruncIOp::create(rewriter, loc, dstTy, args[0]);
}
}
@@ -710,14 +719,14 @@ static Value createIndex(PatternRewriter &rewriter, Location loc,
auto [it, inserted] = indexPool.try_emplace(index);
if (inserted)
it->second =
- rewriter.create<arith::ConstantOp>(loc, rewriter.getIndexAttr(index));
+ arith::ConstantOp::create(rewriter, loc, rewriter.getIndexAttr(index));
return it->second;
}
static Value getTensorDim(PatternRewriter &rewriter, Location loc,
IndexPool &indexPool, Value tensor, int64_t index) {
auto indexValue = createIndex(rewriter, loc, indexPool, index);
- return rewriter.create<tensor::DimOp>(loc, tensor, indexValue).getResult();
+ return tensor::DimOp::create(rewriter, loc, tensor, indexValue).getResult();
}
static OpFoldResult getOrFoldTensorDim(PatternRewriter &rewriter, Location loc,
@@ -783,7 +792,7 @@ computeTargetSize(PatternRewriter &rewriter, Location loc, IndexPool &indexPool,
for (size_t i = 1; i < operandsWithDynamicDim.size(); i++) {
auto nextSize =
getTensorDim(rewriter, loc, indexPool, operandsWithDynamicDim[i], dim);
- targetSize = rewriter.create<arith::MaxUIOp>(loc, targetSize, nextSize);
+ targetSize = arith::MaxUIOp::create(rewriter, loc, targetSize, nextSize);
}
return {targetSize, nullptr};
}
@@ -838,8 +847,8 @@ static Value broadcastDynamicDimension(PatternRewriter &rewriter, Location loc,
// Check if broadcast is necessary
auto one = createIndex(rewriter, loc, indexPool, 1);
auto runtimeSize = getTensorDim(rewriter, loc, indexPool, operand, dim);
- auto broadcastNecessary = rewriter.create<arith::CmpIOp>(
- loc, arith::CmpIPredicate::eq, runtimeSize, one);
+ auto broadcastNecessary = arith::CmpIOp::create(
+ rewriter, loc, arith::CmpIPredicate::eq, runtimeSize, one);
// Emit 'then' region of 'scf.if'
auto emitThenRegion = [&](OpBuilder &opBuilder, Location loc) {
@@ -855,8 +864,8 @@ static Value broadcastDynamicDimension(PatternRewriter &rewriter, Location loc,
operand, index);
outputTensorShape.push_back(size);
}
- Value outputTensor = opBuilder.create<tensor::EmptyOp>(
- loc, outputTensorShape, rankedTensorType.getElementType());
+ Value outputTensor = tensor::EmptyOp::create(
+ opBuilder, loc, outputTensorShape, rankedTensorType.getElementType());
// Emit 'linalg.generic' op
auto resultTensor =
@@ -866,7 +875,7 @@ static Value broadcastDynamicDimension(PatternRewriter &rewriter, Location loc,
getNParallelLoopsAttrs(rank),
[&](OpBuilder &opBuilder, Location loc, ValueRange blockArgs) {
// Emit 'linalg.yield' op
- opBuilder.create<linalg::YieldOp>(loc, blockArgs.front());
+ linalg::YieldOp::create(opBuilder, loc, blockArgs.front());
})
.getResult(0);
@@ -875,17 +884,17 @@ static Value broadcastDynamicDimension(PatternRewriter &rewriter, Location loc,
loc, operand.getType(), resultTensor);
// Emit 'scf.yield' op
- opBuilder.create<scf::YieldOp>(loc, castResultTensor);
+ scf::YieldOp::create(opBuilder, loc, castResultTensor);
};
// Emit 'else' region of 'scf.if'
auto emitElseRegion = [&](OpBuilder &opBuilder, Location loc) {
- opBuilder.create<scf::YieldOp>(loc, operand);
+ scf::YieldOp::create(opBuilder, loc, operand);
};
// Emit 'scf.if' op
- auto ifOp = rewriter.create<scf::IfOp>(loc, broadcastNecessary,
- emitThenRegion, emitElseRegion);
+ auto ifOp = scf::IfOp::create(rewriter, loc, broadcastNecessary,
+ emitThenRegion, emitElseRegion);
return ifOp.getResult(0);
}
@@ -930,8 +939,8 @@ emitElementwiseComputation(ConversionPatternRewriter &rewriter, Location loc,
if (!resultType) {
return rewriter.notifyMatchFailure(operation, "failed to convert type");
}
- Value outputTensor = rewriter.create<tensor::EmptyOp>(
- loc, targetShape, resultType.getElementType());
+ Value outputTensor = tensor::EmptyOp::create(rewriter, loc, targetShape,
+ resultType.getElementType());
// Create affine maps. Input affine maps broadcast static dimensions of size
// 1. The output affine map is an identity map.
@@ -957,8 +966,8 @@ emitElementwiseComputation(ConversionPatternRewriter &rewriter, Location loc,
// Emit 'linalg.generic' op
bool encounteredError = false;
- auto linalgOp = rewriter.create<linalg::GenericOp>(
- loc, outputTensor.getType(), operands, outputTensor, affineMaps,
+ auto linalgOp = linalg::GenericOp::create(
+ rewriter, loc, outputTensor.getType(), operands, outputTensor, affineMaps,
getNParallelLoopsAttrs(rank),
[&](OpBuilder &opBuilder, Location loc, ValueRange blockArgs) {
Value opResult = createLinalgBodyCalculationForElementwiseOp(
@@ -968,7 +977,7 @@ emitElementwiseComputation(ConversionPatternRewriter &rewriter, Location loc,
encounteredError = true;
return;
}
- opBuilder.create<linalg::YieldOp>(loc, opResult);
+ linalg::YieldOp::create(opBuilder, loc, opResult);
});
if (encounteredError)
return rewriter.notifyMatchFailure(
@@ -1078,42 +1087,42 @@ static Value createLinalgBodyCalculationForReduceOp(Operation *op,
PatternRewriter &rewriter) {
Location loc = op->getLoc();
if (isa<tosa::ReduceSumOp>(op) && isa<FloatType>(elementTy)) {
- return rewriter.create<arith::AddFOp>(loc, args);
+ return arith::AddFOp::create(rewriter, loc, args);
}
if (isa<tosa::ReduceSumOp>(op) && isa<IntegerType>(elementTy)) {
- return rewriter.create<arith::AddIOp>(loc, args);
+ return arith::AddIOp::create(rewriter, loc, args);
}
if (isa<tosa::ReduceProductOp>(op) && isa<FloatType>(elementTy)) {
- return rewriter.create<arith::MulFOp>(loc, args);
+ return arith::MulFOp::create(rewriter, loc, args);
}
if (isa<tosa::ReduceProductOp>(op) && isa<IntegerType>(elementTy)) {
- return rewriter.create<arith::MulIOp>(loc, args);
+ return arith::MulIOp::create(rewriter, loc, args);
}
if (isa<tosa::ReduceMinOp>(op) && isa<FloatType>(elementTy)) {
- return rewriter.create<arith::MinimumFOp>(loc, args[0], args[1]);
+ return arith::MinimumFOp::create(rewriter, loc, args[0], args[1]);
}
if (isa<tosa::ReduceMinOp>(op) && isa<IntegerType>(elementTy)) {
- return rewriter.create<arith::MinSIOp>(loc, args[0], args[1]);
+ return arith::MinSIOp::create(rewriter, loc, args[0], args[1]);
}
if (isa<tosa::ReduceMaxOp>(op) && isa<FloatType>(elementTy)) {
- return rewriter.create<arith::MaximumFOp>(loc, args[0], args[1]);
+ return arith::MaximumFOp::create(rewriter, loc, args[0], args[1]);
}
if (isa<tosa::ReduceMaxOp>(op) && isa<IntegerType>(elementTy)) {
- return rewriter.create<arith::MaxSIOp>(loc, args[0], args[1]);
+ return arith::MaxSIOp::create(rewriter, loc, args[0], args[1]);
}
if (isa<tosa::ReduceAllOp>(op) && elementTy.isInteger(1))
- return rewriter.create<arith::AndIOp>(loc, args);
+ return arith::AndIOp::create(rewriter, loc, args);
if (isa<tosa::ReduceAnyOp>(op) && elementTy.isInteger(1))
- return rewriter.create<arith::OrIOp>(loc, args);
+ return arith::OrIOp::create(rewriter, loc, args);
return {};
}
@@ -1139,7 +1148,7 @@ static LogicalResult reduceMatchAndRewriteHelper(OpTy op, uint64_t axis,
if (axis != i) {
reduceShape.push_back(inputTy.getDimSize(i));
if (inputTy.isDynamicDim(i))
- dynDims.push_back(rewriter.create<tensor::DimOp>(loc, input, i));
+ dynDims.push_back(tensor::DimOp::create(rewriter, loc, input, i));
}
}
@@ -1158,7 +1167,7 @@ static LogicalResult reduceMatchAndRewriteHelper(OpTy op, uint64_t axis,
return rewriter.notifyMatchFailure(
op, "No initial value found for reduction operation");
- auto fillValue = rewriter.create<arith::ConstantOp>(loc, fillValueAttr);
+ auto fillValue = arith::ConstantOp::create(rewriter, loc, fillValueAttr);
auto filledTensor = rewriter
.create<linalg::FillOp>(loc, ValueRange{fillValue},
ValueRange{emptyTensor})
@@ -1176,7 +1185,7 @@ static LogicalResult reduceMatchAndRewriteHelper(OpTy op, uint64_t axis,
// Additionally we have to keep track of whether we've seen any non-NaN
// values and then do a final select based on this predicate.
auto trueAttr = rewriter.getBoolAttr(true);
- auto trueValue = rewriter.create<arith::ConstantOp>(loc, trueAttr);
+ auto trueValue = arith::ConstantOp::create(rewriter, loc, trueAttr);
auto emptyBoolTensor =
rewriter
.create<tensor::EmptyOp>(loc, reduceShape, trueValue.getType(),
@@ -1202,8 +1211,8 @@ static LogicalResult reduceMatchAndRewriteHelper(OpTy op, uint64_t axis,
}
bool didEncounterError = false;
- linalg::LinalgOp linalgOp = rewriter.create<linalg::ReduceOp>(
- loc, inputs, outputs, axis,
+ linalg::LinalgOp linalgOp = linalg::ReduceOp::create(
+ rewriter, loc, inputs, outputs, axis,
[&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange blockArgs) {
std::array<Value, 2> binaryArgs{
blockArgs[0], isNanIgnoreMode ? blockArgs[2] : blockArgs[1]};
@@ -1219,21 +1228,22 @@ static LogicalResult reduceMatchAndRewriteHelper(OpTy op, uint64_t axis,
auto oldAllResultsNanFlagValue = blockArgs[3];
// Unordered comparison of NaN against itself will always return true.
- Value isNaN = nestedBuilder.create<arith::CmpFOp>(
- op->getLoc(), arith::CmpFPredicate::UNO, inputValue, inputValue);
+ Value isNaN = arith::CmpFOp::create(nestedBuilder, op->getLoc(),
+ arith::CmpFPredicate::UNO,
+ inputValue, inputValue);
// If we've encountered a NaN, take the non-NaN value.
- auto selectOp = nestedBuilder.create<arith::SelectOp>(
- op->getLoc(), isNaN, initialValue, result);
+ auto selectOp = arith::SelectOp::create(nestedBuilder, op->getLoc(),
+ isNaN, initialValue, result);
// Update the flag which keeps track of whether we have seen a non-NaN
// value.
- auto newAllResultsNanFlagValue = nestedBuilder.create<arith::AndIOp>(
- op->getLoc(), oldAllResultsNanFlagValue, isNaN);
+ auto newAllResultsNanFlagValue = arith::AndIOp::create(
+ nestedBuilder, op->getLoc(), oldAllResultsNanFlagValue, isNaN);
resultsToYield.push_back(selectOp);
resultsToYield.push_back(newAllResultsNanFlagValue);
} else {
resultsToYield.push_back(result);
}
- nestedBuilder.create<linalg::YieldOp>(loc, resultsToYield);
+ linalg::YieldOp::create(nestedBuilder, loc, resultsToYield);
});
if (!didEncounterError)
@@ -1250,7 +1260,7 @@ static LogicalResult reduceMatchAndRewriteHelper(OpTy op, uint64_t axis,
auto nanValueAttr = rewriter.getFloatAttr(
elementTy,
APFloat::getNaN(cast<FloatType>(elementTy).getFloatSemantics(), false));
- auto nanValue = rewriter.create<arith::ConstantOp>(loc, nanValueAttr);
+ auto nanValue = arith::ConstantOp::create(rewriter, loc, nanValueAttr);
auto emptyNanTensor =
rewriter
.create<tensor::EmptyOp>(loc, reduceShape,
@@ -1278,7 +1288,7 @@ static LogicalResult reduceMatchAndRewriteHelper(OpTy op, uint64_t axis,
ins.push_back(linalgOp->getResult(0));
outs.push_back(finalEmptyTensor);
auto linalgSelect =
- rewriter.create<linalg::SelectOp>(op->getLoc(), ins, outs);
+ linalg::SelectOp::create(rewriter, op->getLoc(), ins, outs);
linalgOp = linalgSelect;
}
@@ -1350,7 +1360,7 @@ class RescaleConverter : public OpRewritePattern<tosa::RescaleOp> {
SmallVector<Value> dynDims;
for (int i = 0; i < outputTy.getRank(); i++) {
if (outputTy.isDynamicDim(i)) {
- dynDims.push_back(rewriter.create<tensor::DimOp>(loc, input, i));
+ dynDims.push_back(tensor::DimOp::create(rewriter, loc, input, i));
}
}
@@ -1401,16 +1411,17 @@ class RescaleConverter : public OpRewritePattern<tosa::RescaleOp> {
Value multiplierConstant;
int64_t multiplierArg = 0;
if (multiplierValues.size() == 1) {
- multiplierConstant = rewriter.create<arith::ConstantOp>(
- loc, rewriter.getI32IntegerAttr(multiplierValues.front()));
+ multiplierConstant = arith::ConstantOp::create(
+ rewriter, loc, rewriter.getI32IntegerAttr(multiplierValues.front()));
} else {
SmallVector<AffineExpr, 2> multiplierExprs{
rewriter.getAffineDimExpr(rank - 1)};
auto multiplierType =
RankedTensorType::get({static_cast<int64_t>(multiplierValues.size())},
rewriter.getI32Type());
- genericInputs.push_back(rewriter.create<arith::ConstantOp>(
- loc, DenseIntElementsAttr::get(multiplierType, multiplierValues)));
+ genericInputs.push_back(arith::ConstantOp::create(
+ rewriter, loc,
+ DenseIntElementsAttr::get(multiplierType, multiplierValues)));
indexingMaps.push_back(AffineMap::get(/*dimCount=*/rank,
/*symbolCount=*/0, multiplierExprs,
@@ -1424,16 +1435,16 @@ class RescaleConverter : public OpRewritePattern<tosa::RescaleOp> {
Value shiftConstant;
int64_t shiftArg = 0;
if (shiftValues.size() == 1) {
- shiftConstant = rewriter.create<arith::ConstantOp>(
- loc, rewriter.getI8IntegerAttr(shiftValues.front()));
+ shiftConstant = arith::ConstantOp::create(
+ rewriter, loc, rewriter.getI8IntegerAttr(shiftValues.front()));
} else {
SmallVector<AffineExpr, 2> shiftExprs = {
rewriter.getAffineDimExpr(rank - 1)};
auto shiftType =
RankedTensorType::get({static_cast<int64_t>(shiftValues.size())},
rewriter.getIntegerType(8));
- genericInputs.push_back(rewriter.create<arith::ConstantOp>(
- loc, DenseIntElementsAttr::get(shiftType, shiftValues)));
+ genericInputs.push_back(arith::ConstantOp::create(
+ rewriter, loc, DenseIntElementsAttr::get(shiftType, shiftValues)));
indexingMaps.push_back(AffineMap::get(/*dimCount=*/rank,
/*symbolCount=*/0, shiftExprs,
rewriter.getContext()));
@@ -1444,13 +1455,13 @@ class RescaleConverter : public OpRewritePattern<tosa::RescaleOp> {
indexingMaps.push_back(rewriter.getMultiDimIdentityMap(rank));
// Construct the indexing maps needed for linalg.generic ops.
- Value emptyTensor = rewriter.create<tensor::EmptyOp>(
- loc, outputTy.getShape(), outputTy.getElementType(),
+ Value emptyTensor = tensor::EmptyOp::create(
+ rewriter, loc, outputTy.getShape(), outputTy.getElementType(),
ArrayRef<Value>({dynDims}));
- auto linalgOp = rewriter.create<linalg::GenericOp>(
- loc, outputTy, genericInputs, ValueRange{emptyTensor}, indexingMaps,
- getNParallelLoopsAttrs(rank),
+ auto linalgOp = linalg::GenericOp::create(
+ rewriter, loc, outputTy, genericInputs, ValueRange{emptyTensor},
+ indexingMaps, getNParallelLoopsAttrs(rank),
[&](OpBuilder &nestedBuilder, Location nestedLoc,
ValueRange blockArgs) {
Value value = blockArgs[0];
@@ -1466,9 +1477,10 @@ class RescaleConverter : public OpRewritePattern<tosa::RescaleOp> {
const int32_t inBitwidth = valueTy.getIntOrFloatBitWidth();
// Extend zeropoint for sub-32bits widths.
const int32_t inAttrBitwidth = inBitwidth > 32 ? inBitwidth : 32;
- auto inputZp = nestedBuilder.create<arith::ConstantOp>(
- loc, IntegerAttr::get(rewriter.getIntegerType(inAttrBitwidth),
- *maybeIZp));
+ auto inputZp = arith::ConstantOp::create(
+ nestedBuilder, loc,
+ IntegerAttr::get(rewriter.getIntegerType(inAttrBitwidth),
+ *maybeIZp));
FailureOr<int64_t> maybeOZp = op.getOutputZeroPoint();
if (failed(maybeOZp)) {
@@ -1482,9 +1494,10 @@ class RescaleConverter : public OpRewritePattern<tosa::RescaleOp> {
unsigned outBitWidth = outIntType.getWidth();
const int32_t outAttrBitwidth = 32;
assert(outBitWidth <= 32 && "Unexpected output zeropoint bitwidth");
- auto outputZp = nestedBuilder.create<arith::ConstantOp>(
- loc, IntegerAttr::get(rewriter.getIntegerType(outAttrBitwidth),
- *maybeOZp));
+ auto outputZp = arith::ConstantOp::create(
+ nestedBuilder, loc,
+ IntegerAttr::get(rewriter.getIntegerType(outAttrBitwidth),
+ *maybeOZp));
Value multiplier = multiplierConstant ? multiplierConstant
: blockArgs[multiplierArg];
@@ -1501,24 +1514,24 @@ class RescaleConverter : public OpRewritePattern<tosa::RescaleOp> {
}
if (valueTy.getIntOrFloatBitWidth() < 32) {
if (op.getInputUnsigned()) {
- value = nestedBuilder.create<arith::ExtUIOp>(
- nestedLoc, nestedBuilder.getI32Type(), value);
+ value = arith::ExtUIOp::create(nestedBuilder, nestedLoc,
+ nestedBuilder.getI32Type(), value);
} else {
- value = nestedBuilder.create<arith::ExtSIOp>(
- nestedLoc, nestedBuilder.getI32Type(), value);
+ value = arith::ExtSIOp::create(nestedBuilder, nestedLoc,
+ nestedBuilder.getI32Type(), value);
}
}
value =
- nestedBuilder.create<arith::SubIOp>(nestedLoc, value, inputZp);
+ arith::SubIOp::create(nestedBuilder, nestedLoc, value, inputZp);
- value = nestedBuilder.create<tosa::ApplyScaleOp>(
- loc, nestedBuilder.getI32Type(), value, multiplier, shift,
- roundingMode);
+ value = tosa::ApplyScaleOp::create(nestedBuilder, loc,
+ nestedBuilder.getI32Type(), value,
+ multiplier, shift, roundingMode);
// Move to the new zero-point.
value =
- nestedBuilder.create<arith::AddIOp>(nestedLoc, value, outputZp);
+ arith::AddIOp::create(nestedBuilder, nestedLoc, value, outputZp);
// Saturate to the output size.
int32_t intMin = APInt::getSignedMinValue(outBitWidth).getSExtValue();
@@ -1530,18 +1543,18 @@ class RescaleConverter : public OpRewritePattern<tosa::RescaleOp> {
intMax = APInt::getMaxValue(outBitWidth).getZExtValue();
}
- auto intMinVal = nestedBuilder.create<arith::ConstantOp>(
- loc, nestedBuilder.getI32IntegerAttr(intMin));
- auto intMaxVal = nestedBuilder.create<arith::ConstantOp>(
- loc, nestedBuilder.getI32IntegerAttr(intMax));
+ auto intMinVal = arith::ConstantOp::create(
+ nestedBuilder, loc, nestedBuilder.getI32IntegerAttr(intMin));
+ auto intMaxVal = arith::ConstantOp::create(
+ nestedBuilder, loc, nestedBuilder.getI32IntegerAttr(intMax));
value = clampIntHelper(nestedLoc, value, intMinVal, intMaxVal,
nestedBuilder, /*isUnsigned=*/false);
if (outIntType.getWidth() < 32) {
- value = nestedBuilder.create<arith::TruncIOp>(
- nestedLoc, rewriter.getIntegerType(outIntType.getWidth()),
- value);
+ value = arith::TruncIOp::create(
+ nestedBuilder, nestedLoc,
+ rewriter.getIntegerType(outIntType.getWidth()), value);
}
if (outIntType.isUnsignedInteger()) {
@@ -1550,7 +1563,7 @@ class RescaleConverter : public OpRewritePattern<tosa::RescaleOp> {
outIntType, value)
.getResult(0);
}
- nestedBuilder.create<linalg::YieldOp>(loc, value);
+ linalg::YieldOp::create(nestedBuilder, loc, value);
});
rewriter.replaceOp(op, linalgOp->getResults());
@@ -1608,48 +1621,49 @@ class ResizeUnaryConverter : public OpRewritePattern<tosa::ResizeOp> {
auto collapseTy =
RankedTensorType::get({inputTy.getDimSize(0), inputTy.getDimSize(3)},
inputTy.getElementType());
- Value collapse = builder.create<tensor::CollapseShapeOp>(collapseTy, input,
- reassociationMap);
+ Value collapse = tensor::CollapseShapeOp::create(builder, collapseTy, input,
+ reassociationMap);
// Get any dynamic shapes that appear in the input format.
llvm::SmallVector<Value> outputDynSize;
if (inputTy.isDynamicDim(0))
- outputDynSize.push_back(builder.create<tensor::DimOp>(input, 0));
+ outputDynSize.push_back(tensor::DimOp::create(builder, input, 0));
if (inputTy.isDynamicDim(3))
- outputDynSize.push_back(builder.create<tensor::DimOp>(input, 3));
+ outputDynSize.push_back(tensor::DimOp::create(builder, input, 3));
// Generate the elementwise operation for casting scaling the input value.
auto genericTy = collapseTy.clone(resultTy.getElementType());
- Value empty = builder.create<tensor::EmptyOp>(
- genericTy.getShape(), resultTy.getElementType(), outputDynSize);
+ Value empty =
+ tensor::EmptyOp::create(builder, genericTy.getShape(),
+ resultTy.getElementType(), outputDynSize);
auto genericMap = rewriter.getMultiDimIdentityMap(genericTy.getRank());
SmallVector<utils::IteratorType> iterators(genericTy.getRank(),
utils::IteratorType::parallel);
- auto generic = builder.create<linalg::GenericOp>(
- genericTy, ValueRange{collapse}, ValueRange{empty},
+ auto generic = linalg::GenericOp::create(
+ builder, genericTy, ValueRange{collapse}, ValueRange{empty},
ArrayRef<AffineMap>{genericMap, genericMap}, iterators,
[=](OpBuilder &b, Location loc, ValueRange args) {
Value value = args[0];
// This is the quantized case.
if (inputTy.getElementType() != resultTy.getElementType()) {
- value =
- b.create<arith::ExtSIOp>(loc, resultTy.getElementType(), value);
+ value = arith::ExtSIOp::create(b, loc, resultTy.getElementType(),
+ value);
if (isBilinear && scale[0] != 0) {
- Value scaleY = b.create<arith::ConstantOp>(
- loc, b.getI32IntegerAttr(scale[0]));
- value = b.create<arith::MulIOp>(loc, value, scaleY);
+ Value scaleY = arith::ConstantOp::create(
+ b, loc, b.getI32IntegerAttr(scale[0]));
+ value = arith::MulIOp::create(b, loc, value, scaleY);
}
if (isBilinear && scale[2] != 0) {
- Value scaleX = b.create<arith::ConstantOp>(
- loc, b.getI32IntegerAttr(scale[2]));
- value = b.create<arith::MulIOp>(loc, value, scaleX);
+ Value scaleX = arith::ConstantOp::create(
+ b, loc, b.getI32IntegerAttr(scale[2]));
+ value = arith::MulIOp::create(b, loc, value, scaleX);
}
}
- b.create<linalg::YieldOp>(loc, value);
+ linalg::YieldOp::create(b, loc, value);
});
rewriter.replaceOpWithNewOp<tensor::ExpandShapeOp>(
@@ -1697,9 +1711,9 @@ class MaterializeResizeBroadcast : public OpRewritePattern<tosa::ResizeOp> {
resizeShape.push_back(channels);
auto resizeTy = resultTy.clone(resizeShape);
- auto resize = builder.create<tosa::ResizeOp>(resizeTy, input, op.getScale(),
- op.getOffset(), op.getBorder(),
- op.getMode());
+ auto resize =
+ tosa::ResizeOp::create(builder, resizeTy, input, op.getScale(),
+ op.getOffset(), op.getBorder(), op.getMode());
// Collapse an unit result dims.
SmallVector<ReassociationExprs, 4> reassociationMap(2);
@@ -1720,20 +1734,20 @@ class MaterializeResizeBroadcast : public OpRewritePattern<tosa::ResizeOp> {
collapseShape.push_back(channels);
auto collapseTy = resultTy.clone(collapseShape);
- Value collapse = builder.create<tensor::CollapseShapeOp>(collapseTy, resize,
- reassociationMap);
+ Value collapse = tensor::CollapseShapeOp::create(builder, collapseTy,
+ resize, reassociationMap);
// Broadcast the collapsed shape to the output result.
llvm::SmallVector<Value> outputDynSize;
if (inputTy.isDynamicDim(0))
- outputDynSize.push_back(builder.create<tensor::DimOp>(input, 0));
+ outputDynSize.push_back(tensor::DimOp::create(builder, input, 0));
if (inputTy.isDynamicDim(3))
- outputDynSize.push_back(builder.create<tensor::DimOp>(input, 3));
+ outputDynSize.push_back(tensor::DimOp::create(builder, input, 3));
SmallVector<utils::IteratorType> iterators(resultTy.getRank(),
utils::IteratorType::parallel);
- Value empty = builder.create<tensor::EmptyOp>(
- resultTy.getShape(), resultTy.getElementType(), outputDynSize);
+ Value empty = tensor::EmptyOp::create(
+ builder, resultTy.getShape(), resultTy.getElementType(), outputDynSize);
SmallVector<AffineExpr, 4> inputExprs{rewriter.getAffineDimExpr(0)};
if (inputH != 1)
@@ -1751,7 +1765,7 @@ class MaterializeResizeBroadcast : public OpRewritePattern<tosa::ResizeOp> {
ArrayRef<AffineMap>{inputMap, outputMap}, iterators,
[=](OpBuilder &b, Location loc, ValueRange args) {
Value value = args[0];
- b.create<linalg::YieldOp>(loc, value);
+ linalg::YieldOp::create(b, loc, value);
});
return success();
@@ -1789,10 +1803,10 @@ class GenericResizeConverter : public OpRewritePattern<tosa::ResizeOp> {
SmallVector<AffineMap, 2> affineMaps = {
rewriter.getMultiDimIdentityMap(resultTy.getRank())};
- auto emptyTensor = b.create<tensor::EmptyOp>(resultTy.getShape(), resultETy,
- *dynamicDimsOr);
- auto genericOp = b.create<linalg::GenericOp>(
- resultTy, ValueRange({}), ValueRange{emptyTensor}, affineMaps,
+ auto emptyTensor = tensor::EmptyOp::create(b, resultTy.getShape(),
+ resultETy, *dynamicDimsOr);
+ auto genericOp = linalg::GenericOp::create(
+ b, resultTy, ValueRange({}), ValueRange{emptyTensor}, affineMaps,
getNParallelLoopsAttrs(resultTy.getRank()));
Value resize = genericOp.getResult(0);
@@ -1800,19 +1814,21 @@ class GenericResizeConverter : public OpRewritePattern<tosa::ResizeOp> {
OpBuilder::InsertionGuard regionGuard(b);
b.createBlock(&genericOp.getRegion(), genericOp.getRegion().end(),
TypeRange({resultETy}), loc);
- Value batch = b.create<linalg::IndexOp>(0);
- Value y = b.create<linalg::IndexOp>(1);
- Value x = b.create<linalg::IndexOp>(2);
- Value channel = b.create<linalg::IndexOp>(3);
+ Value batch = linalg::IndexOp::create(b, 0);
+ Value y = linalg::IndexOp::create(b, 1);
+ Value x = linalg::IndexOp::create(b, 2);
+ Value channel = linalg::IndexOp::create(b, 3);
Value zeroI32 =
- b.create<arith::ConstantOp>(b.getZeroAttr(b.getI32Type()));
- Value zeroFp = b.create<arith::ConstantOp>(b.getZeroAttr(floatTy));
- Value hMax = b.create<arith::ConstantOp>(b.getI32IntegerAttr(imageH - 1));
- Value wMax = b.create<arith::ConstantOp>(b.getI32IntegerAttr(imageW - 1));
+ arith::ConstantOp::create(b, b.getZeroAttr(b.getI32Type()));
+ Value zeroFp = arith::ConstantOp::create(b, b.getZeroAttr(floatTy));
+ Value hMax =
+ arith::ConstantOp::create(b, b.getI32IntegerAttr(imageH - 1));
+ Value wMax =
+ arith::ConstantOp::create(b, b.getI32IntegerAttr(imageW - 1));
- Value inY = b.create<arith::IndexCastOp>(b.getI32Type(), y);
- Value inX = b.create<arith::IndexCastOp>(b.getI32Type(), x);
+ Value inY = arith::IndexCastOp::create(b, b.getI32Type(), y);
+ Value inX = arith::IndexCastOp::create(b, b.getI32Type(), x);
SmallVector<int64_t> scale, offset, border;
if (!tosa::getConstShapeValues(op.getScale().getDefiningOp(), scale) ||
@@ -1824,16 +1840,16 @@ class GenericResizeConverter : public OpRewritePattern<tosa::ResizeOp> {
}
Value yScaleN, yScaleD, xScaleN, xScaleD;
- yScaleN = b.create<arith::ConstantOp>(b.getI32IntegerAttr(scale[0]));
- yScaleD = b.create<arith::ConstantOp>(b.getI32IntegerAttr(scale[1]));
- xScaleN = b.create<arith::ConstantOp>(b.getI32IntegerAttr(scale[2]));
- xScaleD = b.create<arith::ConstantOp>(b.getI32IntegerAttr(scale[3]));
+ yScaleN = arith::ConstantOp::create(b, b.getI32IntegerAttr(scale[0]));
+ yScaleD = arith::ConstantOp::create(b, b.getI32IntegerAttr(scale[1]));
+ xScaleN = arith::ConstantOp::create(b, b.getI32IntegerAttr(scale[2]));
+ xScaleD = arith::ConstantOp::create(b, b.getI32IntegerAttr(scale[3]));
Value yOffset, xOffset, yBorder, xBorder;
- yOffset = b.create<arith::ConstantOp>(b.getI32IntegerAttr(offset[0]));
- xOffset = b.create<arith::ConstantOp>(b.getI32IntegerAttr(offset[1]));
- yBorder = b.create<arith::ConstantOp>(b.getI32IntegerAttr(border[0]));
- xBorder = b.create<arith::ConstantOp>(b.getI32IntegerAttr(border[1]));
+ yOffset = arith::ConstantOp::create(b, b.getI32IntegerAttr(offset[0]));
+ xOffset = arith::ConstantOp::create(b, b.getI32IntegerAttr(offset[1]));
+ yBorder = arith::ConstantOp::create(b, b.getI32IntegerAttr(border[0]));
+ xBorder = arith::ConstantOp::create(b, b.getI32IntegerAttr(border[1]));
// Compute the ix and dx values for both the X and Y dimensions.
auto getIndexAndDeltaFp = [&](Value &index, Value &delta, Value in,
@@ -1846,16 +1862,16 @@ class GenericResizeConverter : public OpRewritePattern<tosa::ResizeOp> {
}
// x = x * scale_d + offset;
// ix = floor(x / scale_n)
- Value val = b.create<arith::MulIOp>(in, scaleD);
- val = b.create<arith::AddIOp>(val, offset);
- index = b.create<arith::FloorDivSIOp>(val, scaleN);
+ Value val = arith::MulIOp::create(b, in, scaleD);
+ val = arith::AddIOp::create(b, val, offset);
+ index = arith::FloorDivSIOp::create(b, val, scaleN);
// rx = x % scale_n
// dx = rx / scale_n
- Value r = b.create<arith::RemSIOp>(val, scaleN);
- Value rFp = b.create<arith::SIToFPOp>(floatTy, r);
- Value scaleNfp = b.create<arith::UIToFPOp>(floatTy, scaleN);
- delta = b.create<arith::DivFOp>(rFp, scaleNfp);
+ Value r = arith::RemSIOp::create(b, val, scaleN);
+ Value rFp = arith::SIToFPOp::create(b, floatTy, r);
+ Value scaleNfp = arith::UIToFPOp::create(b, floatTy, scaleN);
+ delta = arith::DivFOp::create(b, rFp, scaleNfp);
};
// Compute the ix and dx values for the X and Y dimensions - int case.
@@ -1870,11 +1886,11 @@ class GenericResizeConverter : public OpRewritePattern<tosa::ResizeOp> {
// x = x * scale_d + offset;
// ix = floor(x / scale_n)
// dx = x - ix * scale_n;
- Value val = b.create<arith::MulIOp>(in, scaleD);
- val = b.create<arith::AddIOp>(val, offset);
- index = b.create<arith::DivSIOp>(val, scaleN);
- delta = b.create<arith::MulIOp>(index, scaleN);
- delta = b.create<arith::SubIOp>(val, delta);
+ Value val = arith::MulIOp::create(b, in, scaleD);
+ val = arith::AddIOp::create(b, val, offset);
+ index = arith::DivSIOp::create(b, val, scaleN);
+ delta = arith::MulIOp::create(b, index, scaleN);
+ delta = arith::SubIOp::create(b, val, delta);
};
Value ix, iy, dx, dy;
@@ -1887,54 +1903,55 @@ class GenericResizeConverter : public OpRewritePattern<tosa::ResizeOp> {
}
if (op.getMode() == "NEAREST_NEIGHBOR") {
- auto one = b.create<arith::ConstantOp>(b.getI32IntegerAttr(1));
+ auto one = arith::ConstantOp::create(b, b.getI32IntegerAttr(1));
auto getNearestIndexAndClamp = [&](Value val, Value dval, Value scale,
Value max, int size,
ImplicitLocOpBuilder &b) -> Value {
if (size == 1) {
- return b.create<arith::ConstantIndexOp>(0);
+ return arith::ConstantIndexOp::create(b, 0);
}
Value pred;
if (floatingPointMode) {
- auto h = b.create<arith::ConstantOp>(b.getFloatAttr(floatTy, 0.5f));
- pred = b.create<arith::CmpFOp>(arith::CmpFPredicate::OGE, dval, h);
+ auto h =
+ arith::ConstantOp::create(b, b.getFloatAttr(floatTy, 0.5f));
+ pred = arith::CmpFOp::create(b, arith::CmpFPredicate::OGE, dval, h);
} else {
- Value dvalDouble = b.create<arith::ShLIOp>(dval, one);
- pred = b.create<arith::CmpIOp>(arith::CmpIPredicate::sge,
- dvalDouble, scale);
+ Value dvalDouble = arith::ShLIOp::create(b, dval, one);
+ pred = arith::CmpIOp::create(b, arith::CmpIPredicate::sge,
+ dvalDouble, scale);
}
- auto offset = b.create<arith::SelectOp>(pred, one, zeroI32);
- val = b.create<arith::AddIOp>(val, offset);
+ auto offset = arith::SelectOp::create(b, pred, one, zeroI32);
+ val = arith::AddIOp::create(b, val, offset);
val = clampIntHelper(loc, val, zeroI32, max, b, /*isUnsigned=*/false);
- return b.create<arith::IndexCastOp>(b.getIndexType(), val);
+ return arith::IndexCastOp::create(b, b.getIndexType(), val);
};
iy = getNearestIndexAndClamp(iy, dy, yScaleN, hMax, imageH, b);
ix = getNearestIndexAndClamp(ix, dx, xScaleN, wMax, imageW, b);
- Value result = b.create<tensor::ExtractOp>(
- input, ValueRange{batch, iy, ix, channel});
+ Value result = tensor::ExtractOp::create(
+ b, input, ValueRange{batch, iy, ix, channel});
- b.create<linalg::YieldOp>(result);
+ linalg::YieldOp::create(b, result);
} else {
// The mode here must be BILINEAR.
assert(op.getMode() == "BILINEAR");
- auto oneVal = b.create<arith::ConstantOp>(b.getI32IntegerAttr(1));
+ auto oneVal = arith::ConstantOp::create(b, b.getI32IntegerAttr(1));
auto getClampedIdxs = [&](Value &val0, Value &val1, int size, Value in,
Value max, ImplicitLocOpBuilder &b) {
val0 = in;
- val1 = b.create<arith::AddIOp>(val0, oneVal);
+ val1 = arith::AddIOp::create(b, val0, oneVal);
val0 =
clampIntHelper(loc, val0, zeroI32, max, b, /*isUnsigned=*/false);
val1 =
clampIntHelper(loc, val1, zeroI32, max, b, /*isUnsigned=*/false);
- val0 = b.create<arith::IndexCastOp>(b.getIndexType(), val0);
- val1 = b.create<arith::IndexCastOp>(b.getIndexType(), val1);
+ val0 = arith::IndexCastOp::create(b, b.getIndexType(), val0);
+ val1 = arith::IndexCastOp::create(b, b.getIndexType(), val1);
};
// Linalg equivalent to the section below:
@@ -1946,27 +1963,27 @@ class GenericResizeConverter : public OpRewritePattern<tosa::ResizeOp> {
getClampedIdxs(y0, y1, imageH, iy, hMax, b);
getClampedIdxs(x0, x1, imageW, ix, wMax, b);
- Value y0x0 = b.create<tensor::ExtractOp>(
- input, ValueRange{batch, y0, x0, channel});
- Value y0x1 = b.create<tensor::ExtractOp>(
- input, ValueRange{batch, y0, x1, channel});
- Value y1x0 = b.create<tensor::ExtractOp>(
- input, ValueRange{batch, y1, x0, channel});
- Value y1x1 = b.create<tensor::ExtractOp>(
- input, ValueRange{batch, y1, x1, channel});
+ Value y0x0 = tensor::ExtractOp::create(
+ b, input, ValueRange{batch, y0, x0, channel});
+ Value y0x1 = tensor::ExtractOp::create(
+ b, input, ValueRange{batch, y0, x1, channel});
+ Value y1x0 = tensor::ExtractOp::create(
+ b, input, ValueRange{batch, y1, x0, channel});
+ Value y1x1 = tensor::ExtractOp::create(
+ b, input, ValueRange{batch, y1, x1, channel});
if (floatingPointMode) {
auto oneVal =
- b.create<arith::ConstantOp>(b.getFloatAttr(floatTy, 1.0f));
+ arith::ConstantOp::create(b, b.getFloatAttr(floatTy, 1.0f));
auto interpolate = [&](Value val0, Value val1, Value delta,
int inputSize,
ImplicitLocOpBuilder &b) -> Value {
if (inputSize == 1)
return val0;
- Value oneMinusDelta = b.create<arith::SubFOp>(oneVal, delta);
- Value mul0 = b.create<arith::MulFOp>(val0, oneMinusDelta);
- Value mul1 = b.create<arith::MulFOp>(val1, delta);
- return b.create<arith::AddFOp>(mul0, mul1);
+ Value oneMinusDelta = arith::SubFOp::create(b, oneVal, delta);
+ Value mul0 = arith::MulFOp::create(b, val0, oneMinusDelta);
+ Value mul1 = arith::MulFOp::create(b, val1, delta);
+ return arith::AddFOp::create(b, mul0, mul1);
};
// Linalg equivalent to the section below:
@@ -1982,18 +1999,18 @@ class GenericResizeConverter : public OpRewritePattern<tosa::ResizeOp> {
// Linalg equivalent to the section below:
// result = topAcc * (unit_y - dy) + bottomAcc * dy
Value result = interpolate(topAcc, bottomAcc, dy, imageH, b);
- b.create<linalg::YieldOp>(result);
+ linalg::YieldOp::create(b, result);
} else {
// Perform in quantized space.
- y0x0 = b.create<arith::ExtSIOp>(resultETy, y0x0);
- y0x1 = b.create<arith::ExtSIOp>(resultETy, y0x1);
- y1x0 = b.create<arith::ExtSIOp>(resultETy, y1x0);
- y1x1 = b.create<arith::ExtSIOp>(resultETy, y1x1);
+ y0x0 = arith::ExtSIOp::create(b, resultETy, y0x0);
+ y0x1 = arith::ExtSIOp::create(b, resultETy, y0x1);
+ y1x0 = arith::ExtSIOp::create(b, resultETy, y1x0);
+ y1x1 = arith::ExtSIOp::create(b, resultETy, y1x1);
const int64_t deltaBitwidth = dx.getType().getIntOrFloatBitWidth();
if (resultETy.getIntOrFloatBitWidth() > deltaBitwidth) {
- dx = b.create<arith::ExtSIOp>(resultETy, dx);
- dy = b.create<arith::ExtSIOp>(resultETy, dy);
+ dx = arith::ExtSIOp::create(b, resultETy, dx);
+ dy = arith::ExtSIOp::create(b, resultETy, dy);
}
Value yScaleNExt = yScaleN;
@@ -2002,26 +2019,26 @@ class GenericResizeConverter : public OpRewritePattern<tosa::ResizeOp> {
const int64_t scaleBitwidth =
xScaleN.getType().getIntOrFloatBitWidth();
if (resultETy.getIntOrFloatBitWidth() > scaleBitwidth) {
- yScaleNExt = b.create<arith::ExtSIOp>(resultETy, yScaleN);
- xScaleNExt = b.create<arith::ExtSIOp>(resultETy, xScaleN);
+ yScaleNExt = arith::ExtSIOp::create(b, resultETy, yScaleN);
+ xScaleNExt = arith::ExtSIOp::create(b, resultETy, xScaleN);
}
auto interpolate = [](Value val0, Value val1, Value weight1,
Value scale, int inputSize,
ImplicitLocOpBuilder &b) -> Value {
if (inputSize == 1)
- return b.create<arith::MulIOp>(val0, scale);
- Value weight0 = b.create<arith::SubIOp>(scale, weight1);
- Value mul0 = b.create<arith::MulIOp>(val0, weight0);
- Value mul1 = b.create<arith::MulIOp>(val1, weight1);
- return b.create<arith::AddIOp>(mul0, mul1);
+ return arith::MulIOp::create(b, val0, scale);
+ Value weight0 = arith::SubIOp::create(b, scale, weight1);
+ Value mul0 = arith::MulIOp::create(b, val0, weight0);
+ Value mul1 = arith::MulIOp::create(b, val1, weight1);
+ return arith::AddIOp::create(b, mul0, mul1);
};
Value topAcc = interpolate(y0x0, y0x1, dx, xScaleNExt, imageW, b);
Value bottomAcc = interpolate(y1x0, y1x1, dx, xScaleNExt, imageW, b);
Value result =
interpolate(topAcc, bottomAcc, dy, yScaleNExt, imageH, b);
- b.create<linalg::YieldOp>(result);
+ linalg::YieldOp::create(b, result);
}
}
}
@@ -2072,11 +2089,11 @@ class ReverseConverter : public OpRewritePattern<tosa::ReverseOp> {
SmallVector<Value> dynDims;
for (int i = 0; i < inputTy.getRank(); i++) {
if (inputTy.isDynamicDim(i)) {
- dynDims.push_back(rewriter.create<tensor::DimOp>(loc, input, i));
+ dynDims.push_back(tensor::DimOp::create(rewriter, loc, input, i));
}
}
- Value axisDimSize = rewriter.create<tensor::DimOp>(loc, input, axis);
+ Value axisDimSize = tensor::DimOp::create(rewriter, loc, input, axis);
// First fill the output buffer with the init value.
auto emptyTensor = rewriter
@@ -2094,22 +2111,22 @@ class ReverseConverter : public OpRewritePattern<tosa::ReverseOp> {
llvm::SmallVector<Value> indices;
for (unsigned int i = 0; i < inputTy.getRank(); i++) {
Value index =
- rewriter.create<linalg::IndexOp>(nestedLoc, i).getResult();
+ linalg::IndexOp::create(rewriter, nestedLoc, i).getResult();
if (i == axis) {
- auto one = rewriter.create<arith::ConstantIndexOp>(nestedLoc, 1);
+ auto one = arith::ConstantIndexOp::create(rewriter, nestedLoc, 1);
auto sizeMinusOne =
- rewriter.create<arith::SubIOp>(nestedLoc, axisDimSize, one);
- index = rewriter.create<arith::SubIOp>(nestedLoc, sizeMinusOne,
- index);
+ arith::SubIOp::create(rewriter, nestedLoc, axisDimSize, one);
+ index = arith::SubIOp::create(rewriter, nestedLoc, sizeMinusOne,
+ index);
}
indices.push_back(index);
}
- auto extract = nestedBuilder.create<tensor::ExtractOp>(
- nestedLoc, input, indices);
- nestedBuilder.create<linalg::YieldOp>(op.getLoc(),
- extract.getResult());
+ auto extract = tensor::ExtractOp::create(nestedBuilder, nestedLoc,
+ input, indices);
+ linalg::YieldOp::create(nestedBuilder, op.getLoc(),
+ extract.getResult());
});
return success();
}
@@ -2148,12 +2165,12 @@ struct TileConverter : public OpConversionPattern<tosa::TileOp> {
SmallVector<Value> dynDims;
for (int i = 0; i < inputTy.getRank(); i++) {
if (inputTy.isDynamicDim(i) || multiples[i] == -1) {
- dynDims.push_back(rewriter.create<tensor::DimOp>(loc, input, i));
+ dynDims.push_back(tensor::DimOp::create(rewriter, loc, input, i));
}
}
- auto emptyTensor = rewriter.create<tensor::EmptyOp>(
- op.getLoc(), genericShape, elementTy, dynDims);
+ auto emptyTensor = tensor::EmptyOp::create(
+ rewriter, op.getLoc(), genericShape, elementTy, dynDims);
// We needs to map the input shape to the non-broadcasted dimensions.
SmallVector<AffineExpr, 4> dimExprs;
@@ -2168,12 +2185,12 @@ struct TileConverter : public OpConversionPattern<tosa::TileOp> {
SmallVector<AffineMap, 2> affineMaps = {
readAffineMap, rewriter.getMultiDimIdentityMap(genericShape.size())};
- auto genericOp = rewriter.create<linalg::GenericOp>(
- loc, RankedTensorType::get(genericShape, elementTy), input,
+ auto genericOp = linalg::GenericOp::create(
+ rewriter, loc, RankedTensorType::get(genericShape, elementTy), input,
ValueRange{emptyTensor}, affineMaps,
getNParallelLoopsAttrs(genericShape.size()),
[&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange args) {
- nestedBuilder.create<linalg::YieldOp>(op.getLoc(), *args.begin());
+ linalg::YieldOp::create(nestedBuilder, op.getLoc(), *args.begin());
});
auto shapeValue = getTosaConstShape(
@@ -2220,7 +2237,7 @@ class ArgMaxConverter : public OpRewritePattern<tosa::ArgMaxOp> {
SmallVector<Value> dynDims;
for (int i = 0; i < inputTy.getRank(); i++) {
if (inputTy.isDynamicDim(i) && i != axis) {
- dynDims.push_back(rewriter.create<tensor::DimOp>(loc, input, i));
+ dynDims.push_back(tensor::DimOp::create(rewriter, loc, input, i));
}
}
@@ -2229,8 +2246,8 @@ class ArgMaxConverter : public OpRewritePattern<tosa::ArgMaxOp> {
.create<tensor::EmptyOp>(loc, resultTy.getShape(),
outElementTy, dynDims)
.getResult();
- auto fillValueIdx = rewriter.create<arith::ConstantOp>(
- loc, rewriter.getIntegerAttr(outElementTy, 0));
+ auto fillValueIdx = arith::ConstantOp::create(
+ rewriter, loc, rewriter.getIntegerAttr(outElementTy, 0));
auto filledTensorIdx =
rewriter
.create<linalg::FillOp>(loc, ValueRange{fillValueIdx},
@@ -2250,7 +2267,7 @@ class ArgMaxConverter : public OpRewritePattern<tosa::ArgMaxOp> {
argmaxOp, "unsupported tosa.argmax element type");
auto fillValueMax =
- rewriter.create<arith::ConstantOp>(loc, fillValueMaxAttr);
+ arith::ConstantOp::create(rewriter, loc, fillValueMaxAttr);
auto filledTensorMax =
rewriter
.create<linalg::FillOp>(loc, ValueRange{fillValueMax},
@@ -2274,8 +2291,8 @@ class ArgMaxConverter : public OpRewritePattern<tosa::ArgMaxOp> {
bool didEncounterError = false;
auto maps = AffineMap::inferFromExprList({srcExprs, dstExprs, dstExprs},
rewriter.getContext());
- auto linalgOp = rewriter.create<linalg::GenericOp>(
- loc, ArrayRef<Type>({resultTy, resultMaxTy}), input,
+ auto linalgOp = linalg::GenericOp::create(
+ rewriter, loc, ArrayRef<Type>({resultTy, resultMaxTy}), input,
ValueRange({filledTensorIdx, filledTensorMax}), maps, iteratorTypes,
[&](OpBuilder &nestedBuilder, Location nestedLoc,
ValueRange blockArgs) {
@@ -2283,42 +2300,46 @@ class ArgMaxConverter : public OpRewritePattern<tosa::ArgMaxOp> {
auto oldIndex = blockArgs[1];
auto oldValue = blockArgs[2];
- Value newIndex = rewriter.create<arith::IndexCastOp>(
- nestedLoc, oldIndex.getType(),
- rewriter.create<linalg::IndexOp>(loc, axis));
+ Value newIndex = arith::IndexCastOp::create(
+ rewriter, nestedLoc, oldIndex.getType(),
+ linalg::IndexOp::create(rewriter, loc, axis));
Value predicate;
if (isa<FloatType>(inElementTy)) {
if (argmaxOp.getNanMode() == "IGNORE") {
// Only update index & max value for non NaN values. If all
// values are NaNs, the initial index will be return which is 0.
- predicate = rewriter.create<arith::CmpFOp>(
- nestedLoc, arith::CmpFPredicate::OGT, newValue, oldValue);
+ predicate = arith::CmpFOp::create(rewriter, nestedLoc,
+ arith::CmpFPredicate::OGT,
+ newValue, oldValue);
} else {
// Update max value if either of the following is true:
// - new value is bigger
// - cur max is not NaN and new value is NaN
- Value gt = rewriter.create<arith::CmpFOp>(
- nestedLoc, arith::CmpFPredicate::UGT, newValue, oldValue);
- Value oldNonNaN = rewriter.create<arith::CmpFOp>(
- nestedLoc, arith::CmpFPredicate::ORD, oldValue, oldValue);
- predicate = rewriter.create<arith::AndIOp>(
- nestedLoc, rewriter.getI1Type(), gt, oldNonNaN);
+ Value gt = arith::CmpFOp::create(rewriter, nestedLoc,
+ arith::CmpFPredicate::UGT,
+ newValue, oldValue);
+ Value oldNonNaN = arith::CmpFOp::create(rewriter, nestedLoc,
+ arith::CmpFPredicate::ORD,
+ oldValue, oldValue);
+ predicate = arith::AndIOp::create(
+ rewriter, nestedLoc, rewriter.getI1Type(), gt, oldNonNaN);
}
} else if (isa<IntegerType>(inElementTy)) {
- predicate = rewriter.create<arith::CmpIOp>(
- nestedLoc, arith::CmpIPredicate::sgt, newValue, oldValue);
+ predicate = arith::CmpIOp::create(rewriter, nestedLoc,
+ arith::CmpIPredicate::sgt,
+ newValue, oldValue);
} else {
didEncounterError = true;
return;
}
- auto resultMax = rewriter.create<arith::SelectOp>(
- nestedLoc, predicate, newValue, oldValue);
- auto resultIndex = rewriter.create<arith::SelectOp>(
- nestedLoc, predicate, newIndex, oldIndex);
- nestedBuilder.create<linalg::YieldOp>(
- nestedLoc, ValueRange({resultIndex, resultMax}));
+ auto resultMax = arith::SelectOp::create(
+ rewriter, nestedLoc, predicate, newValue, oldValue);
+ auto resultIndex = arith::SelectOp::create(
+ rewriter, nestedLoc, predicate, newIndex, oldIndex);
+ linalg::YieldOp::create(nestedBuilder, nestedLoc,
+ ValueRange({resultIndex, resultMax}));
});
if (didEncounterError)
@@ -2363,19 +2384,19 @@ class GatherConverter : public OpConversionPattern<tosa::GatherOp> {
rewriter.getContext()),
rewriter.getMultiDimIdentityMap(resultTy.getRank())};
- auto genericOp = rewriter.create<linalg::GenericOp>(
- loc, ArrayRef<Type>({resultTy}), ValueRange{indices},
+ auto genericOp = linalg::GenericOp::create(
+ rewriter, loc, ArrayRef<Type>({resultTy}), ValueRange{indices},
ValueRange{emptyTensor}, affineMaps,
getNParallelLoopsAttrs(resultTy.getRank()),
[&](OpBuilder &b, Location loc, ValueRange args) {
auto indexValue = args[0];
- auto index0 = rewriter.create<linalg::IndexOp>(loc, 0);
- Value index1 = rewriter.create<arith::IndexCastOp>(
- loc, rewriter.getIndexType(), indexValue);
- auto index2 = rewriter.create<linalg::IndexOp>(loc, 2);
- Value extract = rewriter.create<tensor::ExtractOp>(
- loc, input, ValueRange{index0, index1, index2});
- rewriter.create<linalg::YieldOp>(loc, extract);
+ auto index0 = linalg::IndexOp::create(rewriter, loc, 0);
+ Value index1 = arith::IndexCastOp::create(
+ rewriter, loc, rewriter.getIndexType(), indexValue);
+ auto index2 = linalg::IndexOp::create(rewriter, loc, 2);
+ Value extract = tensor::ExtractOp::create(
+ rewriter, loc, input, ValueRange{index0, index1, index2});
+ linalg::YieldOp::create(rewriter, loc, extract);
});
rewriter.replaceOp(op, genericOp.getResult(0));
return success();
@@ -2424,7 +2445,7 @@ class TableConverter : public OpRewritePattern<tosa::TableOp> {
for (int i = 0; i < resultTy.getRank(); ++i) {
if (inputTy.isDynamicDim(i)) {
dynDims.push_back(
- rewriter.create<tensor::DimOp>(loc, op.getOperand(0), i));
+ tensor::DimOp::create(rewriter, loc, op.getOperand(0), i));
}
}
@@ -2437,9 +2458,9 @@ class TableConverter : public OpRewritePattern<tosa::TableOp> {
rewriter.getMultiDimIdentityMap(resultTy.getRank()),
rewriter.getMultiDimIdentityMap(resultTy.getRank())};
- auto genericOp = rewriter.create<linalg::GenericOp>(
- loc, resultTy, ValueRange({input}), ValueRange{emptyTensor}, affineMaps,
- getNParallelLoopsAttrs(resultTy.getRank()));
+ auto genericOp = linalg::GenericOp::create(
+ rewriter, loc, resultTy, ValueRange({input}), ValueRange{emptyTensor},
+ affineMaps, getNParallelLoopsAttrs(resultTy.getRank()));
rewriter.replaceOp(op, genericOp.getResult(0));
{
@@ -2452,69 +2473,69 @@ class TableConverter : public OpRewritePattern<tosa::TableOp> {
rewriter.setInsertionPointToStart(block);
if (inputElementTy.isInteger(8) && tableElementTy.isInteger(8) &&
resultElementTy.isInteger(8)) {
- Value index = rewriter.create<arith::IndexCastOp>(
- loc, rewriter.getIndexType(), inputValue);
- Value offset = rewriter.create<arith::ConstantIndexOp>(loc, 128);
- index = rewriter.create<arith::AddIOp>(loc, rewriter.getIndexType(),
- index, offset);
+ Value index = arith::IndexCastOp::create(
+ rewriter, loc, rewriter.getIndexType(), inputValue);
+ Value offset = arith::ConstantIndexOp::create(rewriter, loc, 128);
+ index = arith::AddIOp::create(rewriter, loc, rewriter.getIndexType(),
+ index, offset);
Value extract =
- rewriter.create<tensor::ExtractOp>(loc, table, ValueRange{index});
- rewriter.create<linalg::YieldOp>(loc, extract);
+ tensor::ExtractOp::create(rewriter, loc, table, ValueRange{index});
+ linalg::YieldOp::create(rewriter, loc, extract);
return success();
}
if (inputElementTy.isInteger(16) && tableElementTy.isInteger(16) &&
resultElementTy.isInteger(32)) {
- Value extend = rewriter.create<arith::ExtSIOp>(
- loc, rewriter.getI32Type(), inputValue);
-
- auto offset = rewriter.create<arith::ConstantOp>(
- loc, rewriter.getI32IntegerAttr(32768));
- auto seven = rewriter.create<arith::ConstantOp>(
- loc, rewriter.getI32IntegerAttr(7));
- auto one = rewriter.create<arith::ConstantOp>(
- loc, rewriter.getI32IntegerAttr(1));
- auto b1111111 = rewriter.create<arith::ConstantOp>(
- loc, rewriter.getI32IntegerAttr(127));
+ Value extend = arith::ExtSIOp::create(
+ rewriter, loc, rewriter.getI32Type(), inputValue);
+
+ auto offset = arith::ConstantOp::create(
+ rewriter, loc, rewriter.getI32IntegerAttr(32768));
+ auto seven = arith::ConstantOp::create(rewriter, loc,
+ rewriter.getI32IntegerAttr(7));
+ auto one = arith::ConstantOp::create(rewriter, loc,
+ rewriter.getI32IntegerAttr(1));
+ auto b1111111 = arith::ConstantOp::create(
+ rewriter, loc, rewriter.getI32IntegerAttr(127));
// Compute the index and fractional part from the input value:
// value = value + 32768
// index = value >> 7;
// fraction = 0x01111111 & value
- auto extendAdd = rewriter.create<arith::AddIOp>(loc, extend, offset);
- Value index = rewriter.create<arith::ShRUIOp>(loc, extendAdd, seven);
+ auto extendAdd = arith::AddIOp::create(rewriter, loc, extend, offset);
+ Value index = arith::ShRUIOp::create(rewriter, loc, extendAdd, seven);
Value fraction =
- rewriter.create<arith::AndIOp>(loc, extendAdd, b1111111);
+ arith::AndIOp::create(rewriter, loc, extendAdd, b1111111);
// Extract the base and next values from the table.
// base = (int32_t) table[index];
// next = (int32_t) table[index + 1];
- Value indexPlusOne = rewriter.create<arith::AddIOp>(loc, index, one);
+ Value indexPlusOne = arith::AddIOp::create(rewriter, loc, index, one);
- index = rewriter.create<arith::IndexCastOp>(
- loc, rewriter.getIndexType(), index);
- indexPlusOne = rewriter.create<arith::IndexCastOp>(
- loc, rewriter.getIndexType(), indexPlusOne);
+ index = arith::IndexCastOp::create(rewriter, loc,
+ rewriter.getIndexType(), index);
+ indexPlusOne = arith::IndexCastOp::create(
+ rewriter, loc, rewriter.getIndexType(), indexPlusOne);
Value base =
- rewriter.create<tensor::ExtractOp>(loc, table, ValueRange{index});
- Value next = rewriter.create<tensor::ExtractOp>(
- loc, table, ValueRange{indexPlusOne});
+ tensor::ExtractOp::create(rewriter, loc, table, ValueRange{index});
+ Value next = tensor::ExtractOp::create(rewriter, loc, table,
+ ValueRange{indexPlusOne});
base =
- rewriter.create<arith::ExtSIOp>(loc, rewriter.getI32Type(), base);
+ arith::ExtSIOp::create(rewriter, loc, rewriter.getI32Type(), base);
next =
- rewriter.create<arith::ExtSIOp>(loc, rewriter.getI32Type(), next);
+ arith::ExtSIOp::create(rewriter, loc, rewriter.getI32Type(), next);
// Use the fractional part to interpolate between the input values:
// result = (base << 7) + (next - base) * fraction
- Value baseScaled = rewriter.create<arith::ShLIOp>(loc, base, seven);
- Value diff = rewriter.create<arith::SubIOp>(loc, next, base);
- Value diffScaled = rewriter.create<arith::MulIOp>(loc, diff, fraction);
+ Value baseScaled = arith::ShLIOp::create(rewriter, loc, base, seven);
+ Value diff = arith::SubIOp::create(rewriter, loc, next, base);
+ Value diffScaled = arith::MulIOp::create(rewriter, loc, diff, fraction);
Value result =
- rewriter.create<arith::AddIOp>(loc, baseScaled, diffScaled);
+ arith::AddIOp::create(rewriter, loc, baseScaled, diffScaled);
- rewriter.create<linalg::YieldOp>(loc, result);
+ linalg::YieldOp::create(rewriter, loc, result);
return success();
}
@@ -2532,8 +2553,8 @@ struct RFFT2dConverter final : public OpRewritePattern<RFFT2dOp> {
static OpFoldResult halfPlusOne(OpBuilder &builder, Location loc,
OpFoldResult ofr) {
- auto one = builder.create<arith::ConstantIndexOp>(loc, 1);
- auto two = builder.create<arith::ConstantIndexOp>(loc, 2);
+ auto one = arith::ConstantIndexOp::create(builder, loc, 1);
+ auto two = arith::ConstantIndexOp::create(builder, loc, 2);
auto value = getValueOrCreateConstantIndexOp(builder, loc, ofr);
auto divBy2 = builder.createOrFold<arith::DivUIOp>(loc, value, two);
@@ -2562,9 +2583,9 @@ struct RFFT2dConverter final : public OpRewritePattern<RFFT2dOp> {
RankedTensorType type,
llvm::ArrayRef<Value> dynamicSizes) {
auto emptyTensor =
- rewriter.create<tensor::EmptyOp>(loc, type, dynamicSizes);
+ tensor::EmptyOp::create(rewriter, loc, type, dynamicSizes);
auto fillValueAttr = rewriter.getZeroAttr(type.getElementType());
- auto fillValue = rewriter.create<arith::ConstantOp>(loc, fillValueAttr);
+ auto fillValue = arith::ConstantOp::create(rewriter, loc, fillValueAttr);
auto filledTensor = rewriter
.create<linalg::FillOp>(loc, ValueRange{fillValue},
ValueRange{emptyTensor})
@@ -2574,18 +2595,18 @@ struct RFFT2dConverter final : public OpRewritePattern<RFFT2dOp> {
static Value castIndexToFloat(OpBuilder &builder, Location loc,
FloatType type, Value value) {
- auto integerVal = builder.create<arith::IndexCastUIOp>(
- loc,
+ auto integerVal = arith::IndexCastUIOp::create(
+ builder, loc,
type.getIntOrFloatBitWidth() > 32 ? builder.getI64Type()
: builder.getI32Type(),
value);
- return builder.create<arith::UIToFPOp>(loc, type, integerVal);
+ return arith::UIToFPOp::create(builder, loc, type, integerVal);
}
static Value createLinalgIndex(OpBuilder &builder, Location loc,
FloatType type, int64_t index) {
- auto indexVal = builder.create<linalg::IndexOp>(loc, index);
+ auto indexVal = linalg::IndexOp::create(builder, loc, index);
return castIndexToFloat(builder, loc, type, indexVal);
}
@@ -2640,7 +2661,7 @@ struct RFFT2dConverter final : public OpRewritePattern<RFFT2dOp> {
// Constants and dimension sizes
auto twoPiAttr = rewriter.getFloatAttr(elementType, 6.283185307179586);
- auto twoPi = rewriter.create<arith::ConstantOp>(loc, twoPiAttr);
+ auto twoPi = arith::ConstantOp::create(rewriter, loc, twoPiAttr);
auto constH = castIndexToFloat(rewriter, loc, elementType, dimH);
auto constW = castIndexToFloat(rewriter, loc, elementType, dimW);
@@ -2650,43 +2671,45 @@ struct RFFT2dConverter final : public OpRewritePattern<RFFT2dOp> {
Value sumImag = args[2];
// Indices for angle computation
- Value oy = builder.create<linalg::IndexOp>(loc, 1);
- Value ox = builder.create<linalg::IndexOp>(loc, 2);
- Value iy = builder.create<linalg::IndexOp>(loc, 3);
- Value ix = builder.create<linalg::IndexOp>(loc, 4);
+ Value oy = linalg::IndexOp::create(builder, loc, 1);
+ Value ox = linalg::IndexOp::create(builder, loc, 2);
+ Value iy = linalg::IndexOp::create(builder, loc, 3);
+ Value ix = linalg::IndexOp::create(builder, loc, 4);
// Calculating angle without integer parts of components as sin/cos are
// periodic: angle = 2 * pi() * ( ( (iy * oy) % H) / H + ( (ix * ox) % W )
// / W);
- auto iyXoy = builder.create<index::MulOp>(loc, iy, oy);
- auto ixXox = builder.create<index::MulOp>(loc, ix, ox);
+ auto iyXoy = index::MulOp::create(builder, loc, iy, oy);
+ auto ixXox = index::MulOp::create(builder, loc, ix, ox);
- auto iyRem = builder.create<index::RemUOp>(loc, iyXoy, dimH);
- auto ixRem = builder.create<index::RemUOp>(loc, ixXox, dimW);
+ auto iyRem = index::RemUOp::create(builder, loc, iyXoy, dimH);
+ auto ixRem = index::RemUOp::create(builder, loc, ixXox, dimW);
auto iyRemFloat = castIndexToFloat(builder, loc, elementType, iyRem);
auto ixRemFloat = castIndexToFloat(builder, loc, elementType, ixRem);
- auto yComponent = builder.create<arith::DivFOp>(loc, iyRemFloat, constH);
- auto xComponent = builder.create<arith::DivFOp>(loc, ixRemFloat, constW);
- auto sumXY = builder.create<arith::AddFOp>(loc, yComponent, xComponent);
- auto angle = builder.create<arith::MulFOp>(loc, twoPi, sumXY);
+ auto yComponent = arith::DivFOp::create(builder, loc, iyRemFloat, constH);
+ auto xComponent = arith::DivFOp::create(builder, loc, ixRemFloat, constW);
+ auto sumXY = arith::AddFOp::create(builder, loc, yComponent, xComponent);
+ auto angle = arith::MulFOp::create(builder, loc, twoPi, sumXY);
// realComponent = valReal * cos(angle)
// imagComponent = valReal * sin(angle)
- auto cosAngle = builder.create<math::CosOp>(loc, angle);
- auto sinAngle = builder.create<math::SinOp>(loc, angle);
+ auto cosAngle = math::CosOp::create(builder, loc, angle);
+ auto sinAngle = math::SinOp::create(builder, loc, angle);
auto realComponent =
- builder.create<arith::MulFOp>(loc, valReal, cosAngle);
+ arith::MulFOp::create(builder, loc, valReal, cosAngle);
auto imagComponent =
- builder.create<arith::MulFOp>(loc, valReal, sinAngle);
+ arith::MulFOp::create(builder, loc, valReal, sinAngle);
// outReal = sumReal + realComponent
// outImag = sumImag - imagComponent
- auto outReal = builder.create<arith::AddFOp>(loc, sumReal, realComponent);
- auto outImag = builder.create<arith::SubFOp>(loc, sumImag, imagComponent);
+ auto outReal =
+ arith::AddFOp::create(builder, loc, sumReal, realComponent);
+ auto outImag =
+ arith::SubFOp::create(builder, loc, sumImag, imagComponent);
- builder.create<linalg::YieldOp>(loc, ValueRange{outReal, outImag});
+ linalg::YieldOp::create(builder, loc, ValueRange{outReal, outImag});
};
rewriter.replaceOpWithNewOp<linalg::GenericOp>(
@@ -2760,7 +2783,7 @@ struct FFT2dConverter final : OpRewritePattern<FFT2dOp> {
// Constants and dimension sizes
auto twoPiAttr = rewriter.getFloatAttr(real_el_ty, 6.283185307179586);
- auto twoPi = rewriter.create<arith::ConstantOp>(loc, twoPiAttr);
+ auto twoPi = arith::ConstantOp::create(rewriter, loc, twoPiAttr);
Value constH =
RFFT2dConverter::castIndexToFloat(rewriter, loc, real_el_ty, dimH);
Value constW =
@@ -2773,57 +2796,59 @@ struct FFT2dConverter final : OpRewritePattern<FFT2dOp> {
Value sumImag = args[3];
// Indices for angle computation
- Value oy = builder.create<linalg::IndexOp>(loc, 1);
- Value ox = builder.create<linalg::IndexOp>(loc, 2);
- Value iy = builder.create<linalg::IndexOp>(loc, 3);
- Value ix = builder.create<linalg::IndexOp>(loc, 4);
+ Value oy = linalg::IndexOp::create(builder, loc, 1);
+ Value ox = linalg::IndexOp::create(builder, loc, 2);
+ Value iy = linalg::IndexOp::create(builder, loc, 3);
+ Value ix = linalg::IndexOp::create(builder, loc, 4);
// float_t angle = sign_val * 2 * pi() * ( ( (iy * oy) % H) / H + ( (ix *
// ox) % W ) / W);
- auto iyXoy = builder.create<index::MulOp>(loc, iy, oy);
- auto ixXox = builder.create<index::MulOp>(loc, ix, ox);
+ auto iyXoy = index::MulOp::create(builder, loc, iy, oy);
+ auto ixXox = index::MulOp::create(builder, loc, ix, ox);
- auto iyRem = builder.create<index::RemUOp>(loc, iyXoy, dimH);
- auto ixRem = builder.create<index::RemUOp>(loc, ixXox, dimW);
+ auto iyRem = index::RemUOp::create(builder, loc, iyXoy, dimH);
+ auto ixRem = index::RemUOp::create(builder, loc, ixXox, dimW);
auto iyRemFloat =
RFFT2dConverter::castIndexToFloat(builder, loc, real_el_ty, iyRem);
auto ixRemFloat =
RFFT2dConverter::castIndexToFloat(builder, loc, real_el_ty, ixRem);
- auto yComponent = builder.create<arith::DivFOp>(loc, iyRemFloat, constH);
- auto xComponent = builder.create<arith::DivFOp>(loc, ixRemFloat, constW);
+ auto yComponent = arith::DivFOp::create(builder, loc, iyRemFloat, constH);
+ auto xComponent = arith::DivFOp::create(builder, loc, ixRemFloat, constW);
- auto sumXY = builder.create<arith::AddFOp>(loc, yComponent, xComponent);
- auto angle = builder.create<arith::MulFOp>(loc, twoPi, sumXY);
+ auto sumXY = arith::AddFOp::create(builder, loc, yComponent, xComponent);
+ auto angle = arith::MulFOp::create(builder, loc, twoPi, sumXY);
if (inverse.getValue()) {
- angle = builder.create<arith::MulFOp>(
- loc, angle,
- rewriter.create<arith::ConstantOp>(
- loc, rewriter.getFloatAttr(real_el_ty, -1.0)));
+ angle = arith::MulFOp::create(
+ builder, loc, angle,
+ arith::ConstantOp::create(rewriter, loc,
+ rewriter.getFloatAttr(real_el_ty, -1.0)));
}
// realComponent = val_real * cos(a) + val_imag * sin(a);
// imagComponent = -val_real * sin(a) + val_imag * cos(a);
- auto cosAngle = builder.create<math::CosOp>(loc, angle);
- auto sinAngle = builder.create<math::SinOp>(loc, angle);
+ auto cosAngle = math::CosOp::create(builder, loc, angle);
+ auto sinAngle = math::SinOp::create(builder, loc, angle);
- auto rcos = builder.create<arith::MulFOp>(loc, valReal, cosAngle);
- auto rsin = builder.create<arith::MulFOp>(loc, valImag, sinAngle);
- auto realComponent = builder.create<arith::AddFOp>(loc, rcos, rsin);
+ auto rcos = arith::MulFOp::create(builder, loc, valReal, cosAngle);
+ auto rsin = arith::MulFOp::create(builder, loc, valImag, sinAngle);
+ auto realComponent = arith::AddFOp::create(builder, loc, rcos, rsin);
- auto icos = builder.create<arith::MulFOp>(loc, valImag, cosAngle);
- auto isin = builder.create<arith::MulFOp>(loc, valReal, sinAngle);
+ auto icos = arith::MulFOp::create(builder, loc, valImag, cosAngle);
+ auto isin = arith::MulFOp::create(builder, loc, valReal, sinAngle);
- auto imagComponent = builder.create<arith::SubFOp>(loc, icos, isin);
+ auto imagComponent = arith::SubFOp::create(builder, loc, icos, isin);
// outReal = sumReal + realComponent
// outImag = sumImag - imagComponent
- auto outReal = builder.create<arith::AddFOp>(loc, sumReal, realComponent);
- auto outImag = builder.create<arith::AddFOp>(loc, sumImag, imagComponent);
+ auto outReal =
+ arith::AddFOp::create(builder, loc, sumReal, realComponent);
+ auto outImag =
+ arith::AddFOp::create(builder, loc, sumImag, imagComponent);
- builder.create<linalg::YieldOp>(loc, ValueRange{outReal, outImag});
+ linalg::YieldOp::create(builder, loc, ValueRange{outReal, outImag});
};
rewriter.replaceOpWithNewOp<linalg::GenericOp>(
diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp
index 00b9a065dfb3d..3a205246ddd9e 100644
--- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp
+++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp
@@ -52,11 +52,11 @@ static mlir::Value applyPad(Location loc, Value input, ArrayRef<int64_t> pad,
highIndices.push_back(rewriter.getIndexAttr(highPad));
}
- Value padValue = rewriter.create<arith::ConstantOp>(loc, padAttr);
+ Value padValue = arith::ConstantOp::create(rewriter, loc, padAttr);
- return rewriter.create<tensor::PadOp>(
- loc, RankedTensorType::get(paddedShape, inputETy), input, lowIndices,
- highIndices, padValue);
+ return tensor::PadOp::create(rewriter, loc,
+ RankedTensorType::get(paddedShape, inputETy),
+ input, lowIndices, highIndices, padValue);
}
static mlir::Value
@@ -72,10 +72,10 @@ linalgIntBroadcastExtSIAdd(PatternRewriter &rewriter, Location loc, Value bias,
Value biasVal = args[0];
Type resType = args[1].getType();
if (resType != biasVal.getType()) {
- biasVal = builder.create<arith::ExtSIOp>(loc, resType, biasVal);
+ biasVal = arith::ExtSIOp::create(builder, loc, resType, biasVal);
}
- Value added = builder.create<arith::AddIOp>(loc, biasVal, args[1]);
- builder.create<linalg::YieldOp>(loc, added);
+ Value added = arith::AddIOp::create(builder, loc, biasVal, args[1]);
+ linalg::YieldOp::create(builder, loc, added);
})
.getResult(0);
}
@@ -134,19 +134,19 @@ static mlir::Value linalgBroadcastAndMaybeExt(PatternRewriter &rewriter,
if (resType != biasVal.getType()) {
biasVal =
resultTy.getElementType().isFloat()
- ? builder.create<arith::ExtFOp>(loc, resType, biasVal)
+ ? arith::ExtFOp::create(builder, loc, resType, biasVal)
.getResult()
- : builder.create<arith::ExtSIOp>(loc, resType, biasVal)
+ : arith::ExtSIOp::create(builder, loc, resType, biasVal)
.getResult();
}
- builder.create<linalg::YieldOp>(loc, biasVal);
+ linalg::YieldOp::create(builder, loc, biasVal);
})
.getResult(0);
}
static mlir::Value reifyConstantDim(int64_t attr,
ImplicitLocOpBuilder &builder) {
- return builder.create<arith::ConstantIndexOp>(attr);
+ return arith::ConstantIndexOp::create(builder, attr);
}
// Calculating the output width/height using the formula:
@@ -160,22 +160,22 @@ static mlir::Value getConvOrPoolOutputDim(Location loc, Value inputDim,
int64_t dilationAttr,
OpBuilder &rewriter) {
ImplicitLocOpBuilder builder(loc, rewriter);
- auto one = rewriter.create<arith::ConstantOp>(
- loc, IntegerAttr::get(inputDim.getType(), 1));
+ auto one = arith::ConstantOp::create(rewriter, loc,
+ IntegerAttr::get(inputDim.getType(), 1));
Value padBefore = reifyConstantDim(padBeforeAttr, builder);
- Value paddedBefore = builder.create<arith::AddIOp>(inputDim, padBefore);
+ Value paddedBefore = arith::AddIOp::create(builder, inputDim, padBefore);
Value padAfter = reifyConstantDim(padAfterAttr, builder);
- Value paddedAfter = builder.create<arith::AddIOp>(paddedBefore, padAfter);
+ Value paddedAfter = arith::AddIOp::create(builder, paddedBefore, padAfter);
- Value subOne = builder.create<arith::SubIOp>(kernelDim, one);
+ Value subOne = arith::SubIOp::create(builder, kernelDim, one);
Value dilation = reifyConstantDim(dilationAttr, builder);
- Value dilated = builder.create<arith::MulIOp>(dilation, subOne);
- Value addOne = builder.create<arith::AddIOp>(dilated, one);
+ Value dilated = arith::MulIOp::create(builder, dilation, subOne);
+ Value addOne = arith::AddIOp::create(builder, dilated, one);
- Value subtract = builder.create<arith::SubIOp>(paddedAfter, addOne);
+ Value subtract = arith::SubIOp::create(builder, paddedAfter, addOne);
Value stride = reifyConstantDim(strideAttr, builder);
- Value divide = builder.create<arith::DivUIOp>(subtract, stride);
- return builder.create<arith::AddIOp>(divide, one);
+ Value divide = arith::DivUIOp::create(builder, subtract, stride);
+ return arith::AddIOp::create(builder, divide, one);
}
// Creates a vector of the dynamic output dims for Conv2D and Depthwise_Conv2D
@@ -198,9 +198,9 @@ static SmallVector<Value> inferDynamicDimsForConv(
auto padBottom = padAttr[i * 2 + 1];
auto stride = strideAttr[i];
auto dilation = dilationAttr[i];
- Value initDynDim = rewriter.create<tensor::DimOp>(loc, input, inputDim);
+ Value initDynDim = tensor::DimOp::create(rewriter, loc, input, inputDim);
Value kernelDynDim =
- rewriter.create<tensor::DimOp>(loc, weight, kernelDim);
+ tensor::DimOp::create(rewriter, loc, weight, kernelDim);
// H = F(IH, pad_top, pad_bottom, dilation_y, KH, stride_y)
dynDims[inputDim] =
getConvOrPoolOutputDim(loc, initDynDim, padTop, padBottom,
@@ -211,7 +211,7 @@ static SmallVector<Value> inferDynamicDimsForConv(
// Get the batch/channels dimensions.
for (int i = 0; i < inputRank; i++) {
if (resultTy.isDynamicDim(i) && !dynDims[i])
- dynDims[i] = rewriter.create<tensor::DimOp>(loc, input, i);
+ dynDims[i] = tensor::DimOp::create(rewriter, loc, input, i);
}
SmallVector<Value> filteredDims = condenseValues(dynDims);
@@ -350,8 +350,8 @@ class ConvConverter : public OpConversionPattern<TosaConvOp> {
auto weightPermAttr = rewriter.getDenseI32ArrayAttr(weightPerm);
Type newWeightTy =
RankedTensorType::get(newWeightShape, weightTy.getElementType());
- weight = rewriter.create<tosa::TransposeOp>(loc, newWeightTy, weight,
- weightPermAttr);
+ weight = tosa::TransposeOp::create(rewriter, loc, newWeightTy, weight,
+ weightPermAttr);
}
}
@@ -372,8 +372,8 @@ class ConvConverter : public OpConversionPattern<TosaConvOp> {
auto weightPermAttr = rewriter.getDenseI32ArrayAttr(weightPerm);
Type newWeightTy =
RankedTensorType::get(newWeightShape, weightTy.getElementType());
- weight = rewriter.create<tosa::TransposeOp>(loc, newWeightTy, weight,
- weightPermAttr);
+ weight = tosa::TransposeOp::create(rewriter, loc, newWeightTy, weight,
+ weightPermAttr);
}
// Extract the attributes for convolution.
@@ -384,8 +384,8 @@ class ConvConverter : public OpConversionPattern<TosaConvOp> {
auto strideAttr = rewriter.getI64TensorAttr(stride);
auto dilationAttr = rewriter.getI64TensorAttr(dilation);
- Value biasEmptyTensor = rewriter.create<tensor::EmptyOp>(
- loc, resultTy.getShape(), accETy, filteredDims);
+ Value biasEmptyTensor = tensor::EmptyOp::create(
+ rewriter, loc, resultTy.getShape(), accETy, filteredDims);
Value broadcastBias =
linalgBroadcastAndMaybeExt(rewriter, loc, bias, biasEmptyTensor);
@@ -394,8 +394,8 @@ class ConvConverter : public OpConversionPattern<TosaConvOp> {
auto iZp = rewriter.getI32IntegerAttr(inputZpVal);
auto kZp = rewriter.getI32IntegerAttr(weightZpVal);
- auto iZpVal = rewriter.create<arith::ConstantOp>(loc, iZp);
- auto kZpVal = rewriter.create<arith::ConstantOp>(loc, kZp);
+ auto iZpVal = arith::ConstantOp::create(rewriter, loc, iZp);
+ auto kZpVal = arith::ConstantOp::create(rewriter, loc, kZp);
Value conv =
rewriter
@@ -417,7 +417,7 @@ class ConvConverter : public OpConversionPattern<TosaConvOp> {
// We may need to truncate back to the result type if the accumulator was
// wider than the result.
if (resultTy != accTy)
- conv = rewriter.create<tosa::CastOp>(loc, resultTy, conv);
+ conv = tosa::CastOp::create(rewriter, loc, resultTy, conv);
rewriter.replaceOp(op, conv);
return success();
@@ -526,16 +526,16 @@ class DepthwiseConvConverter
accETy);
auto resultZeroAttr = rewriter.getZeroAttr(accETy);
- Value emptyTensor = rewriter.create<tensor::EmptyOp>(
- loc, linalgConvTy.getShape(), accETy, filteredDims);
- Value zero = rewriter.create<arith::ConstantOp>(loc, resultZeroAttr);
+ Value emptyTensor = tensor::EmptyOp::create(
+ rewriter, loc, linalgConvTy.getShape(), accETy, filteredDims);
+ Value zero = arith::ConstantOp::create(rewriter, loc, resultZeroAttr);
Value zeroTensor = rewriter
.create<linalg::FillOp>(loc, ValueRange{zero},
ValueRange{emptyTensor})
.result();
- Value biasEmptyTensor = rewriter.create<tensor::EmptyOp>(
- loc, resultTy.getShape(), resultETy, filteredDims);
+ Value biasEmptyTensor = tensor::EmptyOp::create(
+ rewriter, loc, resultTy.getShape(), resultETy, filteredDims);
// Broadcast the initial value to the output tensor before convolving.
SmallVector<AffineMap, 4> indexingMaps;
@@ -553,16 +553,16 @@ class DepthwiseConvConverter
// We may need to truncate back to the result type if the accumulator was
// wider than the result.
if (accETy != resultETy)
- conv = rewriter.create<tosa::CastOp>(
- loc,
+ conv = tosa::CastOp::create(
+ rewriter, loc,
RankedTensorType::get(cast<ShapedType>(conv.getType()).getShape(),
resultETy),
conv);
SmallVector<ReassociationExprs, 4> reassociationMap;
createDepthwiseConvCollapseMap(resultRank, reassociationMap, rewriter);
- Value convReshape = rewriter.create<tensor::CollapseShapeOp>(
- loc, resultTy, conv, reassociationMap);
+ Value convReshape = tensor::CollapseShapeOp::create(
+ rewriter, loc, resultTy, conv, reassociationMap);
Value result =
rewriter
@@ -574,20 +574,20 @@ class DepthwiseConvConverter
ValueRange args) {
Value added;
if (llvm::isa<FloatType>(inputETy))
- added = nestedBuilder.create<arith::AddFOp>(loc, args[0],
- args[1]);
+ added = arith::AddFOp::create(nestedBuilder, loc, args[0],
+ args[1]);
else
- added = nestedBuilder.create<arith::AddIOp>(loc, args[0],
- args[1]);
- nestedBuilder.create<linalg::YieldOp>(nestedLoc, added);
+ added = arith::AddIOp::create(nestedBuilder, loc, args[0],
+ args[1]);
+ linalg::YieldOp::create(nestedBuilder, nestedLoc, added);
})
.getResult(0);
rewriter.replaceOp(op, result);
} else {
IntegerAttr iZp = rewriter.getI32IntegerAttr(inputZpVal);
IntegerAttr wZp = rewriter.getI32IntegerAttr(weightZpVal);
- auto iZpVal = rewriter.create<arith::ConstantOp>(loc, iZp);
- auto kZpVal = rewriter.create<arith::ConstantOp>(loc, wZp);
+ auto iZpVal = arith::ConstantOp::create(rewriter, loc, iZp);
+ auto kZpVal = arith::ConstantOp::create(rewriter, loc, wZp);
Value conv =
rewriter
.create<linalg::DepthwiseConv2DNhwcHwcmQOp>(
@@ -596,8 +596,8 @@ class DepthwiseConvConverter
.getResult(0);
SmallVector<ReassociationExprs, 4> reassociationMap;
createDepthwiseConvCollapseMap(resultRank, reassociationMap, rewriter);
- Value convReshape = rewriter.create<tensor::CollapseShapeOp>(
- loc, resultTy, conv, reassociationMap);
+ Value convReshape = tensor::CollapseShapeOp::create(
+ rewriter, loc, resultTy, conv, reassociationMap);
Value result = linalgIntBroadcastExtSIAdd(
rewriter, loc, bias, convReshape, biasEmptyTensor, indexingMaps);
rewriter.replaceOp(op, result);
@@ -621,23 +621,24 @@ class MatMulConverter : public OpConversionPattern<tosa::MatMulOp> {
dynDims.resize(cast<ShapedType>(op->getResult(0).getType()).getRank());
if (!outputTy.hasRank() || outputTy.isDynamicDim(0)) {
- dynDims[0] = rewriter.create<tensor::DimOp>(loc, op->getOperand(0), 0);
+ dynDims[0] = tensor::DimOp::create(rewriter, loc, op->getOperand(0), 0);
}
if (!outputTy.hasRank() || outputTy.isDynamicDim(1)) {
- dynDims[1] = rewriter.create<tensor::DimOp>(loc, op->getOperand(0), 1);
+ dynDims[1] = tensor::DimOp::create(rewriter, loc, op->getOperand(0), 1);
}
if (!outputTy.hasRank() || outputTy.isDynamicDim(2)) {
- dynDims[2] = rewriter.create<tensor::DimOp>(loc, op->getOperand(1), 2);
+ dynDims[2] = tensor::DimOp::create(rewriter, loc, op->getOperand(1), 2);
}
SmallVector<Value> filteredDims = condenseValues(dynDims);
auto zeroAttr = rewriter.getZeroAttr(outputElementTy);
- Value zero = rewriter.create<arith::ConstantOp>(loc, zeroAttr);
- auto emptyTensor = rewriter.create<tensor::EmptyOp>(
- loc, outputTy.getShape(), outputTy.getElementType(), filteredDims);
+ Value zero = arith::ConstantOp::create(rewriter, loc, zeroAttr);
+ auto emptyTensor =
+ tensor::EmptyOp::create(rewriter, loc, outputTy.getShape(),
+ outputTy.getElementType(), filteredDims);
Value zeroTensor = rewriter
.create<linalg::FillOp>(loc, ValueRange{zero},
ValueRange{emptyTensor})
@@ -670,10 +671,10 @@ class MatMulConverter : public OpConversionPattern<tosa::MatMulOp> {
return success();
}
- auto aZp = rewriter.create<arith::ConstantOp>(
- loc, rewriter.getI32IntegerAttr(aZpVal));
- auto bZp = rewriter.create<arith::ConstantOp>(
- loc, rewriter.getI32IntegerAttr(bZpVal));
+ auto aZp = arith::ConstantOp::create(rewriter, loc,
+ rewriter.getI32IntegerAttr(aZpVal));
+ auto bZp = arith::ConstantOp::create(rewriter, loc,
+ rewriter.getI32IntegerAttr(bZpVal));
rewriter.replaceOpWithNewOp<linalg::QuantizedBatchMatmulOp>(
op, TypeRange{op.getType()},
ValueRange{adaptor.getA(), adaptor.getB(), aZp, bZp}, zeroTensor);
@@ -702,7 +703,7 @@ class MaxPool2dConverter : public OpConversionPattern<tosa::MaxPool2dOp> {
// Batch dimension
if (resultTy.isDynamicDim(0))
- dynamicDims.push_back(rewriter.create<tensor::DimOp>(loc, input, 0));
+ dynamicDims.push_back(tensor::DimOp::create(rewriter, loc, input, 0));
// Height/width dimensions
for (int64_t dim : {1, 2}) {
@@ -713,10 +714,10 @@ class MaxPool2dConverter : public OpConversionPattern<tosa::MaxPool2dOp> {
int64_t index = dim - 1;
// Input height/width
- Value ihw = rewriter.create<tensor::DimOp>(loc, input, dim);
+ Value ihw = tensor::DimOp::create(rewriter, loc, input, dim);
// Kernel height/width
- Value khw = rewriter.create<arith::ConstantIndexOp>(loc, kernel[index]);
+ Value khw = arith::ConstantIndexOp::create(rewriter, loc, kernel[index]);
// Output height/width
Value ohw = getConvOrPoolOutputDim(loc, ihw, pad[index * 2],
@@ -727,7 +728,7 @@ class MaxPool2dConverter : public OpConversionPattern<tosa::MaxPool2dOp> {
// Channel dimension
if (resultTy.isDynamicDim(3))
- dynamicDims.push_back(rewriter.create<tensor::DimOp>(loc, input, 3));
+ dynamicDims.push_back(tensor::DimOp::create(rewriter, loc, input, 3));
return dynamicDims;
}
@@ -776,7 +777,7 @@ class MaxPool2dConverter : public OpConversionPattern<tosa::MaxPool2dOp> {
Value paddedInput = applyPad(loc, input, pad, initialAttr, rewriter);
- Value initialValue = rewriter.create<arith::ConstantOp>(loc, initialAttr);
+ Value initialValue = arith::ConstantOp::create(rewriter, loc, initialAttr);
ArrayRef<int64_t> kernel = op.getKernel();
ArrayRef<int64_t> stride = op.getStride();
@@ -785,15 +786,16 @@ class MaxPool2dConverter : public OpConversionPattern<tosa::MaxPool2dOp> {
Attribute dilationAttr = rewriter.getI64VectorAttr({1, 1});
// Create the linalg op that performs pooling.
- Value emptyTensor = rewriter.create<tensor::EmptyOp>(
- loc, resultTy.getShape(), resultTy.getElementType(), dynamicDims);
+ Value emptyTensor =
+ tensor::EmptyOp::create(rewriter, loc, resultTy.getShape(),
+ resultTy.getElementType(), dynamicDims);
Value filledEmptyTensor =
- rewriter.create<linalg::FillOp>(loc, initialValue, emptyTensor)
+ linalg::FillOp::create(rewriter, loc, initialValue, emptyTensor)
.result();
Value fakeWindowDims =
- rewriter.create<tensor::EmptyOp>(loc, kernel, resultETy);
+ tensor::EmptyOp::create(rewriter, loc, kernel, resultETy);
if (isUnsigned) {
rewriter.replaceOpWithNewOp<linalg::PoolingNhwcMaxUnsignedOp>(
@@ -802,8 +804,8 @@ class MaxPool2dConverter : public OpConversionPattern<tosa::MaxPool2dOp> {
return llvm::success();
}
- auto resultOp = rewriter.create<linalg::PoolingNhwcMaxOp>(
- op->getLoc(), ArrayRef<Type>{resultTy},
+ auto resultOp = linalg::PoolingNhwcMaxOp::create(
+ rewriter, op->getLoc(), ArrayRef<Type>{resultTy},
ValueRange{paddedInput, fakeWindowDims}, filledEmptyTensor, strideAttr,
dilationAttr);
@@ -823,9 +825,10 @@ class MaxPool2dConverter : public OpConversionPattern<tosa::MaxPool2dOp> {
// it to include the appropriate checks. If the current value is NaN the
// old value of pool will be taken otherwise we use the result.
if (nanMode == "IGNORE") {
- auto genericOp = rewriter.create<linalg::GenericOp>(
- loc, resultOp.getType(0), resultOp.getInputs(), resultOp.getOutputs(),
- resultOp.getIndexingMapsArray(), resultOp.getIteratorTypesArray(),
+ auto genericOp = linalg::GenericOp::create(
+ rewriter, loc, resultOp.getType(0), resultOp.getInputs(),
+ resultOp.getOutputs(), resultOp.getIndexingMapsArray(),
+ resultOp.getIteratorTypesArray(),
[&](OpBuilder &opBuilder, Location loc, ValueRange blockArgs) {
IRMapping map;
auto oldBlock = resultOp.getRegion().begin();
@@ -833,12 +836,12 @@ class MaxPool2dConverter : public OpConversionPattern<tosa::MaxPool2dOp> {
auto &oldMaxOp = *resultOp.getBlock()->begin();
map.map(oldArgs, blockArgs);
auto *newOp = opBuilder.clone(oldMaxOp, map);
- Value isNaN = opBuilder.create<arith::CmpFOp>(
- loc, arith::CmpFPredicate::UNO, blockArgs.front(),
- blockArgs.front());
- auto selectOp = opBuilder.create<arith::SelectOp>(
- loc, isNaN, blockArgs.back(), newOp->getResult(0));
- opBuilder.create<linalg::YieldOp>(loc, selectOp.getResult());
+ Value isNaN =
+ arith::CmpFOp::create(opBuilder, loc, arith::CmpFPredicate::UNO,
+ blockArgs.front(), blockArgs.front());
+ auto selectOp = arith::SelectOp::create(
+ opBuilder, loc, isNaN, blockArgs.back(), newOp->getResult(0));
+ linalg::YieldOp::create(opBuilder, loc, selectOp.getResult());
});
rewriter.replaceOp(resultOp, genericOp);
}
@@ -894,7 +897,7 @@ class AvgPool2dConverter : public OpRewritePattern<tosa::AvgPool2dOp> {
Value paddedInput = applyPad(loc, input, pad, padAttr, rewriter);
auto initialAttr = rewriter.getZeroAttr(accETy);
- Value initialValue = rewriter.create<arith::ConstantOp>(loc, initialAttr);
+ Value initialValue = arith::ConstantOp::create(rewriter, loc, initialAttr);
ArrayRef<int64_t> kernel = op.getKernel();
ArrayRef<int64_t> stride = op.getStride();
@@ -903,8 +906,8 @@ class AvgPool2dConverter : public OpRewritePattern<tosa::AvgPool2dOp> {
Attribute dilationAttr = rewriter.getI64VectorAttr({1, 1});
// Create the linalg op that performs pooling.
- Value poolEmptyTensor = rewriter.create<tensor::EmptyOp>(
- loc, accTy.getShape(), accETy, dynamicDims);
+ Value poolEmptyTensor = tensor::EmptyOp::create(
+ rewriter, loc, accTy.getShape(), accETy, dynamicDims);
Value filledEmptyTensor =
rewriter
@@ -913,7 +916,7 @@ class AvgPool2dConverter : public OpRewritePattern<tosa::AvgPool2dOp> {
.result();
Value fakeWindowDims =
- rewriter.create<tensor::EmptyOp>(loc, kernel, accETy);
+ tensor::EmptyOp::create(rewriter, loc, kernel, accETy);
// Sum across the pooled region.
Value poolingOp = rewriter
@@ -925,24 +928,24 @@ class AvgPool2dConverter : public OpRewritePattern<tosa::AvgPool2dOp> {
// Normalize the summed value by the number of elements grouped in each
// pool.
- Value iH = rewriter.create<tensor::DimOp>(loc, poolingOp, 1);
- Value iW = rewriter.create<tensor::DimOp>(loc, poolingOp, 2);
+ Value iH = tensor::DimOp::create(rewriter, loc, poolingOp, 1);
+ Value iW = tensor::DimOp::create(rewriter, loc, poolingOp, 2);
- auto one = rewriter.create<arith::ConstantIndexOp>(loc, 1);
- iH = rewriter.create<arith::SubIOp>(loc, iH, one);
- iW = rewriter.create<arith::SubIOp>(loc, iW, one);
+ auto one = arith::ConstantIndexOp::create(rewriter, loc, 1);
+ iH = arith::SubIOp::create(rewriter, loc, iH, one);
+ iW = arith::SubIOp::create(rewriter, loc, iW, one);
- Value genericEmptyTensor = rewriter.create<tensor::EmptyOp>(
- loc, resultTy.getShape(), resultETy, dynamicDims);
+ Value genericEmptyTensor = tensor::EmptyOp::create(
+ rewriter, loc, resultTy.getShape(), resultETy, dynamicDims);
auto affineMap = rewriter.getMultiDimIdentityMap(resultTy.getRank());
- auto genericOp = rewriter.create<linalg::GenericOp>(
- loc, ArrayRef<Type>({resultTy}), ValueRange{poolingOp},
+ auto genericOp = linalg::GenericOp::create(
+ rewriter, loc, ArrayRef<Type>({resultTy}), ValueRange{poolingOp},
ValueRange{genericEmptyTensor},
ArrayRef<AffineMap>({affineMap, affineMap}),
getNParallelLoopsAttrs(resultTy.getRank()),
[&](OpBuilder &b, Location loc, ValueRange args) {
- auto zero = rewriter.create<arith::ConstantIndexOp>(loc, 0);
+ auto zero = arith::ConstantIndexOp::create(rewriter, loc, 0);
// Determines what the portion of valid input is covered by the
// kernel.
@@ -950,30 +953,30 @@ class AvgPool2dConverter : public OpRewritePattern<tosa::AvgPool2dOp> {
if (pad == 0)
return valid;
- auto padVal = rewriter.create<arith::ConstantIndexOp>(loc, pad);
- Value dpos = rewriter.create<arith::SubIOp>(loc, pos, padVal);
+ auto padVal = arith::ConstantIndexOp::create(rewriter, loc, pad);
+ Value dpos = arith::SubIOp::create(rewriter, loc, pos, padVal);
- Value offset = rewriter.create<arith::MinSIOp>(loc, dpos, zero);
- return rewriter.create<arith::AddIOp>(loc, valid, offset)
+ Value offset = arith::MinSIOp::create(rewriter, loc, dpos, zero);
+ return arith::AddIOp::create(rewriter, loc, valid, offset)
->getResult(0);
};
auto coverageFn = [&](int64_t i, Value isize) -> Value {
Value strideVal =
- rewriter.create<arith::ConstantIndexOp>(loc, stride[i - 1]);
+ arith::ConstantIndexOp::create(rewriter, loc, stride[i - 1]);
Value val =
- rewriter.create<arith::ConstantIndexOp>(loc, kernel[i - 1]);
+ arith::ConstantIndexOp::create(rewriter, loc, kernel[i - 1]);
// Find the position relative to the input tensor's ends.
- Value left = rewriter.create<linalg::IndexOp>(loc, i);
- Value right = rewriter.create<arith::SubIOp>(loc, isize, left);
- left = rewriter.create<arith::MulIOp>(loc, left, strideVal);
- right = rewriter.create<arith::MulIOp>(loc, right, strideVal);
+ Value left = linalg::IndexOp::create(rewriter, loc, i);
+ Value right = arith::SubIOp::create(rewriter, loc, isize, left);
+ left = arith::MulIOp::create(rewriter, loc, left, strideVal);
+ right = arith::MulIOp::create(rewriter, loc, right, strideVal);
// Determine how much padding was included.
val = padFn(val, left, pad[i * 2]);
val = padFn(val, right, pad[i * 2 + 1]);
- return rewriter.create<arith::MaxSIOp>(loc, one, val);
+ return arith::MaxSIOp::create(rewriter, loc, one, val);
};
// Compute the indices from either end.
@@ -981,70 +984,70 @@ class AvgPool2dConverter : public OpRewritePattern<tosa::AvgPool2dOp> {
Value kW3 = coverageFn(2, iW);
// Compute the total number of elements and normalize.
- auto count = rewriter.create<arith::IndexCastOp>(
- loc, rewriter.getI32Type(),
- rewriter.create<arith::MulIOp>(loc, kH3, kW3));
+ auto count = arith::IndexCastOp::create(
+ rewriter, loc, rewriter.getI32Type(),
+ arith::MulIOp::create(rewriter, loc, kH3, kW3));
// Divide by the number of summed values. For floats this is just
// a div however for quantized values input normalization had
// to be applied.
Value poolVal = args[0];
if (isa<FloatType>(accETy)) {
- auto countF = rewriter.create<arith::SIToFPOp>(loc, accETy, count);
- poolVal = rewriter.create<arith::DivFOp>(loc, poolVal, countF)
+ auto countF = arith::SIToFPOp::create(rewriter, loc, accETy, count);
+ poolVal = arith::DivFOp::create(rewriter, loc, poolVal, countF)
->getResult(0);
if (accETy.getIntOrFloatBitWidth() >
resultETy.getIntOrFloatBitWidth())
poolVal =
- rewriter.create<arith::TruncFOp>(loc, resultETy, poolVal);
+ arith::TruncFOp::create(rewriter, loc, resultETy, poolVal);
} else {
// If we have quantization information we need to apply an offset
// for the input zp value.
if (inputZpVal != 0) {
- auto inputZp = rewriter.create<arith::ConstantOp>(
- loc, b.getIntegerAttr(accETy, inputZpVal));
+ auto inputZp = arith::ConstantOp::create(
+ rewriter, loc, b.getIntegerAttr(accETy, inputZpVal));
Value offset =
- rewriter.create<arith::MulIOp>(loc, accETy, count, inputZp);
+ arith::MulIOp::create(rewriter, loc, accETy, count, inputZp);
poolVal =
- rewriter.create<arith::SubIOp>(loc, accETy, poolVal, offset);
+ arith::SubIOp::create(rewriter, loc, accETy, poolVal, offset);
}
// Compute: k = 32 - count_leading_zeros(value - 1)
- Value one32 = rewriter.create<arith::ConstantOp>(
- loc, rewriter.getI32IntegerAttr(1));
- Value thirtyTwo32 = rewriter.create<arith::ConstantOp>(
- loc, rewriter.getI32IntegerAttr(32));
+ Value one32 = arith::ConstantOp::create(
+ rewriter, loc, rewriter.getI32IntegerAttr(1));
+ Value thirtyTwo32 = arith::ConstantOp::create(
+ rewriter, loc, rewriter.getI32IntegerAttr(32));
Value countSubOne =
- rewriter.create<arith::SubIOp>(loc, count, one32);
+ arith::SubIOp::create(rewriter, loc, count, one32);
Value leadingZeros =
- rewriter.create<math::CountLeadingZerosOp>(loc, countSubOne);
+ math::CountLeadingZerosOp::create(rewriter, loc, countSubOne);
Value k =
- rewriter.create<arith::SubIOp>(loc, thirtyTwo32, leadingZeros);
+ arith::SubIOp::create(rewriter, loc, thirtyTwo32, leadingZeros);
// Compute: numerator = ((1 << 30) + 1) << k
Value k64 =
- rewriter.create<arith::ExtUIOp>(loc, rewriter.getI64Type(), k);
- Value thirtyShiftPlusOne = rewriter.create<arith::ConstantOp>(
- loc, rewriter.getI64IntegerAttr((1 << 30) + 1));
+ arith::ExtUIOp::create(rewriter, loc, rewriter.getI64Type(), k);
+ Value thirtyShiftPlusOne = arith::ConstantOp::create(
+ rewriter, loc, rewriter.getI64IntegerAttr((1 << 30) + 1));
Value numerator =
- rewriter.create<arith::ShLIOp>(loc, thirtyShiftPlusOne, k64);
+ arith::ShLIOp::create(rewriter, loc, thirtyShiftPlusOne, k64);
// Compute: scale.multiplier = numerator / value;
- Value count64 = rewriter.create<arith::ExtUIOp>(
- loc, rewriter.getI64Type(), count);
+ Value count64 = arith::ExtUIOp::create(
+ rewriter, loc, rewriter.getI64Type(), count);
Value multiplier =
- rewriter.create<arith::DivUIOp>(loc, numerator, count64);
- multiplier = rewriter.create<arith::TruncIOp>(
- loc, rewriter.getI32Type(), multiplier);
+ arith::DivUIOp::create(rewriter, loc, numerator, count64);
+ multiplier = arith::TruncIOp::create(
+ rewriter, loc, rewriter.getI32Type(), multiplier);
// Compute: scale.shift = 30 + k
Value k8 =
- rewriter.create<arith::TruncIOp>(loc, rewriter.getI8Type(), k);
- Value thirty8 = rewriter.create<arith::ConstantOp>(
- loc, rewriter.getI8IntegerAttr(30));
- Value shift = rewriter.create<arith::AddIOp>(loc, k8, thirty8);
+ arith::TruncIOp::create(rewriter, loc, rewriter.getI8Type(), k);
+ Value thirty8 = arith::ConstantOp::create(
+ rewriter, loc, rewriter.getI8IntegerAttr(30));
+ Value shift = arith::AddIOp::create(rewriter, loc, k8, thirty8);
auto scaled =
rewriter
@@ -1056,20 +1059,21 @@ class AvgPool2dConverter : public OpRewritePattern<tosa::AvgPool2dOp> {
// If we have quantization information we need to apply output
// zeropoint.
if (outputZpVal != 0) {
- auto outputZp = rewriter.create<arith::ConstantOp>(
- loc, b.getIntegerAttr(scaled.getType(), outputZpVal));
- scaled = rewriter.create<arith::AddIOp>(loc, scaled, outputZp)
+ auto outputZp = arith::ConstantOp::create(
+ rewriter, loc,
+ b.getIntegerAttr(scaled.getType(), outputZpVal));
+ scaled = arith::AddIOp::create(rewriter, loc, scaled, outputZp)
.getResult();
}
// Apply Clip.
int64_t outBitwidth = resultETy.getIntOrFloatBitWidth();
- auto min = rewriter.create<arith::ConstantIntOp>(
- loc, accETy,
+ auto min = arith::ConstantIntOp::create(
+ rewriter, loc, accETy,
APInt::getSignedMinValue(outBitwidth).getSExtValue());
- auto max = rewriter.create<arith::ConstantIntOp>(
- loc, accETy,
+ auto max = arith::ConstantIntOp::create(
+ rewriter, loc, accETy,
APInt::getSignedMaxValue(outBitwidth).getSExtValue());
auto clamp = clampIntHelper(loc, scaled, min, max, rewriter,
/*isUnsigned=*/false);
@@ -1078,11 +1082,11 @@ class AvgPool2dConverter : public OpRewritePattern<tosa::AvgPool2dOp> {
// Convert type.
if (resultETy != clamp.getType()) {
poolVal =
- rewriter.create<arith::TruncIOp>(loc, resultETy, poolVal);
+ arith::TruncIOp::create(rewriter, loc, resultETy, poolVal);
}
}
- rewriter.create<linalg::YieldOp>(loc, poolVal);
+ linalg::YieldOp::create(rewriter, loc, poolVal);
});
rewriter.replaceOp(op, genericOp.getResult(0));
@@ -1107,8 +1111,9 @@ class TransposeConverter : public OpRewritePattern<tosa::TransposeOp> {
auto permutedSizes =
applyTOSAPermutation<OpFoldResult>(inputSizes, constantPerms);
- auto permutedInit = rewriter.create<tensor::EmptyOp>(
- loc, permutedSizes, op.getInput1().getType().getElementType());
+ auto permutedInit =
+ tensor::EmptyOp::create(rewriter, loc, permutedSizes,
+ op.getInput1().getType().getElementType());
rewriter.replaceOpWithNewOp<linalg::TransposeOp>(
op, op.getInput1(), permutedInit,
llvm::to_vector(llvm::map_range(
diff --git a/mlir/lib/Conversion/TosaToMLProgram/TosaToMLProgram.cpp b/mlir/lib/Conversion/TosaToMLProgram/TosaToMLProgram.cpp
index 7dbccd19a0518..b83f5ec9b0283 100644
--- a/mlir/lib/Conversion/TosaToMLProgram/TosaToMLProgram.cpp
+++ b/mlir/lib/Conversion/TosaToMLProgram/TosaToMLProgram.cpp
@@ -27,8 +27,8 @@ class VariableOpConverter : public OpRewritePattern<tosa::VariableOp> {
LogicalResult matchAndRewrite(tosa::VariableOp op,
PatternRewriter &rewriter) const final {
auto variableType = tosa::getVariableType(op);
- auto newVariable = rewriter.create<mlir::ml_program::GlobalOp>(
- op.getLoc(), op.getName(), variableType, /*is_mutable=*/true,
+ auto newVariable = mlir::ml_program::GlobalOp::create(
+ rewriter, op.getLoc(), op.getName(), variableType, /*is_mutable=*/true,
op.getInitialValueAttr(), /*sym_visibility=*/nullptr);
newVariable.setPrivate();
rewriter.replaceOp(op, newVariable);
@@ -45,8 +45,8 @@ class VariableWriteOpConverter
PatternRewriter &rewriter) const final {
auto globalSymbolRef =
SymbolRefAttr::get(rewriter.getContext(), op.getName());
- auto newVariableWrite = rewriter.create<ml_program::GlobalStoreOp>(
- op.getLoc(), globalSymbolRef, op.getInput1());
+ auto newVariableWrite = ml_program::GlobalStoreOp::create(
+ rewriter, op.getLoc(), globalSymbolRef, op.getInput1());
rewriter.replaceOp(op, newVariableWrite);
return success();
}
@@ -60,8 +60,8 @@ class VariableReadOpConverter : public OpRewritePattern<tosa::VariableReadOp> {
PatternRewriter &rewriter) const final {
auto globalSymbolRef =
SymbolRefAttr::get(rewriter.getContext(), op.getName());
- auto newVariableRead = rewriter.create<ml_program::GlobalLoadOp>(
- op.getLoc(), op.getType(), globalSymbolRef);
+ auto newVariableRead = ml_program::GlobalLoadOp::create(
+ rewriter, op.getLoc(), op.getType(), globalSymbolRef);
rewriter.replaceOp(op, newVariableRead);
return success();
diff --git a/mlir/lib/Conversion/TosaToSCF/TosaToSCF.cpp b/mlir/lib/Conversion/TosaToSCF/TosaToSCF.cpp
index 03f9d20ad69de..aa6b4164e9876 100644
--- a/mlir/lib/Conversion/TosaToSCF/TosaToSCF.cpp
+++ b/mlir/lib/Conversion/TosaToSCF/TosaToSCF.cpp
@@ -30,7 +30,7 @@ static void inlineIfCase(Region &srcRegion, Region &dstRegion,
auto yield = cast<YieldOp>(headBlock->getTerminator());
rewriter.setInsertionPoint(yield);
- rewriter.create<scf::YieldOp>(yield.getLoc(), yield.getInputs());
+ scf::YieldOp::create(rewriter, yield.getLoc(), yield.getInputs());
rewriter.eraseOp(yield);
headBlock->eraseArguments(0, headBlock->getNumArguments());
@@ -46,13 +46,13 @@ static void inlineWhileCase(Region &srcRegion, Region &dstRegion,
auto yield = cast<YieldOp>(headBlock->getTerminator());
rewriter.setInsertionPoint(yield);
if (isCond) {
- auto condition =
- rewriter.create<tensor::ExtractOp>(yield.getLoc(), yield.getOperand(0));
- rewriter.create<scf::ConditionOp>(yield.getLoc(), condition,
- headBlock->getArguments());
+ auto condition = tensor::ExtractOp::create(rewriter, yield.getLoc(),
+ yield.getOperand(0));
+ scf::ConditionOp::create(rewriter, yield.getLoc(), condition,
+ headBlock->getArguments());
} else {
rewriter.setInsertionPoint(yield);
- rewriter.create<scf::YieldOp>(yield.getLoc(), yield.getInputs());
+ scf::YieldOp::create(rewriter, yield.getLoc(), yield.getInputs());
}
rewriter.eraseOp(yield);
}
@@ -66,9 +66,9 @@ class IfOpConverter : public OpRewritePattern<tosa::IfOp> {
LogicalResult matchAndRewrite(tosa::IfOp op,
PatternRewriter &rewriter) const final {
auto condition =
- rewriter.create<tensor::ExtractOp>(op.getLoc(), op.getCondition());
- auto newIf = rewriter.create<scf::IfOp>(op.getLoc(), op.getResultTypes(),
- condition, true);
+ tensor::ExtractOp::create(rewriter, op.getLoc(), op.getCondition());
+ auto newIf = scf::IfOp::create(rewriter, op.getLoc(), op.getResultTypes(),
+ condition, true);
inlineIfCase(op.getThenGraph(), newIf.getThenRegion(), op.getInputList(),
rewriter);
@@ -88,7 +88,7 @@ class ScatterOpConverter : public OpRewritePattern<tosa::ScatterOp> {
static Value createIndexConst(OpBuilder &builder, Location loc,
int64_t value) {
- return builder.create<arith::ConstantIndexOp>(loc, value);
+ return arith::ConstantIndexOp::create(builder, loc, value);
}
public:
@@ -119,9 +119,9 @@ class ScatterOpConverter : public OpRewritePattern<tosa::ScatterOp> {
auto n = ivs[0];
// Read the index and cast it to index type
- auto index = builder.create<tensor::ExtractOp>(loc, indices, ivs);
- auto castIndex = builder.create<arith::IndexCastOp>(
- loc, builder.getIndexType(), index);
+ auto index = tensor::ExtractOp::create(builder, loc, indices, ivs);
+ auto castIndex = arith::IndexCastOp::create(
+ builder, loc, builder.getIndexType(), index);
// Offset, sizes, and strides for the input tensor
auto inputOffset = llvm::to_vector(ivs);
@@ -130,13 +130,13 @@ class ScatterOpConverter : public OpRewritePattern<tosa::ScatterOp> {
llvm::SmallVector<Value> sizes = {one, one, dimC};
llvm::SmallVector<Value> strides = {one, one, one};
- auto slice = builder.create<tensor::ExtractSliceOp>(
- loc, input, inputOffset, sizes, strides);
+ auto slice = tensor::ExtractSliceOp::create(builder, loc, input,
+ inputOffset, sizes, strides);
// Insert the slice into the output accumulator tensor.
llvm::SmallVector<Value> outputOffset = {n, castIndex, zero};
- auto updated = builder.create<tensor::InsertSliceOp>(
- loc, slice, args[0], outputOffset, sizes, strides);
+ auto updated = tensor::InsertSliceOp::create(
+ builder, loc, slice, args[0], outputOffset, sizes, strides);
return {updated};
};
@@ -155,8 +155,8 @@ class WhileOpConverter : public OpRewritePattern<tosa::WhileOp> {
LogicalResult matchAndRewrite(tosa::WhileOp op,
PatternRewriter &rewriter) const final {
- auto newWhile = rewriter.create<scf::WhileOp>(
- op.getLoc(), op.getResultTypes(), op.getInputList());
+ auto newWhile = scf::WhileOp::create(
+ rewriter, op.getLoc(), op.getResultTypes(), op.getInputList());
rewriter.createBlock(&newWhile.getBefore());
rewriter.createBlock(&newWhile.getAfter());
diff --git a/mlir/lib/Conversion/TosaToTensor/TosaToTensor.cpp b/mlir/lib/Conversion/TosaToTensor/TosaToTensor.cpp
index c6cbcb0f8ab2b..2945ae3b49f1f 100644
--- a/mlir/lib/Conversion/TosaToTensor/TosaToTensor.cpp
+++ b/mlir/lib/Conversion/TosaToTensor/TosaToTensor.cpp
@@ -308,15 +308,15 @@ class SliceConverter : public OpConversionPattern<tosa::SliceOp> {
if (ShapedType::isStatic(sizes.back()))
continue;
- auto dim = rewriter.create<tensor::DimOp>(loc, input, index);
- auto offset = rewriter.create<arith::ConstantOp>(
- loc, rewriter.getIndexAttr(sliceStarts[index]));
- dynSizes.push_back(rewriter.create<arith::SubIOp>(loc, dim, offset));
+ auto dim = tensor::DimOp::create(rewriter, loc, input, index);
+ auto offset = arith::ConstantOp::create(
+ rewriter, loc, rewriter.getIndexAttr(sliceStarts[index]));
+ dynSizes.push_back(arith::SubIOp::create(rewriter, loc, dim, offset));
}
- auto newSliceOp = rewriter.create<tensor::ExtractSliceOp>(
- sliceOp.getLoc(), sliceOp.getType(), input, ValueRange({}), dynSizes,
- ValueRange({}), rewriter.getDenseI64ArrayAttr(sliceStarts),
+ auto newSliceOp = tensor::ExtractSliceOp::create(
+ rewriter, sliceOp.getLoc(), sliceOp.getType(), input, ValueRange({}),
+ dynSizes, ValueRange({}), rewriter.getDenseI64ArrayAttr(sliceStarts),
rewriter.getDenseI64ArrayAttr(sizes),
rewriter.getDenseI64ArrayAttr(strides));
@@ -361,7 +361,7 @@ class PadConverter : public OpConversionPattern<tosa::PadOp> {
Value padConstant = rewriter.createOrFold<tensor::ExtractOp>(
loc, padOp.getPadConst(),
- ValueRange({rewriter.create<arith::ConstantIndexOp>(loc, 0)}));
+ ValueRange({arith::ConstantIndexOp::create(rewriter, loc, 0)}));
if (!padConstant) {
return rewriter.notifyMatchFailure(
@@ -375,16 +375,16 @@ class PadConverter : public OpConversionPattern<tosa::PadOp> {
highValues.reserve(rank);
for (int i = 0; i < rank; i++) {
- Value lowVal = rewriter.create<arith::ConstantOp>(
- loc, rewriter.getIndexAttr(paddingVals[2 * i]));
- Value highVal = rewriter.create<arith::ConstantOp>(
- loc, rewriter.getIndexAttr(paddingVals[2 * i + 1]));
+ Value lowVal = arith::ConstantOp::create(
+ rewriter, loc, rewriter.getIndexAttr(paddingVals[2 * i]));
+ Value highVal = arith::ConstantOp::create(
+ rewriter, loc, rewriter.getIndexAttr(paddingVals[2 * i + 1]));
lowValues.push_back(lowVal);
highValues.push_back(highVal);
}
- auto newPadOp = rewriter.create<tensor::PadOp>(
- loc, padOp.getType(), input, lowValues, highValues, padConstant);
+ auto newPadOp = tensor::PadOp::create(rewriter, loc, padOp.getType(), input,
+ lowValues, highValues, padConstant);
rewriter.replaceOp(padOp, newPadOp.getResult());
return success();
@@ -402,7 +402,7 @@ struct ConcatConverter : public OpConversionPattern<tosa::ConcatOp> {
Location loc = op.getLoc();
int axis = op.getAxis();
Value axisValue =
- rewriter.create<arith::ConstantOp>(loc, rewriter.getIndexAttr(axis));
+ arith::ConstantOp::create(rewriter, loc, rewriter.getIndexAttr(axis));
int64_t rank = resultType.getRank();
SmallVector<OpFoldResult> strides(rank, rewriter.getIndexAttr(1));
@@ -439,8 +439,9 @@ struct ConcatConverter : public OpConversionPattern<tosa::ConcatOp> {
}
}
- Value result = rewriter.create<tensor::EmptyOp>(
- loc, resultType.getShape(), resultType.getElementType(), dynDims);
+ Value result =
+ tensor::EmptyOp::create(rewriter, loc, resultType.getShape(),
+ resultType.getElementType(), dynDims);
for (auto [arg, offset] : llvm::zip(adaptor.getOperands(), axisOffsets)) {
auto sizes = tensor::getMixedSizes(rewriter, op.getLoc(), arg);
diff --git a/mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp b/mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp
index d6f9495b2567c..125ea1eb60ed6 100644
--- a/mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp
+++ b/mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp
@@ -226,22 +226,22 @@ struct BroadcastOpToArmSMELowering
(srcVectorType && (srcVectorType.getRank() == 0))) {
// Broadcast scalar or 0-d vector to 1-d vector.
VectorType tileSliceType = VectorType::Builder(tileType).dropDim(0);
- broadcastOp1D = rewriter.create<vector::BroadcastOp>(
- loc, tileSliceType, broadcastOp.getSource());
+ broadcastOp1D = vector::BroadcastOp::create(rewriter, loc, tileSliceType,
+ broadcastOp.getSource());
} else if (srcVectorType && (srcVectorType.getRank() == 1))
// Value to broadcast is already a 1-d vector, nothing to do.
broadcastOp1D = broadcastOp.getSource();
else
return failure();
- auto initTile = rewriter.create<arm_sme::GetTileOp>(loc, tileType);
+ auto initTile = arm_sme::GetTileOp::create(rewriter, loc, tileType);
auto makeLoopBody = [&](OpBuilder &b, Location loc, Value tileSliceIndex,
Value currentTile) {
// Create 'arm_sme.insert_tile_slice' to broadcast the value
// to each tile slice.
- auto nextTile = b.create<arm_sme::InsertTileSliceOp>(
- loc, tileType, broadcastOp1D, currentTile, tileSliceIndex);
+ auto nextTile = arm_sme::InsertTileSliceOp::create(
+ b, loc, tileType, broadcastOp1D, currentTile, tileSliceIndex);
return nextTile.getResult();
};
@@ -292,15 +292,15 @@ struct SplatOpToArmSMELowering : public OpRewritePattern<vector::SplatOp> {
// First, broadcast the scalar to a 1-d vector.
VectorType tileSliceType = VectorType::Builder(tileType).dropDim(0);
- Value broadcastOp1D = rewriter.create<vector::BroadcastOp>(
- loc, tileSliceType, splatOp.getInput());
+ Value broadcastOp1D = vector::BroadcastOp::create(
+ rewriter, loc, tileSliceType, splatOp.getInput());
- auto initTile = rewriter.create<arm_sme::GetTileOp>(loc, tileType);
+ auto initTile = arm_sme::GetTileOp::create(rewriter, loc, tileType);
auto makeLoopBody = [&](OpBuilder &b, Location loc, Value tileSliceIndex,
Value currentTile) {
- auto nextTile = b.create<arm_sme::InsertTileSliceOp>(
- loc, tileType, broadcastOp1D, currentTile, tileSliceIndex);
+ auto nextTile = arm_sme::InsertTileSliceOp::create(
+ b, loc, tileType, broadcastOp1D, currentTile, tileSliceIndex);
return nextTile.getResult();
};
@@ -370,22 +370,22 @@ struct TransposeOpToArmSMELowering
// Allocate buffer to store input tile to.
Value vscale =
- rewriter.create<vector::VectorScaleOp>(loc, rewriter.getIndexType());
- Value minTileSlices = rewriter.create<arith::ConstantOp>(
- loc, rewriter.getIndexAttr(tileType.getDimSize(0)));
+ vector::VectorScaleOp::create(rewriter, loc, rewriter.getIndexType());
+ Value minTileSlices = arith::ConstantOp::create(
+ rewriter, loc, rewriter.getIndexAttr(tileType.getDimSize(0)));
Value c0 =
- rewriter.create<arith::ConstantOp>(loc, rewriter.getIndexAttr(0));
+ arith::ConstantOp::create(rewriter, loc, rewriter.getIndexAttr(0));
Value numTileSlices =
- rewriter.create<arith::MulIOp>(loc, vscale, minTileSlices);
+ arith::MulIOp::create(rewriter, loc, vscale, minTileSlices);
auto bufferType =
MemRefType::get({ShapedType::kDynamic, ShapedType::kDynamic},
tileType.getElementType());
- auto buffer = rewriter.create<memref::AllocaOp>(
- loc, bufferType, ValueRange{numTileSlices, numTileSlices});
+ auto buffer = memref::AllocaOp::create(
+ rewriter, loc, bufferType, ValueRange{numTileSlices, numTileSlices});
// Store input tile.
- auto tileStoreOp = rewriter.create<arm_sme::TileStoreOp>(
- loc, input, buffer, ValueRange{c0, c0});
+ auto tileStoreOp = arm_sme::TileStoreOp::create(rewriter, loc, input,
+ buffer, ValueRange{c0, c0});
// Reload input tile vertically.
rewriter.replaceOpWithNewOp<arm_sme::TileLoadOp>(
@@ -488,10 +488,10 @@ struct VectorOuterProductToArmSMELowering
Value rhsMaskDim = createMaskOp.getOperand(1);
VectorType operandMaskType = VectorType::Builder(maskType).dropDim(0);
- Value lhsMask =
- rewriter.create<vector::CreateMaskOp>(loc, operandMaskType, lhsMaskDim);
- Value rhsMask =
- rewriter.create<vector::CreateMaskOp>(loc, operandMaskType, rhsMaskDim);
+ Value lhsMask = vector::CreateMaskOp::create(rewriter, loc, operandMaskType,
+ lhsMaskDim);
+ Value rhsMask = vector::CreateMaskOp::create(rewriter, loc, operandMaskType,
+ rhsMaskDim);
return std::make_pair(lhsMask, rhsMask);
}
@@ -531,8 +531,8 @@ struct VectorExtractToArmSMELowering
}
Value sliceIndex = vector::getAsValues(rewriter, loc, position[0]).front();
- auto extractTileSlice = rewriter.create<arm_sme::ExtractTileSliceOp>(
- loc, sourceVector, sliceIndex);
+ auto extractTileSlice = arm_sme::ExtractTileSliceOp::create(
+ rewriter, loc, sourceVector, sliceIndex);
if (position.size() == 1) {
// Single index case: Extracts a 1D slice.
@@ -593,10 +593,10 @@ struct VectorInsertToArmSMELowering
if (position.size() == 2) {
// Two indices case: Insert single element into tile.
// We need to first extract the existing slice and update the element.
- tileSlice = rewriter.create<arm_sme::ExtractTileSliceOp>(
- loc, insertOp.getDest(), sliceIndex);
- tileSlice = rewriter.create<vector::InsertOp>(loc, source, tileSlice,
- position[1]);
+ tileSlice = arm_sme::ExtractTileSliceOp::create(
+ rewriter, loc, insertOp.getDest(), sliceIndex);
+ tileSlice = vector::InsertOp::create(rewriter, loc, source, tileSlice,
+ position[1]);
}
// Insert the slice into the destination tile.
@@ -642,23 +642,24 @@ struct VectorPrintToArmSMELowering : public OpRewritePattern<vector::PrintOp> {
auto loc = printOp.getLoc();
// Create a loop over the rows of the tile.
- auto vscale = rewriter.create<vector::VectorScaleOp>(loc);
+ auto vscale = vector::VectorScaleOp::create(rewriter, loc);
auto minTileRows =
- rewriter.create<arith::ConstantIndexOp>(loc, vectorType.getDimSize(0));
- auto lowerBound = rewriter.create<arith::ConstantIndexOp>(loc, 0);
- auto upperBound = rewriter.create<arith::MulIOp>(loc, minTileRows, vscale);
- auto step = rewriter.create<arith::ConstantIndexOp>(loc, 1);
- auto forOp = rewriter.create<scf::ForOp>(loc, lowerBound, upperBound, step);
+ arith::ConstantIndexOp::create(rewriter, loc, vectorType.getDimSize(0));
+ auto lowerBound = arith::ConstantIndexOp::create(rewriter, loc, 0);
+ auto upperBound = arith::MulIOp::create(rewriter, loc, minTileRows, vscale);
+ auto step = arith::ConstantIndexOp::create(rewriter, loc, 1);
+ auto forOp =
+ scf::ForOp::create(rewriter, loc, lowerBound, upperBound, step);
{
// Loop body.
rewriter.setInsertionPointToStart(forOp.getBody());
// Extract the current row from the tile.
Value rowIndex = forOp.getInductionVar();
- auto tileSlice = rewriter.create<arm_sme::ExtractTileSliceOp>(
- loc, printOp.getSource(), rowIndex);
+ auto tileSlice = arm_sme::ExtractTileSliceOp::create(
+ rewriter, loc, printOp.getSource(), rowIndex);
// Print the row with a 1D vector.print.
- rewriter.create<vector::PrintOp>(loc, tileSlice,
- printOp.getPunctuation());
+ vector::PrintOp::create(rewriter, loc, tileSlice,
+ printOp.getPunctuation());
}
rewriter.eraseOp(printOp);
@@ -707,8 +708,8 @@ struct FoldTransferWriteOfExtractTileSlice
Value mask = writeOp.getMask();
if (!mask) {
auto maskType = writeOp.getVectorType().clone(rewriter.getI1Type());
- mask = rewriter.create<arith::ConstantOp>(
- writeOp.getLoc(), maskType, DenseElementsAttr::get(maskType, true));
+ mask = arith::ConstantOp::create(rewriter, writeOp.getLoc(), maskType,
+ DenseElementsAttr::get(maskType, true));
}
rewriter.replaceOpWithNewOp<arm_sme::StoreTileSliceOp>(
@@ -776,10 +777,10 @@ struct ExtractFromCreateMaskToPselLowering
// Create the two 1-D masks at the location of the 2-D create_mask (which is
// usually outside a loop). This prevents the need for later hoisting.
rewriter.setInsertionPoint(createMaskOp);
- auto rowMask = rewriter.create<vector::CreateMaskOp>(
- loc, rowMaskType, createMaskOp.getOperand(0));
- auto colMask = rewriter.create<vector::CreateMaskOp>(
- loc, colMaskType, createMaskOp.getOperand(1));
+ auto rowMask = vector::CreateMaskOp::create(rewriter, loc, rowMaskType,
+ createMaskOp.getOperand(0));
+ auto colMask = vector::CreateMaskOp::create(rewriter, loc, colMaskType,
+ createMaskOp.getOperand(1));
rewriter.setInsertionPoint(extractOp);
auto position =
diff --git a/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp b/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp
index 9a8eb72d72925..77aab85483a8b 100644
--- a/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp
+++ b/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp
@@ -412,22 +412,22 @@ struct PrepareContractToGPUMMA
if (maps == infer({{m, k}, {k, n}, {m, n}}))
return rewriter.notifyMatchFailure(op, "contraction already prepared");
if (maps == infer({{m, k}, {n, k}, {m, n}})) {
- rhs = rewriter.create<vector::TransposeOp>(loc, rhs, perm);
+ rhs = vector::TransposeOp::create(rewriter, loc, rhs, perm);
} else if (maps == infer({{k, m}, {k, n}, {m, n}})) {
- lhs = rewriter.create<vector::TransposeOp>(loc, lhs, perm);
+ lhs = vector::TransposeOp::create(rewriter, loc, lhs, perm);
} else if (maps == infer({{k, m}, {n, k}, {m, n}})) {
- rhs = rewriter.create<vector::TransposeOp>(loc, rhs, perm);
- lhs = rewriter.create<vector::TransposeOp>(loc, lhs, perm);
+ rhs = vector::TransposeOp::create(rewriter, loc, rhs, perm);
+ lhs = vector::TransposeOp::create(rewriter, loc, lhs, perm);
} else if (maps == infer({{m, k}, {k, n}, {n, m}})) {
std::swap(rhs, lhs);
- rhs = rewriter.create<vector::TransposeOp>(loc, rhs, perm);
- lhs = rewriter.create<vector::TransposeOp>(loc, lhs, perm);
+ rhs = vector::TransposeOp::create(rewriter, loc, rhs, perm);
+ lhs = vector::TransposeOp::create(rewriter, loc, lhs, perm);
} else if (maps == infer({{m, k}, {n, k}, {n, m}})) {
std::swap(rhs, lhs);
- rhs = rewriter.create<vector::TransposeOp>(loc, rhs, perm);
+ rhs = vector::TransposeOp::create(rewriter, loc, rhs, perm);
} else if (maps == infer({{k, m}, {k, n}, {n, m}})) {
std::swap(lhs, rhs);
- lhs = rewriter.create<vector::TransposeOp>(loc, lhs, perm);
+ lhs = vector::TransposeOp::create(rewriter, loc, lhs, perm);
} else if (maps == infer({{k, m}, {n, k}, {n, m}})) {
std::swap(lhs, rhs);
} else {
@@ -494,13 +494,13 @@ struct CombineTransferReadOpTranspose final
// Fuse through the integer extend op.
if (extOp) {
if (isa<arith::ExtSIOp>(extOp))
- result = rewriter.create<arith::ExtSIOp>(loc, op.getType(), result)
+ result = arith::ExtSIOp::create(rewriter, loc, op.getType(), result)
.getResult();
else if (isa<arith::ExtUIOp>(extOp))
- result = rewriter.create<arith::ExtUIOp>(loc, op.getType(), result)
+ result = arith::ExtUIOp::create(rewriter, loc, op.getType(), result)
.getResult();
else
- result = rewriter.create<arith::ExtFOp>(loc, op.getType(), result)
+ result = arith::ExtFOp::create(rewriter, loc, op.getType(), result)
.getResult();
}
@@ -579,8 +579,8 @@ convertTransferReadOp(RewriterBase &rewriter, vector::TransferReadOp op,
}
gpu::MMAMatrixType type =
gpu::MMAMatrixType::get(op.getVectorType().getShape(), elType, fragType);
- Value load = rewriter.create<gpu::SubgroupMmaLoadMatrixOp>(
- op.getLoc(), type, op.getBase(), op.getIndices(),
+ Value load = gpu::SubgroupMmaLoadMatrixOp::create(
+ rewriter, op.getLoc(), type, op.getBase(), op.getIndices(),
rewriter.getIndexAttr(*stride),
isTranspose ? rewriter.getUnitAttr() : UnitAttr());
valueMapping[mappingResult] = load;
@@ -610,8 +610,8 @@ convertTransferWriteOp(RewriterBase &rewriter, vector::TransferWriteOp op,
}
Value matrix = it->second;
- auto store = rewriter.create<gpu::SubgroupMmaStoreMatrixOp>(
- op.getLoc(), matrix, op.getBase(), op.getIndices(),
+ auto store = gpu::SubgroupMmaStoreMatrixOp::create(
+ rewriter, op.getLoc(), matrix, op.getBase(), op.getIndices(),
rewriter.getIndexAttr(*stride), /*transpose=*/UnitAttr());
(void)store;
@@ -661,8 +661,8 @@ convertConstantOpMmaSync(RewriterBase &rewriter, arith::ConstantOp op,
return rewriter.notifyMatchFailure(op, "not a splat");
}
- Value result = rewriter.create<arith::ConstantOp>(
- op.getLoc(), vectorType,
+ Value result = arith::ConstantOp::create(
+ rewriter, op.getLoc(), vectorType,
DenseElementsAttr::get(vectorType, dense.getSplatValue<Attribute>()));
valueMapping[op.getResult()] = result;
return success();
@@ -743,7 +743,7 @@ creatLdMatrixCompatibleLoads(RewriterBase &rewriter, vector::TransferReadOp op,
}
// Adjust the load offset.
- auto laneId = rewriter.create<gpu::LaneIdOp>(loc, /*upperBound=*/nullptr);
+ auto laneId = gpu::LaneIdOp::create(rewriter, loc, /*upperBound=*/nullptr);
FailureOr<AffineMap> offsets =
nvgpu::getLaneIdToLdMatrixMatrixCoord(rewriter, loc, *params);
if (failed(offsets)) {
@@ -757,8 +757,9 @@ creatLdMatrixCompatibleLoads(RewriterBase &rewriter, vector::TransferReadOp op,
getXferIndices<vector::TransferReadOp>(rewriter, op, *offsets, {laneId},
indices);
- nvgpu::LdMatrixOp newOp = rewriter.create<nvgpu::LdMatrixOp>(
- loc, vectorType, op.getBase(), indices, *transpose, params->numTiles);
+ nvgpu::LdMatrixOp newOp =
+ nvgpu::LdMatrixOp::create(rewriter, loc, vectorType, op.getBase(),
+ indices, *transpose, params->numTiles);
valueMapping[op] = newOp->getResult(0);
return success();
}
@@ -782,17 +783,17 @@ createNonLdMatrixLoads(RewriterBase &rewriter, vector::TransferReadOp op,
"conversion to distributed non-ldmatrix compatible load");
}
- Value laneId = rewriter.create<gpu::LaneIdOp>(loc, /*upperBound=*/nullptr);
+ Value laneId = gpu::LaneIdOp::create(rewriter, loc, /*upperBound=*/nullptr);
// This is the individual element type.
Type loadedElType = regInfo->registerLLVMType;
VectorType vectorType = getMmaSyncVectorOperandType(*regInfo);
- Value fill = rewriter.create<arith::ConstantOp>(
- op.getLoc(), vectorType.getElementType(),
+ Value fill = arith::ConstantOp::create(
+ rewriter, op.getLoc(), vectorType.getElementType(),
rewriter.getZeroAttr(vectorType.getElementType()));
Value result =
- rewriter.create<vector::BroadcastOp>(op.getLoc(), vectorType, fill);
+ vector::BroadcastOp::create(rewriter, op.getLoc(), vectorType, fill);
bool isTransposeLoad = !op.getPermutationMap().isMinorIdentity();
@@ -809,16 +810,16 @@ createNonLdMatrixLoads(RewriterBase &rewriter, vector::TransferReadOp op,
if (failed(coords))
return rewriter.notifyMatchFailure(op, "no coords");
- Value logicalValueId = rewriter.create<arith::ConstantOp>(
- loc, rewriter.getIndexType(),
+ Value logicalValueId = arith::ConstantOp::create(
+ rewriter, loc, rewriter.getIndexType(),
rewriter.getIndexAttr(i * regInfo->elementsPerRegister));
SmallVector<Value, 4> newIndices;
getXferIndices<vector::TransferReadOp>(
rewriter, op, *coords, {laneId, logicalValueId}, newIndices);
- Value el = rewriter.create<vector::LoadOp>(loc, loadedElType,
- op.getBase(), newIndices);
- result = rewriter.create<vector::InsertOp>(loc, el, result, i);
+ Value el = vector::LoadOp::create(rewriter, loc, loadedElType,
+ op.getBase(), newIndices);
+ result = vector::InsertOp::create(rewriter, loc, el, result, i);
}
} else {
if (auto vecType = dyn_cast<VectorType>(loadedElType)) {
@@ -828,8 +829,8 @@ createNonLdMatrixLoads(RewriterBase &rewriter, vector::TransferReadOp op,
for (unsigned innerIdx = 0; innerIdx < vectorType.getShape()[1];
innerIdx++) {
- Value logicalValueId = rewriter.create<arith::ConstantOp>(
- loc, rewriter.getIndexType(),
+ Value logicalValueId = arith::ConstantOp::create(
+ rewriter, loc, rewriter.getIndexType(),
rewriter.getIndexAttr(i * regInfo->elementsPerRegister + innerIdx));
FailureOr<AffineMap> coords = nvgpu::getLaneIdAndValueIdToOperandCoord(
rewriter, op.getLoc(), *warpMatrixInfo);
@@ -839,10 +840,10 @@ createNonLdMatrixLoads(RewriterBase &rewriter, vector::TransferReadOp op,
SmallVector<Value, 4> newIndices;
getXferIndices<vector::TransferReadOp>(
rewriter, op, *coords, {laneId, logicalValueId}, newIndices);
- Value el = rewriter.create<memref::LoadOp>(op.getLoc(), loadedElType,
- op.getBase(), newIndices);
- result = rewriter.create<vector::InsertOp>(
- op.getLoc(), el, result, ArrayRef<int64_t>{i, innerIdx});
+ Value el = memref::LoadOp::create(rewriter, op.getLoc(), loadedElType,
+ op.getBase(), newIndices);
+ result = vector::InsertOp::create(rewriter, op.getLoc(), el, result,
+ ArrayRef<int64_t>{i, innerIdx});
}
}
}
@@ -916,11 +917,11 @@ convertTransferWriteToStores(RewriterBase &rewriter, vector::TransferWriteOp op,
return rewriter.notifyMatchFailure(op, "not mma sync reg info");
VectorType vectorType = getMmaSyncVectorOperandType(*regInfo);
- Value laneId = rewriter.create<gpu::LaneIdOp>(loc, /*upperBound=*/nullptr);
+ Value laneId = gpu::LaneIdOp::create(rewriter, loc, /*upperBound=*/nullptr);
for (unsigned i = 0; i < vectorType.getShape()[0]; i++) {
- Value logicalValueId = rewriter.create<arith::ConstantOp>(
- loc, rewriter.getIndexType(),
+ Value logicalValueId = arith::ConstantOp::create(
+ rewriter, loc, rewriter.getIndexType(),
rewriter.getIndexAttr(i * regInfo->elementsPerRegister));
FailureOr<AffineMap> coords = nvgpu::getLaneIdAndValueIdToOperandCoord(
rewriter, op.getLoc(), *warpMatrixInfo);
@@ -928,11 +929,11 @@ convertTransferWriteToStores(RewriterBase &rewriter, vector::TransferWriteOp op,
return rewriter.notifyMatchFailure(op, "no coords");
Value el =
- rewriter.create<vector::ExtractOp>(loc, matrix, ArrayRef<int64_t>{i});
+ vector::ExtractOp::create(rewriter, loc, matrix, ArrayRef<int64_t>{i});
SmallVector<Value, 4> newIndices;
getXferIndices<vector::TransferWriteOp>(
rewriter, op, *coords, {laneId, logicalValueId}, newIndices);
- rewriter.create<vector::StoreOp>(loc, el, op.getBase(), newIndices);
+ vector::StoreOp::create(rewriter, loc, el, op.getBase(), newIndices);
}
LLVM_DEBUG(DBGS() << "erase: " << op << "\n");
@@ -1015,8 +1016,8 @@ convertExtractStridedSlice(RewriterBase &rewriter,
else if (offsets[1])
sliceOffset[0] = (warpVectorShape[1] / offsets[1]);
- Value newOp = rewriter.create<vector::ExtractStridedSliceOp>(
- loc, sourceVector, sliceOffset, sliceShape, strides);
+ Value newOp = vector::ExtractStridedSliceOp::create(
+ rewriter, loc, sourceVector, sliceOffset, sliceShape, strides);
valueMapping[op] = newOp;
return success();
@@ -1035,9 +1036,10 @@ convertContractOp(RewriterBase &rewriter, vector::ContractionOp op,
itC == valueMapping.end())
return rewriter.notifyMatchFailure(op, "no mapping");
Value opA = itA->second, opB = itB->second, opC = itC->second;
- Value matmul = rewriter.create<gpu::SubgroupMmaComputeOp>(
- op.getLoc(), opC.getType(), opA, opB, opC, /*a_transpose=*/UnitAttr(),
- /*b_transpose=*/UnitAttr());
+ Value matmul = gpu::SubgroupMmaComputeOp::create(rewriter, op.getLoc(),
+ opC.getType(), opA, opB, opC,
+ /*a_transpose=*/UnitAttr(),
+ /*b_transpose=*/UnitAttr());
valueMapping[op.getResult()] = matmul;
return success();
}
@@ -1058,8 +1060,8 @@ convertContractOpToMmaSync(RewriterBase &rewriter, vector::ContractionOp op,
int64_t m = cast<VectorType>(op.getLhs().getType()).getShape()[0];
int64_t n = cast<VectorType>(op.getRhs().getType()).getShape()[0];
int64_t k = cast<VectorType>(op.getLhs().getType()).getShape()[1];
- Value matmul = rewriter.create<nvgpu::MmaSyncOp>(
- op.getLoc(), opA, opB, opC, rewriter.getI64ArrayAttr({m, n, k}));
+ Value matmul = nvgpu::MmaSyncOp::create(rewriter, op.getLoc(), opA, opB, opC,
+ rewriter.getI64ArrayAttr({m, n, k}));
valueMapping[op.getResult()] = matmul;
return success();
}
@@ -1076,13 +1078,13 @@ convertConstantOp(RewriterBase &rewriter, arith::ConstantOp op,
auto splat =
cast<SplatElementsAttr>(op.getValue()).getSplatValue<TypedAttr>();
auto scalarConstant =
- rewriter.create<arith::ConstantOp>(op.getLoc(), splat.getType(), splat);
+ arith::ConstantOp::create(rewriter, op.getLoc(), splat.getType(), splat);
const char *fragType = inferFragType(op);
auto vecType = cast<VectorType>(op.getType());
gpu::MMAMatrixType type = gpu::MMAMatrixType::get(
vecType.getShape(), vecType.getElementType(), llvm::StringRef(fragType));
- auto matrix = rewriter.create<gpu::SubgroupMmaConstantMatrixOp>(
- op.getLoc(), type, scalarConstant);
+ auto matrix = gpu::SubgroupMmaConstantMatrixOp::create(rewriter, op.getLoc(),
+ type, scalarConstant);
valueMapping[op.getResult()] = matrix;
return success();
}
@@ -1100,8 +1102,8 @@ convertBroadcastOp(RewriterBase &rewriter, vector::BroadcastOp op,
auto vecType = op.getResultVectorType();
gpu::MMAMatrixType type = gpu::MMAMatrixType::get(
vecType.getShape(), vecType.getElementType(), llvm::StringRef(fragType));
- auto matrix = rewriter.create<gpu::SubgroupMmaConstantMatrixOp>(
- op.getLoc(), type, op.getSource());
+ auto matrix = gpu::SubgroupMmaConstantMatrixOp::create(rewriter, op.getLoc(),
+ type, op.getSource());
valueMapping[op.getResult()] = matrix;
return success();
}
@@ -1118,9 +1120,9 @@ static scf::ForOp replaceForOpWithNewSignature(RewriterBase &rewriter,
rewriter.setInsertionPoint(loop);
auto operands = llvm::to_vector<4>(loop.getInitArgs());
llvm::append_range(operands, newInitArgs);
- scf::ForOp newLoop = rewriter.create<scf::ForOp>(
- loop.getLoc(), loop.getLowerBound(), loop.getUpperBound(), loop.getStep(),
- operands);
+ scf::ForOp newLoop =
+ scf::ForOp::create(rewriter, loop.getLoc(), loop.getLowerBound(),
+ loop.getUpperBound(), loop.getStep(), operands);
rewriter.eraseBlock(newLoop.getBody());
newLoop.getRegion().getBlocks().splice(
@@ -1189,7 +1191,7 @@ convertYieldOp(RewriterBase &rewriter, scf::YieldOp op,
yieldOperands[operand.index()] = loop.getInitArgs()[operand.index()];
yieldOperands.push_back(it->second);
}
- rewriter.create<scf::YieldOp>(op.getLoc(), yieldOperands);
+ scf::YieldOp::create(rewriter, op.getLoc(), yieldOperands);
LLVM_DEBUG(DBGS() << "erase: " << op << "\n");
rewriter.eraseOp(op);
@@ -1220,8 +1222,8 @@ convertElementwiseOp(RewriterBase &rewriter, Operation *op,
resultType.getOperand());
}
- Value newOp = rewriter.create<gpu::SubgroupMmaElementwiseOp>(
- op->getLoc(), resultType, matrixOperands, opType);
+ Value newOp = gpu::SubgroupMmaElementwiseOp::create(
+ rewriter, op->getLoc(), resultType, matrixOperands, opType);
valueMapping[op->getResult(0)] = newOp;
return success();
}
diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
index e4ff770a807c6..9cd491caa9421 100644
--- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
+++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
@@ -43,13 +43,13 @@ static Value insertOne(ConversionPatternRewriter &rewriter,
assert(rank > 0 && "0-D vector corner case should have been handled already");
if (rank == 1) {
auto idxType = rewriter.getIndexType();
- auto constant = rewriter.create<LLVM::ConstantOp>(
- loc, typeConverter.convertType(idxType),
+ auto constant = LLVM::ConstantOp::create(
+ rewriter, loc, typeConverter.convertType(idxType),
rewriter.getIntegerAttr(idxType, pos));
- return rewriter.create<LLVM::InsertElementOp>(loc, llvmType, val1, val2,
- constant);
+ return LLVM::InsertElementOp::create(rewriter, loc, llvmType, val1, val2,
+ constant);
}
- return rewriter.create<LLVM::InsertValueOp>(loc, val1, val2, pos);
+ return LLVM::InsertValueOp::create(rewriter, loc, val1, val2, pos);
}
// Helper that picks the proper sequence for extracting.
@@ -58,13 +58,13 @@ static Value extractOne(ConversionPatternRewriter &rewriter,
Value val, Type llvmType, int64_t rank, int64_t pos) {
if (rank <= 1) {
auto idxType = rewriter.getIndexType();
- auto constant = rewriter.create<LLVM::ConstantOp>(
- loc, typeConverter.convertType(idxType),
+ auto constant = LLVM::ConstantOp::create(
+ rewriter, loc, typeConverter.convertType(idxType),
rewriter.getIntegerAttr(idxType, pos));
- return rewriter.create<LLVM::ExtractElementOp>(loc, llvmType, val,
- constant);
+ return LLVM::ExtractElementOp::create(rewriter, loc, llvmType, val,
+ constant);
}
- return rewriter.create<LLVM::ExtractValueOp>(loc, val, pos);
+ return LLVM::ExtractValueOp::create(rewriter, loc, val, pos);
}
// Helper that returns data layout alignment of a vector.
@@ -141,9 +141,9 @@ static Value getIndexedPtrs(ConversionPatternRewriter &rewriter, Location loc,
auto ptrsType =
LLVM::getVectorType(pType, vectorType.getDimSize(0),
/*isScalable=*/vectorType.getScalableDims()[0]);
- return rewriter.create<LLVM::GEPOp>(
- loc, ptrsType, typeConverter.convertType(memRefType.getElementType()),
- base, index);
+ return LLVM::GEPOp::create(
+ rewriter, loc, ptrsType,
+ typeConverter.convertType(memRefType.getElementType()), base, index);
}
/// Convert `foldResult` into a Value. Integer attribute is converted to
@@ -152,7 +152,7 @@ static Value getAsLLVMValue(OpBuilder &builder, Location loc,
OpFoldResult foldResult) {
if (auto attr = dyn_cast<Attribute>(foldResult)) {
auto intAttr = cast<IntegerAttr>(attr);
- return builder.create<LLVM::ConstantOp>(loc, intAttr).getResult();
+ return LLVM::ConstantOp::create(builder, loc, intAttr).getResult();
}
return cast<Value>(foldResult);
@@ -440,32 +440,32 @@ class ReductionNeutralFPMax {};
static Value createReductionNeutralValue(ReductionNeutralZero neutral,
ConversionPatternRewriter &rewriter,
Location loc, Type llvmType) {
- return rewriter.create<LLVM::ConstantOp>(loc, llvmType,
- rewriter.getZeroAttr(llvmType));
+ return LLVM::ConstantOp::create(rewriter, loc, llvmType,
+ rewriter.getZeroAttr(llvmType));
}
/// Create the reduction neutral integer one value.
static Value createReductionNeutralValue(ReductionNeutralIntOne neutral,
ConversionPatternRewriter &rewriter,
Location loc, Type llvmType) {
- return rewriter.create<LLVM::ConstantOp>(
- loc, llvmType, rewriter.getIntegerAttr(llvmType, 1));
+ return LLVM::ConstantOp::create(rewriter, loc, llvmType,
+ rewriter.getIntegerAttr(llvmType, 1));
}
/// Create the reduction neutral fp one value.
static Value createReductionNeutralValue(ReductionNeutralFPOne neutral,
ConversionPatternRewriter &rewriter,
Location loc, Type llvmType) {
- return rewriter.create<LLVM::ConstantOp>(
- loc, llvmType, rewriter.getFloatAttr(llvmType, 1.0));
+ return LLVM::ConstantOp::create(rewriter, loc, llvmType,
+ rewriter.getFloatAttr(llvmType, 1.0));
}
/// Create the reduction neutral all-ones value.
static Value createReductionNeutralValue(ReductionNeutralAllOnes neutral,
ConversionPatternRewriter &rewriter,
Location loc, Type llvmType) {
- return rewriter.create<LLVM::ConstantOp>(
- loc, llvmType,
+ return LLVM::ConstantOp::create(
+ rewriter, loc, llvmType,
rewriter.getIntegerAttr(
llvmType, llvm::APInt::getAllOnes(llvmType.getIntOrFloatBitWidth())));
}
@@ -474,8 +474,8 @@ static Value createReductionNeutralValue(ReductionNeutralAllOnes neutral,
static Value createReductionNeutralValue(ReductionNeutralSIntMin neutral,
ConversionPatternRewriter &rewriter,
Location loc, Type llvmType) {
- return rewriter.create<LLVM::ConstantOp>(
- loc, llvmType,
+ return LLVM::ConstantOp::create(
+ rewriter, loc, llvmType,
rewriter.getIntegerAttr(llvmType, llvm::APInt::getSignedMinValue(
llvmType.getIntOrFloatBitWidth())));
}
@@ -484,8 +484,8 @@ static Value createReductionNeutralValue(ReductionNeutralSIntMin neutral,
static Value createReductionNeutralValue(ReductionNeutralUIntMin neutral,
ConversionPatternRewriter &rewriter,
Location loc, Type llvmType) {
- return rewriter.create<LLVM::ConstantOp>(
- loc, llvmType,
+ return LLVM::ConstantOp::create(
+ rewriter, loc, llvmType,
rewriter.getIntegerAttr(llvmType, llvm::APInt::getMinValue(
llvmType.getIntOrFloatBitWidth())));
}
@@ -494,8 +494,8 @@ static Value createReductionNeutralValue(ReductionNeutralUIntMin neutral,
static Value createReductionNeutralValue(ReductionNeutralSIntMax neutral,
ConversionPatternRewriter &rewriter,
Location loc, Type llvmType) {
- return rewriter.create<LLVM::ConstantOp>(
- loc, llvmType,
+ return LLVM::ConstantOp::create(
+ rewriter, loc, llvmType,
rewriter.getIntegerAttr(llvmType, llvm::APInt::getSignedMaxValue(
llvmType.getIntOrFloatBitWidth())));
}
@@ -504,8 +504,8 @@ static Value createReductionNeutralValue(ReductionNeutralSIntMax neutral,
static Value createReductionNeutralValue(ReductionNeutralUIntMax neutral,
ConversionPatternRewriter &rewriter,
Location loc, Type llvmType) {
- return rewriter.create<LLVM::ConstantOp>(
- loc, llvmType,
+ return LLVM::ConstantOp::create(
+ rewriter, loc, llvmType,
rewriter.getIntegerAttr(llvmType, llvm::APInt::getMaxValue(
llvmType.getIntOrFloatBitWidth())));
}
@@ -515,8 +515,8 @@ static Value createReductionNeutralValue(ReductionNeutralFPMin neutral,
ConversionPatternRewriter &rewriter,
Location loc, Type llvmType) {
auto floatType = cast<FloatType>(llvmType);
- return rewriter.create<LLVM::ConstantOp>(
- loc, llvmType,
+ return LLVM::ConstantOp::create(
+ rewriter, loc, llvmType,
rewriter.getFloatAttr(
llvmType, llvm::APFloat::getQNaN(floatType.getFloatSemantics(),
/*Negative=*/false)));
@@ -527,8 +527,8 @@ static Value createReductionNeutralValue(ReductionNeutralFPMax neutral,
ConversionPatternRewriter &rewriter,
Location loc, Type llvmType) {
auto floatType = cast<FloatType>(llvmType);
- return rewriter.create<LLVM::ConstantOp>(
- loc, llvmType,
+ return LLVM::ConstantOp::create(
+ rewriter, loc, llvmType,
rewriter.getFloatAttr(
llvmType, llvm::APFloat::getQNaN(floatType.getFloatSemantics(),
/*Negative=*/true)));
@@ -556,19 +556,19 @@ static Value createVectorLengthValue(ConversionPatternRewriter &rewriter,
auto vShape = vType.getShape();
assert(vShape.size() == 1 && "Unexpected multi-dim vector type");
- Value baseVecLength = rewriter.create<LLVM::ConstantOp>(
- loc, rewriter.getI32Type(),
+ Value baseVecLength = LLVM::ConstantOp::create(
+ rewriter, loc, rewriter.getI32Type(),
rewriter.getIntegerAttr(rewriter.getI32Type(), vShape[0]));
if (!vType.getScalableDims()[0])
return baseVecLength;
// For a scalable vector type, create and return `vScale * baseVecLength`.
- Value vScale = rewriter.create<vector::VectorScaleOp>(loc);
+ Value vScale = vector::VectorScaleOp::create(rewriter, loc);
vScale =
- rewriter.create<arith::IndexCastOp>(loc, rewriter.getI32Type(), vScale);
+ arith::IndexCastOp::create(rewriter, loc, rewriter.getI32Type(), vScale);
Value scalableVecLength =
- rewriter.create<arith::MulIOp>(loc, baseVecLength, vScale);
+ arith::MulIOp::create(rewriter, loc, baseVecLength, vScale);
return scalableVecLength;
}
@@ -581,10 +581,11 @@ static Value createIntegerReductionArithmeticOpLowering(
ConversionPatternRewriter &rewriter, Location loc, Type llvmType,
Value vectorOperand, Value accumulator) {
- Value result = rewriter.create<LLVMRedIntrinOp>(loc, llvmType, vectorOperand);
+ Value result =
+ LLVMRedIntrinOp::create(rewriter, loc, llvmType, vectorOperand);
if (accumulator)
- result = rewriter.create<ScalarOp>(loc, accumulator, result);
+ result = ScalarOp::create(rewriter, loc, accumulator, result);
return result;
}
@@ -596,11 +597,12 @@ template <class LLVMRedIntrinOp>
static Value createIntegerReductionComparisonOpLowering(
ConversionPatternRewriter &rewriter, Location loc, Type llvmType,
Value vectorOperand, Value accumulator, LLVM::ICmpPredicate predicate) {
- Value result = rewriter.create<LLVMRedIntrinOp>(loc, llvmType, vectorOperand);
+ Value result =
+ LLVMRedIntrinOp::create(rewriter, loc, llvmType, vectorOperand);
if (accumulator) {
Value cmp =
- rewriter.create<LLVM::ICmpOp>(loc, predicate, accumulator, result);
- result = rewriter.create<LLVM::SelectOp>(loc, cmp, accumulator, result);
+ LLVM::ICmpOp::create(rewriter, loc, predicate, accumulator, result);
+ result = LLVM::SelectOp::create(rewriter, loc, cmp, accumulator, result);
}
return result;
}
@@ -631,12 +633,11 @@ static Value createFPReductionComparisonOpLowering(
ConversionPatternRewriter &rewriter, Location loc, Type llvmType,
Value vectorOperand, Value accumulator, LLVM::FastmathFlagsAttr fmf) {
Value result =
- rewriter.create<LLVMRedIntrinOp>(loc, llvmType, vectorOperand, fmf);
+ LLVMRedIntrinOp::create(rewriter, loc, llvmType, vectorOperand, fmf);
if (accumulator) {
- result =
- rewriter.create<typename VectorToScalarMapper<LLVMRedIntrinOp>::Type>(
- loc, result, accumulator);
+ result = VectorToScalarMapper<LLVMRedIntrinOp>::Type::create(
+ rewriter, loc, result, accumulator);
}
return result;
@@ -667,7 +668,7 @@ static Value createMaskNeutralValue(ConversionPatternRewriter &rewriter,
const auto &floatSemantics = cast<FloatType>(llvmType).getFloatSemantics();
auto value = getMaskNeutralValue(MaskNeutral{}, floatSemantics);
auto denseValue = DenseElementsAttr::get(cast<ShapedType>(vectorType), value);
- return rewriter.create<LLVM::ConstantOp>(loc, vectorType, denseValue);
+ return LLVM::ConstantOp::create(rewriter, loc, vectorType, denseValue);
}
/// Lowers masked `fmaximum` and `fminimum` reductions using the non-masked
@@ -682,8 +683,8 @@ lowerMaskedReductionWithRegular(ConversionPatternRewriter &rewriter,
Value mask, LLVM::FastmathFlagsAttr fmf) {
const Value vectorMaskNeutral = createMaskNeutralValue<MaskNeutral>(
rewriter, loc, llvmType, vectorOperand.getType());
- const Value selectedVectorByMask = rewriter.create<LLVM::SelectOp>(
- loc, mask, vectorOperand, vectorMaskNeutral);
+ const Value selectedVectorByMask = LLVM::SelectOp::create(
+ rewriter, loc, mask, vectorOperand, vectorMaskNeutral);
return createFPReductionComparisonOpLowering<LLVMRedIntrinOp>(
rewriter, loc, llvmType, selectedVectorByMask, accumulator, fmf);
}
@@ -695,9 +696,9 @@ lowerReductionWithStartValue(ConversionPatternRewriter &rewriter, Location loc,
Value accumulator, LLVM::FastmathFlagsAttr fmf) {
accumulator = getOrCreateAccumulator<ReductionNeutral>(rewriter, loc,
llvmType, accumulator);
- return rewriter.create<LLVMRedIntrinOp>(loc, llvmType,
- /*startValue=*/accumulator,
- vectorOperand, fmf);
+ return LLVMRedIntrinOp::create(rewriter, loc, llvmType,
+ /*startValue=*/accumulator, vectorOperand,
+ fmf);
}
/// Overloaded methods to lower a *predicated* reduction to an llvm intrinsic
@@ -710,9 +711,8 @@ lowerPredicatedReductionWithStartValue(ConversionPatternRewriter &rewriter,
Value vectorOperand, Value accumulator) {
accumulator = getOrCreateAccumulator<ReductionNeutral>(rewriter, loc,
llvmType, accumulator);
- return rewriter.create<LLVMVPRedIntrinOp>(loc, llvmType,
- /*startValue=*/accumulator,
- vectorOperand);
+ return LLVMVPRedIntrinOp::create(rewriter, loc, llvmType,
+ /*startValue=*/accumulator, vectorOperand);
}
template <class LLVMVPRedIntrinOp, class ReductionNeutral>
@@ -723,9 +723,9 @@ static Value lowerPredicatedReductionWithStartValue(
llvmType, accumulator);
Value vectorLength =
createVectorLengthValue(rewriter, loc, vectorOperand.getType());
- return rewriter.create<LLVMVPRedIntrinOp>(loc, llvmType,
- /*startValue=*/accumulator,
- vectorOperand, mask, vectorLength);
+ return LLVMVPRedIntrinOp::create(rewriter, loc, llvmType,
+ /*startValue=*/accumulator, vectorOperand,
+ mask, vectorLength);
}
template <class LLVMIntVPRedIntrinOp, class IntReductionNeutral,
@@ -1036,8 +1036,8 @@ class VectorShuffleOpConversion
// For rank 0 and 1, where both operands have *exactly* the same vector
// type, there is direct shuffle support in LLVM. Use it!
if (rank <= 1 && v1Type == v2Type) {
- Value llvmShuffleOp = rewriter.create<LLVM::ShuffleVectorOp>(
- loc, adaptor.getV1(), adaptor.getV2(),
+ Value llvmShuffleOp = LLVM::ShuffleVectorOp::create(
+ rewriter, loc, adaptor.getV1(), adaptor.getV2(),
llvm::to_vector_of<int32_t>(mask));
rewriter.replaceOp(shuffleOp, llvmShuffleOp);
return success();
@@ -1050,7 +1050,7 @@ class VectorShuffleOpConversion
eltType = arrayType.getElementType();
else
eltType = cast<VectorType>(llvmType).getElementType();
- Value insert = rewriter.create<LLVM::PoisonOp>(loc, llvmType);
+ Value insert = LLVM::PoisonOp::create(rewriter, loc, llvmType);
int64_t insPos = 0;
for (int64_t extPos : mask) {
Value value = adaptor.getV1();
@@ -1087,9 +1087,9 @@ class VectorExtractElementOpConversion
if (vectorType.getRank() == 0) {
Location loc = extractEltOp.getLoc();
auto idxType = rewriter.getIndexType();
- auto zero = rewriter.create<LLVM::ConstantOp>(
- loc, typeConverter->convertType(idxType),
- rewriter.getIntegerAttr(idxType, 0));
+ auto zero = LLVM::ConstantOp::create(rewriter, loc,
+ typeConverter->convertType(idxType),
+ rewriter.getIntegerAttr(idxType, 0));
rewriter.replaceOpWithNewOp<LLVM::ExtractElementOp>(
extractEltOp, llvmType, adaptor.getVector(), zero);
return success();
@@ -1158,13 +1158,14 @@ class VectorExtractOpConversion
if (!llvm::all_of(position, llvm::IsaPred<Attribute>)) {
return failure();
}
- extracted = rewriter.create<LLVM::ExtractValueOp>(
- loc, extracted, getAsIntegers(position));
+ extracted = LLVM::ExtractValueOp::create(rewriter, loc, extracted,
+ getAsIntegers(position));
}
if (extractsScalar) {
- extracted = rewriter.create<LLVM::ExtractElementOp>(
- loc, extracted, getAsLLVMValue(rewriter, loc, positionVec.back()));
+ extracted = LLVM::ExtractElementOp::create(
+ rewriter, loc, extracted,
+ getAsLLVMValue(rewriter, loc, positionVec.back()));
}
rewriter.replaceOp(extractOp, extracted);
@@ -1221,9 +1222,9 @@ class VectorInsertElementOpConversion
if (vectorType.getRank() == 0) {
Location loc = insertEltOp.getLoc();
auto idxType = rewriter.getIndexType();
- auto zero = rewriter.create<LLVM::ConstantOp>(
- loc, typeConverter->convertType(idxType),
- rewriter.getIntegerAttr(idxType, 0));
+ auto zero = LLVM::ConstantOp::create(rewriter, loc,
+ typeConverter->convertType(idxType),
+ rewriter.getIntegerAttr(idxType, 0));
rewriter.replaceOpWithNewOp<LLVM::InsertElementOp>(
insertEltOp, llvmType, adaptor.getDest(), adaptor.getSource(), zero);
return success();
@@ -1307,8 +1308,8 @@ class VectorInsertOpConversion
// llvm.extractvalue does not support dynamic dimensions.
return failure();
}
- sourceAggregate = rewriter.create<LLVM::ExtractValueOp>(
- loc, adaptor.getDest(),
+ sourceAggregate = LLVM::ExtractValueOp::create(
+ rewriter, loc, adaptor.getDest(),
getAsIntegers(positionOf1DVectorWithinAggregate));
} else {
// No-aggregate case. The destination for the InsertElementOp is just
@@ -1316,16 +1317,16 @@ class VectorInsertOpConversion
sourceAggregate = adaptor.getDest();
}
// Insert the scalar into the 1D vector.
- sourceAggregate = rewriter.create<LLVM::InsertElementOp>(
- loc, sourceAggregate.getType(), sourceAggregate,
+ sourceAggregate = LLVM::InsertElementOp::create(
+ rewriter, loc, sourceAggregate.getType(), sourceAggregate,
adaptor.getValueToStore(),
getAsLLVMValue(rewriter, loc, positionOfScalarWithin1DVector));
}
Value result = sourceAggregate;
if (isNestedAggregate) {
- result = rewriter.create<LLVM::InsertValueOp>(
- loc, adaptor.getDest(), sourceAggregate,
+ result = LLVM::InsertValueOp::create(
+ rewriter, loc, adaptor.getDest(), sourceAggregate,
getAsIntegers(positionOf1DVectorWithinAggregate));
}
@@ -1404,15 +1405,15 @@ class VectorFMAOpNDRewritePattern : public OpRewritePattern<FMAOp> {
auto loc = op.getLoc();
auto elemType = vType.getElementType();
- Value zero = rewriter.create<arith::ConstantOp>(
- loc, elemType, rewriter.getZeroAttr(elemType));
- Value desc = rewriter.create<vector::BroadcastOp>(loc, vType, zero);
+ Value zero = arith::ConstantOp::create(rewriter, loc, elemType,
+ rewriter.getZeroAttr(elemType));
+ Value desc = vector::BroadcastOp::create(rewriter, loc, vType, zero);
for (int64_t i = 0, e = vType.getShape().front(); i != e; ++i) {
- Value extrLHS = rewriter.create<ExtractOp>(loc, op.getLhs(), i);
- Value extrRHS = rewriter.create<ExtractOp>(loc, op.getRhs(), i);
- Value extrACC = rewriter.create<ExtractOp>(loc, op.getAcc(), i);
- Value fma = rewriter.create<FMAOp>(loc, extrLHS, extrRHS, extrACC);
- desc = rewriter.create<InsertOp>(loc, fma, desc, i);
+ Value extrLHS = ExtractOp::create(rewriter, loc, op.getLhs(), i);
+ Value extrRHS = ExtractOp::create(rewriter, loc, op.getRhs(), i);
+ Value extrACC = ExtractOp::create(rewriter, loc, op.getAcc(), i);
+ Value fma = FMAOp::create(rewriter, loc, extrLHS, extrRHS, extrACC);
+ desc = InsertOp::create(rewriter, loc, fma, desc, i);
}
rewriter.replaceOp(op, desc);
return success();
@@ -1502,7 +1503,7 @@ class VectorTypeCastOpConversion
desc.setAlignedPtr(rewriter, loc, ptr);
// Fill offset 0.
auto attr = rewriter.getIntegerAttr(rewriter.getIndexType(), 0);
- auto zero = rewriter.create<LLVM::ConstantOp>(loc, int64Ty, attr);
+ auto zero = LLVM::ConstantOp::create(rewriter, loc, int64Ty, attr);
desc.setOffset(rewriter, loc, zero);
// Fill size and stride descriptors in memref.
@@ -1511,11 +1512,12 @@ class VectorTypeCastOpConversion
int64_t index = indexedSize.index();
auto sizeAttr =
rewriter.getIntegerAttr(rewriter.getIndexType(), indexedSize.value());
- auto size = rewriter.create<LLVM::ConstantOp>(loc, int64Ty, sizeAttr);
+ auto size = LLVM::ConstantOp::create(rewriter, loc, int64Ty, sizeAttr);
desc.setSize(rewriter, loc, index, size);
auto strideAttr = rewriter.getIntegerAttr(rewriter.getIndexType(),
(*targetStrides)[index]);
- auto stride = rewriter.create<LLVM::ConstantOp>(loc, int64Ty, strideAttr);
+ auto stride =
+ LLVM::ConstantOp::create(rewriter, loc, int64Ty, strideAttr);
desc.setStride(rewriter, loc, index, stride);
}
@@ -1543,14 +1545,15 @@ class VectorCreateMaskOpConversion
IntegerType idxType =
force32BitVectorIndices ? rewriter.getI32Type() : rewriter.getI64Type();
auto loc = op->getLoc();
- Value indices = rewriter.create<LLVM::StepVectorOp>(
- loc, LLVM::getVectorType(idxType, dstType.getShape()[0],
- /*isScalable=*/true));
+ Value indices = LLVM::StepVectorOp::create(
+ rewriter, loc,
+ LLVM::getVectorType(idxType, dstType.getShape()[0],
+ /*isScalable=*/true));
auto bound = getValueOrCreateCastToIndexLike(rewriter, loc, idxType,
adaptor.getOperands()[0]);
- Value bounds = rewriter.create<BroadcastOp>(loc, indices.getType(), bound);
- Value comp = rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::slt,
- indices, bounds);
+ Value bounds = BroadcastOp::create(rewriter, loc, indices.getType(), bound);
+ Value comp = arith::CmpIOp::create(rewriter, loc, arith::CmpIPredicate::slt,
+ indices, bounds);
rewriter.replaceOp(op, comp);
return success();
}
@@ -1706,16 +1709,16 @@ class VectorPrintOpConversion : public ConvertOpToLLVMPattern<vector::PrintOp> {
switch (conversion) {
case PrintConversion::ZeroExt64:
- value = rewriter.create<arith::ExtUIOp>(
- loc, IntegerType::get(rewriter.getContext(), 64), value);
+ value = arith::ExtUIOp::create(
+ rewriter, loc, IntegerType::get(rewriter.getContext(), 64), value);
break;
case PrintConversion::SignExt64:
- value = rewriter.create<arith::ExtSIOp>(
- loc, IntegerType::get(rewriter.getContext(), 64), value);
+ value = arith::ExtSIOp::create(
+ rewriter, loc, IntegerType::get(rewriter.getContext(), 64), value);
break;
case PrintConversion::Bitcast16:
- value = rewriter.create<LLVM::BitcastOp>(
- loc, IntegerType::get(rewriter.getContext(), 16), value);
+ value = LLVM::BitcastOp::create(
+ rewriter, loc, IntegerType::get(rewriter.getContext(), 16), value);
break;
case PrintConversion::None:
break;
@@ -1727,8 +1730,8 @@ class VectorPrintOpConversion : public ConvertOpToLLVMPattern<vector::PrintOp> {
// Helper to emit a call.
static void emitCall(ConversionPatternRewriter &rewriter, Location loc,
Operation *ref, ValueRange params = ValueRange()) {
- rewriter.create<LLVM::CallOp>(loc, TypeRange(), SymbolRefAttr::get(ref),
- params);
+ LLVM::CallOp::create(rewriter, loc, TypeRange(), SymbolRefAttr::get(ref),
+ params);
}
};
@@ -1754,9 +1757,9 @@ struct VectorBroadcastScalarToLowRankLowering
// First insert it into a poison vector so we can shuffle it.
auto vectorType = typeConverter->convertType(broadcast.getType());
Value poison =
- rewriter.create<LLVM::PoisonOp>(broadcast.getLoc(), vectorType);
- auto zero = rewriter.create<LLVM::ConstantOp>(
- broadcast.getLoc(),
+ LLVM::PoisonOp::create(rewriter, broadcast.getLoc(), vectorType);
+ auto zero = LLVM::ConstantOp::create(
+ rewriter, broadcast.getLoc(),
typeConverter->convertType(rewriter.getIntegerType(32)),
rewriter.getZeroAttr(rewriter.getIntegerType(32)));
@@ -1768,8 +1771,9 @@ struct VectorBroadcastScalarToLowRankLowering
}
// For 1-d vector, we additionally do a `vectorshuffle`.
- auto v = rewriter.create<LLVM::InsertElementOp>(
- broadcast.getLoc(), vectorType, poison, adaptor.getSource(), zero);
+ auto v =
+ LLVM::InsertElementOp::create(rewriter, broadcast.getLoc(), vectorType,
+ poison, adaptor.getSource(), zero);
int64_t width = cast<VectorType>(broadcast.getType()).getDimSize(0);
SmallVector<int32_t> zeroValues(width, 0);
@@ -1811,26 +1815,26 @@ struct VectorBroadcastScalarToNdLowering
return failure();
// Construct returned value.
- Value desc = rewriter.create<LLVM::PoisonOp>(loc, llvmNDVectorTy);
+ Value desc = LLVM::PoisonOp::create(rewriter, loc, llvmNDVectorTy);
// Construct a 1-D vector with the broadcasted value that we insert in all
// the places within the returned descriptor.
- Value vdesc = rewriter.create<LLVM::PoisonOp>(loc, llvm1DVectorTy);
- auto zero = rewriter.create<LLVM::ConstantOp>(
- loc, typeConverter->convertType(rewriter.getIntegerType(32)),
+ Value vdesc = LLVM::PoisonOp::create(rewriter, loc, llvm1DVectorTy);
+ auto zero = LLVM::ConstantOp::create(
+ rewriter, loc, typeConverter->convertType(rewriter.getIntegerType(32)),
rewriter.getZeroAttr(rewriter.getIntegerType(32)));
- Value v = rewriter.create<LLVM::InsertElementOp>(loc, llvm1DVectorTy, vdesc,
- adaptor.getSource(), zero);
+ Value v = LLVM::InsertElementOp::create(rewriter, loc, llvm1DVectorTy,
+ vdesc, adaptor.getSource(), zero);
// Shuffle the value across the desired number of elements.
int64_t width = resultType.getDimSize(resultType.getRank() - 1);
SmallVector<int32_t> zeroValues(width, 0);
- v = rewriter.create<LLVM::ShuffleVectorOp>(loc, v, v, zeroValues);
+ v = LLVM::ShuffleVectorOp::create(rewriter, loc, v, v, zeroValues);
// Iterate of linear index, convert to coords space and insert broadcasted
// 1-D vector in each position.
nDVectorIterate(vectorTypeInfo, rewriter, [&](ArrayRef<int64_t> position) {
- desc = rewriter.create<LLVM::InsertValueOp>(loc, desc, v, position);
+ desc = LLVM::InsertValueOp::create(rewriter, loc, desc, v, position);
});
rewriter.replaceOp(broadcast, desc);
return success();
@@ -1900,13 +1904,13 @@ struct VectorDeinterleaveOpLowering
auto deinterleaveResults = deinterleaveOp.getResultTypes();
auto packedOpResults =
llvmTypeConverter->packOperationResults(deinterleaveResults);
- auto intrinsic = rewriter.create<LLVM::vector_deinterleave2>(
- loc, packedOpResults, adaptor.getSource());
+ auto intrinsic = LLVM::vector_deinterleave2::create(
+ rewriter, loc, packedOpResults, adaptor.getSource());
- auto evenResult = rewriter.create<LLVM::ExtractValueOp>(
- loc, intrinsic->getResult(0), 0);
- auto oddResult = rewriter.create<LLVM::ExtractValueOp>(
- loc, intrinsic->getResult(0), 1);
+ auto evenResult = LLVM::ExtractValueOp::create(
+ rewriter, loc, intrinsic->getResult(0), 0);
+ auto oddResult = LLVM::ExtractValueOp::create(rewriter, loc,
+ intrinsic->getResult(0), 1);
rewriter.replaceOp(deinterleaveOp, ValueRange{evenResult, oddResult});
return success();
@@ -1929,11 +1933,11 @@ struct VectorDeinterleaveOpLowering
oddShuffleMask.push_back(i);
}
- auto poison = rewriter.create<LLVM::PoisonOp>(loc, sourceType);
- auto evenShuffle = rewriter.create<LLVM::ShuffleVectorOp>(
- loc, adaptor.getSource(), poison, evenShuffleMask);
- auto oddShuffle = rewriter.create<LLVM::ShuffleVectorOp>(
- loc, adaptor.getSource(), poison, oddShuffleMask);
+ auto poison = LLVM::PoisonOp::create(rewriter, loc, sourceType);
+ auto evenShuffle = LLVM::ShuffleVectorOp::create(
+ rewriter, loc, adaptor.getSource(), poison, evenShuffleMask);
+ auto oddShuffle = LLVM::ShuffleVectorOp::create(
+ rewriter, loc, adaptor.getSource(), poison, oddShuffleMask);
rewriter.replaceOp(deinterleaveOp, ValueRange{evenShuffle, oddShuffle});
return success();
@@ -1956,9 +1960,9 @@ struct VectorFromElementsLowering
return rewriter.notifyMatchFailure(fromElementsOp,
"rank > 1 vectors are not supported");
Type llvmType = typeConverter->convertType(vectorType);
- Value result = rewriter.create<LLVM::PoisonOp>(loc, llvmType);
+ Value result = LLVM::PoisonOp::create(rewriter, loc, llvmType);
for (auto [idx, val] : llvm::enumerate(adaptor.getElements()))
- result = rewriter.create<vector::InsertOp>(loc, val, result, idx);
+ result = vector::InsertOp::create(rewriter, loc, val, result, idx);
rewriter.replaceOp(fromElementsOp, result);
return success();
}
@@ -1982,12 +1986,12 @@ struct VectorToElementsLowering
if (element.use_empty())
continue;
- auto constIdx = rewriter.create<LLVM::ConstantOp>(
- loc, idxType, rewriter.getIntegerAttr(idxType, idx));
+ auto constIdx = LLVM::ConstantOp::create(
+ rewriter, loc, idxType, rewriter.getIntegerAttr(idxType, idx));
auto llvmType = typeConverter->convertType(element.getType());
- Value result = rewriter.create<LLVM::ExtractElementOp>(loc, llvmType,
- source, constIdx);
+ Value result = LLVM::ExtractElementOp::create(rewriter, loc, llvmType,
+ source, constIdx);
results[idx] = result;
}
@@ -2098,7 +2102,7 @@ FailureOr<Value> ContractionOpToMatmulOpLowering::matchAndRewriteMaskableOp(
Value lhs = op.getLhs();
auto lhsMap = op.getIndexingMapsArray()[0];
if (lhsMap == AffineMap::get(3, 0, {k, m}, ctx))
- lhs = rew.create<vector::TransposeOp>(loc, lhs, ArrayRef<int64_t>{1, 0});
+ lhs = vector::TransposeOp::create(rew, loc, lhs, ArrayRef<int64_t>{1, 0});
else if (lhsMap != AffineMap::get(3, 0, {m, k}, ctx))
return failure();
@@ -2106,7 +2110,7 @@ FailureOr<Value> ContractionOpToMatmulOpLowering::matchAndRewriteMaskableOp(
Value rhs = op.getRhs();
auto rhsMap = op.getIndexingMapsArray()[1];
if (rhsMap == AffineMap::get(3, 0, {n, k}, ctx))
- rhs = rew.create<vector::TransposeOp>(loc, rhs, ArrayRef<int64_t>{1, 0});
+ rhs = vector::TransposeOp::create(rew, loc, rhs, ArrayRef<int64_t>{1, 0});
else if (rhsMap != AffineMap::get(3, 0, {k, n}, ctx))
return failure();
@@ -2119,20 +2123,20 @@ FailureOr<Value> ContractionOpToMatmulOpLowering::matchAndRewriteMaskableOp(
Type flattenedLHSType =
VectorType::get(lhsType.getNumElements(), lhsType.getElementType());
- lhs = rew.create<vector::ShapeCastOp>(loc, flattenedLHSType, lhs);
+ lhs = vector::ShapeCastOp::create(rew, loc, flattenedLHSType, lhs);
Type flattenedRHSType =
VectorType::get(rhsType.getNumElements(), rhsType.getElementType());
- rhs = rew.create<vector::ShapeCastOp>(loc, flattenedRHSType, rhs);
+ rhs = vector::ShapeCastOp::create(rew, loc, flattenedRHSType, rhs);
- Value mul = rew.create<LLVM::MatrixMultiplyOp>(
- loc,
+ Value mul = LLVM::MatrixMultiplyOp::create(
+ rew, loc,
VectorType::get(lhsRows * rhsColumns,
cast<VectorType>(lhs.getType()).getElementType()),
lhs, rhs, lhsRows, lhsColumns, rhsColumns);
- mul = rew.create<vector::ShapeCastOp>(
- loc,
+ mul = vector::ShapeCastOp::create(
+ rew, loc,
VectorType::get({lhsRows, rhsColumns},
getElementTypeOrSelf(op.getAcc().getType())),
mul);
@@ -2140,15 +2144,15 @@ FailureOr<Value> ContractionOpToMatmulOpLowering::matchAndRewriteMaskableOp(
// ACC must be C(m, n) or C(n, m).
auto accMap = op.getIndexingMapsArray()[2];
if (accMap == AffineMap::get(3, 0, {n, m}, ctx))
- mul = rew.create<vector::TransposeOp>(loc, mul, ArrayRef<int64_t>{1, 0});
+ mul = vector::TransposeOp::create(rew, loc, mul, ArrayRef<int64_t>{1, 0});
else if (accMap != AffineMap::get(3, 0, {m, n}, ctx))
llvm_unreachable("invalid contraction semantics");
- Value res =
- isa<IntegerType>(elementType)
- ? static_cast<Value>(rew.create<arith::AddIOp>(loc, op.getAcc(), mul))
- : static_cast<Value>(
- rew.create<arith::AddFOp>(loc, op.getAcc(), mul));
+ Value res = isa<IntegerType>(elementType)
+ ? static_cast<Value>(
+ arith::AddIOp::create(rew, loc, op.getAcc(), mul))
+ : static_cast<Value>(
+ arith::AddFOp::create(rew, loc, op.getAcc(), mul));
return res;
}
@@ -2181,11 +2185,11 @@ class TransposeOpToMatrixTransposeOpLowering
Type flattenedType =
VectorType::get(resType.getNumElements(), resType.getElementType());
auto matrix =
- rewriter.create<vector::ShapeCastOp>(loc, flattenedType, input);
+ vector::ShapeCastOp::create(rewriter, loc, flattenedType, input);
auto rows = rewriter.getI32IntegerAttr(resType.getShape()[0]);
auto columns = rewriter.getI32IntegerAttr(resType.getShape()[1]);
- Value trans = rewriter.create<LLVM::MatrixTransposeOp>(
- loc, flattenedType, matrix, rows, columns);
+ Value trans = LLVM::MatrixTransposeOp::create(rewriter, loc, flattenedType,
+ matrix, rows, columns);
rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(op, resType, trans);
return success();
}
diff --git a/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp b/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp
index 43732f58a4e0a..4c1047a8871a5 100644
--- a/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp
+++ b/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp
@@ -132,9 +132,9 @@ static void maybeYieldValue(OpBuilder &b, Location loc, bool hasRetVal,
Value value) {
if (hasRetVal) {
assert(value && "Expected non-empty value");
- b.create<scf::YieldOp>(loc, value);
+ scf::YieldOp::create(b, loc, value);
} else {
- b.create<scf::YieldOp>(loc);
+ scf::YieldOp::create(b, loc);
}
}
@@ -154,7 +154,7 @@ static Value generateMaskCheck(OpBuilder &b, OpTy xferOp, Value iv) {
return Value();
Location loc = xferOp.getLoc();
- return b.create<vector::ExtractOp>(loc, xferOp.getMask(), iv);
+ return vector::ExtractOp::create(b, loc, xferOp.getMask(), iv);
}
/// Helper function TransferOpConversion and TransferOp1dConversion.
@@ -201,22 +201,22 @@ static Value generateInBoundsCheck(
Value base = xferOp.getIndices()[*dim];
Value memrefIdx =
affine::makeComposedAffineApply(b, loc, d0 + d1, {base, iv});
- cond = lb.create<arith::CmpIOp>(arith::CmpIPredicate::sgt, memrefDim,
- memrefIdx);
+ cond = arith::CmpIOp::create(lb, arith::CmpIPredicate::sgt, memrefDim,
+ memrefIdx);
}
// Condition check 2: Masked in?
if (auto maskCond = generateMaskCheck(b, xferOp, iv)) {
if (cond)
- cond = lb.create<arith::AndIOp>(cond, maskCond);
+ cond = arith::AndIOp::create(lb, cond, maskCond);
else
cond = maskCond;
}
// If the condition is non-empty, generate an SCF::IfOp.
if (cond) {
- auto check = lb.create<scf::IfOp>(
- cond,
+ auto check = scf::IfOp::create(
+ lb, cond,
/*thenBuilder=*/
[&](OpBuilder &b, Location loc) {
maybeYieldValue(b, loc, hasRetVal, inBoundsCase(b, loc));
@@ -226,7 +226,7 @@ static Value generateInBoundsCheck(
if (outOfBoundsCase) {
maybeYieldValue(b, loc, hasRetVal, outOfBoundsCase(b, loc));
} else {
- b.create<scf::YieldOp>(loc);
+ scf::YieldOp::create(b, loc);
}
});
@@ -303,14 +303,15 @@ static BufferAllocs allocBuffers(OpBuilder &b, OpTy xferOp) {
BufferAllocs result;
auto bufferType = MemRefType::get({}, xferOp.getVectorType());
- result.dataBuffer = b.create<memref::AllocaOp>(loc, bufferType);
+ result.dataBuffer = memref::AllocaOp::create(b, loc, bufferType);
if (xferOp.getMask()) {
auto maskType = MemRefType::get({}, xferOp.getMask().getType());
- auto maskBuffer = b.create<memref::AllocaOp>(loc, maskType);
+ auto maskBuffer = memref::AllocaOp::create(b, loc, maskType);
b.setInsertionPoint(xferOp);
- b.create<memref::StoreOp>(loc, xferOp.getMask(), maskBuffer);
- result.maskBuffer = b.create<memref::LoadOp>(loc, maskBuffer, ValueRange());
+ memref::StoreOp::create(b, loc, xferOp.getMask(), maskBuffer);
+ result.maskBuffer =
+ memref::LoadOp::create(b, loc, maskBuffer, ValueRange());
}
return result;
@@ -421,14 +422,15 @@ struct Strategy<TransferReadOp> {
auto bufferType = dyn_cast<ShapedType>(buffer.getType());
auto vecType = dyn_cast<VectorType>(bufferType.getElementType());
auto inBoundsAttr = dropFirstElem(b, xferOp.getInBoundsAttr());
- auto newXferOp = b.create<vector::TransferReadOp>(
- loc, vecType, xferOp.getBase(), xferIndices,
+ auto newXferOp = vector::TransferReadOp::create(
+ b, loc, vecType, xferOp.getBase(), xferIndices,
AffineMapAttr::get(unpackedPermutationMap(b, xferOp)),
xferOp.getPadding(), Value(), inBoundsAttr);
maybeApplyPassLabel(b, newXferOp, options.targetRank);
- b.create<memref::StoreOp>(loc, newXferOp.getVector(), buffer, storeIndices);
+ memref::StoreOp::create(b, loc, newXferOp.getVector(), buffer,
+ storeIndices);
return newXferOp;
}
@@ -444,8 +446,9 @@ struct Strategy<TransferReadOp> {
Location loc = xferOp.getLoc();
auto bufferType = dyn_cast<ShapedType>(buffer.getType());
auto vecType = dyn_cast<VectorType>(bufferType.getElementType());
- auto vec = b.create<vector::BroadcastOp>(loc, vecType, xferOp.getPadding());
- b.create<memref::StoreOp>(loc, vec, buffer, storeIndices);
+ auto vec =
+ vector::BroadcastOp::create(b, loc, vecType, xferOp.getPadding());
+ memref::StoreOp::create(b, loc, vec, buffer, storeIndices);
return Value();
}
@@ -506,12 +509,12 @@ struct Strategy<TransferWriteOp> {
getXferIndices(b, xferOp, iv, xferIndices);
Location loc = xferOp.getLoc();
- auto vec = b.create<memref::LoadOp>(loc, buffer, loadIndices);
+ auto vec = memref::LoadOp::create(b, loc, buffer, loadIndices);
auto inBoundsAttr = dropFirstElem(b, xferOp.getInBoundsAttr());
auto source = loopState.empty() ? xferOp.getBase() : loopState[0];
Type type = isTensorOp(xferOp) ? xferOp.getShapedType() : Type();
- auto newXferOp = b.create<vector::TransferWriteOp>(
- loc, type, vec, source, xferIndices,
+ auto newXferOp = vector::TransferWriteOp::create(
+ b, loc, type, vec, source, xferIndices,
AffineMapAttr::get(unpackedPermutationMap(b, xferOp)), Value(),
inBoundsAttr);
@@ -610,8 +613,8 @@ struct PrepareTransferReadConversion
}
Location loc = xferOp.getLoc();
- rewriter.create<memref::StoreOp>(loc, newXfer->getResult(0),
- buffers.dataBuffer);
+ memref::StoreOp::create(rewriter, loc, newXfer->getResult(0),
+ buffers.dataBuffer);
rewriter.replaceOpWithNewOp<memref::LoadOp>(xferOp, buffers.dataBuffer);
return success();
@@ -653,9 +656,9 @@ struct PrepareTransferWriteConversion
Location loc = xferOp.getLoc();
auto buffers = allocBuffers(rewriter, xferOp);
- rewriter.create<memref::StoreOp>(loc, xferOp.getVector(),
- buffers.dataBuffer);
- auto loadedVec = rewriter.create<memref::LoadOp>(loc, buffers.dataBuffer);
+ memref::StoreOp::create(rewriter, loc, xferOp.getVector(),
+ buffers.dataBuffer);
+ auto loadedVec = memref::LoadOp::create(rewriter, loc, buffers.dataBuffer);
rewriter.modifyOpInPlace(xferOp, [&]() {
xferOp.getValueToStoreMutable().assign(loadedVec);
xferOp->setAttr(kPassLabel, rewriter.getUnitAttr());
@@ -735,17 +738,17 @@ struct DecomposePrintOpConversion : public VectorToSCFPattern<vector::PrintOp> {
auto signlessTargetVectorType =
vectorType.cloneWith({}, getIntTypeWithSignlessSemantics(legalIntTy));
auto targetVectorType = vectorType.cloneWith({}, legalIntTy);
- value = rewriter.create<vector::BitCastOp>(loc, signlessSourceVectorType,
- value);
+ value = vector::BitCastOp::create(rewriter, loc, signlessSourceVectorType,
+ value);
if (value.getType() != signlessTargetVectorType) {
if (width == 1 || intTy.isUnsigned())
- value = rewriter.create<arith::ExtUIOp>(loc, signlessTargetVectorType,
- value);
+ value = arith::ExtUIOp::create(rewriter, loc,
+ signlessTargetVectorType, value);
else
- value = rewriter.create<arith::ExtSIOp>(loc, signlessTargetVectorType,
- value);
+ value = arith::ExtSIOp::create(rewriter, loc,
+ signlessTargetVectorType, value);
}
- value = rewriter.create<vector::BitCastOp>(loc, targetVectorType, value);
+ value = vector::BitCastOp::create(rewriter, loc, targetVectorType, value);
vectorType = targetVectorType;
}
@@ -762,29 +765,30 @@ struct DecomposePrintOpConversion : public VectorToSCFPattern<vector::PrintOp> {
std::multiplies<int64_t>());
auto flatVectorType =
VectorType::get({flatLength}, vectorType.getElementType());
- value = rewriter.create<vector::ShapeCastOp>(loc, flatVectorType, value);
+ value = vector::ShapeCastOp::create(rewriter, loc, flatVectorType, value);
}
vector::PrintOp firstClose;
SmallVector<Value, 8> loopIndices;
for (unsigned d = 0; d < shape.size(); d++) {
// Setup loop bounds and step.
- Value lowerBound = rewriter.create<arith::ConstantIndexOp>(loc, 0);
- Value upperBound = rewriter.create<arith::ConstantIndexOp>(loc, shape[d]);
- Value step = rewriter.create<arith::ConstantIndexOp>(loc, 1);
+ Value lowerBound = arith::ConstantIndexOp::create(rewriter, loc, 0);
+ Value upperBound =
+ arith::ConstantIndexOp::create(rewriter, loc, shape[d]);
+ Value step = arith::ConstantIndexOp::create(rewriter, loc, 1);
if (!scalableDimensions.empty() && scalableDimensions[d]) {
- auto vscale = rewriter.create<vector::VectorScaleOp>(
- loc, rewriter.getIndexType());
- upperBound = rewriter.create<arith::MulIOp>(loc, upperBound, vscale);
+ auto vscale = vector::VectorScaleOp::create(rewriter, loc,
+ rewriter.getIndexType());
+ upperBound = arith::MulIOp::create(rewriter, loc, upperBound, vscale);
}
- auto lastIndex = rewriter.create<arith::SubIOp>(loc, upperBound, step);
+ auto lastIndex = arith::SubIOp::create(rewriter, loc, upperBound, step);
// Create a loop to print the elements surrounded by parentheses.
- rewriter.create<vector::PrintOp>(loc, vector::PrintPunctuation::Open);
+ vector::PrintOp::create(rewriter, loc, vector::PrintPunctuation::Open);
auto loop =
- rewriter.create<scf::ForOp>(loc, lowerBound, upperBound, step);
- auto printClose = rewriter.create<vector::PrintOp>(
- loc, vector::PrintPunctuation::Close);
+ scf::ForOp::create(rewriter, loc, lowerBound, upperBound, step);
+ auto printClose = vector::PrintOp::create(
+ rewriter, loc, vector::PrintPunctuation::Close);
if (!firstClose)
firstClose = printClose;
@@ -793,14 +797,14 @@ struct DecomposePrintOpConversion : public VectorToSCFPattern<vector::PrintOp> {
// Print a comma after all but the last element.
rewriter.setInsertionPointToStart(loop.getBody());
- auto notLastIndex = rewriter.create<arith::CmpIOp>(
- loc, arith::CmpIPredicate::ult, loopIdx, lastIndex);
- rewriter.create<scf::IfOp>(loc, notLastIndex,
- [&](OpBuilder &builder, Location loc) {
- builder.create<vector::PrintOp>(
- loc, vector::PrintPunctuation::Comma);
- builder.create<scf::YieldOp>(loc);
- });
+ auto notLastIndex = arith::CmpIOp::create(
+ rewriter, loc, arith::CmpIPredicate::ult, loopIdx, lastIndex);
+ scf::IfOp::create(rewriter, loc, notLastIndex,
+ [&](OpBuilder &builder, Location loc) {
+ vector::PrintOp::create(
+ builder, loc, vector::PrintPunctuation::Comma);
+ scf::YieldOp::create(builder, loc);
+ });
rewriter.setInsertionPointToStart(loop.getBody());
}
@@ -810,22 +814,23 @@ struct DecomposePrintOpConversion : public VectorToSCFPattern<vector::PrintOp> {
Value flatIndex;
auto currentStride = 1;
for (int d = shape.size() - 1; d >= 0; d--) {
- auto stride = rewriter.create<arith::ConstantIndexOp>(loc, currentStride);
- auto index = rewriter.create<arith::MulIOp>(loc, stride, loopIndices[d]);
+ auto stride =
+ arith::ConstantIndexOp::create(rewriter, loc, currentStride);
+ auto index = arith::MulIOp::create(rewriter, loc, stride, loopIndices[d]);
if (flatIndex)
- flatIndex = rewriter.create<arith::AddIOp>(loc, flatIndex, index);
+ flatIndex = arith::AddIOp::create(rewriter, loc, flatIndex, index);
else
flatIndex = index;
currentStride *= shape[d];
}
// Print the scalar elements in the inner most loop.
- auto element = rewriter.create<vector::ExtractOp>(loc, value, flatIndex);
- rewriter.create<vector::PrintOp>(loc, element,
- vector::PrintPunctuation::NoPunctuation);
+ auto element = vector::ExtractOp::create(rewriter, loc, value, flatIndex);
+ vector::PrintOp::create(rewriter, loc, element,
+ vector::PrintPunctuation::NoPunctuation);
rewriter.setInsertionPointAfter(firstClose);
- rewriter.create<vector::PrintOp>(loc, printOp.getPunctuation());
+ vector::PrintOp::create(rewriter, loc, printOp.getPunctuation());
rewriter.eraseOp(printOp);
return success();
}
@@ -916,7 +921,7 @@ struct TransferOpConversion : public VectorToSCFPattern<OpTy> {
"Failed to unpack one vector dim.");
auto castedDataBuffer =
- locB.create<vector::TypeCastOp>(*castedDataType, dataBuffer);
+ vector::TypeCastOp::create(locB, *castedDataType, dataBuffer);
// If the xferOp has a mask: Find and cast mask buffer.
Value castedMaskBuffer;
@@ -935,22 +940,22 @@ struct TransferOpConversion : public VectorToSCFPattern<OpTy> {
auto maskBufferType = cast<MemRefType>(maskBuffer.getType());
MemRefType castedMaskType = *unpackOneDim(maskBufferType);
castedMaskBuffer =
- locB.create<vector::TypeCastOp>(castedMaskType, maskBuffer);
+ vector::TypeCastOp::create(locB, castedMaskType, maskBuffer);
}
}
// Loop bounds and step.
- auto lb = locB.create<arith::ConstantIndexOp>(0);
- auto ub = locB.create<arith::ConstantIndexOp>(
- castedDataType->getDimSize(castedDataType->getRank() - 1));
- auto step = locB.create<arith::ConstantIndexOp>(1);
+ auto lb = arith::ConstantIndexOp::create(locB, 0);
+ auto ub = arith::ConstantIndexOp::create(
+ locB, castedDataType->getDimSize(castedDataType->getRank() - 1));
+ auto step = arith::ConstantIndexOp::create(locB, 1);
// TransferWriteOps that operate on tensors return the modified tensor and
// require a loop state.
auto loopState = Strategy<OpTy>::initialLoopState(xferOp);
// Generate for loop.
- auto result = locB.create<scf::ForOp>(
- lb, ub, step, loopState ? ValueRange(loopState) : ValueRange(),
+ auto result = scf::ForOp::create(
+ locB, lb, ub, step, loopState ? ValueRange(loopState) : ValueRange(),
[&](OpBuilder &b, Location loc, Value iv, ValueRange loopState) {
Type stateType = loopState.empty() ? Type() : loopState[0].getType();
@@ -975,8 +980,8 @@ struct TransferOpConversion : public VectorToSCFPattern<OpTy> {
SmallVector<Value, 8> loadIndices;
getMaskBufferLoadIndices(xferOp, castedMaskBuffer,
loadIndices, iv);
- auto mask = b.create<memref::LoadOp>(loc, castedMaskBuffer,
- loadIndices);
+ auto mask = memref::LoadOp::create(b, loc, castedMaskBuffer,
+ loadIndices);
rewriter.modifyOpInPlace(newXfer, [&]() {
newXfer.getMaskMutable().assign(mask);
});
@@ -1119,30 +1124,30 @@ struct ScalableTransposeTransferWriteConversion
auto transposeSource = transposeOp.getVector();
SmallVector<Value> transposeSourceSlices =
llvm::map_to_vector(fixedDimOffsets, [&](int64_t idx) -> Value {
- return rewriter.create<vector::ExtractOp>(loc, transposeSource, idx);
+ return vector::ExtractOp::create(rewriter, loc, transposeSource, idx);
});
// Loop bounds and step.
- auto lb = rewriter.create<arith::ConstantIndexOp>(loc, 0);
+ auto lb = arith::ConstantIndexOp::create(rewriter, loc, 0);
auto ub =
maskDims->empty()
? Value(createVscaleMultiple(vectorType.getDimSize(0)))
: vector::getAsValues(rewriter, loc, maskDims->front()).front();
- auto step = rewriter.create<arith::ConstantIndexOp>(loc, 1);
+ auto step = arith::ConstantIndexOp::create(rewriter, loc, 1);
// Generate a new mask for the slice.
VectorType sliceType = VectorType::Builder(vectorType).dropDim(0);
Value sliceMask = nullptr;
if (!maskDims->empty()) {
- sliceMask = rewriter.create<vector::CreateMaskOp>(
- loc, sliceType.clone(rewriter.getI1Type()),
+ sliceMask = vector::CreateMaskOp::create(
+ rewriter, loc, sliceType.clone(rewriter.getI1Type()),
ArrayRef<OpFoldResult>(*maskDims).drop_front());
}
Value initDest = isTensorOp(writeOp) ? writeOp.getBase() : Value{};
ValueRange initLoopArgs = initDest ? initDest : ValueRange{};
- auto result = rewriter.create<scf::ForOp>(
- loc, lb, ub, step, initLoopArgs,
+ auto result = scf::ForOp::create(
+ rewriter, loc, lb, ub, step, initLoopArgs,
[&](OpBuilder &b, Location loc, Value iv, ValueRange loopIterArgs) {
// Indices for the new transfer op.
SmallVector<Value, 8> xferIndices;
@@ -1151,25 +1156,25 @@ struct ScalableTransposeTransferWriteConversion
// Extract a transposed slice from the source vector.
SmallVector<Value> transposeElements =
llvm::map_to_vector(fixedDimOffsets, [&](int64_t idx) -> Value {
- return b.create<vector::ExtractOp>(
- loc, transposeSourceSlices[idx], iv);
+ return vector::ExtractOp::create(
+ b, loc, transposeSourceSlices[idx], iv);
});
- auto sliceVec = b.create<vector::FromElementsOp>(loc, sliceType,
- transposeElements);
+ auto sliceVec = vector::FromElementsOp::create(b, loc, sliceType,
+ transposeElements);
// Create the transfer_write for the slice.
Value dest =
loopIterArgs.empty() ? writeOp.getBase() : loopIterArgs.front();
- auto newWriteOp = b.create<vector::TransferWriteOp>(
- loc, sliceVec, dest, xferIndices,
+ auto newWriteOp = vector::TransferWriteOp::create(
+ b, loc, sliceVec, dest, xferIndices,
ArrayRef<bool>(writeOp.getInBoundsValues()).drop_front());
if (sliceMask)
newWriteOp.getMaskMutable().assign(sliceMask);
// Yield from the loop.
- b.create<scf::YieldOp>(loc, loopIterArgs.empty()
- ? ValueRange{}
- : newWriteOp.getResult());
+ scf::YieldOp::create(b, loc,
+ loopIterArgs.empty() ? ValueRange{}
+ : newWriteOp.getResult());
});
if (isTensorOp(writeOp))
@@ -1207,7 +1212,7 @@ static void maybeAssignMask(OpBuilder &b, OpTy xferOp, OpTy newXferOp,
llvm::SmallVector<int64_t, 1> indices({i});
Location loc = xferOp.getLoc();
- auto newMask = b.create<vector::ExtractOp>(loc, xferOp.getMask(), indices);
+ auto newMask = vector::ExtractOp::create(b, loc, xferOp.getMask(), indices);
newXferOp.getMaskMutable().assign(newMask);
}
@@ -1261,8 +1266,8 @@ struct UnrollTransferReadConversion
if (auto insertOp = getInsertOp(xferOp))
return insertOp.getDest();
Location loc = xferOp.getLoc();
- return rewriter.create<vector::BroadcastOp>(loc, xferOp.getVectorType(),
- xferOp.getPadding());
+ return vector::BroadcastOp::create(rewriter, loc, xferOp.getVectorType(),
+ xferOp.getPadding());
}
/// If the result of the TransferReadOp has exactly one user, which is a
@@ -1317,7 +1322,7 @@ struct UnrollTransferReadConversion
// Generate fully unrolled loop of transfer ops.
Location loc = xferOp.getLoc();
for (int64_t i = 0; i < dimSize; ++i) {
- Value iv = rewriter.create<arith::ConstantIndexOp>(loc, i);
+ Value iv = arith::ConstantIndexOp::create(rewriter, loc, i);
// FIXME: Rename this lambda - it does much more than just
// in-bounds-check generation.
@@ -1336,8 +1341,8 @@ struct UnrollTransferReadConversion
auto inBoundsAttr = dropFirstElem(b, xferOp.getInBoundsAttr());
- auto newXferOp = b.create<vector::TransferReadOp>(
- loc, newXferVecType, xferOp.getBase(), xferIndices,
+ auto newXferOp = vector::TransferReadOp::create(
+ b, loc, newXferVecType, xferOp.getBase(), xferIndices,
AffineMapAttr::get(unpackedPermutationMap(b, xferOp)),
xferOp.getPadding(), Value(), inBoundsAttr);
maybeAssignMask(b, xferOp, newXferOp, i);
@@ -1346,11 +1351,11 @@ struct UnrollTransferReadConversion
if (newXferVecType.getRank() == 0) {
// vector.insert does not accept rank-0 as the non-indexed
// argument. Extract the scalar before inserting.
- valToInser = b.create<vector::ExtractOp>(loc, valToInser,
- SmallVector<int64_t>());
+ valToInser = vector::ExtractOp::create(b, loc, valToInser,
+ SmallVector<int64_t>());
}
- return b.create<vector::InsertOp>(loc, valToInser, vec,
- insertionIndices);
+ return vector::InsertOp::create(b, loc, valToInser, vec,
+ insertionIndices);
},
/*outOfBoundsCase=*/
[&](OpBuilder &b, Location loc) {
@@ -1460,7 +1465,7 @@ struct UnrollTransferWriteConversion
// Generate fully unrolled loop of transfer ops.
Location loc = xferOp.getLoc();
for (int64_t i = 0; i < dimSize; ++i) {
- Value iv = rewriter.create<arith::ConstantIndexOp>(loc, i);
+ Value iv = arith::ConstantIndexOp::create(rewriter, loc, i);
auto updatedSource = generateInBoundsCheck(
rewriter, xferOp, iv, unpackedDim(xferOp),
@@ -1477,20 +1482,20 @@ struct UnrollTransferWriteConversion
extractionIndices.push_back(b.getI64IntegerAttr(i));
auto extracted =
- b.create<vector::ExtractOp>(loc, vec, extractionIndices);
+ vector::ExtractOp::create(b, loc, vec, extractionIndices);
auto inBoundsAttr = dropFirstElem(b, xferOp.getInBoundsAttr());
Value xferVec;
if (inputVectorTy.getRank() == 1) {
// When target-rank=0, unrolling would causes the vector input
// argument into `transfer_write` to become a scalar. We solve
// this by broadcasting the scalar to a 0D vector.
- xferVec = b.create<vector::BroadcastOp>(
- loc, VectorType::get({}, extracted.getType()), extracted);
+ xferVec = vector::BroadcastOp::create(
+ b, loc, VectorType::get({}, extracted.getType()), extracted);
} else {
xferVec = extracted;
}
- auto newXferOp = b.create<vector::TransferWriteOp>(
- loc, sourceType, xferVec, source, xferIndices,
+ auto newXferOp = vector::TransferWriteOp::create(
+ b, loc, sourceType, xferVec, source, xferIndices,
AffineMapAttr::get(unpackedPermutationMap(b, xferOp)), Value(),
inBoundsAttr);
@@ -1572,19 +1577,19 @@ struct Strategy1d<TransferReadOp> {
b, xferOp, iv, dim, TypeRange(xferOp.getVectorType()),
/*inBoundsCase=*/
[&](OpBuilder &b, Location loc) {
- Value val = b.create<memref::LoadOp>(loc, xferOp.getBase(), indices);
- return b.create<vector::InsertOp>(loc, val, vec, iv);
+ Value val = memref::LoadOp::create(b, loc, xferOp.getBase(), indices);
+ return vector::InsertOp::create(b, loc, val, vec, iv);
},
/*outOfBoundsCase=*/
[&](OpBuilder & /*b*/, Location loc) { return vec; });
- b.create<scf::YieldOp>(loc, nextVec);
+ scf::YieldOp::create(b, loc, nextVec);
}
static Value initialLoopState(OpBuilder &b, TransferReadOp xferOp) {
// Inititalize vector with padding value.
Location loc = xferOp.getLoc();
- return b.create<vector::BroadcastOp>(loc, xferOp.getVectorType(),
- xferOp.getPadding());
+ return vector::BroadcastOp::create(b, loc, xferOp.getVectorType(),
+ xferOp.getPadding());
}
};
@@ -1601,10 +1606,10 @@ struct Strategy1d<TransferWriteOp> {
generateInBoundsCheck(
b, xferOp, iv, dim,
/*inBoundsCase=*/[&](OpBuilder &b, Location loc) {
- auto val = b.create<vector::ExtractOp>(loc, xferOp.getVector(), iv);
- b.create<memref::StoreOp>(loc, val, xferOp.getBase(), indices);
+ auto val = vector::ExtractOp::create(b, loc, xferOp.getVector(), iv);
+ memref::StoreOp::create(b, loc, val, xferOp.getBase(), indices);
});
- b.create<scf::YieldOp>(loc);
+ scf::YieldOp::create(b, loc);
}
static Value initialLoopState(OpBuilder &b, TransferWriteOp xferOp) {
@@ -1665,15 +1670,15 @@ struct TransferOp1dConversion : public VectorToSCFPattern<OpTy> {
// Loop bounds, step, state...
Location loc = xferOp.getLoc();
auto vecType = xferOp.getVectorType();
- auto lb = rewriter.create<arith::ConstantIndexOp>(loc, 0);
+ auto lb = arith::ConstantIndexOp::create(rewriter, loc, 0);
Value ub =
- rewriter.create<arith::ConstantIndexOp>(loc, vecType.getDimSize(0));
+ arith::ConstantIndexOp::create(rewriter, loc, vecType.getDimSize(0));
if (vecType.isScalable()) {
Value vscale =
- rewriter.create<vector::VectorScaleOp>(loc, rewriter.getIndexType());
- ub = rewriter.create<arith::MulIOp>(loc, ub, vscale);
+ vector::VectorScaleOp::create(rewriter, loc, rewriter.getIndexType());
+ ub = arith::MulIOp::create(rewriter, loc, ub, vscale);
}
- auto step = rewriter.create<arith::ConstantIndexOp>(loc, 1);
+ auto step = arith::ConstantIndexOp::create(rewriter, loc, 1);
auto loopState = Strategy1d<OpTy>::initialLoopState(rewriter, xferOp);
// Generate for loop.
diff --git a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
index 750ce85049409..00ee3faa908e1 100644
--- a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
+++ b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
@@ -161,19 +161,19 @@ static Value sanitizeDynamicIndex(ConversionPatternRewriter &rewriter,
Location loc, Value dynamicIndex,
int64_t kPoisonIndex, unsigned vectorSize) {
if (llvm::isPowerOf2_32(vectorSize)) {
- Value inBoundsMask = rewriter.create<spirv::ConstantOp>(
- loc, dynamicIndex.getType(),
+ Value inBoundsMask = spirv::ConstantOp::create(
+ rewriter, loc, dynamicIndex.getType(),
rewriter.getIntegerAttr(dynamicIndex.getType(), vectorSize - 1));
- return rewriter.create<spirv::BitwiseAndOp>(loc, dynamicIndex,
- inBoundsMask);
+ return spirv::BitwiseAndOp::create(rewriter, loc, dynamicIndex,
+ inBoundsMask);
}
- Value poisonIndex = rewriter.create<spirv::ConstantOp>(
- loc, dynamicIndex.getType(),
+ Value poisonIndex = spirv::ConstantOp::create(
+ rewriter, loc, dynamicIndex.getType(),
rewriter.getIntegerAttr(dynamicIndex.getType(), kPoisonIndex));
Value cmpResult =
- rewriter.create<spirv::IEqualOp>(loc, dynamicIndex, poisonIndex);
- return rewriter.create<spirv::SelectOp>(
- loc, cmpResult,
+ spirv::IEqualOp::create(rewriter, loc, dynamicIndex, poisonIndex);
+ return spirv::SelectOp::create(
+ rewriter, loc, cmpResult,
spirv::ConstantOp::getZero(dynamicIndex.getType(), loc, rewriter),
dynamicIndex);
}
@@ -441,8 +441,8 @@ static SmallVector<Value> extractAllElements(
Location loc = reduceOp.getLoc();
for (int i = 0; i < numElements; ++i) {
- values.push_back(rewriter.create<spirv::CompositeExtractOp>(
- loc, srcVectorType.getElementType(), adaptor.getVector(),
+ values.push_back(spirv::CompositeExtractOp::create(
+ rewriter, loc, srcVectorType.getElementType(), adaptor.getVector(),
rewriter.getI32ArrayAttr({i})));
}
if (Value acc = adaptor.getAcc())
@@ -495,16 +495,16 @@ struct VectorReductionPattern final : OpConversionPattern<vector::ReductionOp> {
#define INT_AND_FLOAT_CASE(kind, iop, fop) \
case vector::CombiningKind::kind: \
if (llvm::isa<IntegerType>(resultType)) { \
- result = rewriter.create<spirv::iop>(loc, resultType, result, next); \
+ result = spirv::iop::create(rewriter, loc, resultType, result, next); \
} else { \
assert(llvm::isa<FloatType>(resultType)); \
- result = rewriter.create<spirv::fop>(loc, resultType, result, next); \
+ result = spirv::fop::create(rewriter, loc, resultType, result, next); \
} \
break
#define INT_OR_FLOAT_CASE(kind, fop) \
case vector::CombiningKind::kind: \
- result = rewriter.create<fop>(loc, resultType, result, next); \
+ result = fop::create(rewriter, loc, resultType, result, next); \
break
INT_AND_FLOAT_CASE(ADD, IAddOp, FAddOp);
@@ -551,7 +551,7 @@ struct VectorReductionFloatMinMax final
#define INT_OR_FLOAT_CASE(kind, fop) \
case vector::CombiningKind::kind: \
- result = rewriter.create<fop>(loc, resultType, result, next); \
+ result = fop::create(rewriter, loc, resultType, result, next); \
break
INT_OR_FLOAT_CASE(MAXIMUMF, SPIRVFMaxOp);
@@ -632,8 +632,8 @@ struct VectorShuffleOpConvert final
auto getElementAtIdx = [&rewriter, loc = shuffleOp.getLoc()](
Value scalarOrVec, int32_t idx) -> Value {
if (auto vecTy = dyn_cast<VectorType>(scalarOrVec.getType()))
- return rewriter.create<spirv::CompositeExtractOp>(loc, scalarOrVec,
- idx);
+ return spirv::CompositeExtractOp::create(rewriter, loc, scalarOrVec,
+ idx);
assert(idx == 0 && "Invalid scalar element index");
return scalarOrVec;
@@ -731,11 +731,13 @@ struct VectorDeinterleaveOpConvert final
// We cannot use `spirv::VectorShuffleOp` directly in this case, and need to
// use `spirv::CompositeExtractOp`.
if (n == 2) {
- auto elem0 = rewriter.create<spirv::CompositeExtractOp>(
- loc, newResultType, sourceVector, rewriter.getI32ArrayAttr({0}));
+ auto elem0 = spirv::CompositeExtractOp::create(
+ rewriter, loc, newResultType, sourceVector,
+ rewriter.getI32ArrayAttr({0}));
- auto elem1 = rewriter.create<spirv::CompositeExtractOp>(
- loc, newResultType, sourceVector, rewriter.getI32ArrayAttr({1}));
+ auto elem1 = spirv::CompositeExtractOp::create(
+ rewriter, loc, newResultType, sourceVector,
+ rewriter.getI32ArrayAttr({1}));
rewriter.replaceOp(deinterleaveOp, {elem0, elem1});
return success();
@@ -752,12 +754,12 @@ struct VectorDeinterleaveOpConvert final
llvm::map_to_vector(seqOdd, [](int i) { return i * 2 + 1; });
// Create two SPIR-V shuffles.
- auto shuffleEven = rewriter.create<spirv::VectorShuffleOp>(
- loc, newResultType, sourceVector, sourceVector,
+ auto shuffleEven = spirv::VectorShuffleOp::create(
+ rewriter, loc, newResultType, sourceVector, sourceVector,
rewriter.getI32ArrayAttr(indicesEven));
- auto shuffleOdd = rewriter.create<spirv::VectorShuffleOp>(
- loc, newResultType, sourceVector, sourceVector,
+ auto shuffleOdd = spirv::VectorShuffleOp::create(
+ rewriter, loc, newResultType, sourceVector, sourceVector,
rewriter.getI32ArrayAttr(indicesOdd));
rewriter.replaceOp(deinterleaveOp, {shuffleEven, shuffleOdd});
@@ -798,10 +800,11 @@ struct VectorLoadOpConverter final
// For single element vectors, we don't need to bitcast the access chain to
// the original vector type. Both is going to be the same, a pointer
// to a scalar.
- Value castedAccessChain = (vectorType.getNumElements() == 1)
- ? accessChain
- : rewriter.create<spirv::BitcastOp>(
- loc, vectorPtrType, accessChain);
+ Value castedAccessChain =
+ (vectorType.getNumElements() == 1)
+ ? accessChain
+ : spirv::BitcastOp::create(rewriter, loc, vectorPtrType,
+ accessChain);
rewriter.replaceOpWithNewOp<spirv::LoadOp>(loadOp, spirvVectorType,
castedAccessChain);
@@ -840,10 +843,11 @@ struct VectorStoreOpConverter final
// For single element vectors, we don't need to bitcast the access chain to
// the original vector type. Both is going to be the same, a pointer
// to a scalar.
- Value castedAccessChain = (vectorType.getNumElements() == 1)
- ? accessChain
- : rewriter.create<spirv::BitcastOp>(
- loc, vectorPtrType, accessChain);
+ Value castedAccessChain =
+ (vectorType.getNumElements() == 1)
+ ? accessChain
+ : spirv::BitcastOp::create(rewriter, loc, vectorPtrType,
+ accessChain);
rewriter.replaceOpWithNewOp<spirv::StoreOp>(storeOp, castedAccessChain,
adaptor.getValueToStore());
@@ -924,10 +928,10 @@ struct VectorReductionToIntDotProd final
auto v4i8Type = VectorType::get({4}, i8Type);
Location loc = op.getLoc();
Value zero = spirv::ConstantOp::getZero(i8Type, loc, rewriter);
- lhsIn = rewriter.create<spirv::CompositeConstructOp>(
- loc, v4i8Type, ValueRange{lhsIn, zero});
- rhsIn = rewriter.create<spirv::CompositeConstructOp>(
- loc, v4i8Type, ValueRange{rhsIn, zero});
+ lhsIn = spirv::CompositeConstructOp::create(rewriter, loc, v4i8Type,
+ ValueRange{lhsIn, zero});
+ rhsIn = spirv::CompositeConstructOp::create(rewriter, loc, v4i8Type,
+ ValueRange{rhsIn, zero});
}
// There's no variant of dot prod ops for unsigned LHS and signed RHS, so
@@ -990,14 +994,14 @@ struct VectorReductionToFPDotProd final
Attribute oneAttr =
rewriter.getFloatAttr(vectorType.getElementType(), 1.0);
oneAttr = SplatElementsAttr::get(vectorType, oneAttr);
- rhs = rewriter.create<spirv::ConstantOp>(loc, vectorType, oneAttr);
+ rhs = spirv::ConstantOp::create(rewriter, loc, vectorType, oneAttr);
}
assert(lhs);
assert(rhs);
- Value res = rewriter.create<spirv::DotOp>(loc, resultType, lhs, rhs);
+ Value res = spirv::DotOp::create(rewriter, loc, resultType, lhs, rhs);
if (acc)
- res = rewriter.create<spirv::FAddOp>(loc, acc, res);
+ res = spirv::FAddOp::create(rewriter, loc, acc, res);
rewriter.replaceOp(op, res);
return success();
@@ -1032,7 +1036,8 @@ struct VectorStepOpConvert final : OpConversionPattern<vector::StepOp> {
source.reserve(numElements);
for (int64_t i = 0; i < numElements; ++i) {
Attribute intAttr = rewriter.getIntegerAttr(intType, i);
- Value constOp = rewriter.create<spirv::ConstantOp>(loc, intType, intAttr);
+ Value constOp =
+ spirv::ConstantOp::create(rewriter, loc, intType, intAttr);
source.push_back(constOp);
}
rewriter.replaceOpWithNewOp<spirv::CompositeConstructOp>(stepOp, dstType,
@@ -1075,8 +1080,8 @@ struct VectorToElementOpConvert final
if (element.use_empty())
continue;
- Value result = rewriter.create<spirv::CompositeExtractOp>(
- loc, elementType, adaptor.getSource(),
+ Value result = spirv::CompositeExtractOp::create(
+ rewriter, loc, elementType, adaptor.getSource(),
rewriter.getI32ArrayAttr({static_cast<int32_t>(idx)}));
results[idx] = result;
}
diff --git a/mlir/lib/Conversion/VectorToXeGPU/VectorToXeGPU.cpp b/mlir/lib/Conversion/VectorToXeGPU/VectorToXeGPU.cpp
index 2e6a16ddbfdaa..80107554144cf 100644
--- a/mlir/lib/Conversion/VectorToXeGPU/VectorToXeGPU.cpp
+++ b/mlir/lib/Conversion/VectorToXeGPU/VectorToXeGPU.cpp
@@ -108,15 +108,15 @@ createNdDescriptor(PatternRewriter &rewriter, Location loc,
xegpu::CreateNdDescOp ndDesc;
if (srcTy.hasStaticShape()) {
- ndDesc = rewriter.create<xegpu::CreateNdDescOp>(loc, descType, src,
- getAsOpFoldResult(offsets));
+ ndDesc = xegpu::CreateNdDescOp::create(rewriter, loc, descType, src,
+ getAsOpFoldResult(offsets));
} else {
// In case of any dynamic shapes, source's shape and strides have to be
// explicitly provided.
SmallVector<Value> sourceDims;
unsigned srcRank = srcTy.getRank();
for (unsigned i = 0; i < srcRank; ++i)
- sourceDims.push_back(rewriter.create<memref::DimOp>(loc, src, i));
+ sourceDims.push_back(memref::DimOp::create(rewriter, loc, src, i));
SmallVector<int64_t> constOffsets;
SmallVector<Value> dynOffsets;
@@ -135,18 +135,18 @@ createNdDescriptor(PatternRewriter &rewriter, Location loc,
// Compute strides in reverse order.
SmallVector<Value> dynStrides;
- Value accStride = rewriter.create<arith::ConstantIndexOp>(loc, 1);
+ Value accStride = arith::ConstantIndexOp::create(rewriter, loc, 1);
// Last stride is guaranteed to be static and unit.
for (int i = static_cast<int>(strides.size()) - 2; i >= 0; --i) {
accStride =
- rewriter.create<arith::MulIOp>(loc, accStride, sourceDims[i + 1]);
+ arith::MulIOp::create(rewriter, loc, accStride, sourceDims[i + 1]);
if (strides[i] == ShapedType::kDynamic)
dynStrides.push_back(accStride);
}
std::reverse(dynStrides.begin(), dynStrides.end());
- ndDesc = rewriter.create<xegpu::CreateNdDescOp>(
- loc, descType, src, dynOffsets, dynShapes, dynStrides,
+ ndDesc = xegpu::CreateNdDescOp::create(
+ rewriter, loc, descType, src, dynOffsets, dynShapes, dynStrides,
DenseI64ArrayAttr::get(rewriter.getContext(), constOffsets),
DenseI64ArrayAttr::get(rewriter.getContext(), srcTy.getShape()),
DenseI64ArrayAttr::get(rewriter.getContext(), strides));
@@ -200,10 +200,10 @@ struct TransferReadLowering : public OpRewritePattern<vector::TransferReadOp> {
ArrayRef<int64_t>{1, 0});
// By default, no specific caching policy is assigned.
xegpu::CachePolicyAttr hint = nullptr;
- auto loadOp = rewriter.create<xegpu::LoadNdOp>(
- loc, vecTy, ndDesc, /*packed=*/nullptr, transposeAttr,
- /*l1_hint=*/hint,
- /*l2_hint=*/hint, /*l3_hint=*/hint);
+ auto loadOp = xegpu::LoadNdOp::create(rewriter, loc, vecTy, ndDesc,
+ /*packed=*/nullptr, transposeAttr,
+ /*l1_hint=*/hint,
+ /*l2_hint=*/hint, /*l3_hint=*/hint);
rewriter.replaceOp(readOp, loadOp);
return success();
@@ -238,9 +238,9 @@ struct TransferWriteLowering
// By default, no specific caching policy is assigned.
xegpu::CachePolicyAttr hint = nullptr;
auto storeOp =
- rewriter.create<xegpu::StoreNdOp>(loc, writeOp.getVector(), ndDesc,
- /*l1_hint=*/hint,
- /*l2_hint=*/hint, /*l3_hint=*/hint);
+ xegpu::StoreNdOp::create(rewriter, loc, writeOp.getVector(), ndDesc,
+ /*l1_hint=*/hint,
+ /*l2_hint=*/hint, /*l3_hint=*/hint);
rewriter.replaceOp(writeOp, storeOp);
return success();
@@ -269,8 +269,8 @@ struct LoadLowering : public OpRewritePattern<vector::LoadOp> {
// By default, no specific caching policy is assigned.
xegpu::CachePolicyAttr hint = nullptr;
- auto loadNdOp = rewriter.create<xegpu::LoadNdOp>(
- loc, vecTy, ndDesc, /*packed=*/nullptr, /*transpose=*/nullptr,
+ auto loadNdOp = xegpu::LoadNdOp::create(
+ rewriter, loc, vecTy, ndDesc, /*packed=*/nullptr, /*transpose=*/nullptr,
/*l1_hint=*/hint,
/*l2_hint=*/hint, /*l3_hint=*/hint);
rewriter.replaceOp(loadOp, loadNdOp);
@@ -303,9 +303,9 @@ struct StoreLowering : public OpRewritePattern<vector::StoreOp> {
// By default, no specific caching policy is assigned.
xegpu::CachePolicyAttr hint = nullptr;
auto storeNdOp =
- rewriter.create<xegpu::StoreNdOp>(loc, vector, ndDesc,
- /*l1_hint=*/hint,
- /*l2_hint=*/hint, /*l3_hint=*/hint);
+ xegpu::StoreNdOp::create(rewriter, loc, vector, ndDesc,
+ /*l1_hint=*/hint,
+ /*l2_hint=*/hint, /*l3_hint=*/hint);
rewriter.replaceOp(storeOp, storeNdOp);
return success();
@@ -339,8 +339,9 @@ struct ContractionLowering : public OpRewritePattern<vector::ContractionOp> {
if (!isRowMajorMatmul(contractOp.getIndexingMapsAttr()))
return rewriter.notifyMatchFailure(contractOp, "Invalid indexing maps");
- auto dpasOp = rewriter.create<xegpu::DpasOp>(
- loc, TypeRange{contractOp.getResultType()}, ValueRange{lhs, rhs, acc});
+ auto dpasOp = xegpu::DpasOp::create(rewriter, loc,
+ TypeRange{contractOp.getResultType()},
+ ValueRange{lhs, rhs, acc});
rewriter.replaceOp(contractOp, dpasOp);
return success();
diff --git a/mlir/lib/Conversion/XeVMToLLVM/XeVMToLLVM.cpp b/mlir/lib/Conversion/XeVMToLLVM/XeVMToLLVM.cpp
index a8380b9669f0f..2411af043f3f7 100644
--- a/mlir/lib/Conversion/XeVMToLLVM/XeVMToLLVM.cpp
+++ b/mlir/lib/Conversion/XeVMToLLVM/XeVMToLLVM.cpp
@@ -251,7 +251,7 @@ static LLVM::CallOp createDeviceFunctionCall(
for (auto [idx, attrName] : paramAttrs)
funcOp.setArgAttr(idx, attrName, rewriter.getUnitAttr());
- auto callOp = rewriter.create<LLVM::CallOp>(loc, funcOp, args);
+ auto callOp = LLVM::CallOp::create(rewriter, loc, funcOp, args);
callOp->setAttrs(funcOp->getAttrs());
return callOp;
@@ -299,7 +299,7 @@ class MMAToOCLPattern : public OpConversionPattern<xevm::MMAOp> {
VectorType newTy = VectorType::get(
vecBitSize / packedType.getIntOrFloatBitWidth(), packedType);
if (origTy != newTy)
- val = rewriter.create<LLVM::BitcastOp>(loc, newTy, val);
+ val = LLVM::BitcastOp::create(rewriter, loc, newTy, val);
return val;
};
@@ -326,7 +326,7 @@ class MMAToOCLPattern : public OpConversionPattern<xevm::MMAOp> {
: cOrigTy;
VectorType resTy = cTy;
if (cOrigTy != cTy)
- c = rewriter.create<LLVM::BitcastOp>(loc, cTy, c);
+ c = LLVM::BitcastOp::create(rewriter, loc, cTy, c);
constexpr int32_t systolicDepth{8};
std::string fnName =
@@ -352,7 +352,7 @@ class MMAToOCLPattern : public OpConversionPattern<xevm::MMAOp> {
->getResult(0);
if (resOrigTy != resTy)
- result = rewriter.create<LLVM::BitcastOp>(loc, resOrigTy, result);
+ result = LLVM::BitcastOp::create(rewriter, loc, resOrigTy, result);
rewriter.replaceOp(op, result);
return success();
@@ -383,7 +383,7 @@ class PrefetchToOCLPattern : public OpConversionPattern<PrefetchOp> {
auto loc = op.getLoc();
const std::string fnName{"_Z8prefetchPU3AS1Kcm"};
Value one =
- rewriter.create<LLVM::ConstantOp>(loc, rewriter.getI64Type(), 1);
+ LLVM::ConstantOp::create(rewriter, loc, rewriter.getI64Type(), 1);
SmallVector<Value> args{op.getPtr(), one};
SmallVector<Type> argTypes;
for (auto arg : args)
@@ -439,11 +439,11 @@ class MemfenceToOCLPattern : public OpConversionPattern<MemfenceOp> {
op, "Fence only supports workgroup and device memory scopes.");
}
Type i32Type = rewriter.getI32Type();
- Value acqRel = rewriter.create<LLVM::ConstantOp>(loc, i32Type, 4);
+ Value acqRel = LLVM::ConstantOp::create(rewriter, loc, i32Type, 4);
Value memScopeConst =
- rewriter.create<LLVM::ConstantOp>(loc, i32Type, memScope);
+ LLVM::ConstantOp::create(rewriter, loc, i32Type, memScope);
Value addrSpaceConst =
- rewriter.create<LLVM::ConstantOp>(loc, i32Type, addrSpace);
+ LLVM::ConstantOp::create(rewriter, loc, i32Type, addrSpace);
SmallVector<Value> args{addrSpaceConst, acqRel, memScopeConst};
SmallVector<Type> argTypes{3, i32Type};
createDeviceFunctionCall(rewriter, mangle(fnName, argTypes),
@@ -477,13 +477,13 @@ class LoadStorePrefetchToOCLPattern : public OpConversionPattern<OpType> {
auto i32Type = rewriter.getI32Type();
Value byteCoord =
- rewriter.create<LLVM::UndefOp>(loc, VectorType::get(2, i32Type));
- Value zero = rewriter.create<LLVM::ConstantOp>(loc, i32Type, 0);
- Value one = rewriter.create<LLVM::ConstantOp>(loc, i32Type, 1);
- byteCoord = rewriter.create<LLVM::InsertElementOp>(
- loc, VectorType::get(2, i32Type), byteCoord, op.getX(), zero);
- byteCoord = rewriter.create<LLVM::InsertElementOp>(
- loc, VectorType::get(2, i32Type), byteCoord, op.getY(), one);
+ LLVM::UndefOp::create(rewriter, loc, VectorType::get(2, i32Type));
+ Value zero = LLVM::ConstantOp::create(rewriter, loc, i32Type, 0);
+ Value one = LLVM::ConstantOp::create(rewriter, loc, i32Type, 1);
+ byteCoord = LLVM::InsertElementOp::create(
+ rewriter, loc, VectorType::get(2, i32Type), byteCoord, op.getX(), zero);
+ byteCoord = LLVM::InsertElementOp::create(
+ rewriter, loc, VectorType::get(2, i32Type), byteCoord, op.getY(), one);
SmallVector<Value> args{op.getPtr(), op.getBaseWidth(), op.getBaseHeight(),
op.getBasePitch(), byteCoord};
SmallVector<Type> retTypes;
@@ -504,11 +504,11 @@ class LoadStorePrefetchToOCLPattern : public OpConversionPattern<OpType> {
} else {
auto vecElemType = vecType.getElementType();
auto vecElemBitWidth = vecElemType.getIntOrFloatBitWidth();
- Value numElems = rewriter.create<LLVM::ConstantOp>(
- loc, i32Type, vecType.getNumElements());
- auto dstOrSrcPtr = rewriter.create<LLVM::AllocaOp>(
- loc, LLVM::LLVMPointerType::get(rewriter.getContext()), vecElemType,
- numElems);
+ Value numElems = LLVM::ConstantOp::create(rewriter, loc, i32Type,
+ vecType.getNumElements());
+ auto dstOrSrcPtr = LLVM::AllocaOp::create(
+ rewriter, loc, LLVM::LLVMPointerType::get(rewriter.getContext()),
+ vecElemType, numElems);
args.push_back(dstOrSrcPtr);
if constexpr (isLoad) { // Load
funcName += "read";
@@ -530,7 +530,7 @@ class LoadStorePrefetchToOCLPattern : public OpConversionPattern<OpType> {
bitWidthId = (vecElemBitWidth == 32)
? "j"
: ((vecElemBitWidth == 16) ? "t" : "h");
- rewriter.create<LLVM::StoreOp>(loc, op.getStoredVal(), dstOrSrcPtr);
+ LLVM::StoreOp::create(rewriter, loc, op.getStoredVal(), dstOrSrcPtr);
paramAttrs = {
std::make_pair(0, LLVM::LLVMDialect::getNonNullAttrName()),
std::make_pair(0, LLVM::LLVMDialect::getWriteOnlyAttrName()),
@@ -563,7 +563,7 @@ class LoadStorePrefetchToOCLPattern : public OpConversionPattern<OpType> {
}
if constexpr (isLoad)
rewriter.replaceOp(
- op, rewriter.create<LLVM::LoadOp>(loc, vecType, spvLoadDstPtr));
+ op, LLVM::LoadOp::create(rewriter, loc, vecType, spvLoadDstPtr));
else
rewriter.eraseOp(op);
return success();
More information about the Mlir-commits
mailing list