[Mlir-commits] [mlir] 967626b - [mlir][NFC] update `mlir/Dialect` create APIs (14/n) (#149920)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Thu Jul 24 11:03:50 PDT 2025
Author: Maksim Levental
Date: 2025-07-24T13:03:47-05:00
New Revision: 967626b842551ecd997c0d10eb68c3015b63a3d7
URL: https://github.com/llvm/llvm-project/commit/967626b842551ecd997c0d10eb68c3015b63a3d7
DIFF: https://github.com/llvm/llvm-project/commit/967626b842551ecd997c0d10eb68c3015b63a3d7.diff
LOG: [mlir][NFC] update `mlir/Dialect` create APIs (14/n) (#149920)
See https://github.com/llvm/llvm-project/pull/147168 for more info.
Added:
Modified:
mlir/lib/Dialect/AMDGPU/Transforms/EmulateAtomics.cpp
mlir/lib/Dialect/AMDGPU/Transforms/MaskedloadToLoad.cpp
mlir/lib/Dialect/AMDGPU/Transforms/ResolveStridedMetadata.cpp
mlir/lib/Dialect/AMX/IR/AMXDialect.cpp
mlir/lib/Dialect/Affine/Transforms/AffineDataCopyGeneration.cpp
mlir/lib/Dialect/Affine/Transforms/AffineExpandIndexOps.cpp
mlir/lib/Dialect/Affine/Transforms/DecomposeAffineOps.cpp
mlir/lib/Dialect/Affine/Transforms/LoopFusion.cpp
mlir/lib/Dialect/Affine/Transforms/PipelineDataTransfer.cpp
mlir/lib/Dialect/Affine/Transforms/ReifyValueBounds.cpp
mlir/lib/Dialect/Affine/Transforms/SuperVectorize.cpp
mlir/lib/Dialect/Affine/Utils/LoopUtils.cpp
mlir/lib/Dialect/Affine/Utils/Utils.cpp
mlir/lib/Dialect/Arith/IR/ArithDialect.cpp
mlir/lib/Dialect/Arith/Transforms/BufferDeallocationOpInterfaceImpl.cpp
mlir/lib/Dialect/Arith/Transforms/BufferizableOpInterfaceImpl.cpp
mlir/lib/Dialect/Arith/Transforms/EmulateUnsupportedFloats.cpp
mlir/lib/Dialect/Arith/Transforms/EmulateWideInt.cpp
mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp
mlir/lib/Dialect/Arith/Transforms/IntRangeOptimizations.cpp
mlir/lib/Dialect/Arith/Transforms/ReifyValueBounds.cpp
mlir/lib/Dialect/Arith/Transforms/ShardingInterfaceImpl.cpp
mlir/lib/Dialect/Arith/Utils/Utils.cpp
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/AMDGPU/Transforms/EmulateAtomics.cpp b/mlir/lib/Dialect/AMDGPU/Transforms/EmulateAtomics.cpp
index 37e0d2af55fe1..6d1f64e94df15 100644
--- a/mlir/lib/Dialect/AMDGPU/Transforms/EmulateAtomics.cpp
+++ b/mlir/lib/Dialect/AMDGPU/Transforms/EmulateAtomics.cpp
@@ -99,8 +99,8 @@ static Value flattenVecToBits(ConversionPatternRewriter &rewriter, Location loc,
vectorType.getElementTypeBitWidth() * vectorType.getNumElements();
Type allBitsType = rewriter.getIntegerType(bitwidth);
auto allBitsVecType = VectorType::get({1}, allBitsType);
- Value bitcast = rewriter.create<vector::BitCastOp>(loc, allBitsVecType, val);
- Value scalar = rewriter.create<vector::ExtractOp>(loc, bitcast, 0);
+ Value bitcast = vector::BitCastOp::create(rewriter, loc, allBitsVecType, val);
+ Value scalar = vector::ExtractOp::create(rewriter, loc, bitcast, 0);
return scalar;
}
@@ -118,27 +118,27 @@ LogicalResult RawBufferAtomicByCasPattern<AtomicOp, ArithOp>::matchAndRewrite(
SmallVector<NamedAttribute> loadAttrs;
patchOperandSegmentSizes(origAttrs, loadAttrs, DataArgAction::Drop);
- Value initialLoad =
- rewriter.create<RawBufferLoadOp>(loc, dataType, invariantArgs, loadAttrs);
+ Value initialLoad = RawBufferLoadOp::create(rewriter, loc, dataType,
+ invariantArgs, loadAttrs);
Block *currentBlock = rewriter.getInsertionBlock();
Block *afterAtomic =
rewriter.splitBlock(currentBlock, rewriter.getInsertionPoint());
Block *loopBlock = rewriter.createBlock(afterAtomic, {dataType}, {loc});
rewriter.setInsertionPointToEnd(currentBlock);
- rewriter.create<cf::BranchOp>(loc, loopBlock, initialLoad);
+ cf::BranchOp::create(rewriter, loc, loopBlock, initialLoad);
rewriter.setInsertionPointToEnd(loopBlock);
Value prevLoad = loopBlock->getArgument(0);
- Value operated = rewriter.create<ArithOp>(loc, data, prevLoad);
+ Value operated = ArithOp::create(rewriter, loc, data, prevLoad);
dataType = operated.getType();
SmallVector<NamedAttribute> cmpswapAttrs;
patchOperandSegmentSizes(origAttrs, cmpswapAttrs, DataArgAction::Duplicate);
SmallVector<Value> cmpswapArgs = {operated, prevLoad};
cmpswapArgs.append(invariantArgs.begin(), invariantArgs.end());
- Value atomicRes = rewriter.create<RawBufferAtomicCmpswapOp>(
- loc, dataType, cmpswapArgs, cmpswapAttrs);
+ Value atomicRes = RawBufferAtomicCmpswapOp::create(rewriter, loc, dataType,
+ cmpswapArgs, cmpswapAttrs);
// We care about exact bitwise equality here, so do some bitcasts.
// These will fold away during lowering to the ROCDL dialect, where
@@ -150,14 +150,15 @@ LogicalResult RawBufferAtomicByCasPattern<AtomicOp, ArithOp>::matchAndRewrite(
if (auto floatDataTy = dyn_cast<FloatType>(dataType)) {
Type equivInt = rewriter.getIntegerType(floatDataTy.getWidth());
prevLoadForCompare =
- rewriter.create<arith::BitcastOp>(loc, equivInt, prevLoad);
+ arith::BitcastOp::create(rewriter, loc, equivInt, prevLoad);
atomicResForCompare =
- rewriter.create<arith::BitcastOp>(loc, equivInt, atomicRes);
+ arith::BitcastOp::create(rewriter, loc, equivInt, atomicRes);
}
- Value canLeave = rewriter.create<arith::CmpIOp>(
- loc, arith::CmpIPredicate::eq, atomicResForCompare, prevLoadForCompare);
- rewriter.create<cf::CondBranchOp>(loc, canLeave, afterAtomic, ValueRange{},
- loopBlock, atomicRes);
+ Value canLeave =
+ arith::CmpIOp::create(rewriter, loc, arith::CmpIPredicate::eq,
+ atomicResForCompare, prevLoadForCompare);
+ cf::CondBranchOp::create(rewriter, loc, canLeave, afterAtomic, ValueRange{},
+ loopBlock, atomicRes);
rewriter.eraseOp(atomicOp);
return success();
}
diff --git a/mlir/lib/Dialect/AMDGPU/Transforms/MaskedloadToLoad.cpp b/mlir/lib/Dialect/AMDGPU/Transforms/MaskedloadToLoad.cpp
index af8634c692654..f15c63c166e0a 100644
--- a/mlir/lib/Dialect/AMDGPU/Transforms/MaskedloadToLoad.cpp
+++ b/mlir/lib/Dialect/AMDGPU/Transforms/MaskedloadToLoad.cpp
@@ -54,11 +54,11 @@ static Value createVectorLoadForMaskedLoad(OpBuilder &builder, Location loc,
vector::MaskedLoadOp maskedOp,
bool passthru) {
VectorType vectorType = maskedOp.getVectorType();
- Value load = builder.create<vector::LoadOp>(
- loc, vectorType, maskedOp.getBase(), maskedOp.getIndices());
+ Value load = vector::LoadOp::create(
+ builder, loc, vectorType, maskedOp.getBase(), maskedOp.getIndices());
if (passthru)
- load = builder.create<arith::SelectOp>(loc, vectorType, maskedOp.getMask(),
- load, maskedOp.getPassThru());
+ load = arith::SelectOp::create(builder, loc, vectorType, maskedOp.getMask(),
+ load, maskedOp.getPassThru());
return load;
}
@@ -108,7 +108,7 @@ struct MaskedLoadLowering final : OpRewritePattern<vector::MaskedLoadOp> {
SmallVector<OpFoldResult> indices = maskedOp.getIndices();
auto stridedMetadata =
- rewriter.create<memref::ExtractStridedMetadataOp>(loc, src);
+ memref::ExtractStridedMetadataOp::create(rewriter, loc, src);
SmallVector<OpFoldResult> strides =
stridedMetadata.getConstifiedMixedStrides();
SmallVector<OpFoldResult> sizes = stridedMetadata.getConstifiedMixedSizes();
@@ -122,47 +122,47 @@ struct MaskedLoadLowering final : OpRewritePattern<vector::MaskedLoadOp> {
// delta = bufferSize - linearizedOffset
Value vectorSizeOffset =
- rewriter.create<arith::ConstantIndexOp>(loc, vectorSize);
+ arith::ConstantIndexOp::create(rewriter, loc, vectorSize);
Value linearIndex =
getValueOrCreateConstantIndexOp(rewriter, loc, linearizedIndices);
Value totalSize = getValueOrCreateConstantIndexOp(
rewriter, loc, linearizedInfo.linearizedSize);
- Value delta = rewriter.create<arith::SubIOp>(loc, totalSize, linearIndex);
+ Value delta = arith::SubIOp::create(rewriter, loc, totalSize, linearIndex);
// 1) check if delta < vectorSize
- Value isOutofBounds = rewriter.create<arith::CmpIOp>(
- loc, arith::CmpIPredicate::ult, delta, vectorSizeOffset);
+ Value isOutofBounds = arith::CmpIOp::create(
+ rewriter, loc, arith::CmpIPredicate::ult, delta, vectorSizeOffset);
// 2) check if (detla % elements_per_word != 0)
- Value elementsPerWord = rewriter.create<arith::ConstantIndexOp>(
- loc, llvm::divideCeil(32, elementBitWidth));
- Value isNotWordAligned = rewriter.create<arith::CmpIOp>(
- loc, arith::CmpIPredicate::ne,
- rewriter.create<arith::RemUIOp>(loc, delta, elementsPerWord),
- rewriter.create<arith::ConstantIndexOp>(loc, 0));
+ Value elementsPerWord = arith::ConstantIndexOp::create(
+ rewriter, loc, llvm::divideCeil(32, elementBitWidth));
+ Value isNotWordAligned = arith::CmpIOp::create(
+ rewriter, loc, arith::CmpIPredicate::ne,
+ arith::RemUIOp::create(rewriter, loc, delta, elementsPerWord),
+ arith::ConstantIndexOp::create(rewriter, loc, 0));
// We take the fallback of maskedload default lowering only it is both
// out-of-bounds and not word aligned. The fallback ensures correct results
// when loading at the boundary of the buffer since buffer load returns
// inconsistent zeros for the whole word when boundary is crossed.
Value ifCondition =
- rewriter.create<arith::AndIOp>(loc, isOutofBounds, isNotWordAligned);
+ arith::AndIOp::create(rewriter, loc, isOutofBounds, isNotWordAligned);
auto thenBuilder = [&](OpBuilder &builder, Location loc) {
Operation *read = builder.clone(*maskedOp.getOperation());
read->setAttr(kMaskedloadNeedsMask, builder.getUnitAttr());
Value readResult = read->getResult(0);
- builder.create<scf::YieldOp>(loc, readResult);
+ scf::YieldOp::create(builder, loc, readResult);
};
auto elseBuilder = [&](OpBuilder &builder, Location loc) {
Value res = createVectorLoadForMaskedLoad(builder, loc, maskedOp,
/*passthru=*/true);
- rewriter.create<scf::YieldOp>(loc, res);
+ scf::YieldOp::create(rewriter, loc, res);
};
auto ifOp =
- rewriter.create<scf::IfOp>(loc, ifCondition, thenBuilder, elseBuilder);
+ scf::IfOp::create(rewriter, loc, ifCondition, thenBuilder, elseBuilder);
rewriter.replaceOp(maskedOp, ifOp);
@@ -185,13 +185,13 @@ struct FullMaskedLoadToConditionalLoad
auto trueBuilder = [&](OpBuilder &builder, Location loc) {
Value res = createVectorLoadForMaskedLoad(builder, loc, loadOp,
/*passthru=*/false);
- rewriter.create<scf::YieldOp>(loc, res);
+ scf::YieldOp::create(rewriter, loc, res);
};
auto falseBuilder = [&](OpBuilder &builder, Location loc) {
- rewriter.create<scf::YieldOp>(loc, loadOp.getPassThru());
+ scf::YieldOp::create(rewriter, loc, loadOp.getPassThru());
};
- auto ifOp = rewriter.create<scf::IfOp>(loadOp.getLoc(), cond, trueBuilder,
- falseBuilder);
+ auto ifOp = scf::IfOp::create(rewriter, loadOp.getLoc(), cond, trueBuilder,
+ falseBuilder);
rewriter.replaceOp(loadOp, ifOp);
return success();
}
@@ -210,11 +210,12 @@ struct FullMaskedStoreToConditionalStore
Value cond = maybeCond.value();
auto trueBuilder = [&](OpBuilder &builder, Location loc) {
- rewriter.create<vector::StoreOp>(loc, storeOp.getValueToStore(),
- storeOp.getBase(), storeOp.getIndices());
- rewriter.create<scf::YieldOp>(loc);
+ vector::StoreOp::create(rewriter, loc, storeOp.getValueToStore(),
+ storeOp.getBase(), storeOp.getIndices());
+ scf::YieldOp::create(rewriter, loc);
};
- auto ifOp = rewriter.create<scf::IfOp>(storeOp.getLoc(), cond, trueBuilder);
+ auto ifOp =
+ scf::IfOp::create(rewriter, storeOp.getLoc(), cond, trueBuilder);
rewriter.replaceOp(storeOp, ifOp);
return success();
}
diff --git a/mlir/lib/Dialect/AMDGPU/Transforms/ResolveStridedMetadata.cpp b/mlir/lib/Dialect/AMDGPU/Transforms/ResolveStridedMetadata.cpp
index 195f59d625554..f8bab8289cbc6 100644
--- a/mlir/lib/Dialect/AMDGPU/Transforms/ResolveStridedMetadata.cpp
+++ b/mlir/lib/Dialect/AMDGPU/Transforms/ResolveStridedMetadata.cpp
@@ -37,8 +37,8 @@ struct ExtractStridedMetadataOnFatRawBufferCastFolder final
return rewriter.notifyMatchFailure(metadataOp,
"not a fat raw buffer cast");
Location loc = castOp.getLoc();
- auto sourceMetadata = rewriter.create<memref::ExtractStridedMetadataOp>(
- loc, castOp.getSource());
+ auto sourceMetadata = memref::ExtractStridedMetadataOp::create(
+ rewriter, loc, castOp.getSource());
SmallVector<Value> results;
if (metadataOp.getBaseBuffer().use_empty()) {
results.push_back(nullptr);
@@ -48,13 +48,13 @@ struct ExtractStridedMetadataOnFatRawBufferCastFolder final
if (baseBufferType == castOp.getResult().getType()) {
results.push_back(castOp.getResult());
} else {
- results.push_back(rewriter.create<memref::ReinterpretCastOp>(
- loc, baseBufferType, castOp.getResult(), /*offset=*/0,
+ results.push_back(memref::ReinterpretCastOp::create(
+ rewriter, loc, baseBufferType, castOp.getResult(), /*offset=*/0,
/*sizes=*/ArrayRef<int64_t>{}, /*strides=*/ArrayRef<int64_t>{}));
}
}
if (castOp.getResetOffset())
- results.push_back(rewriter.create<arith::ConstantIndexOp>(loc, 0));
+ results.push_back(arith::ConstantIndexOp::create(rewriter, loc, 0));
else
results.push_back(sourceMetadata.getOffset());
llvm::append_range(results, sourceMetadata.getSizes());
diff --git a/mlir/lib/Dialect/AMX/IR/AMXDialect.cpp b/mlir/lib/Dialect/AMX/IR/AMXDialect.cpp
index 12b375b373fa9..748ff1edbfeb2 100644
--- a/mlir/lib/Dialect/AMX/IR/AMXDialect.cpp
+++ b/mlir/lib/Dialect/AMX/IR/AMXDialect.cpp
@@ -76,8 +76,8 @@ static SmallVector<Value> getTileSizes(Location loc, amx::TileType tType,
auto mattr = rewriter.getI16IntegerAttr(tType.getDimSize(0));
auto nattr = rewriter.getI16IntegerAttr(tType.getDimSize(1) * bytes);
return SmallVector<Value>{
- rewriter.create<LLVM::ConstantOp>(loc, llvmInt16Type, mattr),
- rewriter.create<LLVM::ConstantOp>(loc, llvmInt16Type, nattr)};
+ LLVM::ConstantOp::create(rewriter, loc, llvmInt16Type, mattr),
+ LLVM::ConstantOp::create(rewriter, loc, llvmInt16Type, nattr)};
}
/// Maps the 2-dim memref shape to the 64-bit stride. Note that the buffer
@@ -95,7 +95,7 @@ static Value getStride(Location loc, MemRefType mType, Value base,
// Dynamic stride needs code to compute the stride at runtime.
MemRefDescriptor memrefDescriptor(base);
auto attr = rewriter.getI64IntegerAttr(bytes);
- Value scale = rewriter.create<LLVM::ConstantOp>(loc, llvmInt64Type, attr);
+ Value scale = LLVM::ConstantOp::create(rewriter, loc, llvmInt64Type, attr);
return rewriter
.create<LLVM::MulOp>(loc, llvmInt64Type, scale,
memrefDescriptor.stride(rewriter, loc, preLast))
@@ -103,7 +103,7 @@ static Value getStride(Location loc, MemRefType mType, Value base,
}
// Use direct constant for static stride.
auto attr = rewriter.getI64IntegerAttr(strides[preLast] * bytes);
- return rewriter.create<LLVM::ConstantOp>(loc, llvmInt64Type, attr)
+ return LLVM::ConstantOp::create(rewriter, loc, llvmInt64Type, attr)
.getResult();
}
diff --git a/mlir/lib/Dialect/Affine/Transforms/AffineDataCopyGeneration.cpp b/mlir/lib/Dialect/Affine/Transforms/AffineDataCopyGeneration.cpp
index f18cec5a14fae..df39544aeaa09 100644
--- a/mlir/lib/Dialect/Affine/Transforms/AffineDataCopyGeneration.cpp
+++ b/mlir/lib/Dialect/Affine/Transforms/AffineDataCopyGeneration.cpp
@@ -202,7 +202,7 @@ void AffineDataCopyGeneration::runOnBlock(Block *block,
void AffineDataCopyGeneration::runOnOperation() {
func::FuncOp f = getOperation();
OpBuilder topBuilder(f.getBody());
- zeroIndex = topBuilder.create<arith::ConstantIndexOp>(f.getLoc(), 0);
+ zeroIndex = arith::ConstantIndexOp::create(topBuilder, f.getLoc(), 0);
// Nests that are copy-in's or copy-out's; the root AffineForOps of those
// nests are stored herein.
diff --git a/mlir/lib/Dialect/Affine/Transforms/AffineExpandIndexOps.cpp b/mlir/lib/Dialect/Affine/Transforms/AffineExpandIndexOps.cpp
index 5430bdc4ff858..c0d174a04abf9 100644
--- a/mlir/lib/Dialect/Affine/Transforms/AffineExpandIndexOps.cpp
+++ b/mlir/lib/Dialect/Affine/Transforms/AffineExpandIndexOps.cpp
@@ -58,8 +58,9 @@ static SmallVector<Value> computeStrides(Location loc, RewriterBase &rewriter,
// Note: basis elements and their products are, definitionally,
// non-negative, so `nuw` is justified.
if (dynamicPart)
- dynamicPart = rewriter.create<arith::MulIOp>(
- loc, dynamicPart, dynamicBasis[dynamicIndex - 1], ovflags);
+ dynamicPart =
+ arith::MulIOp::create(rewriter, loc, dynamicPart,
+ dynamicBasis[dynamicIndex - 1], ovflags);
else
dynamicPart = dynamicBasis[dynamicIndex - 1];
--dynamicIndex;
@@ -74,7 +75,7 @@ static SmallVector<Value> computeStrides(Location loc, RewriterBase &rewriter,
rewriter.createOrFold<arith::ConstantIndexOp>(loc, staticPart);
if (dynamicPart)
stride =
- rewriter.create<arith::MulIOp>(loc, dynamicPart, stride, ovflags);
+ arith::MulIOp::create(rewriter, loc, dynamicPart, stride, ovflags);
result.push_back(stride);
}
}
@@ -106,20 +107,20 @@ affine::lowerAffineDelinearizeIndexOp(RewriterBase &rewriter,
Value zero = rewriter.createOrFold<arith::ConstantIndexOp>(loc, 0);
Value initialPart =
- rewriter.create<arith::FloorDivSIOp>(loc, linearIdx, strides.front());
+ arith::FloorDivSIOp::create(rewriter, loc, linearIdx, strides.front());
results.push_back(initialPart);
auto emitModTerm = [&](Value stride) -> Value {
- Value remainder = rewriter.create<arith::RemSIOp>(loc, linearIdx, stride);
- Value remainderNegative = rewriter.create<arith::CmpIOp>(
- loc, arith::CmpIPredicate::slt, remainder, zero);
+ Value remainder = arith::RemSIOp::create(rewriter, loc, linearIdx, stride);
+ Value remainderNegative = arith::CmpIOp::create(
+ rewriter, loc, arith::CmpIPredicate::slt, remainder, zero);
// If the correction is relevant, this term is <= stride, which is known
// to be positive in `index`. Otherwise, while 2 * stride might overflow,
// this branch won't be taken, so the risk of `poison` is fine.
- Value corrected = rewriter.create<arith::AddIOp>(
- loc, remainder, stride, arith::IntegerOverflowFlags::nsw);
- Value mod = rewriter.create<arith::SelectOp>(loc, remainderNegative,
- corrected, remainder);
+ Value corrected = arith::AddIOp::create(rewriter, loc, remainder, stride,
+ arith::IntegerOverflowFlags::nsw);
+ Value mod = arith::SelectOp::create(rewriter, loc, remainderNegative,
+ corrected, remainder);
return mod;
};
@@ -131,7 +132,7 @@ affine::lowerAffineDelinearizeIndexOp(RewriterBase &rewriter,
// We know both inputs are positive, so floorDiv == div.
// This could potentially be a divui, but it's not clear if that would
// cause issues.
- Value divided = rewriter.create<arith::DivSIOp>(loc, modulus, nextStride);
+ Value divided = arith::DivSIOp::create(rewriter, loc, modulus, nextStride);
results.push_back(divided);
}
@@ -167,8 +168,8 @@ LogicalResult affine::lowerAffineLinearizeIndexOp(RewriterBase &rewriter,
// our hands on an `OpOperand&` for the loop invariant counting function.
for (auto [stride, idxOp] :
llvm::zip_equal(strides, llvm::drop_end(op.getMultiIndexMutable()))) {
- Value scaledIdx = rewriter.create<arith::MulIOp>(
- loc, idxOp.get(), stride, arith::IntegerOverflowFlags::nsw);
+ Value scaledIdx = arith::MulIOp::create(rewriter, loc, idxOp.get(), stride,
+ arith::IntegerOverflowFlags::nsw);
int64_t numHoistableLoops = numEnclosingInvariantLoops(idxOp);
scaledValues.emplace_back(scaledIdx, numHoistableLoops);
}
@@ -184,8 +185,8 @@ LogicalResult affine::lowerAffineLinearizeIndexOp(RewriterBase &rewriter,
Value result = scaledValues.front().first;
for (auto [scaledValue, numHoistableLoops] : llvm::drop_begin(scaledValues)) {
std::ignore = numHoistableLoops;
- result = rewriter.create<arith::AddIOp>(loc, result, scaledValue,
- arith::IntegerOverflowFlags::nsw);
+ result = arith::AddIOp::create(rewriter, loc, result, scaledValue,
+ arith::IntegerOverflowFlags::nsw);
}
rewriter.replaceOp(op, result);
return success();
diff --git a/mlir/lib/Dialect/Affine/Transforms/DecomposeAffineOps.cpp b/mlir/lib/Dialect/Affine/Transforms/DecomposeAffineOps.cpp
index 4fd0cf9b3cd25..3c00b323473d2 100644
--- a/mlir/lib/Dialect/Affine/Transforms/DecomposeAffineOps.cpp
+++ b/mlir/lib/Dialect/Affine/Transforms/DecomposeAffineOps.cpp
@@ -88,8 +88,8 @@ static AffineApplyOp createSubApply(RewriterBase &rewriter,
auto rhsMap = AffineMap::get(m.getNumDims(), m.getNumSymbols(), expr, ctx);
SmallVector<Value> rhsOperands = originalOp->getOperands();
canonicalizeMapAndOperands(&rhsMap, &rhsOperands);
- return rewriter.create<AffineApplyOp>(originalOp.getLoc(), rhsMap,
- rhsOperands);
+ return AffineApplyOp::create(rewriter, originalOp.getLoc(), rhsMap,
+ rhsOperands);
}
FailureOr<AffineApplyOp> mlir::affine::decompose(RewriterBase &rewriter,
@@ -160,8 +160,8 @@ FailureOr<AffineApplyOp> mlir::affine::decompose(RewriterBase &rewriter,
auto current = createSubApply(rewriter, op, subExpressions[0]);
for (int64_t i = 1, e = subExpressions.size(); i < e; ++i) {
Value tmp = createSubApply(rewriter, op, subExpressions[i]);
- current = rewriter.create<AffineApplyOp>(op.getLoc(), binMap,
- ValueRange{current, tmp});
+ current = AffineApplyOp::create(rewriter, op.getLoc(), binMap,
+ ValueRange{current, tmp});
LLVM_DEBUG(DBGS() << "--reassociate into: " << current << "\n");
}
diff --git a/mlir/lib/Dialect/Affine/Transforms/LoopFusion.cpp b/mlir/lib/Dialect/Affine/Transforms/LoopFusion.cpp
index 1d5a665bf6bb1..6c9adff7e9106 100644
--- a/mlir/lib/Dialect/Affine/Transforms/LoopFusion.cpp
+++ b/mlir/lib/Dialect/Affine/Transforms/LoopFusion.cpp
@@ -424,7 +424,7 @@ static Value createPrivateMemRef(AffineForOp forOp,
// consumer loop nests to reduce their live range. Currently they are added
// at the beginning of the block, because loop nests can be reordered
// during the fusion pass.
- Value newMemRef = top.create<memref::AllocOp>(forOp.getLoc(), newMemRefType);
+ Value newMemRef = memref::AllocOp::create(top, forOp.getLoc(), newMemRefType);
// Build an AffineMap to remap access functions based on lower bound offsets.
SmallVector<AffineExpr, 4> remapExprs;
diff --git a/mlir/lib/Dialect/Affine/Transforms/PipelineDataTransfer.cpp b/mlir/lib/Dialect/Affine/Transforms/PipelineDataTransfer.cpp
index 05a352f39a93c..c942c0248fefd 100644
--- a/mlir/lib/Dialect/Affine/Transforms/PipelineDataTransfer.cpp
+++ b/mlir/lib/Dialect/Affine/Transforms/PipelineDataTransfer.cpp
@@ -100,16 +100,16 @@ static bool doubleBuffer(Value oldMemRef, AffineForOp forOp) {
}
// Create and place the alloc right before the 'affine.for' operation.
- Value newMemRef = bOuter.create<memref::AllocOp>(
- forOp.getLoc(), newMemRefType, allocOperands);
+ Value newMemRef = memref::AllocOp::create(bOuter, forOp.getLoc(),
+ newMemRefType, allocOperands);
// Create 'iv mod 2' value to index the leading dimension.
auto d0 = bInner.getAffineDimExpr(0);
int64_t step = forOp.getStepAsInt();
auto modTwoMap =
AffineMap::get(/*dimCount=*/1, /*symbolCount=*/0, d0.floorDiv(step) % 2);
- auto ivModTwoOp = bInner.create<AffineApplyOp>(forOp.getLoc(), modTwoMap,
- forOp.getInductionVar());
+ auto ivModTwoOp = AffineApplyOp::create(bInner, forOp.getLoc(), modTwoMap,
+ forOp.getInductionVar());
// replaceAllMemRefUsesWith will succeed unless the forOp body has
// non-dereferencing uses of the memref (dealloc's are fine though).
@@ -130,7 +130,7 @@ static bool doubleBuffer(Value oldMemRef, AffineForOp forOp) {
}
// Insert the dealloc op right after the for loop.
bOuter.setInsertionPointAfter(forOp);
- bOuter.create<memref::DeallocOp>(forOp.getLoc(), newMemRef);
+ memref::DeallocOp::create(bOuter, forOp.getLoc(), newMemRef);
return true;
}
diff --git a/mlir/lib/Dialect/Affine/Transforms/ReifyValueBounds.cpp b/mlir/lib/Dialect/Affine/Transforms/ReifyValueBounds.cpp
index 1a266b72d1f8d..9537d3e75c26a 100644
--- a/mlir/lib/Dialect/Affine/Transforms/ReifyValueBounds.cpp
+++ b/mlir/lib/Dialect/Affine/Transforms/ReifyValueBounds.cpp
@@ -51,10 +51,10 @@ OpFoldResult affine::materializeComputedBound(
"expected dynamic dim");
if (isa<RankedTensorType>(value.getType())) {
// A tensor dimension is used: generate a tensor.dim.
- operands.push_back(b.create<tensor::DimOp>(loc, value, *dim));
+ operands.push_back(tensor::DimOp::create(b, loc, value, *dim));
} else if (isa<MemRefType>(value.getType())) {
// A memref dimension is used: generate a memref.dim.
- operands.push_back(b.create<memref::DimOp>(loc, value, *dim));
+ operands.push_back(memref::DimOp::create(b, loc, value, *dim));
} else {
llvm_unreachable("cannot generate DimOp for unsupported shaped type");
}
@@ -76,7 +76,7 @@ OpFoldResult affine::materializeComputedBound(
operands[expr.getPosition() + boundMap.getNumDims()]);
// General case: build affine.apply op.
return static_cast<OpFoldResult>(
- b.create<affine::AffineApplyOp>(loc, boundMap, operands).getResult());
+ affine::AffineApplyOp::create(b, loc, boundMap, operands).getResult());
}
FailureOr<OpFoldResult> mlir::affine::reifyShapedValueDimBound(
diff --git a/mlir/lib/Dialect/Affine/Transforms/SuperVectorize.cpp b/mlir/lib/Dialect/Affine/Transforms/SuperVectorize.cpp
index 7fae260767e0a..10da9070136c1 100644
--- a/mlir/lib/Dialect/Affine/Transforms/SuperVectorize.cpp
+++ b/mlir/lib/Dialect/Affine/Transforms/SuperVectorize.cpp
@@ -905,8 +905,8 @@ static void computeMemoryOpIndices(Operation *op, AffineMap map,
for (auto resultExpr : map.getResults()) {
auto singleResMap =
AffineMap::get(map.getNumDims(), map.getNumSymbols(), resultExpr);
- auto afOp = state.builder.create<AffineApplyOp>(op->getLoc(), singleResMap,
- mapOperands);
+ auto afOp = AffineApplyOp::create(state.builder, op->getLoc(), singleResMap,
+ mapOperands);
results.push_back(afOp);
}
}
@@ -961,7 +961,7 @@ static arith::ConstantOp vectorizeConstant(arith::ConstantOp constOp,
auto vecForOp = cast<AffineForOp>(parentOp);
state.builder.setInsertionPointToStart(vecForOp.getBody());
auto newConstOp =
- state.builder.create<arith::ConstantOp>(constOp.getLoc(), vecAttr);
+ arith::ConstantOp::create(state.builder, constOp.getLoc(), vecAttr);
// Register vector replacement for future uses in the scope.
state.registerOpVectorReplacement(constOp, newConstOp);
@@ -986,8 +986,8 @@ static Operation *vectorizeAffineApplyOp(AffineApplyOp applyOp,
}
}
- auto newApplyOp = state.builder.create<AffineApplyOp>(
- applyOp.getLoc(), applyOp.getAffineMap(), updatedOperands);
+ auto newApplyOp = AffineApplyOp::create(
+ state.builder, applyOp.getLoc(), applyOp.getAffineMap(), updatedOperands);
// Register the new affine.apply result.
state.registerValueScalarReplacement(applyOp.getResult(),
@@ -1010,7 +1010,7 @@ static arith::ConstantOp createInitialVector(arith::AtomicRMWKind reductionKind,
auto vecTy = getVectorType(scalarTy, state.strategy);
auto vecAttr = DenseElementsAttr::get(vecTy, valueAttr);
auto newConstOp =
- state.builder.create<arith::ConstantOp>(oldOperand.getLoc(), vecAttr);
+ arith::ConstantOp::create(state.builder, oldOperand.getLoc(), vecAttr);
return newConstOp;
}
@@ -1062,11 +1062,11 @@ static Value createMask(AffineForOp vecForOp, VectorizationState &state) {
AffineMap ubMap = vecForOp.getUpperBoundMap();
Value ub;
if (ubMap.getNumResults() == 1)
- ub = state.builder.create<AffineApplyOp>(loc, vecForOp.getUpperBoundMap(),
- vecForOp.getUpperBoundOperands());
+ ub = AffineApplyOp::create(state.builder, loc, vecForOp.getUpperBoundMap(),
+ vecForOp.getUpperBoundOperands());
else
- ub = state.builder.create<AffineMinOp>(loc, vecForOp.getUpperBoundMap(),
- vecForOp.getUpperBoundOperands());
+ ub = AffineMinOp::create(state.builder, loc, vecForOp.getUpperBoundMap(),
+ vecForOp.getUpperBoundOperands());
// Then we compute the number of (original) iterations left in the loop.
AffineExpr subExpr =
state.builder.getAffineDimExpr(0) - state.builder.getAffineDimExpr(1);
@@ -1080,7 +1080,7 @@ static Value createMask(AffineForOp vecForOp, VectorizationState &state) {
Type maskTy = VectorType::get(state.strategy->vectorSizes,
state.builder.getIntegerType(1));
Value mask =
- state.builder.create<vector::CreateMaskOp>(loc, maskTy, itersLeft);
+ vector::CreateMaskOp::create(state.builder, loc, maskTy, itersLeft);
LLVM_DEBUG(dbgs() << "\n[early-vect]+++++ creating a mask:\n"
<< itersLeft << "\n"
@@ -1123,8 +1123,8 @@ static Operation *vectorizeUniform(Value uniformVal,
state.builder.setInsertionPointAfterValue(uniformScalarRepl);
auto vectorTy = getVectorType(uniformVal.getType(), state.strategy);
- auto bcastOp = state.builder.create<BroadcastOp>(uniformVal.getLoc(),
- vectorTy, uniformScalarRepl);
+ auto bcastOp = BroadcastOp::create(state.builder, uniformVal.getLoc(),
+ vectorTy, uniformScalarRepl);
state.registerValueVectorReplacement(uniformVal, bcastOp);
return bcastOp;
}
@@ -1256,8 +1256,8 @@ static Operation *vectorizeAffineLoad(AffineLoadOp loadOp,
LLVM_DEBUG(dbgs() << "\n[early-vect]+++++ permutationMap: ");
LLVM_DEBUG(permutationMap.print(dbgs()));
- auto transfer = state.builder.create<vector::TransferReadOp>(
- loadOp.getLoc(), vectorType, loadOp.getMemRef(), indices,
+ auto transfer = vector::TransferReadOp::create(
+ state.builder, loadOp.getLoc(), vectorType, loadOp.getMemRef(), indices,
/*padding=*/std::nullopt, permutationMap);
// Register replacement for future uses in the scope.
@@ -1303,9 +1303,9 @@ static Operation *vectorizeAffineStore(AffineStoreOp storeOp,
LLVM_DEBUG(dbgs() << "\n[early-vect]+++++ permutationMap: ");
LLVM_DEBUG(permutationMap.print(dbgs()));
- auto transfer = state.builder.create<vector::TransferWriteOp>(
- storeOp.getLoc(), vectorValue, storeOp.getMemRef(), indices,
- permutationMap);
+ auto transfer = vector::TransferWriteOp::create(
+ state.builder, storeOp.getLoc(), vectorValue, storeOp.getMemRef(),
+ indices, permutationMap);
LLVM_DEBUG(dbgs() << "\n[early-vect]+++++ vectorized store: " << transfer);
// Register replacement for future uses in the scope.
@@ -1387,10 +1387,10 @@ static Operation *vectorizeAffineForOp(AffineForOp forOp,
}
}
- auto vecForOp = state.builder.create<AffineForOp>(
- forOp.getLoc(), forOp.getLowerBoundOperands(), forOp.getLowerBoundMap(),
- forOp.getUpperBoundOperands(), forOp.getUpperBoundMap(), newStep,
- vecIterOperands,
+ auto vecForOp = AffineForOp::create(
+ state.builder, forOp.getLoc(), forOp.getLowerBoundOperands(),
+ forOp.getLowerBoundMap(), forOp.getUpperBoundOperands(),
+ forOp.getUpperBoundMap(), newStep, vecIterOperands,
/*bodyBuilder=*/[](OpBuilder &, Location, Value, ValueRange) {
// Make sure we don't create a default terminator in the loop body as
// the proper terminator will be added during vectorization.
@@ -1512,8 +1512,8 @@ static Operation *vectorizeAffineYieldOp(AffineYieldOp yieldOp,
// IterOperands are neutral element vectors.
Value neutralVal = cast<AffineForOp>(newParentOp).getInits()[i];
state.builder.setInsertionPoint(combinerOps.back());
- Value maskedReducedVal = state.builder.create<arith::SelectOp>(
- reducedVal.getLoc(), mask, reducedVal, neutralVal);
+ Value maskedReducedVal = arith::SelectOp::create(
+ state.builder, reducedVal.getLoc(), mask, reducedVal, neutralVal);
LLVM_DEBUG(
dbgs() << "\n[early-vect]+++++ masking an input to a binary op that"
"produces value for a yield Op: "
@@ -1865,7 +1865,6 @@ verifyLoopNesting(const std::vector<SmallVector<AffineForOp, 2>> &loops) {
return success();
}
-
/// External utility to vectorize affine loops in 'loops' using the n-D
/// vectorization factors in 'vectorSizes'. By default, each vectorization
/// factor is applied inner-to-outer to the loops of each loop nest.
@@ -1927,4 +1926,4 @@ LogicalResult mlir::affine::vectorizeAffineLoopNest(
if (failed(verifyLoopNesting(loops)))
return failure();
return vectorizeLoopNest(loops, strategy);
-}
+}
\ No newline at end of file
diff --git a/mlir/lib/Dialect/Affine/Utils/LoopUtils.cpp b/mlir/lib/Dialect/Affine/Utils/LoopUtils.cpp
index 21f69ad2d4c25..2de057d1d0758 100644
--- a/mlir/lib/Dialect/Affine/Utils/LoopUtils.cpp
+++ b/mlir/lib/Dialect/Affine/Utils/LoopUtils.cpp
@@ -54,8 +54,8 @@ getCleanupLoopLowerBound(AffineForOp forOp, unsigned unrollFactor,
OpBuilder b(forOp);
auto lbMap = forOp.getLowerBoundMap();
- auto lb = b.create<AffineApplyOp>(forOp.getLoc(), lbMap,
- forOp.getLowerBoundOperands());
+ auto lb = AffineApplyOp::create(b, forOp.getLoc(), lbMap,
+ forOp.getLowerBoundOperands());
// For each upper bound expr, get the range.
// Eg: affine.for %i = lb to min (ub1, ub2),
@@ -71,7 +71,7 @@ getCleanupLoopLowerBound(AffineForOp forOp, unsigned unrollFactor,
auto bumpMap = AffineMap::get(tripCountMap.getNumDims(),
tripCountMap.getNumSymbols(), bumpExprs[i]);
bumpValues[i] =
- b.create<AffineApplyOp>(forOp.getLoc(), bumpMap, tripCountOperands);
+ AffineApplyOp::create(b, forOp.getLoc(), bumpMap, tripCountOperands);
}
SmallVector<AffineExpr, 4> newUbExprs(tripCountMap.getNumResults());
@@ -134,8 +134,8 @@ LogicalResult mlir::affine::promoteIfSingleIteration(AffineForOp forOp) {
builder.setInsertionPointToStart(&func.getFunctionBody().front());
else
builder.setInsertionPoint(forOp);
- auto constOp = builder.create<arith::ConstantIndexOp>(
- forOp.getLoc(), forOp.getConstantLowerBound());
+ auto constOp = arith::ConstantIndexOp::create(
+ builder, forOp.getLoc(), forOp.getConstantLowerBound());
iv.replaceAllUsesWith(constOp);
} else {
auto lbOperands = forOp.getLowerBoundOperands();
@@ -146,7 +146,7 @@ LogicalResult mlir::affine::promoteIfSingleIteration(AffineForOp forOp) {
iv.replaceAllUsesWith(lbOperands[0]);
} else {
auto affineApplyOp =
- builder.create<AffineApplyOp>(forOp.getLoc(), lbMap, lbOperands);
+ AffineApplyOp::create(builder, forOp.getLoc(), lbMap, lbOperands);
iv.replaceAllUsesWith(affineApplyOp);
}
}
@@ -181,8 +181,8 @@ static AffineForOp generateShiftedLoop(
assert(ubMap.getNumInputs() == ubOperands.size());
auto loopChunk =
- b.create<AffineForOp>(srcForOp.getLoc(), lbOperands, lbMap, ubOperands,
- ubMap, srcForOp.getStepAsInt());
+ AffineForOp::create(b, srcForOp.getLoc(), lbOperands, lbMap, ubOperands,
+ ubMap, srcForOp.getStepAsInt());
auto loopChunkIV = loopChunk.getInductionVar();
auto srcIV = srcForOp.getInductionVar();
@@ -197,8 +197,8 @@ static AffineForOp generateShiftedLoop(
// Generate the remapping if the shift is not zero: remappedIV = newIV -
// shift.
if (!srcIV.use_empty() && shift != 0) {
- auto ivRemap = bodyBuilder.create<AffineApplyOp>(
- srcForOp.getLoc(),
+ auto ivRemap = AffineApplyOp::create(
+ bodyBuilder, srcForOp.getLoc(),
bodyBuilder.getSingleDimShiftAffineMap(
-static_cast<int64_t>(srcForOp.getStepAsInt() * shift)),
loopChunkIV);
@@ -433,7 +433,7 @@ static void constructTiledLoopNest(MutableArrayRef<AffineForOp> origLoops,
for (unsigned i = 0; i < width; i++) {
OpBuilder b(topLoop);
// Loop bounds will be set later.
- AffineForOp pointLoop = b.create<AffineForOp>(loc, 0, 0);
+ AffineForOp pointLoop = AffineForOp::create(b, loc, 0, 0);
pointLoop.getBody()->getOperations().splice(
pointLoop.getBody()->begin(), topLoop->getBlock()->getOperations(),
topLoop);
@@ -447,7 +447,7 @@ static void constructTiledLoopNest(MutableArrayRef<AffineForOp> origLoops,
for (unsigned i = width; i < 2 * width; i++) {
OpBuilder b(topLoop);
// Loop bounds will be set later.
- AffineForOp tileSpaceLoop = b.create<AffineForOp>(loc, 0, 0);
+ AffineForOp tileSpaceLoop = AffineForOp::create(b, loc, 0, 0);
tileSpaceLoop.getBody()->getOperations().splice(
tileSpaceLoop.getBody()->begin(), topLoop->getBlock()->getOperations(),
topLoop);
@@ -1048,7 +1048,7 @@ LogicalResult mlir::affine::loopUnrollByFactor(
// iv' = iv + i * step
auto d0 = b.getAffineDimExpr(0);
auto bumpMap = AffineMap::get(1, 0, d0 + i * step);
- return b.create<AffineApplyOp>(forOp.getLoc(), bumpMap, iv);
+ return AffineApplyOp::create(b, forOp.getLoc(), bumpMap, iv);
},
/*annotateFn=*/annotateFn,
/*iterArgs=*/iterArgs, /*yieldedValues=*/yieldedValues);
@@ -1212,7 +1212,7 @@ LogicalResult mlir::affine::loopUnrollJamByFactor(AffineForOp forOp,
auto d0 = builder.getAffineDimExpr(0);
auto bumpMap = AffineMap::get(1, 0, d0 + i * step);
auto ivUnroll =
- builder.create<AffineApplyOp>(forOp.getLoc(), bumpMap, forOpIV);
+ AffineApplyOp::create(builder, forOp.getLoc(), bumpMap, forOpIV);
operandMaps[i - 1].map(forOpIV, ivUnroll);
}
// Clone the sub-block being unroll-jammed.
@@ -1541,8 +1541,8 @@ stripmineSink(AffineForOp forOp, uint64_t factor,
for (auto t : targets) {
// Insert newForOp before the terminator of `t`.
auto b = OpBuilder::atBlockTerminator(t.getBody());
- auto newForOp = b.create<AffineForOp>(t.getLoc(), lbOperands, lbMap,
- ubOperands, ubMap, originalStep);
+ auto newForOp = AffineForOp::create(b, t.getLoc(), lbOperands, lbMap,
+ ubOperands, ubMap, originalStep);
auto begin = t.getBody()->begin();
// Skip terminator and `newForOp` which is just before the terminator.
auto nOps = t.getBody()->getOperations().size() - 2;
@@ -1616,9 +1616,9 @@ LogicalResult mlir::affine::coalesceLoops(MutableArrayRef<AffineForOp> loops) {
// 1. Store the upper bound of the outermost loop in a variable.
Value prev;
if (!llvm::hasSingleElement(origUbMap.getResults()))
- prev = builder.create<AffineMinOp>(loc, origUbMap, ubOperands);
+ prev = AffineMinOp::create(builder, loc, origUbMap, ubOperands);
else
- prev = builder.create<AffineApplyOp>(loc, origUbMap, ubOperands);
+ prev = AffineApplyOp::create(builder, loc, origUbMap, ubOperands);
upperBoundSymbols.push_back(prev);
// 2. Emit code computing the upper bound of the coalesced loop as product of
@@ -1630,16 +1630,16 @@ LogicalResult mlir::affine::coalesceLoops(MutableArrayRef<AffineForOp> loops) {
Value upperBound;
// If upper bound map has more than one result, take their minimum.
if (!llvm::hasSingleElement(origUbMap.getResults()))
- upperBound = builder.create<AffineMinOp>(loc, origUbMap, ubOperands);
+ upperBound = AffineMinOp::create(builder, loc, origUbMap, ubOperands);
else
- upperBound = builder.create<AffineApplyOp>(loc, origUbMap, ubOperands);
+ upperBound = AffineApplyOp::create(builder, loc, origUbMap, ubOperands);
upperBoundSymbols.push_back(upperBound);
SmallVector<Value, 4> operands;
operands.push_back(prev);
operands.push_back(upperBound);
// Maintain running product of loop upper bounds.
- prev = builder.create<AffineApplyOp>(
- loc,
+ prev = AffineApplyOp::create(
+ builder, loc,
AffineMap::get(/*dimCount=*/1,
/*symbolCount=*/1,
builder.getAffineDimExpr(0) *
@@ -1668,13 +1668,12 @@ LogicalResult mlir::affine::coalesceLoops(MutableArrayRef<AffineForOp> loops) {
SmallVector<Value, 4> operands;
operands.push_back(previous);
operands.push_back(upperBoundSymbols[idx]);
- previous = builder.create<AffineApplyOp>(
- loc,
- AffineMap::get(
- /*dimCount=*/1, /*symbolCount=*/1,
- builder.getAffineDimExpr(0).floorDiv(
- builder.getAffineSymbolExpr(0))),
- operands);
+ previous = AffineApplyOp::create(builder, loc,
+ AffineMap::get(
+ /*dimCount=*/1, /*symbolCount=*/1,
+ builder.getAffineDimExpr(0).floorDiv(
+ builder.getAffineSymbolExpr(0))),
+ operands);
}
// Modified value of the induction variables of the nested loops after
// coalescing.
@@ -1685,8 +1684,8 @@ LogicalResult mlir::affine::coalesceLoops(MutableArrayRef<AffineForOp> loops) {
SmallVector<Value, 4> applyOperands;
applyOperands.push_back(previous);
applyOperands.push_back(upperBoundSymbols[idx - 1]);
- inductionVariable = builder.create<AffineApplyOp>(
- loc,
+ inductionVariable = AffineApplyOp::create(
+ builder, loc,
AffineMap::get(
/*dimCount=*/1, /*symbolCount=*/1,
builder.getAffineDimExpr(0) % builder.getAffineSymbolExpr(0)),
@@ -1723,21 +1722,21 @@ void mlir::affine::mapLoopToProcessorIds(scf::ForOp forOp,
Value linearIndex = processorId.front();
for (unsigned i = 1, e = processorId.size(); i < e; ++i) {
- auto mulApplyOp = b.create<AffineApplyOp>(
- loc, mulMap, ValueRange{linearIndex, numProcessors[i]});
- linearIndex = b.create<AffineApplyOp>(
- loc, addMap, ValueRange{mulApplyOp, processorId[i]});
+ auto mulApplyOp = AffineApplyOp::create(
+ b, loc, mulMap, ValueRange{linearIndex, numProcessors[i]});
+ linearIndex = AffineApplyOp::create(b, loc, addMap,
+ ValueRange{mulApplyOp, processorId[i]});
}
- auto mulApplyOp = b.create<AffineApplyOp>(
- loc, mulMap, ValueRange{linearIndex, forOp.getStep()});
- Value lb = b.create<AffineApplyOp>(
- loc, addMap, ValueRange{mulApplyOp, forOp.getLowerBound()});
+ auto mulApplyOp = AffineApplyOp::create(
+ b, loc, mulMap, ValueRange{linearIndex, forOp.getStep()});
+ Value lb = AffineApplyOp::create(
+ b, loc, addMap, ValueRange{mulApplyOp, forOp.getLowerBound()});
forOp.setLowerBound(lb);
Value step = forOp.getStep();
for (auto numProcs : numProcessors)
- step = b.create<AffineApplyOp>(loc, mulMap, ValueRange{numProcs, step});
+ step = AffineApplyOp::create(b, loc, mulMap, ValueRange{numProcs, step});
forOp.setStep(step);
}
@@ -1874,7 +1873,7 @@ generatePointWiseCopy(Location loc, Value memref, Value fastMemRef,
auto fastBufOffsetMap =
AffineMap::get(lbOperands.size(), 0, fastBufOffsets[d]);
- auto offset = b.create<AffineApplyOp>(loc, fastBufOffsetMap, lbOperands);
+ auto offset = AffineApplyOp::create(b, loc, fastBufOffsetMap, lbOperands);
// Construct the subscript for the fast memref being copied into/from:
// x - offset_x.
@@ -1901,16 +1900,16 @@ generatePointWiseCopy(Location loc, Value memref, Value fastMemRef,
if (!isCopyOut) {
// Copy in.
- auto load = b.create<AffineLoadOp>(loc, memref, memIndices);
- b.create<AffineStoreOp>(loc, load, fastMemRef, fastBufMap,
- fastBufMapOperands);
+ auto load = AffineLoadOp::create(b, loc, memref, memIndices);
+ AffineStoreOp::create(b, loc, load, fastMemRef, fastBufMap,
+ fastBufMapOperands);
return copyNestRoot;
}
// Copy out.
auto load =
- b.create<AffineLoadOp>(loc, fastMemRef, fastBufMap, fastBufMapOperands);
- b.create<AffineStoreOp>(loc, load, memref, memIndices);
+ AffineLoadOp::create(b, loc, fastMemRef, fastBufMap, fastBufMapOperands);
+ AffineStoreOp::create(b, loc, load, memref, memIndices);
return copyNestRoot;
}
@@ -1945,7 +1944,7 @@ static LogicalResult generateCopy(
auto f = begin->getParentOfType<FunctionOpInterface>();
OpBuilder topBuilder(f.getFunctionBody());
- Value zeroIndex = topBuilder.create<arith::ConstantIndexOp>(f.getLoc(), 0);
+ Value zeroIndex = arith::ConstantIndexOp::create(topBuilder, f.getLoc(), 0);
*sizeInBytes = 0;
@@ -2056,7 +2055,7 @@ static LogicalResult generateCopy(
memIndices.push_back(zeroIndex);
} else {
memIndices.push_back(
- top.create<arith::ConstantIndexOp>(loc, indexVal).getResult());
+ arith::ConstantIndexOp::create(top, loc, indexVal).getResult());
}
} else {
// The coordinate for the start location is just the lower bound along the
@@ -2070,7 +2069,8 @@ static LogicalResult generateCopy(
lbs[d] = lbs[d].replaceDimsAndSymbols(
/*dimReplacements=*/{}, symReplacements, lbs[d].getNumSymbols(),
/*numResultSyms=*/0);
- memIndices.push_back(b.create<AffineApplyOp>(loc, lbs[d], regionSymbols));
+ memIndices.push_back(
+ AffineApplyOp::create(b, loc, lbs[d], regionSymbols));
}
// The fast buffer is copied into at location zero; addressing is relative.
bufIndices.push_back(zeroIndex);
@@ -2094,7 +2094,7 @@ static LogicalResult generateCopy(
// Create the fast memory space buffer just before the 'affine.for'
// operation.
fastMemRef =
- prologue.create<memref::AllocOp>(loc, fastMemRefType).getResult();
+ memref::AllocOp::create(prologue, loc, fastMemRefType).getResult();
// Record it.
fastBufferMap[memref] = fastMemRef;
// fastMemRefType is a constant shaped memref.
@@ -2111,7 +2111,7 @@ static LogicalResult generateCopy(
fastMemRef = fastBufferMap[memref];
}
- auto numElementsSSA = top.create<arith::ConstantIndexOp>(loc, *numElements);
+ auto numElementsSSA = arith::ConstantIndexOp::create(top, loc, *numElements);
Value dmaStride;
Value numEltPerDmaStride;
@@ -2128,9 +2128,9 @@ static LogicalResult generateCopy(
if (!dmaStrideInfos.empty()) {
dmaStride =
- top.create<arith::ConstantIndexOp>(loc, dmaStrideInfos[0].stride);
- numEltPerDmaStride = top.create<arith::ConstantIndexOp>(
- loc, dmaStrideInfos[0].numEltPerStride);
+ arith::ConstantIndexOp::create(top, loc, dmaStrideInfos[0].stride);
+ numEltPerDmaStride = arith::ConstantIndexOp::create(
+ top, loc, dmaStrideInfos[0].numEltPerStride);
}
}
@@ -2160,21 +2160,21 @@ static LogicalResult generateCopy(
// Create a tag (single element 1-d memref) for the DMA.
auto tagMemRefType = MemRefType::get({1}, top.getIntegerType(32), {},
copyOptions.tagMemorySpace);
- auto tagMemRef = prologue.create<memref::AllocOp>(loc, tagMemRefType);
+ auto tagMemRef = memref::AllocOp::create(prologue, loc, tagMemRefType);
SmallVector<Value, 4> tagIndices({zeroIndex});
auto tagAffineMap = b.getMultiDimIdentityMap(tagIndices.size());
fullyComposeAffineMapAndOperands(&tagAffineMap, &tagIndices);
if (!region.isWrite()) {
// DMA non-blocking read from original buffer to fast buffer.
- b.create<AffineDmaStartOp>(loc, memref, memAffineMap, memIndices,
- fastMemRef, bufAffineMap, bufIndices,
- tagMemRef, tagAffineMap, tagIndices,
- numElementsSSA, dmaStride, numEltPerDmaStride);
+ AffineDmaStartOp::create(b, loc, memref, memAffineMap, memIndices,
+ fastMemRef, bufAffineMap, bufIndices, tagMemRef,
+ tagAffineMap, tagIndices, numElementsSSA,
+ dmaStride, numEltPerDmaStride);
} else {
// DMA non-blocking write from fast buffer to the original memref.
- auto op = b.create<AffineDmaStartOp>(
- loc, fastMemRef, bufAffineMap, bufIndices, memref, memAffineMap,
+ auto op = AffineDmaStartOp::create(
+ b, loc, fastMemRef, bufAffineMap, bufIndices, memref, memAffineMap,
memIndices, tagMemRef, tagAffineMap, tagIndices, numElementsSSA,
dmaStride, numEltPerDmaStride);
// Since new ops may be appended at 'end' (for outgoing DMAs), adjust the
@@ -2184,11 +2184,11 @@ static LogicalResult generateCopy(
}
// Matching DMA wait to block on completion; tag always has a 0 index.
- b.create<AffineDmaWaitOp>(loc, tagMemRef, tagAffineMap, zeroIndex,
- numElementsSSA);
+ AffineDmaWaitOp::create(b, loc, tagMemRef, tagAffineMap, zeroIndex,
+ numElementsSSA);
// Generate dealloc for the tag.
- auto tagDeallocOp = epilogue.create<memref::DeallocOp>(loc, tagMemRef);
+ auto tagDeallocOp = memref::DeallocOp::create(epilogue, loc, tagMemRef);
if (*nEnd == end && isCopyOutAtEndOfBlock)
// Since new ops are being appended (for outgoing DMAs), adjust the end to
// mark end of range of the original.
@@ -2197,7 +2197,7 @@ static LogicalResult generateCopy(
// Generate dealloc for the buffer.
if (!existingBuf) {
- auto bufDeallocOp = epilogue.create<memref::DeallocOp>(loc, fastMemRef);
+ auto bufDeallocOp = memref::DeallocOp::create(epilogue, loc, fastMemRef);
// When generating pointwise copies, `nEnd' has to be set to deallocOp on
// the fast buffer (since it marks the new end insertion point).
if (!copyOptions.generateDma && *nEnd == end && isCopyOutAtEndOfBlock)
@@ -2567,8 +2567,8 @@ AffineForOp mlir::affine::createCanonicalizedAffineForOp(
canonicalizeMapAndOperands(&ubMap, &upperOperands);
ubMap = removeDuplicateExprs(ubMap);
- return b.create<AffineForOp>(loc, lowerOperands, lbMap, upperOperands, ubMap,
- step);
+ return AffineForOp::create(b, loc, lowerOperands, lbMap, upperOperands, ubMap,
+ step);
}
/// Creates an AffineIfOp that encodes the conditional to choose between
@@ -2651,8 +2651,8 @@ static AffineIfOp createSeparationCondition(MutableArrayRef<AffineForOp> loops,
SmallVector<Value, 4> setOperands;
cst.getValues(0, cst.getNumDimAndSymbolVars(), &setOperands);
canonicalizeSetAndOperands(&ifCondSet, &setOperands);
- return b.create<AffineIfOp>(loops[0].getLoc(), ifCondSet, setOperands,
- /*withElseRegion=*/true);
+ return AffineIfOp::create(b, loops[0].getLoc(), ifCondSet, setOperands,
+ /*withElseRegion=*/true);
}
/// Create the full tile loop nest (along with its body).
diff --git a/mlir/lib/Dialect/Affine/Utils/Utils.cpp b/mlir/lib/Dialect/Affine/Utils/Utils.cpp
index 7bb158eb6dfc0..845be20d15b69 100644
--- a/mlir/lib/Dialect/Affine/Utils/Utils.cpp
+++ b/mlir/lib/Dialect/Affine/Utils/Utils.cpp
@@ -56,7 +56,7 @@ class AffineApplyExpander
auto rhs = visit(expr.getRHS());
if (!lhs || !rhs)
return nullptr;
- auto op = builder.create<OpTy>(loc, lhs, rhs, overflowFlags);
+ auto op = OpTy::create(builder, loc, lhs, rhs, overflowFlags);
return op.getResult();
}
@@ -90,14 +90,14 @@ class AffineApplyExpander
auto rhs = visit(expr.getRHS());
assert(lhs && rhs && "unexpected affine expr lowering failure");
- Value remainder = builder.create<arith::RemSIOp>(loc, lhs, rhs);
- Value zeroCst = builder.create<arith::ConstantIndexOp>(loc, 0);
- Value isRemainderNegative = builder.create<arith::CmpIOp>(
- loc, arith::CmpIPredicate::slt, remainder, zeroCst);
+ Value remainder = arith::RemSIOp::create(builder, loc, lhs, rhs);
+ Value zeroCst = arith::ConstantIndexOp::create(builder, loc, 0);
+ Value isRemainderNegative = arith::CmpIOp::create(
+ builder, loc, arith::CmpIPredicate::slt, remainder, zeroCst);
Value correctedRemainder =
- builder.create<arith::AddIOp>(loc, remainder, rhs);
- Value result = builder.create<arith::SelectOp>(
- loc, isRemainderNegative, correctedRemainder, remainder);
+ arith::AddIOp::create(builder, loc, remainder, rhs);
+ Value result = arith::SelectOp::create(builder, loc, isRemainderNegative,
+ correctedRemainder, remainder);
return result;
}
@@ -129,18 +129,19 @@ class AffineApplyExpander
auto rhs = visit(expr.getRHS());
assert(lhs && rhs && "unexpected affine expr lowering failure");
- Value zeroCst = builder.create<arith::ConstantIndexOp>(loc, 0);
- Value noneCst = builder.create<arith::ConstantIndexOp>(loc, -1);
- Value negative = builder.create<arith::CmpIOp>(
- loc, arith::CmpIPredicate::slt, lhs, zeroCst);
- Value negatedDecremented = builder.create<arith::SubIOp>(loc, noneCst, lhs);
- Value dividend =
- builder.create<arith::SelectOp>(loc, negative, negatedDecremented, lhs);
- Value quotient = builder.create<arith::DivSIOp>(loc, dividend, rhs);
+ Value zeroCst = arith::ConstantIndexOp::create(builder, loc, 0);
+ Value noneCst = arith::ConstantIndexOp::create(builder, loc, -1);
+ Value negative = arith::CmpIOp::create(
+ builder, loc, arith::CmpIPredicate::slt, lhs, zeroCst);
+ Value negatedDecremented =
+ arith::SubIOp::create(builder, loc, noneCst, lhs);
+ Value dividend = arith::SelectOp::create(builder, loc, negative,
+ negatedDecremented, lhs);
+ Value quotient = arith::DivSIOp::create(builder, loc, dividend, rhs);
Value correctedQuotient =
- builder.create<arith::SubIOp>(loc, noneCst, quotient);
- Value result = builder.create<arith::SelectOp>(loc, negative,
- correctedQuotient, quotient);
+ arith::SubIOp::create(builder, loc, noneCst, quotient);
+ Value result = arith::SelectOp::create(builder, loc, negative,
+ correctedQuotient, quotient);
return result;
}
@@ -168,26 +169,26 @@ class AffineApplyExpander
auto rhs = visit(expr.getRHS());
assert(lhs && rhs && "unexpected affine expr lowering failure");
- Value zeroCst = builder.create<arith::ConstantIndexOp>(loc, 0);
- Value oneCst = builder.create<arith::ConstantIndexOp>(loc, 1);
- Value nonPositive = builder.create<arith::CmpIOp>(
- loc, arith::CmpIPredicate::sle, lhs, zeroCst);
- Value negated = builder.create<arith::SubIOp>(loc, zeroCst, lhs);
- Value decremented = builder.create<arith::SubIOp>(loc, lhs, oneCst);
- Value dividend =
- builder.create<arith::SelectOp>(loc, nonPositive, negated, decremented);
- Value quotient = builder.create<arith::DivSIOp>(loc, dividend, rhs);
+ Value zeroCst = arith::ConstantIndexOp::create(builder, loc, 0);
+ Value oneCst = arith::ConstantIndexOp::create(builder, loc, 1);
+ Value nonPositive = arith::CmpIOp::create(
+ builder, loc, arith::CmpIPredicate::sle, lhs, zeroCst);
+ Value negated = arith::SubIOp::create(builder, loc, zeroCst, lhs);
+ Value decremented = arith::SubIOp::create(builder, loc, lhs, oneCst);
+ Value dividend = arith::SelectOp::create(builder, loc, nonPositive, negated,
+ decremented);
+ Value quotient = arith::DivSIOp::create(builder, loc, dividend, rhs);
Value negatedQuotient =
- builder.create<arith::SubIOp>(loc, zeroCst, quotient);
+ arith::SubIOp::create(builder, loc, zeroCst, quotient);
Value incrementedQuotient =
- builder.create<arith::AddIOp>(loc, quotient, oneCst);
- Value result = builder.create<arith::SelectOp>(
- loc, nonPositive, negatedQuotient, incrementedQuotient);
+ arith::AddIOp::create(builder, loc, quotient, oneCst);
+ Value result = arith::SelectOp::create(
+ builder, loc, nonPositive, negatedQuotient, incrementedQuotient);
return result;
}
Value visitConstantExpr(AffineConstantExpr expr) {
- auto op = builder.create<arith::ConstantIndexOp>(loc, expr.getValue());
+ auto op = arith::ConstantIndexOp::create(builder, loc, expr.getValue());
return op.getResult();
}
@@ -297,9 +298,9 @@ static AffineIfOp hoistAffineIfOp(AffineIfOp ifOp, Operation *hoistOverOp) {
// block.
IRMapping operandMap;
OpBuilder b(hoistOverOp);
- auto hoistedIfOp = b.create<AffineIfOp>(ifOp.getLoc(), ifOp.getIntegerSet(),
- ifOp.getOperands(),
- /*elseBlock=*/true);
+ auto hoistedIfOp = AffineIfOp::create(b, ifOp.getLoc(), ifOp.getIntegerSet(),
+ ifOp.getOperands(),
+ /*elseBlock=*/true);
// Create a clone of hoistOverOp to use for the else branch of the hoisted
// conditional. The else block may get optimized away if empty.
@@ -368,8 +369,8 @@ mlir::affine::affineParallelize(AffineForOp forOp,
parallelReductions, [](const LoopReduction &red) { return red.value; }));
auto reductionKinds = llvm::to_vector<4>(llvm::map_range(
parallelReductions, [](const LoopReduction &red) { return red.kind; }));
- AffineParallelOp newPloop = outsideBuilder.create<AffineParallelOp>(
- loc, ValueRange(reducedValues).getTypes(), reductionKinds,
+ AffineParallelOp newPloop = AffineParallelOp::create(
+ outsideBuilder, loc, ValueRange(reducedValues).getTypes(), reductionKinds,
llvm::ArrayRef(lowerBoundMap), lowerBoundOperands,
llvm::ArrayRef(upperBoundMap), upperBoundOperands,
llvm::ArrayRef(forOp.getStepAsInt()));
@@ -540,7 +541,8 @@ void mlir::affine::normalizeAffineParallel(AffineParallelOp op) {
SmallVector<Value, 8> applyOperands{dimOperands};
applyOperands.push_back(iv);
applyOperands.append(symbolOperands.begin(), symbolOperands.end());
- auto apply = builder.create<AffineApplyOp>(op.getLoc(), map, applyOperands);
+ auto apply =
+ AffineApplyOp::create(builder, op.getLoc(), map, applyOperands);
iv.replaceAllUsesExcept(apply, apply);
}
@@ -621,8 +623,9 @@ LogicalResult mlir::affine::normalizeAffineFor(AffineForOp op,
AffineValueMap newIvToOldIvMap;
AffineValueMap::
diff erence(lbMap, scaleIvValueMap, &newIvToOldIvMap);
(void)newIvToOldIvMap.canonicalize();
- auto newIV = opBuilder.create<AffineApplyOp>(
- loc, newIvToOldIvMap.getAffineMap(), newIvToOldIvMap.getOperands());
+ auto newIV =
+ AffineApplyOp::create(opBuilder, loc, newIvToOldIvMap.getAffineMap(),
+ newIvToOldIvMap.getOperands());
op.getInductionVar().replaceAllUsesExcept(newIV->getResult(0), newIV);
return success();
}
@@ -1186,8 +1189,8 @@ LogicalResult mlir::affine::replaceAllMemRefUsesWith(
for (auto resultExpr : oldMap.getResults()) {
auto singleResMap = AffineMap::get(oldMap.getNumDims(),
oldMap.getNumSymbols(), resultExpr);
- auto afOp = builder.create<AffineApplyOp>(op->getLoc(), singleResMap,
- oldMapOperands);
+ auto afOp = AffineApplyOp::create(builder, op->getLoc(), singleResMap,
+ oldMapOperands);
oldMemRefOperands.push_back(afOp);
affineApplyOps.push_back(afOp);
}
@@ -1213,8 +1216,8 @@ LogicalResult mlir::affine::replaceAllMemRefUsesWith(
for (auto resultExpr : indexRemap.getResults()) {
auto singleResMap = AffineMap::get(
indexRemap.getNumDims(), indexRemap.getNumSymbols(), resultExpr);
- auto afOp = builder.create<AffineApplyOp>(op->getLoc(), singleResMap,
- remapOperands);
+ auto afOp = AffineApplyOp::create(builder, op->getLoc(), singleResMap,
+ remapOperands);
remapOutputs.push_back(afOp);
affineApplyOps.push_back(afOp);
}
@@ -1263,8 +1266,8 @@ LogicalResult mlir::affine::replaceAllMemRefUsesWith(
// AffineMapAccessInterface, we need to apply the values of `newMapOperands`
// to the `newMap` to get the correct indices.
for (unsigned i = 0; i < newMemRefRank; i++) {
- state.operands.push_back(builder.create<AffineApplyOp>(
- op->getLoc(),
+ state.operands.push_back(AffineApplyOp::create(
+ builder, op->getLoc(),
AffineMap::get(newMap.getNumDims(), newMap.getNumSymbols(),
newMap.getResult(i)),
newMapOperands));
@@ -1449,8 +1452,8 @@ void mlir::affine::createAffineComputationSlice(
for (auto resultExpr : composedMap.getResults()) {
auto singleResMap = AffineMap::get(composedMap.getNumDims(),
composedMap.getNumSymbols(), resultExpr);
- sliceOps->push_back(builder.create<AffineApplyOp>(
- opInst->getLoc(), singleResMap, composedOpOperands));
+ sliceOps->push_back(AffineApplyOp::create(
+ builder, opInst->getLoc(), singleResMap, composedOpOperands));
}
// Construct the new operands that include the results from the composed
@@ -1680,7 +1683,7 @@ static void createNewDynamicSizes(MemRefType oldMemRefType,
// Create ConstantOp for static dimension.
auto constantAttr = b.getIntegerAttr(b.getIndexType(), oldMemRefShape[d]);
inAffineApply.emplace_back(
- b.create<arith::ConstantOp>(allocOp.getLoc(), constantAttr));
+ arith::ConstantOp::create(b, allocOp.getLoc(), constantAttr));
}
}
@@ -1704,7 +1707,7 @@ static void createNewDynamicSizes(MemRefType oldMemRefType,
AffineMap newMap =
AffineMap::get(map.getNumInputs(), map.getNumSymbols(), newMapOutput);
Value affineApp =
- b.create<AffineApplyOp>(allocOp.getLoc(), newMap, inAffineApply);
+ AffineApplyOp::create(b, allocOp.getLoc(), newMap, inAffineApply);
newDynamicSizes.emplace_back(affineApp);
}
newDimIdx++;
@@ -1739,12 +1742,11 @@ LogicalResult mlir::affine::normalizeMemRef(AllocLikeOp allocOp) {
createNewDynamicSizes(oldMemRefType, newMemRefType, layoutMap, allocOp, b,
newDynamicSizes);
// Add the new dynamic sizes in new AllocOp.
- newAlloc =
- b.create<AllocLikeOp>(allocOp.getLoc(), newMemRefType, newDynamicSizes,
- allocOp.getAlignmentAttr());
+ newAlloc = AllocLikeOp::create(b, allocOp.getLoc(), newMemRefType,
+ newDynamicSizes, allocOp.getAlignmentAttr());
} else {
- newAlloc = b.create<AllocLikeOp>(allocOp.getLoc(), newMemRefType,
- allocOp.getAlignmentAttr());
+ newAlloc = AllocLikeOp::create(b, allocOp.getLoc(), newMemRefType,
+ allocOp.getAlignmentAttr());
}
// Replace all uses of the old memref.
if (failed(replaceAllMemRefUsesWith(oldMemRef, /*newMemRef=*/newAlloc,
@@ -1802,10 +1804,10 @@ mlir::affine::normalizeMemRef(memref::ReinterpretCastOp reinterpretCastOp) {
for (unsigned i = 0, e = memrefType.getRank(); i < e; i++) {
if (memrefType.isDynamicDim(i))
mapOperands[i] =
- b.create<arith::SubIOp>(loc, oldSizes[0].getType(), oldSizes[idx++],
- b.create<arith::ConstantIndexOp>(loc, 1));
+ arith::SubIOp::create(b, loc, oldSizes[0].getType(), oldSizes[idx++],
+ arith::ConstantIndexOp::create(b, loc, 1));
else
- mapOperands[i] = b.create<arith::ConstantIndexOp>(loc, oldShape[i] - 1);
+ mapOperands[i] = arith::ConstantIndexOp::create(b, loc, oldShape[i] - 1);
}
for (unsigned i = 0, e = oldStrides.size(); i < e; i++)
mapOperands[memrefType.getRank() + i] = oldStrides[i];
@@ -1815,20 +1817,20 @@ mlir::affine::normalizeMemRef(memref::ReinterpretCastOp reinterpretCastOp) {
for (unsigned i = 0; i < newRank; i++) {
if (!newMemRefType.isDynamicDim(i))
continue;
- newSizes.push_back(b.create<AffineApplyOp>(
- loc,
+ newSizes.push_back(AffineApplyOp::create(
+ b, loc,
AffineMap::get(oldLayoutMap.getNumDims(), oldLayoutMap.getNumSymbols(),
oldLayoutMap.getResult(i)),
mapOperands));
}
for (unsigned i = 0, e = newSizes.size(); i < e; i++) {
newSizes[i] =
- b.create<arith::AddIOp>(loc, newSizes[i].getType(), newSizes[i],
- b.create<arith::ConstantIndexOp>(loc, 1));
+ arith::AddIOp::create(b, loc, newSizes[i].getType(), newSizes[i],
+ arith::ConstantIndexOp::create(b, loc, 1));
}
// Create the new reinterpret_cast op.
- auto newReinterpretCast = b.create<memref::ReinterpretCastOp>(
- loc, newMemRefType, reinterpretCastOp.getSource(),
+ auto newReinterpretCast = memref::ReinterpretCastOp::create(
+ b, loc, newMemRefType, reinterpretCastOp.getSource(),
/*offsets=*/ValueRange(), newSizes,
/*strides=*/ValueRange(),
/*static_offsets=*/newStaticOffsets,
diff --git a/mlir/lib/Dialect/Arith/IR/ArithDialect.cpp b/mlir/lib/Dialect/Arith/IR/ArithDialect.cpp
index ebcb951cf3518..e7cbee6b06c45 100644
--- a/mlir/lib/Dialect/Arith/IR/ArithDialect.cpp
+++ b/mlir/lib/Dialect/Arith/IR/ArithDialect.cpp
@@ -64,7 +64,7 @@ Operation *arith::ArithDialect::materializeConstant(OpBuilder &builder,
Attribute value, Type type,
Location loc) {
if (auto poison = dyn_cast<ub::PoisonAttr>(value))
- return builder.create<ub::PoisonOp>(loc, type, poison);
+ return ub::PoisonOp::create(builder, loc, type, poison);
return ConstantOp::materialize(builder, value, type, loc);
}
diff --git a/mlir/lib/Dialect/Arith/Transforms/BufferDeallocationOpInterfaceImpl.cpp b/mlir/lib/Dialect/Arith/Transforms/BufferDeallocationOpInterfaceImpl.cpp
index f2e7732e8ea4a..9199dccdcaff3 100644
--- a/mlir/lib/Dialect/Arith/Transforms/BufferDeallocationOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Arith/Transforms/BufferDeallocationOpInterfaceImpl.cpp
@@ -67,8 +67,8 @@ struct SelectOpInterface
return state.getMemrefWithUniqueOwnership(builder, value,
value.getParentBlock());
- Value ownership = builder.create<arith::SelectOp>(
- op->getLoc(), selectOp.getCondition(),
+ Value ownership = arith::SelectOp::create(
+ builder, op->getLoc(), selectOp.getCondition(),
state.getOwnership(selectOp.getTrueValue(), block).getIndicator(),
state.getOwnership(selectOp.getFalseValue(), block).getIndicator());
return {selectOp.getResult(), ownership};
diff --git a/mlir/lib/Dialect/Arith/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Arith/Transforms/BufferizableOpInterfaceImpl.cpp
index afee162053bea..b073a31850678 100644
--- a/mlir/lib/Dialect/Arith/Transforms/BufferizableOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Arith/Transforms/BufferizableOpInterfaceImpl.cpp
@@ -170,10 +170,10 @@ struct SelectOpInterface
return failure();
if (trueBuffer.getType() != *targetType)
trueBuffer =
- rewriter.create<memref::CastOp>(loc, *targetType, trueBuffer);
+ memref::CastOp::create(rewriter, loc, *targetType, trueBuffer);
if (falseBuffer.getType() != *targetType)
falseBuffer =
- rewriter.create<memref::CastOp>(loc, *targetType, falseBuffer);
+ memref::CastOp::create(rewriter, loc, *targetType, falseBuffer);
}
replaceOpWithNewBufferizedOp<arith::SelectOp>(
diff --git a/mlir/lib/Dialect/Arith/Transforms/EmulateUnsupportedFloats.cpp b/mlir/lib/Dialect/Arith/Transforms/EmulateUnsupportedFloats.cpp
index 55b757c136127..7626d356a37f2 100644
--- a/mlir/lib/Dialect/Arith/Transforms/EmulateUnsupportedFloats.cpp
+++ b/mlir/lib/Dialect/Arith/Transforms/EmulateUnsupportedFloats.cpp
@@ -75,7 +75,7 @@ LogicalResult EmulateFloatPattern::matchAndRewrite(
for (auto [res, oldType, newType] : llvm::zip_equal(
MutableArrayRef{newResults}, op->getResultTypes(), resultTypes)) {
if (oldType != newType) {
- auto truncFOp = rewriter.create<arith::TruncFOp>(loc, oldType, res);
+ auto truncFOp = arith::TruncFOp::create(rewriter, loc, oldType, res);
truncFOp.setFastmath(arith::FastMathFlags::contract);
res = truncFOp.getResult();
}
@@ -98,7 +98,7 @@ void mlir::arith::populateEmulateUnsupportedFloatsConversions(
});
converter.addTargetMaterialization(
[](OpBuilder &b, Type target, ValueRange input, Location loc) {
- auto extFOp = b.create<arith::ExtFOp>(loc, target, input);
+ auto extFOp = arith::ExtFOp::create(b, loc, target, input);
extFOp.setFastmath(arith::FastMathFlags::contract);
return extFOp;
});
diff --git a/mlir/lib/Dialect/Arith/Transforms/EmulateWideInt.cpp b/mlir/lib/Dialect/Arith/Transforms/EmulateWideInt.cpp
index d5d1559c658ff..efe6ad2579055 100644
--- a/mlir/lib/Dialect/Arith/Transforms/EmulateWideInt.cpp
+++ b/mlir/lib/Dialect/Arith/Transforms/EmulateWideInt.cpp
@@ -72,7 +72,7 @@ static Value extractLastDimSlice(ConversionPatternRewriter &rewriter,
// Scalarize the result in case of 1D vectors.
if (shape.size() == 1)
- return rewriter.create<vector::ExtractOp>(loc, input, lastOffset);
+ return vector::ExtractOp::create(rewriter, loc, input, lastOffset);
SmallVector<int64_t> offsets(shape.size(), 0);
offsets.back() = lastOffset;
@@ -80,8 +80,8 @@ static Value extractLastDimSlice(ConversionPatternRewriter &rewriter,
sizes.back() = 1;
SmallVector<int64_t> strides(shape.size(), 1);
- return rewriter.create<vector::ExtractStridedSliceOp>(loc, input, offsets,
- sizes, strides);
+ return vector::ExtractStridedSliceOp::create(rewriter, loc, input, offsets,
+ sizes, strides);
}
/// Extracts two vector slices from the `input` whose type is `vector<...x2T>`,
@@ -107,7 +107,7 @@ static Value dropTrailingX1Dim(ConversionPatternRewriter &rewriter,
assert(shape.back() == 1 && "Expected the last vector dim to be x1");
auto newVecTy = VectorType::get(shape.drop_back(), vecTy.getElementType());
- return rewriter.create<vector::ShapeCastOp>(loc, newVecTy, input);
+ return vector::ShapeCastOp::create(rewriter, loc, newVecTy, input);
}
/// Performs a vector shape cast to append an x1 dimension. If the
@@ -122,7 +122,7 @@ static Value appendX1Dim(ConversionPatternRewriter &rewriter, Location loc,
auto newShape = llvm::to_vector(vecTy.getShape());
newShape.push_back(1);
auto newTy = VectorType::get(newShape, vecTy.getElementType());
- return rewriter.create<vector::ShapeCastOp>(loc, newTy, input);
+ return vector::ShapeCastOp::create(rewriter, loc, newTy, input);
}
/// Inserts the `source` vector slice into the `dest` vector at offset
@@ -136,13 +136,13 @@ static Value insertLastDimSlice(ConversionPatternRewriter &rewriter,
// Handle scalar source.
if (isa<IntegerType>(source.getType()))
- return rewriter.create<vector::InsertOp>(loc, source, dest, lastOffset);
+ return vector::InsertOp::create(rewriter, loc, source, dest, lastOffset);
SmallVector<int64_t> offsets(shape.size(), 0);
offsets.back() = lastOffset;
SmallVector<int64_t> strides(shape.size(), 1);
- return rewriter.create<vector::InsertStridedSliceOp>(loc, source, dest,
- offsets, strides);
+ return vector::InsertStridedSliceOp::create(rewriter, loc, source, dest,
+ offsets, strides);
}
/// Constructs a new vector of type `resultType` by creating a series of
@@ -254,12 +254,12 @@ struct ConvertAddI final : OpConversionPattern<arith::AddIOp> {
extractLastDimHalves(rewriter, loc, adaptor.getRhs());
auto lowSum =
- rewriter.create<arith::AddUIExtendedOp>(loc, lhsElem0, rhsElem0);
+ arith::AddUIExtendedOp::create(rewriter, loc, lhsElem0, rhsElem0);
Value overflowVal =
- rewriter.create<arith::ExtUIOp>(loc, newElemTy, lowSum.getOverflow());
+ arith::ExtUIOp::create(rewriter, loc, newElemTy, lowSum.getOverflow());
- Value high0 = rewriter.create<arith::AddIOp>(loc, overflowVal, lhsElem1);
- Value high = rewriter.create<arith::AddIOp>(loc, high0, rhsElem1);
+ Value high0 = arith::AddIOp::create(rewriter, loc, overflowVal, lhsElem1);
+ Value high = arith::AddIOp::create(rewriter, loc, high0, rhsElem1);
Value resultVec =
constructResultVector(rewriter, loc, newTy, {lowSum.getSum(), high});
@@ -293,8 +293,8 @@ struct ConvertBitwiseBinary final : OpConversionPattern<BinaryOp> {
auto [rhsElem0, rhsElem1] =
extractLastDimHalves(rewriter, loc, adaptor.getRhs());
- Value resElem0 = rewriter.create<BinaryOp>(loc, lhsElem0, rhsElem0);
- Value resElem1 = rewriter.create<BinaryOp>(loc, lhsElem1, rhsElem1);
+ Value resElem0 = BinaryOp::create(rewriter, loc, lhsElem0, rhsElem0);
+ Value resElem1 = BinaryOp::create(rewriter, loc, lhsElem1, rhsElem1);
Value resultVec =
constructResultVector(rewriter, loc, newTy, {resElem0, resElem1});
rewriter.replaceOp(op, resultVec);
@@ -346,26 +346,26 @@ struct ConvertCmpI final : OpConversionPattern<arith::CmpIOp> {
extractLastDimHalves(rewriter, loc, adaptor.getRhs());
Value lowCmp =
- rewriter.create<arith::CmpIOp>(loc, lowPred, lhsElem0, rhsElem0);
+ arith::CmpIOp::create(rewriter, loc, lowPred, lhsElem0, rhsElem0);
Value highCmp =
- rewriter.create<arith::CmpIOp>(loc, highPred, lhsElem1, rhsElem1);
+ arith::CmpIOp::create(rewriter, loc, highPred, lhsElem1, rhsElem1);
Value cmpResult{};
switch (highPred) {
case arith::CmpIPredicate::eq: {
- cmpResult = rewriter.create<arith::AndIOp>(loc, lowCmp, highCmp);
+ cmpResult = arith::AndIOp::create(rewriter, loc, lowCmp, highCmp);
break;
}
case arith::CmpIPredicate::ne: {
- cmpResult = rewriter.create<arith::OrIOp>(loc, lowCmp, highCmp);
+ cmpResult = arith::OrIOp::create(rewriter, loc, lowCmp, highCmp);
break;
}
default: {
// Handle inequality checks.
- Value highEq = rewriter.create<arith::CmpIOp>(
- loc, arith::CmpIPredicate::eq, lhsElem1, rhsElem1);
+ Value highEq = arith::CmpIOp::create(
+ rewriter, loc, arith::CmpIPredicate::eq, lhsElem1, rhsElem1);
cmpResult =
- rewriter.create<arith::SelectOp>(loc, highEq, lowCmp, highCmp);
+ arith::SelectOp::create(rewriter, loc, highEq, lowCmp, highCmp);
break;
}
}
@@ -401,14 +401,14 @@ struct ConvertMulI final : OpConversionPattern<arith::MulIOp> {
// Multiplying two i2N integers produces (at most) an i4N result, but
// because the calculation of top i2N is not necessary, we omit it.
auto mulLowLow =
- rewriter.create<arith::MulUIExtendedOp>(loc, lhsElem0, rhsElem0);
- Value mulLowHi = rewriter.create<arith::MulIOp>(loc, lhsElem0, rhsElem1);
- Value mulHiLow = rewriter.create<arith::MulIOp>(loc, lhsElem1, rhsElem0);
+ arith::MulUIExtendedOp::create(rewriter, loc, lhsElem0, rhsElem0);
+ Value mulLowHi = arith::MulIOp::create(rewriter, loc, lhsElem0, rhsElem1);
+ Value mulHiLow = arith::MulIOp::create(rewriter, loc, lhsElem1, rhsElem0);
Value resLow = mulLowLow.getLow();
Value resHi =
- rewriter.create<arith::AddIOp>(loc, mulLowLow.getHigh(), mulLowHi);
- resHi = rewriter.create<arith::AddIOp>(loc, resHi, mulHiLow);
+ arith::AddIOp::create(rewriter, loc, mulLowLow.getHigh(), mulLowHi);
+ resHi = arith::AddIOp::create(rewriter, loc, resHi, mulHiLow);
Value resultVec =
constructResultVector(rewriter, loc, newTy, {resLow, resHi});
@@ -443,10 +443,10 @@ struct ConvertExtSI final : OpConversionPattern<arith::ExtSIOp> {
loc, newResultComponentTy, newOperand);
Value operandZeroCst =
createScalarOrSplatConstant(rewriter, loc, newResultComponentTy, 0);
- Value signBit = rewriter.create<arith::CmpIOp>(
- loc, arith::CmpIPredicate::slt, extended, operandZeroCst);
+ Value signBit = arith::CmpIOp::create(
+ rewriter, loc, arith::CmpIPredicate::slt, extended, operandZeroCst);
Value signValue =
- rewriter.create<arith::ExtSIOp>(loc, newResultComponentTy, signBit);
+ arith::ExtSIOp::create(rewriter, loc, newResultComponentTy, signBit);
Value resultVec =
constructResultVector(rewriter, loc, newTy, {extended, signValue});
@@ -508,7 +508,7 @@ struct ConvertMaxMin final : OpConversionPattern<SourceOp> {
// Rewrite Max*I/Min*I as compare and select over original operands. Let
// the CmpI and Select emulation patterns handle the final legalization.
Value cmp =
- rewriter.create<arith::CmpIOp>(loc, CmpPred, op.getLhs(), op.getRhs());
+ arith::CmpIOp::create(rewriter, loc, CmpPred, op.getLhs(), op.getRhs());
rewriter.replaceOpWithNewOp<arith::SelectOp>(op, cmp, op.getLhs(),
op.getRhs());
return success();
@@ -587,7 +587,7 @@ struct ConvertIndexCastIndexToInt final : OpConversionPattern<CastOp> {
// Sign or zero-extend the result. Let the matching conversion pattern
// legalize the extension op.
Value underlyingVal =
- rewriter.create<CastOp>(loc, narrowTy, adaptor.getIn());
+ CastOp::create(rewriter, loc, narrowTy, adaptor.getIn());
rewriter.replaceOpWithNewOp<ExtensionOp>(op, resultType, underlyingVal);
return success();
}
@@ -616,9 +616,9 @@ struct ConvertSelect final : OpConversionPattern<arith::SelectOp> {
Value cond = appendX1Dim(rewriter, loc, adaptor.getCondition());
Value resElem0 =
- rewriter.create<arith::SelectOp>(loc, cond, trueElem0, falseElem0);
+ arith::SelectOp::create(rewriter, loc, cond, trueElem0, falseElem0);
Value resElem1 =
- rewriter.create<arith::SelectOp>(loc, cond, trueElem1, falseElem1);
+ arith::SelectOp::create(rewriter, loc, cond, trueElem1, falseElem1);
Value resultVec =
constructResultVector(rewriter, loc, newTy, {resElem0, resElem1});
rewriter.replaceOp(op, resultVec);
@@ -680,33 +680,33 @@ struct ConvertShLI final : OpConversionPattern<arith::ShLIOp> {
Value elemBitWidth =
createScalarOrSplatConstant(rewriter, loc, newOperandTy, newBitWidth);
- Value illegalElemShift = rewriter.create<arith::CmpIOp>(
- loc, arith::CmpIPredicate::uge, rhsElem0, elemBitWidth);
+ Value illegalElemShift = arith::CmpIOp::create(
+ rewriter, loc, arith::CmpIPredicate::uge, rhsElem0, elemBitWidth);
Value shiftedElem0 =
- rewriter.create<arith::ShLIOp>(loc, lhsElem0, rhsElem0);
- Value resElem0 = rewriter.create<arith::SelectOp>(loc, illegalElemShift,
- zeroCst, shiftedElem0);
+ arith::ShLIOp::create(rewriter, loc, lhsElem0, rhsElem0);
+ Value resElem0 = arith::SelectOp::create(rewriter, loc, illegalElemShift,
+ zeroCst, shiftedElem0);
- Value cappedShiftAmount = rewriter.create<arith::SelectOp>(
- loc, illegalElemShift, elemBitWidth, rhsElem0);
+ Value cappedShiftAmount = arith::SelectOp::create(
+ rewriter, loc, illegalElemShift, elemBitWidth, rhsElem0);
Value rightShiftAmount =
- rewriter.create<arith::SubIOp>(loc, elemBitWidth, cappedShiftAmount);
+ arith::SubIOp::create(rewriter, loc, elemBitWidth, cappedShiftAmount);
Value shiftedRight =
- rewriter.create<arith::ShRUIOp>(loc, lhsElem0, rightShiftAmount);
+ arith::ShRUIOp::create(rewriter, loc, lhsElem0, rightShiftAmount);
Value overshotShiftAmount =
- rewriter.create<arith::SubIOp>(loc, rhsElem0, elemBitWidth);
+ arith::SubIOp::create(rewriter, loc, rhsElem0, elemBitWidth);
Value shiftedLeft =
- rewriter.create<arith::ShLIOp>(loc, lhsElem0, overshotShiftAmount);
+ arith::ShLIOp::create(rewriter, loc, lhsElem0, overshotShiftAmount);
Value shiftedElem1 =
- rewriter.create<arith::ShLIOp>(loc, lhsElem1, rhsElem0);
- Value resElem1High = rewriter.create<arith::SelectOp>(
- loc, illegalElemShift, zeroCst, shiftedElem1);
- Value resElem1Low = rewriter.create<arith::SelectOp>(
- loc, illegalElemShift, shiftedLeft, shiftedRight);
+ arith::ShLIOp::create(rewriter, loc, lhsElem1, rhsElem0);
+ Value resElem1High = arith::SelectOp::create(
+ rewriter, loc, illegalElemShift, zeroCst, shiftedElem1);
+ Value resElem1Low = arith::SelectOp::create(rewriter, loc, illegalElemShift,
+ shiftedLeft, shiftedRight);
Value resElem1 =
- rewriter.create<arith::OrIOp>(loc, resElem1Low, resElem1High);
+ arith::OrIOp::create(rewriter, loc, resElem1Low, resElem1High);
Value resultVec =
constructResultVector(rewriter, loc, newTy, {resElem0, resElem1});
@@ -769,33 +769,33 @@ struct ConvertShRUI final : OpConversionPattern<arith::ShRUIOp> {
Value elemBitWidth =
createScalarOrSplatConstant(rewriter, loc, newOperandTy, newBitWidth);
- Value illegalElemShift = rewriter.create<arith::CmpIOp>(
- loc, arith::CmpIPredicate::uge, rhsElem0, elemBitWidth);
+ Value illegalElemShift = arith::CmpIOp::create(
+ rewriter, loc, arith::CmpIPredicate::uge, rhsElem0, elemBitWidth);
Value shiftedElem0 =
- rewriter.create<arith::ShRUIOp>(loc, lhsElem0, rhsElem0);
- Value resElem0Low = rewriter.create<arith::SelectOp>(loc, illegalElemShift,
- zeroCst, shiftedElem0);
+ arith::ShRUIOp::create(rewriter, loc, lhsElem0, rhsElem0);
+ Value resElem0Low = arith::SelectOp::create(rewriter, loc, illegalElemShift,
+ zeroCst, shiftedElem0);
Value shiftedElem1 =
- rewriter.create<arith::ShRUIOp>(loc, lhsElem1, rhsElem0);
- Value resElem1 = rewriter.create<arith::SelectOp>(loc, illegalElemShift,
- zeroCst, shiftedElem1);
+ arith::ShRUIOp::create(rewriter, loc, lhsElem1, rhsElem0);
+ Value resElem1 = arith::SelectOp::create(rewriter, loc, illegalElemShift,
+ zeroCst, shiftedElem1);
- Value cappedShiftAmount = rewriter.create<arith::SelectOp>(
- loc, illegalElemShift, elemBitWidth, rhsElem0);
+ Value cappedShiftAmount = arith::SelectOp::create(
+ rewriter, loc, illegalElemShift, elemBitWidth, rhsElem0);
Value leftShiftAmount =
- rewriter.create<arith::SubIOp>(loc, elemBitWidth, cappedShiftAmount);
+ arith::SubIOp::create(rewriter, loc, elemBitWidth, cappedShiftAmount);
Value shiftedLeft =
- rewriter.create<arith::ShLIOp>(loc, lhsElem1, leftShiftAmount);
+ arith::ShLIOp::create(rewriter, loc, lhsElem1, leftShiftAmount);
Value overshotShiftAmount =
- rewriter.create<arith::SubIOp>(loc, rhsElem0, elemBitWidth);
+ arith::SubIOp::create(rewriter, loc, rhsElem0, elemBitWidth);
Value shiftedRight =
- rewriter.create<arith::ShRUIOp>(loc, lhsElem1, overshotShiftAmount);
+ arith::ShRUIOp::create(rewriter, loc, lhsElem1, overshotShiftAmount);
- Value resElem0High = rewriter.create<arith::SelectOp>(
- loc, illegalElemShift, shiftedRight, shiftedLeft);
+ Value resElem0High = arith::SelectOp::create(
+ rewriter, loc, illegalElemShift, shiftedRight, shiftedLeft);
Value resElem0 =
- rewriter.create<arith::OrIOp>(loc, resElem0Low, resElem0High);
+ arith::OrIOp::create(rewriter, loc, resElem0Low, resElem0High);
Value resultVec =
constructResultVector(rewriter, loc, newTy, {resElem0, resElem1});
@@ -832,33 +832,33 @@ struct ConvertShRSI final : OpConversionPattern<arith::ShRSIOp> {
// Perform as many ops over the narrow integer type as possible and let the
// other emulation patterns convert the rest.
Value elemZero = createScalarOrSplatConstant(rewriter, loc, narrowTy, 0);
- Value signBit = rewriter.create<arith::CmpIOp>(
- loc, arith::CmpIPredicate::slt, lhsElem1, elemZero);
+ Value signBit = arith::CmpIOp::create(
+ rewriter, loc, arith::CmpIPredicate::slt, lhsElem1, elemZero);
signBit = dropTrailingX1Dim(rewriter, loc, signBit);
// Create a bit pattern of either all ones or all zeros. Then shift it left
// to calculate the sign extension bits created by shifting the original
// sign bit right.
- Value allSign = rewriter.create<arith::ExtSIOp>(loc, oldTy, signBit);
+ Value allSign = arith::ExtSIOp::create(rewriter, loc, oldTy, signBit);
Value maxShift =
createScalarOrSplatConstant(rewriter, loc, narrowTy, origBitwidth);
Value numNonSignExtBits =
- rewriter.create<arith::SubIOp>(loc, maxShift, rhsElem0);
+ arith::SubIOp::create(rewriter, loc, maxShift, rhsElem0);
numNonSignExtBits = dropTrailingX1Dim(rewriter, loc, numNonSignExtBits);
numNonSignExtBits =
- rewriter.create<arith::ExtUIOp>(loc, oldTy, numNonSignExtBits);
+ arith::ExtUIOp::create(rewriter, loc, oldTy, numNonSignExtBits);
Value signBits =
- rewriter.create<arith::ShLIOp>(loc, allSign, numNonSignExtBits);
+ arith::ShLIOp::create(rewriter, loc, allSign, numNonSignExtBits);
// Use original arguments to create the right shift.
Value shrui =
- rewriter.create<arith::ShRUIOp>(loc, op.getLhs(), op.getRhs());
- Value shrsi = rewriter.create<arith::OrIOp>(loc, shrui, signBits);
+ arith::ShRUIOp::create(rewriter, loc, op.getLhs(), op.getRhs());
+ Value shrsi = arith::OrIOp::create(rewriter, loc, shrui, signBits);
// Handle shifting by zero. This is necessary when the `signBits` shift is
// invalid.
- Value isNoop = rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::eq,
- rhsElem0, elemZero);
+ Value isNoop = arith::CmpIOp::create(
+ rewriter, loc, arith::CmpIPredicate::eq, rhsElem0, elemZero);
isNoop = dropTrailingX1Dim(rewriter, loc, isNoop);
rewriter.replaceOpWithNewOp<arith::SelectOp>(op, isNoop, op.getLhs(),
shrsi);
@@ -892,14 +892,14 @@ struct ConvertSubI final : OpConversionPattern<arith::SubIOp> {
// Emulates LHS - RHS by [LHS0 - RHS0, LHS1 - RHS1 - CARRY] where
// CARRY is 1 or 0.
- Value low = rewriter.create<arith::SubIOp>(loc, lhsElem0, rhsElem0);
+ Value low = arith::SubIOp::create(rewriter, loc, lhsElem0, rhsElem0);
// We have a carry if lhsElem0 < rhsElem0.
- Value carry0 = rewriter.create<arith::CmpIOp>(
- loc, arith::CmpIPredicate::ult, lhsElem0, rhsElem0);
- Value carryVal = rewriter.create<arith::ExtUIOp>(loc, newElemTy, carry0);
+ Value carry0 = arith::CmpIOp::create(
+ rewriter, loc, arith::CmpIPredicate::ult, lhsElem0, rhsElem0);
+ Value carryVal = arith::ExtUIOp::create(rewriter, loc, newElemTy, carry0);
- Value high0 = rewriter.create<arith::SubIOp>(loc, lhsElem1, carryVal);
- Value high = rewriter.create<arith::SubIOp>(loc, high0, rhsElem1);
+ Value high0 = arith::SubIOp::create(rewriter, loc, lhsElem1, carryVal);
+ Value high = arith::SubIOp::create(rewriter, loc, high0, rhsElem1);
Value resultVec = constructResultVector(rewriter, loc, newTy, {low, high});
rewriter.replaceOp(op, resultVec);
@@ -933,13 +933,13 @@ struct ConvertSIToFP final : OpConversionPattern<arith::SIToFPOp> {
// result or not based on that sign bit. We implement negation by
// subtracting from zero. Note that this relies on the the other conversion
// patterns to legalize created ops and narrow the bit widths.
- Value isNeg = rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::slt,
- in, zeroCst);
- Value neg = rewriter.create<arith::SubIOp>(loc, zeroCst, in);
- Value abs = rewriter.create<arith::SelectOp>(loc, isNeg, neg, in);
+ Value isNeg = arith::CmpIOp::create(rewriter, loc,
+ arith::CmpIPredicate::slt, in, zeroCst);
+ Value neg = arith::SubIOp::create(rewriter, loc, zeroCst, in);
+ Value abs = arith::SelectOp::create(rewriter, loc, isNeg, neg, in);
- Value absResult = rewriter.create<arith::UIToFPOp>(loc, op.getType(), abs);
- Value negResult = rewriter.create<arith::NegFOp>(loc, absResult);
+ Value absResult = arith::UIToFPOp::create(rewriter, loc, op.getType(), abs);
+ Value negResult = arith::NegFOp::create(rewriter, loc, absResult);
rewriter.replaceOpWithNewOp<arith::SelectOp>(op, isNeg, negResult,
absResult);
return success();
@@ -985,13 +985,13 @@ struct ConvertUIToFP final : OpConversionPattern<arith::UIToFPOp> {
//
// Note 2: We do not strictly need the `hi == 0`, case, but it makes
// constant folding easier.
- Value hiEqZero = rewriter.create<arith::CmpIOp>(
- loc, arith::CmpIPredicate::eq, hiInt, zeroCst);
+ Value hiEqZero = arith::CmpIOp::create(
+ rewriter, loc, arith::CmpIPredicate::eq, hiInt, zeroCst);
Type resultTy = op.getType();
Type resultElemTy = getElementTypeOrSelf(resultTy);
- Value lowFp = rewriter.create<arith::UIToFPOp>(loc, resultTy, lowInt);
- Value hiFp = rewriter.create<arith::UIToFPOp>(loc, resultTy, hiInt);
+ Value lowFp = arith::UIToFPOp::create(rewriter, loc, resultTy, lowInt);
+ Value hiFp = arith::UIToFPOp::create(rewriter, loc, resultTy, hiInt);
int64_t pow2Int = int64_t(1) << newBitWidth;
TypedAttr pow2Attr =
@@ -999,10 +999,11 @@ struct ConvertUIToFP final : OpConversionPattern<arith::UIToFPOp> {
if (auto vecTy = dyn_cast<VectorType>(resultTy))
pow2Attr = SplatElementsAttr::get(vecTy, pow2Attr);
- Value pow2Val = rewriter.create<arith::ConstantOp>(loc, resultTy, pow2Attr);
+ Value pow2Val =
+ arith::ConstantOp::create(rewriter, loc, resultTy, pow2Attr);
- Value hiVal = rewriter.create<arith::MulFOp>(loc, hiFp, pow2Val);
- Value result = rewriter.create<arith::AddFOp>(loc, lowFp, hiVal);
+ Value hiVal = arith::MulFOp::create(rewriter, loc, hiFp, pow2Val);
+ Value result = arith::AddFOp::create(rewriter, loc, lowFp, hiVal);
rewriter.replaceOpWithNewOp<arith::SelectOp>(op, hiEqZero, lowFp, result);
return success();
@@ -1037,22 +1038,22 @@ struct ConvertFPToSI final : OpConversionPattern<arith::FPToSIOp> {
// result is UB.
TypedAttr zeroAttr = rewriter.getZeroAttr(fpTy);
- Value zeroCst = rewriter.create<arith::ConstantOp>(loc, zeroAttr);
+ Value zeroCst = arith::ConstantOp::create(rewriter, loc, zeroAttr);
Value zeroCstInt = createScalarOrSplatConstant(rewriter, loc, intTy, 0);
// Get the absolute value. One could have used math.absf here, but that
// introduces an extra dependency.
- Value isNeg = rewriter.create<arith::CmpFOp>(loc, arith::CmpFPredicate::OLT,
- inFp, zeroCst);
- Value negInFp = rewriter.create<arith::NegFOp>(loc, inFp);
+ Value isNeg = arith::CmpFOp::create(
+ rewriter, loc, arith::CmpFPredicate::OLT, inFp, zeroCst);
+ Value negInFp = arith::NegFOp::create(rewriter, loc, inFp);
- Value absVal = rewriter.create<arith::SelectOp>(loc, isNeg, negInFp, inFp);
+ Value absVal = arith::SelectOp::create(rewriter, loc, isNeg, negInFp, inFp);
// Defer the absolute value to fptoui.
- Value res = rewriter.create<arith::FPToUIOp>(loc, intTy, absVal);
+ Value res = arith::FPToUIOp::create(rewriter, loc, intTy, absVal);
// Negate the value if < 0 .
- Value neg = rewriter.create<arith::SubIOp>(loc, zeroCstInt, res);
+ Value neg = arith::SubIOp::create(rewriter, loc, zeroCstInt, res);
rewriter.replaceOpWithNewOp<arith::SelectOp>(op, isNeg, neg, res);
return success();
@@ -1109,17 +1110,17 @@ struct ConvertFPToUI final : OpConversionPattern<arith::FPToUIOp> {
if (auto vecType = dyn_cast<VectorType>(fpTy))
powBitwidthAttr = SplatElementsAttr::get(vecType, powBitwidthAttr);
Value powBitwidthFloatCst =
- rewriter.create<arith::ConstantOp>(loc, powBitwidthAttr);
+ arith::ConstantOp::create(rewriter, loc, powBitwidthAttr);
Value fpDivPowBitwidth =
- rewriter.create<arith::DivFOp>(loc, inFp, powBitwidthFloatCst);
+ arith::DivFOp::create(rewriter, loc, inFp, powBitwidthFloatCst);
Value resHigh =
- rewriter.create<arith::FPToUIOp>(loc, newHalfType, fpDivPowBitwidth);
+ arith::FPToUIOp::create(rewriter, loc, newHalfType, fpDivPowBitwidth);
// Calculate fp - resHigh * 2^N by getting the remainder of the division
Value remainder =
- rewriter.create<arith::RemFOp>(loc, inFp, powBitwidthFloatCst);
+ arith::RemFOp::create(rewriter, loc, inFp, powBitwidthFloatCst);
Value resLow =
- rewriter.create<arith::FPToUIOp>(loc, newHalfType, remainder);
+ arith::FPToUIOp::create(rewriter, loc, newHalfType, remainder);
Value high = appendX1Dim(rewriter, loc, resHigh);
Value low = appendX1Dim(rewriter, loc, resLow);
diff --git a/mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp b/mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp
index e842f44b3b97f..f8fa35c6fa7de 100644
--- a/mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp
+++ b/mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp
@@ -28,10 +28,10 @@ static Value createConst(Location loc, Type type, int value,
PatternRewriter &rewriter) {
auto attr = rewriter.getIntegerAttr(getElementTypeOrSelf(type), value);
if (auto shapedTy = dyn_cast<ShapedType>(type)) {
- return rewriter.create<arith::ConstantOp>(
- loc, DenseElementsAttr::get(shapedTy, attr));
+ return arith::ConstantOp::create(rewriter, loc,
+ DenseElementsAttr::get(shapedTy, attr));
}
- return rewriter.create<arith::ConstantOp>(loc, attr);
+ return arith::ConstantOp::create(rewriter, loc, attr);
}
/// Create a float constant.
@@ -39,11 +39,11 @@ static Value createFloatConst(Location loc, Type type, APFloat value,
PatternRewriter &rewriter) {
auto attr = rewriter.getFloatAttr(getElementTypeOrSelf(type), value);
if (auto shapedTy = dyn_cast<ShapedType>(type)) {
- return rewriter.create<arith::ConstantOp>(
- loc, DenseElementsAttr::get(shapedTy, attr));
+ return arith::ConstantOp::create(rewriter, loc,
+ DenseElementsAttr::get(shapedTy, attr));
}
- return rewriter.create<arith::ConstantOp>(loc, attr);
+ return arith::ConstantOp::create(rewriter, loc, attr);
}
/// Creates shapedType using shape from cloneFrom and base type from cloneTo
@@ -67,11 +67,11 @@ struct CeilDivUIOpConverter : public OpRewritePattern<arith::CeilDivUIOp> {
Value b = op.getRhs();
Value zero = createConst(loc, a.getType(), 0, rewriter);
Value compare =
- rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::eq, a, zero);
+ arith::CmpIOp::create(rewriter, loc, arith::CmpIPredicate::eq, a, zero);
Value one = createConst(loc, a.getType(), 1, rewriter);
- Value minusOne = rewriter.create<arith::SubIOp>(loc, a, one);
- Value quotient = rewriter.create<arith::DivUIOp>(loc, minusOne, b);
- Value plusOne = rewriter.create<arith::AddIOp>(loc, quotient, one);
+ Value minusOne = arith::SubIOp::create(rewriter, loc, a, one);
+ Value quotient = arith::DivUIOp::create(rewriter, loc, minusOne, b);
+ Value plusOne = arith::AddIOp::create(rewriter, loc, quotient, one);
rewriter.replaceOpWithNewOp<arith::SelectOp>(op, compare, zero, plusOne);
return success();
}
@@ -96,22 +96,22 @@ struct CeilDivSIOpConverter : public OpRewritePattern<arith::CeilDivSIOp> {
Value zero = createConst(loc, type, 0, rewriter);
Value one = createConst(loc, type, 1, rewriter);
- Value quotient = rewriter.create<arith::DivSIOp>(loc, a, b);
- Value product = rewriter.create<arith::MulIOp>(loc, quotient, b);
- Value notEqualDivisor = rewriter.create<arith::CmpIOp>(
- loc, arith::CmpIPredicate::ne, a, product);
+ Value quotient = arith::DivSIOp::create(rewriter, loc, a, b);
+ Value product = arith::MulIOp::create(rewriter, loc, quotient, b);
+ Value notEqualDivisor = arith::CmpIOp::create(
+ rewriter, loc, arith::CmpIPredicate::ne, a, product);
- Value aNeg =
- rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::slt, a, zero);
- Value bNeg =
- rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::slt, b, zero);
+ Value aNeg = arith::CmpIOp::create(rewriter, loc, arith::CmpIPredicate::slt,
+ a, zero);
+ Value bNeg = arith::CmpIOp::create(rewriter, loc, arith::CmpIPredicate::slt,
+ b, zero);
- Value signEqual = rewriter.create<arith::CmpIOp>(
- loc, arith::CmpIPredicate::eq, aNeg, bNeg);
+ Value signEqual = arith::CmpIOp::create(
+ rewriter, loc, arith::CmpIPredicate::eq, aNeg, bNeg);
Value cond =
- rewriter.create<arith::AndIOp>(loc, notEqualDivisor, signEqual);
+ arith::AndIOp::create(rewriter, loc, notEqualDivisor, signEqual);
- Value quotientPlusOne = rewriter.create<arith::AddIOp>(loc, quotient, one);
+ Value quotientPlusOne = arith::AddIOp::create(rewriter, loc, quotient, one);
rewriter.replaceOpWithNewOp<arith::SelectOp>(op, cond, quotientPlusOne,
quotient);
@@ -135,25 +135,25 @@ struct FloorDivSIOpConverter : public OpRewritePattern<arith::FloorDivSIOp> {
Value a = op.getLhs();
Value b = op.getRhs();
- Value quotient = rewriter.create<arith::DivSIOp>(loc, a, b);
- Value product = rewriter.create<arith::MulIOp>(loc, quotient, b);
- Value notEqualDivisor = rewriter.create<arith::CmpIOp>(
- loc, arith::CmpIPredicate::ne, a, product);
+ Value quotient = arith::DivSIOp::create(rewriter, loc, a, b);
+ Value product = arith::MulIOp::create(rewriter, loc, quotient, b);
+ Value notEqualDivisor = arith::CmpIOp::create(
+ rewriter, loc, arith::CmpIPredicate::ne, a, product);
Value zero = createConst(loc, type, 0, rewriter);
- Value aNeg =
- rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::slt, a, zero);
- Value bNeg =
- rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::slt, b, zero);
+ Value aNeg = arith::CmpIOp::create(rewriter, loc, arith::CmpIPredicate::slt,
+ a, zero);
+ Value bNeg = arith::CmpIOp::create(rewriter, loc, arith::CmpIPredicate::slt,
+ b, zero);
- Value signOpposite = rewriter.create<arith::CmpIOp>(
- loc, arith::CmpIPredicate::ne, aNeg, bNeg);
+ Value signOpposite = arith::CmpIOp::create(
+ rewriter, loc, arith::CmpIPredicate::ne, aNeg, bNeg);
Value cond =
- rewriter.create<arith::AndIOp>(loc, notEqualDivisor, signOpposite);
+ arith::AndIOp::create(rewriter, loc, notEqualDivisor, signOpposite);
Value minusOne = createConst(loc, type, -1, rewriter);
Value quotientMinusOne =
- rewriter.create<arith::AddIOp>(loc, quotient, minusOne);
+ arith::AddIOp::create(rewriter, loc, quotient, minusOne);
rewriter.replaceOpWithNewOp<arith::SelectOp>(op, cond, quotientMinusOne,
quotient);
@@ -171,7 +171,7 @@ struct MaxMinIOpConverter : public OpRewritePattern<OpTy> {
Value lhs = op.getLhs();
Value rhs = op.getRhs();
- Value cmp = rewriter.create<arith::CmpIOp>(op.getLoc(), pred, lhs, rhs);
+ Value cmp = arith::CmpIOp::create(rewriter, op.getLoc(), pred, lhs, rhs);
rewriter.replaceOpWithNewOp<arith::SelectOp>(op, cmp, lhs, rhs);
return success();
}
@@ -192,12 +192,12 @@ struct MaximumMinimumFOpConverter : public OpRewritePattern<OpTy> {
static_assert(pred == arith::CmpFPredicate::UGT ||
pred == arith::CmpFPredicate::ULT,
"pred must be either UGT or ULT");
- Value cmp = rewriter.create<arith::CmpFOp>(loc, pred, lhs, rhs);
- Value select = rewriter.create<arith::SelectOp>(loc, cmp, lhs, rhs);
+ Value cmp = arith::CmpFOp::create(rewriter, loc, pred, lhs, rhs);
+ Value select = arith::SelectOp::create(rewriter, loc, cmp, lhs, rhs);
// Handle the case where rhs is NaN: 'isNaN(rhs) ? rhs : select'.
- Value isNaN = rewriter.create<arith::CmpFOp>(loc, arith::CmpFPredicate::UNO,
- rhs, rhs);
+ Value isNaN = arith::CmpFOp::create(rewriter, loc,
+ arith::CmpFPredicate::UNO, rhs, rhs);
rewriter.replaceOpWithNewOp<arith::SelectOp>(op, isNaN, rhs, select);
return success();
}
@@ -218,12 +218,12 @@ struct MaxNumMinNumFOpConverter : public OpRewritePattern<OpTy> {
static_assert(pred == arith::CmpFPredicate::UGT ||
pred == arith::CmpFPredicate::ULT,
"pred must be either UGT or ULT");
- Value cmp = rewriter.create<arith::CmpFOp>(loc, pred, lhs, rhs);
- Value select = rewriter.create<arith::SelectOp>(loc, cmp, lhs, rhs);
+ Value cmp = arith::CmpFOp::create(rewriter, loc, pred, lhs, rhs);
+ Value select = arith::SelectOp::create(rewriter, loc, cmp, lhs, rhs);
// Handle the case where lhs is NaN: 'isNaN(lhs) ? rhs : select'.
- Value isNaN = rewriter.create<arith::CmpFOp>(loc, arith::CmpFPredicate::UNO,
- lhs, lhs);
+ Value isNaN = arith::CmpFOp::create(rewriter, loc,
+ arith::CmpFPredicate::UNO, lhs, lhs);
rewriter.replaceOpWithNewOp<arith::SelectOp>(op, isNaN, rhs, select);
return success();
}
@@ -247,12 +247,12 @@ struct BFloat16ExtFOpConverter : public OpRewritePattern<arith::ExtFOp> {
Type i16Ty = cloneToShapedType(operandTy, b.getI16Type());
Type i32Ty = cloneToShapedType(operandTy, b.getI32Type());
- Value bitcast = b.create<arith::BitcastOp>(i16Ty, operand);
- Value exti = b.create<arith::ExtUIOp>(i32Ty, bitcast);
+ Value bitcast = arith::BitcastOp::create(b, i16Ty, operand);
+ Value exti = arith::ExtUIOp::create(b, i32Ty, bitcast);
Value c16 = createConst(op.getLoc(), i32Ty, 16, rewriter);
- Value shl = b.create<arith::ShLIOp>(exti, c16);
- Value result = b.create<arith::BitcastOp>(resultTy, shl);
+ Value shl = arith::ShLIOp::create(b, exti, c16);
+ Value result = arith::BitcastOp::create(b, resultTy, shl);
rewriter.replaceOp(op, result);
return success();
@@ -296,7 +296,7 @@ struct BFloat16TruncFOpConverter : public OpRewritePattern<arith::TruncFOp> {
// exponent bits, that simple truncation is the desired outcome for
// infinities.
Value isNan =
- b.create<arith::CmpFOp>(arith::CmpFPredicate::UNE, operand, operand);
+ arith::CmpFOp::create(b, arith::CmpFPredicate::UNE, operand, operand);
// Constant used to make the rounding bias.
Value c7FFF = createConst(op.getLoc(), i32Ty, 0x7fff, rewriter);
// Constant used to generate a quiet NaN.
@@ -305,30 +305,30 @@ struct BFloat16TruncFOpConverter : public OpRewritePattern<arith::TruncFOp> {
Value c16 = createConst(op.getLoc(), i32Ty, 16, rewriter);
Value c1 = createConst(op.getLoc(), i32Ty, 1, rewriter);
// Reinterpret the input f32 value as bits.
- Value bitcast = b.create<arith::BitcastOp>(i32Ty, operand);
+ Value bitcast = arith::BitcastOp::create(b, i32Ty, operand);
// Read bit 16 as a value in {0,1}.
Value bit16 =
- b.create<arith::AndIOp>(b.create<arith::ShRUIOp>(bitcast, c16), c1);
+ arith::AndIOp::create(b, arith::ShRUIOp::create(b, bitcast, c16), c1);
// Determine the rounding bias to add as either 0x7fff or 0x8000 depending
// on bit 16, implementing the tie-breaking "to nearest even".
- Value roundingBias = b.create<arith::AddIOp>(bit16, c7FFF);
+ Value roundingBias = arith::AddIOp::create(b, bit16, c7FFF);
// Add the rounding bias. Generally we want this to be added to the
// mantissa, but nothing prevents this to from carrying into the exponent
// bits, which would feel like a bug, but this is the magic trick here:
// when that happens, the mantissa gets reset to zero and the exponent
// gets incremented by the carry... which is actually exactly what we
// want.
- Value biased = b.create<arith::AddIOp>(bitcast, roundingBias);
+ Value biased = arith::AddIOp::create(b, bitcast, roundingBias);
// Now that the rounding-bias has been added, truncating the low bits
// yields the correctly rounded result.
- Value biasedAndShifted = b.create<arith::ShRUIOp>(biased, c16);
+ Value biasedAndShifted = arith::ShRUIOp::create(b, biased, c16);
Value normalCaseResultI16 =
- b.create<arith::TruncIOp>(i16Ty, biasedAndShifted);
+ arith::TruncIOp::create(b, i16Ty, biasedAndShifted);
// Select either the above-computed result, or a quiet NaN constant
// if the input was NaN.
Value select =
- b.create<arith::SelectOp>(isNan, c7FC0I16, normalCaseResultI16);
- Value result = b.create<arith::BitcastOp>(resultTy, select);
+ arith::SelectOp::create(b, isNan, c7FC0I16, normalCaseResultI16);
+ Value result = arith::BitcastOp::create(b, resultTy, select);
rewriter.replaceOp(op, result);
return success();
}
@@ -381,7 +381,7 @@ struct F4E2M1ExtFOpConverter : public OpRewritePattern<arith::ExtFOp> {
Type f32Ty = cloneToShapedType(operandTy, b.getF32Type());
Type i4Ty = cloneToShapedType(operandTy, b.getI4Type());
Type i32Ty = cloneToShapedType(operandTy, b.getI32Type());
- Value i4Bits = b.create<arith::BitcastOp>(i4Ty, operand);
+ Value i4Bits = arith::BitcastOp::create(b, i4Ty, operand);
Value c0x0 = createConst(loc, i4Ty, 0x0, rewriter);
Value c0x1 = createConst(loc, i4Ty, 0x1, rewriter);
@@ -390,38 +390,39 @@ struct F4E2M1ExtFOpConverter : public OpRewritePattern<arith::ExtFOp> {
// Set last Exponent bit and Mantissa.
Value c0x00000014 = createConst(loc, i32Ty, 0x14, rewriter);
- Value bits1To24 = b.create<arith::ShLIOp>(i4Bits, c0x2);
+ Value bits1To24 = arith::ShLIOp::create(b, i4Bits, c0x2);
Value isHalf =
- b.create<arith::CmpIOp>(arith::CmpIPredicate::eq, i4Bits, c0x1);
- bits1To24 = b.create<arith::SelectOp>(isHalf, c0x0, bits1To24);
- bits1To24 = b.create<arith::ExtUIOp>(i32Ty, bits1To24);
- bits1To24 = b.create<arith::ShLIOp>(bits1To24, c0x00000014);
+ arith::CmpIOp::create(b, arith::CmpIPredicate::eq, i4Bits, c0x1);
+ bits1To24 = arith::SelectOp::create(b, isHalf, c0x0, bits1To24);
+ bits1To24 = arith::ExtUIOp::create(b, i32Ty, bits1To24);
+ bits1To24 = arith::ShLIOp::create(b, bits1To24, c0x00000014);
// Set first 7 bits of Exponent.
Value zeroExpBits = createConst(loc, i32Ty, 0x00000000, rewriter);
Value highExpBits = createConst(loc, i32Ty, 0x40000000, rewriter);
Value lowExpBits = createConst(loc, i32Ty, 0x3f000000, rewriter);
Value useLargerExp =
- b.create<arith::CmpIOp>(arith::CmpIPredicate::uge, i4Bits, c0x4);
+ arith::CmpIOp::create(b, arith::CmpIPredicate::uge, i4Bits, c0x4);
Value bits25To31 =
- b.create<arith::SelectOp>(useLargerExp, highExpBits, lowExpBits);
+ arith::SelectOp::create(b, useLargerExp, highExpBits, lowExpBits);
Value zeroExp =
- b.create<arith::CmpIOp>(arith::CmpIPredicate::eq, i4Bits, c0x0);
- bits25To31 = b.create<arith::SelectOp>(zeroExp, zeroExpBits, bits25To31);
+ arith::CmpIOp::create(b, arith::CmpIPredicate::eq, i4Bits, c0x0);
+ bits25To31 = arith::SelectOp::create(b, zeroExp, zeroExpBits, bits25To31);
// Set sign.
Value c0x80000000 = createConst(loc, i32Ty, 0x80000000, rewriter);
Value c0x8 = createConst(loc, i4Ty, 0x8, rewriter);
Value negative =
- b.create<arith::CmpIOp>(arith::CmpIPredicate::uge, i4Bits, c0x8);
- Value bit32 = b.create<arith::SelectOp>(negative, c0x80000000, zeroExpBits);
+ arith::CmpIOp::create(b, arith::CmpIPredicate::uge, i4Bits, c0x8);
+ Value bit32 =
+ arith::SelectOp::create(b, negative, c0x80000000, zeroExpBits);
// Add segments together.
- Value bits1To31 = b.create<arith::AddIOp>(bits1To24, bits25To31);
- Value bits1To32 = b.create<arith::AddIOp>(bits1To31, bit32);
- Value result = b.create<arith::BitcastOp>(f32Ty, bits1To32);
+ Value bits1To31 = arith::AddIOp::create(b, bits1To24, bits25To31);
+ Value bits1To32 = arith::AddIOp::create(b, bits1To31, bit32);
+ Value result = arith::BitcastOp::create(b, f32Ty, bits1To32);
if (!isa<Float32Type>(resultETy))
- result = b.create<arith::TruncFOp>(resultTy, result);
+ result = arith::TruncFOp::create(b, resultTy, result);
rewriter.replaceOp(op, result);
return success();
@@ -447,25 +448,25 @@ struct F8E8M0ExtFOpConverter : public OpRewritePattern<arith::ExtFOp> {
Type i32Ty = cloneToShapedType(operandTy, b.getI32Type());
Type f32Ty = cloneToShapedType(operandTy, b.getF32Type());
- Value bitcast = b.create<arith::BitcastOp>(i8Ty, operand);
+ Value bitcast = arith::BitcastOp::create(b, i8Ty, operand);
// create constants for NaNs
Value cF8NaN = createConst(op.getLoc(), i8Ty, 0xff, rewriter);
Value cF32NaN = createConst(op.getLoc(), i32Ty, 0xffffffff, rewriter);
Value cF32MantissaWidth = createConst(op->getLoc(), i32Ty, 23, rewriter);
- Value exti = b.create<arith::ExtUIOp>(i32Ty, bitcast);
- Value f32Bits = b.create<arith::ShLIOp>(exti, cF32MantissaWidth);
+ Value exti = arith::ExtUIOp::create(b, i32Ty, bitcast);
+ Value f32Bits = arith::ShLIOp::create(b, exti, cF32MantissaWidth);
Value isNan =
- b.create<arith::CmpIOp>(arith::CmpIPredicate::eq, bitcast, cF8NaN);
+ arith::CmpIOp::create(b, arith::CmpIPredicate::eq, bitcast, cF8NaN);
// select for NaNs
- f32Bits = b.create<arith::SelectOp>(isNan, cF32NaN, f32Bits);
- Value result = b.create<arith::BitcastOp>(f32Ty, f32Bits);
+ f32Bits = arith::SelectOp::create(b, isNan, cF32NaN, f32Bits);
+ Value result = arith::BitcastOp::create(b, f32Ty, f32Bits);
if (resultETy.getIntOrFloatBitWidth() < 32) {
- result = b.create<arith::TruncFOp>(resultTy, result, nullptr,
- op.getFastmathAttr());
+ result = arith::TruncFOp::create(b, resultTy, result, nullptr,
+ op.getFastmathAttr());
} else if (resultETy.getIntOrFloatBitWidth() > 32) {
- result = b.create<arith::ExtFOp>(resultTy, result, op.getFastmathAttr());
+ result = arith::ExtFOp::create(b, resultTy, result, op.getFastmathAttr());
}
rewriter.replaceOp(op, result);
return success();
@@ -520,7 +521,7 @@ struct F4E2M1TruncFOpConverter : public OpRewritePattern<arith::TruncFOp> {
if (!isa<Float4E2M1FNType>(resultETy))
return rewriter.notifyMatchFailure(op, "not a trunc of F4E2M1FN");
if (!isa<Float32Type>(operandETy))
- operand = b.create<arith::ExtFOp>(f32Ty, operand);
+ operand = arith::ExtFOp::create(b, f32Ty, operand);
Value c0x1 = createConst(loc, i4Ty, 1, rewriter);
Value c0x3 = createConst(loc, i4Ty, 3, rewriter);
@@ -532,65 +533,65 @@ struct F4E2M1TruncFOpConverter : public OpRewritePattern<arith::TruncFOp> {
// Step 0: Clamp to bounds.
Value cHigherBound = createFloatConst(loc, f32Ty, APFloat(6.0f), rewriter);
Value cLowerBound = createFloatConst(loc, f32Ty, APFloat(-6.0f), rewriter);
- Value operandClamped = b.create<arith::MinNumFOp>(cHigherBound, operand);
- operandClamped = b.create<arith::MaxNumFOp>(cLowerBound, operandClamped);
- Value f32Bits = b.create<arith::BitcastOp>(i32Ty, operandClamped);
+ Value operandClamped = arith::MinNumFOp::create(b, cHigherBound, operand);
+ operandClamped = arith::MaxNumFOp::create(b, cLowerBound, operandClamped);
+ Value f32Bits = arith::BitcastOp::create(b, i32Ty, operandClamped);
// Step 1: Set sign bit.
Value cF32ExpManWidth = createConst(loc, i32Ty, 31, rewriter); // 23
- Value f32Sign = b.create<arith::ShRUIOp>(f32Bits, cF32ExpManWidth);
- Value f4Sign = b.create<arith::TruncIOp>(i4Ty, f32Sign);
- Value f4Bits = b.create<arith::ShLIOp>(f4Sign, c0x3);
+ Value f32Sign = arith::ShRUIOp::create(b, f32Bits, cF32ExpManWidth);
+ Value f4Sign = arith::TruncIOp::create(b, i4Ty, f32Sign);
+ Value f4Bits = arith::ShLIOp::create(b, f4Sign, c0x3);
// Step 2: Convert exponent by adjusting bias.
Value biasAdjustment = createConst(loc, i32Ty, 0x7e, rewriter);
Value cF4MantissaWidth = c0x1; // 1
Value cF32MantissaWidth = createConst(loc, i32Ty, 23, rewriter); // 23
- Value f32SignExp = b.create<arith::ShRUIOp>(f32Bits, cF32MantissaWidth);
+ Value f32SignExp = arith::ShRUIOp::create(b, f32Bits, cF32MantissaWidth);
Value biasAdjustedSignExp =
- b.create<arith::SubIOp>(f32SignExp, biasAdjustment);
- Value f4Exp = b.create<arith::TruncIOp>(i4Ty, biasAdjustedSignExp);
- f4Exp = b.create<arith::ShLIOp>(f4Exp, cF4MantissaWidth);
- f4Bits = b.create<arith::AddIOp>(f4Bits, f4Exp);
+ arith::SubIOp::create(b, f32SignExp, biasAdjustment);
+ Value f4Exp = arith::TruncIOp::create(b, i4Ty, biasAdjustedSignExp);
+ f4Exp = arith::ShLIOp::create(b, f4Exp, cF4MantissaWidth);
+ f4Bits = arith::AddIOp::create(b, f4Bits, f4Exp);
// Step 3: Set mantissa to first bit.
Value cF32FirstBitMask = createConst(loc, i32Ty, 0x400000, rewriter);
- Value man1Bit = b.create<arith::AndIOp>(f32Bits, cF32FirstBitMask);
- man1Bit = b.create<arith::ShRUIOp>(man1Bit, c0x00000016);
- Value f4Man = b.create<arith::TruncIOp>(i4Ty, man1Bit);
- f4Bits = b.create<arith::AddIOp>(f4Bits, f4Man);
+ Value man1Bit = arith::AndIOp::create(b, f32Bits, cF32FirstBitMask);
+ man1Bit = arith::ShRUIOp::create(b, man1Bit, c0x00000016);
+ Value f4Man = arith::TruncIOp::create(b, i4Ty, man1Bit);
+ f4Bits = arith::AddIOp::create(b, f4Bits, f4Man);
// Step 4: Special consideration for conversion to 0.5.
Value cF32MantissaMask = createConst(loc, i32Ty, 0x7fffff, rewriter);
- Value f8Exp = b.create<arith::TruncIOp>(i8Ty, biasAdjustedSignExp);
+ Value f8Exp = arith::TruncIOp::create(b, i8Ty, biasAdjustedSignExp);
Value isSubnormal =
- b.create<arith::CmpIOp>(arith::CmpIPredicate::sle, f8Exp, c0x00);
+ arith::CmpIOp::create(b, arith::CmpIPredicate::sle, f8Exp, c0x00);
Value isNegOneExp =
- b.create<arith::CmpIOp>(arith::CmpIPredicate::eq, f8Exp, c0xff);
- Value man23Bits = b.create<arith::AndIOp>(f32Bits, cF32MantissaMask);
- Value isNonZeroMan = b.create<arith::CmpIOp>(arith::CmpIPredicate::ugt,
- man23Bits, zeroExpBits);
- Value roundToHalf = b.create<arith::AndIOp>(isNegOneExp, isNonZeroMan);
+ arith::CmpIOp::create(b, arith::CmpIPredicate::eq, f8Exp, c0xff);
+ Value man23Bits = arith::AndIOp::create(b, f32Bits, cF32MantissaMask);
+ Value isNonZeroMan = arith::CmpIOp::create(b, arith::CmpIPredicate::ugt,
+ man23Bits, zeroExpBits);
+ Value roundToHalf = arith::AndIOp::create(b, isNegOneExp, isNonZeroMan);
Value isZeroExp =
- b.create<arith::CmpIOp>(arith::CmpIPredicate::eq, f8Exp, c0x00);
+ arith::CmpIOp::create(b, arith::CmpIPredicate::eq, f8Exp, c0x00);
Value subnormalF4Bits = createConst(loc, i4Ty, 0xf, rewriter);
Value halfF4Bits = createConst(loc, i4Ty, 0x0, rewriter);
Value subResult =
- b.create<arith::SelectOp>(isSubnormal, subnormalF4Bits, f4Bits);
- subResult = b.create<arith::SelectOp>(roundToHalf, halfF4Bits, subResult);
- f4Bits = b.create<arith::SelectOp>(isZeroExp, f4Bits, subResult);
+ arith::SelectOp::create(b, isSubnormal, subnormalF4Bits, f4Bits);
+ subResult = arith::SelectOp::create(b, roundToHalf, halfF4Bits, subResult);
+ f4Bits = arith::SelectOp::create(b, isZeroExp, f4Bits, subResult);
// Step 5: Round up if necessary.
Value cF32Last22BitMask = createConst(loc, i32Ty, 0x3fffff, rewriter);
Value cRound = createConst(loc, i32Ty, 0x200000, rewriter); // 010 0000...
- Value man22Bits = b.create<arith::AndIOp>(f32Bits, cF32Last22BitMask);
+ Value man22Bits = arith::AndIOp::create(b, f32Bits, cF32Last22BitMask);
Value shouldRound =
- b.create<arith::CmpIOp>(arith::CmpIPredicate::uge, man22Bits, cRound);
- shouldRound = b.create<arith::OrIOp>(shouldRound, isSubnormal);
- Value roundedF4Bits = b.create<arith::AddIOp>(f4Bits, c0x1);
- f4Bits = b.create<arith::SelectOp>(shouldRound, roundedF4Bits, f4Bits);
+ arith::CmpIOp::create(b, arith::CmpIPredicate::uge, man22Bits, cRound);
+ shouldRound = arith::OrIOp::create(b, shouldRound, isSubnormal);
+ Value roundedF4Bits = arith::AddIOp::create(b, f4Bits, c0x1);
+ f4Bits = arith::SelectOp::create(b, shouldRound, roundedF4Bits, f4Bits);
- Value result = b.create<arith::BitcastOp>(resultTy, f4Bits);
+ Value result = arith::BitcastOp::create(b, resultTy, f4Bits);
rewriter.replaceOp(op, result);
return success();
}
@@ -625,16 +626,16 @@ struct F8E8M0TruncFOpConverter : public OpRewritePattern<arith::TruncFOp> {
Type f32Ty = cloneToShapedType(operandTy, b.getF32Type());
if (operandETy.getIntOrFloatBitWidth() < 32) {
- operand = b.create<arith::ExtFOp>(f32Ty, operand, op.getFastmathAttr());
+ operand = arith::ExtFOp::create(b, f32Ty, operand, op.getFastmathAttr());
} else if (operandETy.getIntOrFloatBitWidth() > 32) {
- operand = b.create<arith::TruncFOp>(
- f32Ty, operand, op.getRoundingmodeAttr(), op.getFastmathAttr());
+ operand = arith::TruncFOp::create(
+ b, f32Ty, operand, op.getRoundingmodeAttr(), op.getFastmathAttr());
}
- Value f32Bits = b.create<arith::BitcastOp>(i32Ty, operand);
+ Value f32Bits = arith::BitcastOp::create(b, i32Ty, operand);
Value cF32MantissaWidth = createConst(op->getLoc(), i32Ty, 23, rewriter);
- Value f32SignExp = b.create<arith::ShRUIOp>(f32Bits, cF32MantissaWidth);
- Value exp8Bits = b.create<arith::TruncIOp>(i8Ty, f32SignExp);
- Value result = b.create<arith::BitcastOp>(resultTy, exp8Bits);
+ Value f32SignExp = arith::ShRUIOp::create(b, f32Bits, cF32MantissaWidth);
+ Value exp8Bits = arith::TruncIOp::create(b, i8Ty, f32SignExp);
+ Value result = arith::BitcastOp::create(b, resultTy, exp8Bits);
rewriter.replaceOp(op, result);
return success();
}
@@ -653,8 +654,8 @@ struct ScalingExtFOpConverter : public OpRewritePattern<arith::ScalingExtFOp> {
if (scaleETy.getIntOrFloatBitWidth() >= 16) {
scaleETy = b.getF8E8M0Type();
scaleTy = cloneToShapedType(scaleTy, scaleETy);
- scaleOperand = b.create<arith::TruncFOp>(scaleTy, scaleOperand, nullptr,
- op.getFastmathAttr());
+ scaleOperand = arith::TruncFOp::create(b, scaleTy, scaleOperand, nullptr,
+ op.getFastmathAttr());
}
// Catch scale types like f8E5M2.
if (!llvm::isa<Float8E8M0FNUType>(scaleETy)) {
@@ -666,11 +667,11 @@ struct ScalingExtFOpConverter : public OpRewritePattern<arith::ScalingExtFOp> {
// extf on scale will essentially create floating point number
// of type resulTy that is 2^scale and will also propagate NaNs
Value scaleExt =
- b.create<arith::ExtFOp>(resultTy, scaleOperand, op.getFastmathAttr());
+ arith::ExtFOp::create(b, resultTy, scaleOperand, op.getFastmathAttr());
Value inputExt =
- b.create<arith::ExtFOp>(resultTy, inputOperand, op.getFastmathAttr());
+ arith::ExtFOp::create(b, resultTy, inputOperand, op.getFastmathAttr());
Value result =
- b.create<arith::MulFOp>(inputExt, scaleExt, op.getFastmathAttr());
+ arith::MulFOp::create(b, inputExt, scaleExt, op.getFastmathAttr());
rewriter.replaceOp(op, result);
return success();
}
@@ -695,8 +696,8 @@ struct ScalingTruncFOpConverter
if (scaleETy.getIntOrFloatBitWidth() >= 16) {
scaleETy = b.getF8E8M0Type();
scaleTy = cloneToShapedType(scaleTy, scaleETy);
- scaleOperand = b.create<arith::TruncFOp>(scaleTy, scaleOperand, nullptr,
- op.getFastmathAttr());
+ scaleOperand = arith::TruncFOp::create(b, scaleTy, scaleOperand, nullptr,
+ op.getFastmathAttr());
}
if (!llvm::isa<Float8E8M0FNUType>(scaleETy)) {
return rewriter.notifyMatchFailure(
@@ -708,11 +709,11 @@ struct ScalingTruncFOpConverter
// this will create a floating point number of type
// inputTy that is 2^scale and will also propagate NaNs
scaleOperand =
- b.create<arith::ExtFOp>(inputTy, scaleOperand, op.getFastmathAttr());
- Value result = b.create<arith::DivFOp>(inputOperand, scaleOperand,
- op.getFastmathAttr());
- Value resultCast = b.create<arith::TruncFOp>(
- resultTy, result, op.getRoundingmodeAttr(), op.getFastmathAttr());
+ arith::ExtFOp::create(b, inputTy, scaleOperand, op.getFastmathAttr());
+ Value result = arith::DivFOp::create(b, inputOperand, scaleOperand,
+ op.getFastmathAttr());
+ Value resultCast = arith::TruncFOp::create(
+ b, resultTy, result, op.getRoundingmodeAttr(), op.getFastmathAttr());
rewriter.replaceOp(op, resultCast);
return success();
}
diff --git a/mlir/lib/Dialect/Arith/Transforms/IntRangeOptimizations.cpp b/mlir/lib/Dialect/Arith/Transforms/IntRangeOptimizations.cpp
index f2f93883eb2b7..777ff0ecaa314 100644
--- a/mlir/lib/Dialect/Arith/Transforms/IntRangeOptimizations.cpp
+++ b/mlir/lib/Dialect/Arith/Transforms/IntRangeOptimizations.cpp
@@ -305,18 +305,18 @@ static Value doCast(OpBuilder &builder, Location loc, Value src, Type dstType,
if (isa<IndexType>(srcElemType) || isa<IndexType>(dstElemType)) {
if (castKind == CastKind::Signed)
- return builder.create<arith::IndexCastOp>(loc, dstType, src);
- return builder.create<arith::IndexCastUIOp>(loc, dstType, src);
+ return arith::IndexCastOp::create(builder, loc, dstType, src);
+ return arith::IndexCastUIOp::create(builder, loc, dstType, src);
}
auto srcInt = cast<IntegerType>(srcElemType);
auto dstInt = cast<IntegerType>(dstElemType);
if (dstInt.getWidth() < srcInt.getWidth())
- return builder.create<arith::TruncIOp>(loc, dstType, src);
+ return arith::TruncIOp::create(builder, loc, dstType, src);
if (castKind == CastKind::Signed)
- return builder.create<arith::ExtSIOp>(loc, dstType, src);
- return builder.create<arith::ExtUIOp>(loc, dstType, src);
+ return arith::ExtSIOp::create(builder, loc, dstType, src);
+ return arith::ExtUIOp::create(builder, loc, dstType, src);
}
struct NarrowElementwise final : OpTraitRewritePattern<OpTrait::Elementwise> {
diff --git a/mlir/lib/Dialect/Arith/Transforms/ReifyValueBounds.cpp b/mlir/lib/Dialect/Arith/Transforms/ReifyValueBounds.cpp
index 5fb7953f93700..4bdd1e6a54d69 100644
--- a/mlir/lib/Dialect/Arith/Transforms/ReifyValueBounds.cpp
+++ b/mlir/lib/Dialect/Arith/Transforms/ReifyValueBounds.cpp
@@ -23,8 +23,8 @@ static Value buildArithValue(OpBuilder &b, Location loc, AffineMap map,
std::function<Value(AffineExpr)> buildExpr = [&](AffineExpr e) -> Value {
switch (e.getKind()) {
case AffineExprKind::Constant:
- return b.create<ConstantIndexOp>(loc,
- cast<AffineConstantExpr>(e).getValue());
+ return ConstantIndexOp::create(b, loc,
+ cast<AffineConstantExpr>(e).getValue());
case AffineExprKind::DimId:
return operands[cast<AffineDimExpr>(e).getPosition()];
case AffineExprKind::SymbolId:
@@ -32,28 +32,28 @@ static Value buildArithValue(OpBuilder &b, Location loc, AffineMap map,
map.getNumDims()];
case AffineExprKind::Add: {
auto binaryExpr = cast<AffineBinaryOpExpr>(e);
- return b.create<AddIOp>(loc, buildExpr(binaryExpr.getLHS()),
- buildExpr(binaryExpr.getRHS()));
+ return AddIOp::create(b, loc, buildExpr(binaryExpr.getLHS()),
+ buildExpr(binaryExpr.getRHS()));
}
case AffineExprKind::Mul: {
auto binaryExpr = cast<AffineBinaryOpExpr>(e);
- return b.create<MulIOp>(loc, buildExpr(binaryExpr.getLHS()),
- buildExpr(binaryExpr.getRHS()));
+ return MulIOp::create(b, loc, buildExpr(binaryExpr.getLHS()),
+ buildExpr(binaryExpr.getRHS()));
}
case AffineExprKind::FloorDiv: {
auto binaryExpr = cast<AffineBinaryOpExpr>(e);
- return b.create<DivSIOp>(loc, buildExpr(binaryExpr.getLHS()),
- buildExpr(binaryExpr.getRHS()));
+ return DivSIOp::create(b, loc, buildExpr(binaryExpr.getLHS()),
+ buildExpr(binaryExpr.getRHS()));
}
case AffineExprKind::CeilDiv: {
auto binaryExpr = cast<AffineBinaryOpExpr>(e);
- return b.create<CeilDivSIOp>(loc, buildExpr(binaryExpr.getLHS()),
- buildExpr(binaryExpr.getRHS()));
+ return CeilDivSIOp::create(b, loc, buildExpr(binaryExpr.getLHS()),
+ buildExpr(binaryExpr.getRHS()));
}
case AffineExprKind::Mod: {
auto binaryExpr = cast<AffineBinaryOpExpr>(e);
- return b.create<RemSIOp>(loc, buildExpr(binaryExpr.getLHS()),
- buildExpr(binaryExpr.getRHS()));
+ return RemSIOp::create(b, loc, buildExpr(binaryExpr.getLHS()),
+ buildExpr(binaryExpr.getRHS()));
}
}
llvm_unreachable("unsupported AffineExpr kind");
@@ -89,10 +89,10 @@ FailureOr<OpFoldResult> mlir::arith::reifyValueBound(
"expected dynamic dim");
if (isa<RankedTensorType>(value.getType())) {
// A tensor dimension is used: generate a tensor.dim.
- operands.push_back(b.create<tensor::DimOp>(loc, value, *dim));
+ operands.push_back(tensor::DimOp::create(b, loc, value, *dim));
} else if (isa<MemRefType>(value.getType())) {
// A memref dimension is used: generate a memref.dim.
- operands.push_back(b.create<memref::DimOp>(loc, value, *dim));
+ operands.push_back(memref::DimOp::create(b, loc, value, *dim));
} else {
llvm_unreachable("cannot generate DimOp for unsupported shaped type");
}
diff --git a/mlir/lib/Dialect/Arith/Transforms/ShardingInterfaceImpl.cpp b/mlir/lib/Dialect/Arith/Transforms/ShardingInterfaceImpl.cpp
index 3478adcb4a128..dd6efe6d6bc31 100644
--- a/mlir/lib/Dialect/Arith/Transforms/ShardingInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Arith/Transforms/ShardingInterfaceImpl.cpp
@@ -83,7 +83,7 @@ struct ConstantShardingInterface
cOp.getType(), getMesh(op, sharding.getMeshAttr(), symbolTable),
sharding));
auto newValue = value.resizeSplat(newType);
- auto newOp = builder.create<ConstantOp>(op->getLoc(), newType, newValue);
+ auto newOp = ConstantOp::create(builder, op->getLoc(), newType, newValue);
spmdizationMap.map(op->getResult(0), newOp.getResult());
spmdizationMap.map(op, newOp.getOperation());
} else {
diff --git a/mlir/lib/Dialect/Arith/Utils/Utils.cpp b/mlir/lib/Dialect/Arith/Utils/Utils.cpp
index bdeeccbe0177a..b1fc9aa57c3ba 100644
--- a/mlir/lib/Dialect/Arith/Utils/Utils.cpp
+++ b/mlir/lib/Dialect/Arith/Utils/Utils.cpp
@@ -67,7 +67,7 @@ mlir::inferExpandShapeOutputShape(OpBuilder &b, Location loc,
// dynamism.
Value indexGroupSize = cast<Value>(inputShape[inputIndex]);
Value indexGroupStaticSizesProduct =
- b.create<arith::ConstantIndexOp>(loc, indexGroupStaticSizesProductInt);
+ arith::ConstantIndexOp::create(b, loc, indexGroupStaticSizesProductInt);
Value dynamicDimSize = b.createOrFold<arith::DivSIOp>(
loc, indexGroupSize, indexGroupStaticSizesProduct);
outputShapeValues.push_back(dynamicDimSize);
@@ -104,8 +104,8 @@ Value mlir::getValueOrCreateConstantIntOp(OpBuilder &b, Location loc,
if (auto value = dyn_cast_if_present<Value>(ofr))
return value;
auto attr = cast<IntegerAttr>(cast<Attribute>(ofr));
- return b.create<arith::ConstantOp>(
- loc, b.getIntegerAttr(attr.getType(), attr.getValue().getSExtValue()));
+ return arith::ConstantOp::create(
+ b, loc, b.getIntegerAttr(attr.getType(), attr.getValue().getSExtValue()));
}
Value mlir::getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc,
@@ -113,7 +113,7 @@ Value mlir::getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc,
if (auto value = dyn_cast_if_present<Value>(ofr))
return value;
auto attr = cast<IntegerAttr>(cast<Attribute>(ofr));
- return b.create<arith::ConstantIndexOp>(loc, attr.getValue().getSExtValue());
+ return arith::ConstantIndexOp::create(b, loc, attr.getValue().getSExtValue());
}
Value mlir::getValueOrCreateCastToIndexLike(OpBuilder &b, Location loc,
@@ -124,7 +124,7 @@ Value mlir::getValueOrCreateCastToIndexLike(OpBuilder &b, Location loc,
bool targetIsIndex = targetType.isIndex();
bool valueIsIndex = value.getType().isIndex();
if (targetIsIndex ^ valueIsIndex)
- return b.create<arith::IndexCastOp>(loc, targetType, value);
+ return arith::IndexCastOp::create(b, loc, targetType, value);
auto targetIntegerType = dyn_cast<IntegerType>(targetType);
auto valueIntegerType = dyn_cast<IntegerType>(value.getType());
@@ -133,8 +133,8 @@ Value mlir::getValueOrCreateCastToIndexLike(OpBuilder &b, Location loc,
assert(targetIntegerType.getSignedness() == valueIntegerType.getSignedness());
if (targetIntegerType.getWidth() > valueIntegerType.getWidth())
- return b.create<arith::ExtSIOp>(loc, targetIntegerType, value);
- return b.create<arith::TruncIOp>(loc, targetIntegerType, value);
+ return arith::ExtSIOp::create(b, loc, targetIntegerType, value);
+ return arith::TruncIOp::create(b, loc, targetIntegerType, value);
}
static Value convertScalarToIntDtype(ImplicitLocOpBuilder &b, Value operand,
@@ -142,21 +142,21 @@ static Value convertScalarToIntDtype(ImplicitLocOpBuilder &b, Value operand,
// If operand is floating point, cast directly to the int type.
if (isa<FloatType>(operand.getType())) {
if (isUnsigned)
- return b.create<arith::FPToUIOp>(toType, operand);
- return b.create<arith::FPToSIOp>(toType, operand);
+ return arith::FPToUIOp::create(b, toType, operand);
+ return arith::FPToSIOp::create(b, toType, operand);
}
// Cast index operands directly to the int type.
if (operand.getType().isIndex())
- return b.create<arith::IndexCastOp>(toType, operand);
+ return arith::IndexCastOp::create(b, toType, operand);
if (auto fromIntType = dyn_cast<IntegerType>(operand.getType())) {
// Either extend or truncate.
if (toType.getWidth() > fromIntType.getWidth()) {
if (isUnsigned)
- return b.create<arith::ExtUIOp>(toType, operand);
- return b.create<arith::ExtSIOp>(toType, operand);
+ return arith::ExtUIOp::create(b, toType, operand);
+ return arith::ExtSIOp::create(b, toType, operand);
}
if (toType.getWidth() < fromIntType.getWidth())
- return b.create<arith::TruncIOp>(toType, operand);
+ return arith::TruncIOp::create(b, toType, operand);
return operand;
}
@@ -169,14 +169,14 @@ static Value convertScalarToFpDtype(ImplicitLocOpBuilder &b, Value operand,
// Note that it is unclear how to cast from BF16<->FP16.
if (isa<IntegerType>(operand.getType())) {
if (isUnsigned)
- return b.create<arith::UIToFPOp>(toType, operand);
- return b.create<arith::SIToFPOp>(toType, operand);
+ return arith::UIToFPOp::create(b, toType, operand);
+ return arith::SIToFPOp::create(b, toType, operand);
}
if (auto fromFpTy = dyn_cast<FloatType>(operand.getType())) {
if (toType.getWidth() > fromFpTy.getWidth())
- return b.create<arith::ExtFOp>(toType, operand);
+ return arith::ExtFOp::create(b, toType, operand);
if (toType.getWidth() < fromFpTy.getWidth())
- return b.create<arith::TruncFOp>(toType, operand);
+ return arith::TruncFOp::create(b, toType, operand);
return operand;
}
@@ -189,18 +189,18 @@ static Value convertScalarToComplexDtype(ImplicitLocOpBuilder &b, Value operand,
if (auto fromComplexType = dyn_cast<ComplexType>(operand.getType())) {
if (isa<FloatType>(targetType.getElementType()) &&
isa<FloatType>(fromComplexType.getElementType())) {
- Value real = b.create<complex::ReOp>(operand);
- Value imag = b.create<complex::ImOp>(operand);
+ Value real = complex::ReOp::create(b, operand);
+ Value imag = complex::ImOp::create(b, operand);
Type targetETy = targetType.getElementType();
if (targetType.getElementType().getIntOrFloatBitWidth() <
fromComplexType.getElementType().getIntOrFloatBitWidth()) {
- real = b.create<arith::TruncFOp>(targetETy, real);
- imag = b.create<arith::TruncFOp>(targetETy, imag);
+ real = arith::TruncFOp::create(b, targetETy, real);
+ imag = arith::TruncFOp::create(b, targetETy, imag);
} else {
- real = b.create<arith::ExtFOp>(targetETy, real);
- imag = b.create<arith::ExtFOp>(targetETy, imag);
+ real = arith::ExtFOp::create(b, targetETy, real);
+ imag = arith::ExtFOp::create(b, targetETy, imag);
}
- return b.create<complex::CreateOp>(targetType, real, imag);
+ return complex::CreateOp::create(b, targetType, real, imag);
}
}
@@ -209,27 +209,27 @@ static Value convertScalarToComplexDtype(ImplicitLocOpBuilder &b, Value operand,
auto toBitwidth = toFpTy.getIntOrFloatBitWidth();
Value from = operand;
if (from.getType().getIntOrFloatBitWidth() < toBitwidth) {
- from = b.create<arith::ExtFOp>(toFpTy, from);
+ from = arith::ExtFOp::create(b, toFpTy, from);
}
if (from.getType().getIntOrFloatBitWidth() > toBitwidth) {
- from = b.create<arith::TruncFOp>(toFpTy, from);
+ from = arith::TruncFOp::create(b, toFpTy, from);
}
- Value zero = b.create<mlir::arith::ConstantFloatOp>(
- toFpTy, mlir::APFloat(toFpTy.getFloatSemantics(), 0));
- return b.create<complex::CreateOp>(targetType, from, zero);
+ Value zero = mlir::arith::ConstantFloatOp::create(
+ b, toFpTy, mlir::APFloat(toFpTy.getFloatSemantics(), 0));
+ return complex::CreateOp::create(b, targetType, from, zero);
}
if (isa<IntegerType>(operand.getType())) {
FloatType toFpTy = cast<FloatType>(targetType.getElementType());
Value from = operand;
if (isUnsigned) {
- from = b.create<arith::UIToFPOp>(toFpTy, from);
+ from = arith::UIToFPOp::create(b, toFpTy, from);
} else {
- from = b.create<arith::SIToFPOp>(toFpTy, from);
+ from = arith::SIToFPOp::create(b, toFpTy, from);
}
- Value zero = b.create<mlir::arith::ConstantFloatOp>(
- toFpTy, mlir::APFloat(toFpTy.getFloatSemantics(), 0));
- return b.create<complex::CreateOp>(targetType, from, zero);
+ Value zero = mlir::arith::ConstantFloatOp::create(
+ b, toFpTy, mlir::APFloat(toFpTy.getFloatSemantics(), 0));
+ return complex::CreateOp::create(b, targetType, from, zero);
}
return {};
@@ -277,7 +277,7 @@ Value mlir::createScalarOrSplatConstant(OpBuilder &builder, Location loc,
attr = SplatElementsAttr::get(vecTy, value);
}
- return builder.create<arith::ConstantOp>(loc, attr);
+ return arith::ConstantOp::create(builder, loc, attr);
}
Value mlir::createScalarOrSplatConstant(OpBuilder &builder, Location loc,
@@ -309,35 +309,35 @@ Type mlir::getType(OpFoldResult ofr) {
}
Value ArithBuilder::_and(Value lhs, Value rhs) {
- return b.create<arith::AndIOp>(loc, lhs, rhs);
+ return arith::AndIOp::create(b, loc, lhs, rhs);
}
Value ArithBuilder::add(Value lhs, Value rhs) {
if (isa<FloatType>(lhs.getType()))
- return b.create<arith::AddFOp>(loc, lhs, rhs);
- return b.create<arith::AddIOp>(loc, lhs, rhs, ovf);
+ return arith::AddFOp::create(b, loc, lhs, rhs);
+ return arith::AddIOp::create(b, loc, lhs, rhs, ovf);
}
Value ArithBuilder::sub(Value lhs, Value rhs) {
if (isa<FloatType>(lhs.getType()))
- return b.create<arith::SubFOp>(loc, lhs, rhs);
- return b.create<arith::SubIOp>(loc, lhs, rhs, ovf);
+ return arith::SubFOp::create(b, loc, lhs, rhs);
+ return arith::SubIOp::create(b, loc, lhs, rhs, ovf);
}
Value ArithBuilder::mul(Value lhs, Value rhs) {
if (isa<FloatType>(lhs.getType()))
- return b.create<arith::MulFOp>(loc, lhs, rhs);
- return b.create<arith::MulIOp>(loc, lhs, rhs, ovf);
+ return arith::MulFOp::create(b, loc, lhs, rhs);
+ return arith::MulIOp::create(b, loc, lhs, rhs, ovf);
}
Value ArithBuilder::sgt(Value lhs, Value rhs) {
if (isa<FloatType>(lhs.getType()))
- return b.create<arith::CmpFOp>(loc, arith::CmpFPredicate::OGT, lhs, rhs);
- return b.create<arith::CmpIOp>(loc, arith::CmpIPredicate::sgt, lhs, rhs);
+ return arith::CmpFOp::create(b, loc, arith::CmpFPredicate::OGT, lhs, rhs);
+ return arith::CmpIOp::create(b, loc, arith::CmpIPredicate::sgt, lhs, rhs);
}
Value ArithBuilder::slt(Value lhs, Value rhs) {
if (isa<FloatType>(lhs.getType()))
- return b.create<arith::CmpFOp>(loc, arith::CmpFPredicate::OLT, lhs, rhs);
- return b.create<arith::CmpIOp>(loc, arith::CmpIPredicate::slt, lhs, rhs);
+ return arith::CmpFOp::create(b, loc, arith::CmpFPredicate::OLT, lhs, rhs);
+ return arith::CmpIOp::create(b, loc, arith::CmpIPredicate::slt, lhs, rhs);
}
Value ArithBuilder::select(Value cmp, Value lhs, Value rhs) {
- return b.create<arith::SelectOp>(loc, cmp, lhs, rhs);
+ return arith::SelectOp::create(b, loc, cmp, lhs, rhs);
}
namespace mlir::arith {
@@ -348,8 +348,8 @@ Value createProduct(OpBuilder &builder, Location loc, ArrayRef<Value> values) {
Value createProduct(OpBuilder &builder, Location loc, ArrayRef<Value> values,
Type resultType) {
- Value one = builder.create<ConstantOp>(loc, resultType,
- builder.getOneAttr(resultType));
+ Value one = ConstantOp::create(builder, loc, resultType,
+ builder.getOneAttr(resultType));
ArithBuilder arithBuilder(builder, loc);
return std::accumulate(
values.begin(), values.end(), one,
More information about the Mlir-commits
mailing list