[Mlir-commits] [mlir] 38976a0 - [mlir][NFC] update `Conversion` create APIs (7/n) (#149889)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Tue Jul 22 07:41:10 PDT 2025
Author: Maksim Levental
Date: 2025-07-22T10:41:06-04:00
New Revision: 38976a03cd367b27437e0d1e81c0ccaee2777b47
URL: https://github.com/llvm/llvm-project/commit/38976a03cd367b27437e0d1e81c0ccaee2777b47
DIFF: https://github.com/llvm/llvm-project/commit/38976a03cd367b27437e0d1e81c0ccaee2777b47.diff
LOG: [mlir][NFC] update `Conversion` create APIs (7/n) (#149889)
See https://github.com/llvm/llvm-project/pull/147168 for more info.
Added:
Modified:
mlir/lib/Conversion/SCFToControlFlow/SCFToControlFlow.cpp
mlir/lib/Conversion/SCFToEmitC/SCFToEmitC.cpp
mlir/lib/Conversion/SCFToGPU/SCFToGPU.cpp
mlir/lib/Conversion/SCFToOpenMP/SCFToOpenMP.cpp
mlir/lib/Conversion/SCFToSPIRV/SCFToSPIRV.cpp
mlir/lib/Conversion/SPIRVToLLVM/ConvertLaunchFuncToLLVMCalls.cpp
mlir/lib/Conversion/SPIRVToLLVM/SPIRVToLLVM.cpp
mlir/lib/Conversion/ShapeToStandard/ConvertShapeConstraints.cpp
mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.cpp
mlir/lib/Conversion/TensorToSPIRV/TensorToSPIRV.cpp
mlir/lib/Conversion/TosaToArith/TosaToArith.cpp
mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp
mlir/lib/Conversion/TosaToMLProgram/TosaToMLProgram.cpp
mlir/lib/Conversion/TosaToSCF/TosaToSCF.cpp
mlir/lib/Conversion/TosaToTensor/TosaToTensor.cpp
mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp
mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp
mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp
mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
mlir/lib/Conversion/VectorToXeGPU/VectorToXeGPU.cpp
mlir/lib/Conversion/XeVMToLLVM/XeVMToLLVM.cpp
Removed:
################################################################################
diff --git a/mlir/lib/Conversion/SCFToControlFlow/SCFToControlFlow.cpp b/mlir/lib/Conversion/SCFToControlFlow/SCFToControlFlow.cpp
index 0df91a243d07a..240491a51d2b9 100644
--- a/mlir/lib/Conversion/SCFToControlFlow/SCFToControlFlow.cpp
+++ b/mlir/lib/Conversion/SCFToControlFlow/SCFToControlFlow.cpp
@@ -340,7 +340,7 @@ LogicalResult ForLowering::matchAndRewrite(ForOp forOp,
Operation *terminator = lastBodyBlock->getTerminator();
rewriter.setInsertionPointToEnd(lastBodyBlock);
auto step = forOp.getStep();
- auto stepped = rewriter.create<arith::AddIOp>(loc, iv, step).getResult();
+ auto stepped = arith::AddIOp::create(rewriter, loc, iv, step).getResult();
if (!stepped)
return failure();
@@ -348,7 +348,7 @@ LogicalResult ForLowering::matchAndRewrite(ForOp forOp,
loopCarried.push_back(stepped);
loopCarried.append(terminator->operand_begin(), terminator->operand_end());
auto branchOp =
- rewriter.create<cf::BranchOp>(loc, conditionBlock, loopCarried);
+ cf::BranchOp::create(rewriter, loc, conditionBlock, loopCarried);
// Let the CondBranchOp carry the LLVM attributes from the ForOp, such as the
// llvm.loop_annotation attribute.
@@ -375,16 +375,15 @@ LogicalResult ForLowering::matchAndRewrite(ForOp forOp,
SmallVector<Value, 8> destOperands;
destOperands.push_back(lowerBound);
llvm::append_range(destOperands, forOp.getInitArgs());
- rewriter.create<cf::BranchOp>(loc, conditionBlock, destOperands);
+ cf::BranchOp::create(rewriter, loc, conditionBlock, destOperands);
// With the body block done, we can fill in the condition block.
rewriter.setInsertionPointToEnd(conditionBlock);
- auto comparison = rewriter.create<arith::CmpIOp>(
- loc, arith::CmpIPredicate::slt, iv, upperBound);
+ auto comparison = arith::CmpIOp::create(
+ rewriter, loc, arith::CmpIPredicate::slt, iv, upperBound);
- rewriter.create<cf::CondBranchOp>(loc, comparison, firstBodyBlock,
- ArrayRef<Value>(), endBlock,
- ArrayRef<Value>());
+ cf::CondBranchOp::create(rewriter, loc, comparison, firstBodyBlock,
+ ArrayRef<Value>(), endBlock, ArrayRef<Value>());
// The result of the loop operation is the values of the condition block
// arguments except the induction variable on the last iteration.
@@ -409,7 +408,7 @@ LogicalResult IfLowering::matchAndRewrite(IfOp ifOp,
continueBlock =
rewriter.createBlock(remainingOpsBlock, ifOp.getResultTypes(),
SmallVector<Location>(ifOp.getNumResults(), loc));
- rewriter.create<cf::BranchOp>(loc, remainingOpsBlock);
+ cf::BranchOp::create(rewriter, loc, remainingOpsBlock);
}
// Move blocks from the "then" region to the region containing 'scf.if',
@@ -419,7 +418,7 @@ LogicalResult IfLowering::matchAndRewrite(IfOp ifOp,
Operation *thenTerminator = thenRegion.back().getTerminator();
ValueRange thenTerminatorOperands = thenTerminator->getOperands();
rewriter.setInsertionPointToEnd(&thenRegion.back());
- rewriter.create<cf::BranchOp>(loc, continueBlock, thenTerminatorOperands);
+ cf::BranchOp::create(rewriter, loc, continueBlock, thenTerminatorOperands);
rewriter.eraseOp(thenTerminator);
rewriter.inlineRegionBefore(thenRegion, continueBlock);
@@ -433,15 +432,15 @@ LogicalResult IfLowering::matchAndRewrite(IfOp ifOp,
Operation *elseTerminator = elseRegion.back().getTerminator();
ValueRange elseTerminatorOperands = elseTerminator->getOperands();
rewriter.setInsertionPointToEnd(&elseRegion.back());
- rewriter.create<cf::BranchOp>(loc, continueBlock, elseTerminatorOperands);
+ cf::BranchOp::create(rewriter, loc, continueBlock, elseTerminatorOperands);
rewriter.eraseOp(elseTerminator);
rewriter.inlineRegionBefore(elseRegion, continueBlock);
}
rewriter.setInsertionPointToEnd(condBlock);
- rewriter.create<cf::CondBranchOp>(loc, ifOp.getCondition(), thenBlock,
- /*trueArgs=*/ArrayRef<Value>(), elseBlock,
- /*falseArgs=*/ArrayRef<Value>());
+ cf::CondBranchOp::create(rewriter, loc, ifOp.getCondition(), thenBlock,
+ /*trueArgs=*/ArrayRef<Value>(), elseBlock,
+ /*falseArgs=*/ArrayRef<Value>());
// Ok, we're done!
rewriter.replaceOp(ifOp, continueBlock->getArguments());
@@ -459,13 +458,14 @@ ExecuteRegionLowering::matchAndRewrite(ExecuteRegionOp op,
auto ®ion = op.getRegion();
rewriter.setInsertionPointToEnd(condBlock);
- rewriter.create<cf::BranchOp>(loc, ®ion.front());
+ cf::BranchOp::create(rewriter, loc, ®ion.front());
for (Block &block : region) {
if (auto terminator = dyn_cast<scf::YieldOp>(block.getTerminator())) {
ValueRange terminatorOperands = terminator->getOperands();
rewriter.setInsertionPointToEnd(&block);
- rewriter.create<cf::BranchOp>(loc, remainingOpsBlock, terminatorOperands);
+ cf::BranchOp::create(rewriter, loc, remainingOpsBlock,
+ terminatorOperands);
rewriter.eraseOp(terminator);
}
}
@@ -503,7 +503,7 @@ ParallelLowering::matchAndRewrite(ParallelOp parallelOp,
for (auto [iv, lower, upper, step] :
llvm::zip(parallelOp.getInductionVars(), parallelOp.getLowerBound(),
parallelOp.getUpperBound(), parallelOp.getStep())) {
- ForOp forOp = rewriter.create<ForOp>(loc, lower, upper, step, iterArgs);
+ ForOp forOp = ForOp::create(rewriter, loc, lower, upper, step, iterArgs);
ivs.push_back(forOp.getInductionVar());
auto iterRange = forOp.getRegionIterArgs();
iterArgs.assign(iterRange.begin(), iterRange.end());
@@ -517,7 +517,7 @@ ParallelLowering::matchAndRewrite(ParallelOp parallelOp,
// A loop is constructed with an empty "yield" terminator if there are
// no results.
rewriter.setInsertionPointToEnd(rewriter.getInsertionBlock());
- rewriter.create<scf::YieldOp>(loc, forOp.getResults());
+ scf::YieldOp::create(rewriter, loc, forOp.getResults());
}
rewriter.setInsertionPointToStart(forOp.getBody());
@@ -549,7 +549,7 @@ ParallelLowering::matchAndRewrite(ParallelOp parallelOp,
// has been already created in loop construction).
if (!yieldOperands.empty()) {
rewriter.setInsertionPointToEnd(rewriter.getInsertionBlock());
- rewriter.create<scf::YieldOp>(loc, yieldOperands);
+ scf::YieldOp::create(rewriter, loc, yieldOperands);
}
rewriter.replaceOp(parallelOp, loopResults);
@@ -575,7 +575,7 @@ LogicalResult WhileLowering::matchAndRewrite(WhileOp whileOp,
// Branch to the "before" region.
rewriter.setInsertionPointToEnd(currentBlock);
- rewriter.create<cf::BranchOp>(loc, before, whileOp.getInits());
+ cf::BranchOp::create(rewriter, loc, before, whileOp.getInits());
// Replace terminators with branches. Assuming bodies are SESE, which holds
// given only the patterns from this file, we only need to look at the last
@@ -625,14 +625,14 @@ DoWhileLowering::matchAndRewrite(WhileOp whileOp,
// Branch to the "before" region.
rewriter.setInsertionPointToEnd(currentBlock);
- rewriter.create<cf::BranchOp>(whileOp.getLoc(), before, whileOp.getInits());
+ cf::BranchOp::create(rewriter, whileOp.getLoc(), before, whileOp.getInits());
// Loop around the "before" region based on condition.
rewriter.setInsertionPointToEnd(before);
auto condOp = cast<ConditionOp>(before->getTerminator());
- rewriter.create<cf::CondBranchOp>(condOp.getLoc(), condOp.getCondition(),
- before, condOp.getArgs(), continuation,
- ValueRange());
+ cf::CondBranchOp::create(rewriter, condOp.getLoc(), condOp.getCondition(),
+ before, condOp.getArgs(), continuation,
+ ValueRange());
// Replace the op with values "yielded" from the "before" region, which are
// visible by dominance.
@@ -695,12 +695,12 @@ IndexSwitchLowering::matchAndRewrite(IndexSwitchOp op,
SmallVector<ValueRange> caseOperands(caseSuccessors.size(), {});
// Cast switch index to integer case value.
- Value caseValue = rewriter.create<arith::IndexCastOp>(
- op.getLoc(), rewriter.getI32Type(), op.getArg());
+ Value caseValue = arith::IndexCastOp::create(
+ rewriter, op.getLoc(), rewriter.getI32Type(), op.getArg());
- rewriter.create<cf::SwitchOp>(
- op.getLoc(), caseValue, *defaultBlock, ValueRange(),
- rewriter.getDenseI32ArrayAttr(caseValues), caseSuccessors, caseOperands);
+ cf::SwitchOp::create(rewriter, op.getLoc(), caseValue, *defaultBlock,
+ ValueRange(), rewriter.getDenseI32ArrayAttr(caseValues),
+ caseSuccessors, caseOperands);
rewriter.replaceOp(op, continueBlock->getArguments());
return success();
}
diff --git a/mlir/lib/Conversion/SCFToEmitC/SCFToEmitC.cpp b/mlir/lib/Conversion/SCFToEmitC/SCFToEmitC.cpp
index dcb48529a74e6..84cbd869c78ef 100644
--- a/mlir/lib/Conversion/SCFToEmitC/SCFToEmitC.cpp
+++ b/mlir/lib/Conversion/SCFToEmitC/SCFToEmitC.cpp
@@ -91,7 +91,7 @@ createVariablesForResults(T op, const TypeConverter *typeConverter,
Type varType = emitc::LValueType::get(resultType);
emitc::OpaqueAttr noInit = emitc::OpaqueAttr::get(context, "");
emitc::VariableOp var =
- rewriter.create<emitc::VariableOp>(loc, varType, noInit);
+ emitc::VariableOp::create(rewriter, loc, varType, noInit);
resultVariables.push_back(var);
}
@@ -103,14 +103,14 @@ createVariablesForResults(T op, const TypeConverter *typeConverter,
static void assignValues(ValueRange values, ValueRange variables,
ConversionPatternRewriter &rewriter, Location loc) {
for (auto [value, var] : llvm::zip(values, variables))
- rewriter.create<emitc::AssignOp>(loc, var, value);
+ emitc::AssignOp::create(rewriter, loc, var, value);
}
SmallVector<Value> loadValues(const SmallVector<Value> &variables,
PatternRewriter &rewriter, Location loc) {
return llvm::map_to_vector<>(variables, [&](Value var) {
Type type = cast<emitc::LValueType>(var.getType()).getValueType();
- return rewriter.create<emitc::LoadOp>(loc, type, var).getResult();
+ return emitc::LoadOp::create(rewriter, loc, type, var).getResult();
});
}
@@ -129,7 +129,7 @@ static LogicalResult lowerYield(Operation *op, ValueRange resultVariables,
assignValues(yieldOperands, resultVariables, rewriter, loc);
- rewriter.create<emitc::YieldOp>(loc);
+ emitc::YieldOp::create(rewriter, loc);
rewriter.eraseOp(yield);
return success();
@@ -164,8 +164,9 @@ ForLowering::matchAndRewrite(ForOp forOp, OpAdaptor adaptor,
assignValues(adaptor.getInitArgs(), resultVariables, rewriter, loc);
- emitc::ForOp loweredFor = rewriter.create<emitc::ForOp>(
- loc, adaptor.getLowerBound(), adaptor.getUpperBound(), adaptor.getStep());
+ emitc::ForOp loweredFor =
+ emitc::ForOp::create(rewriter, loc, adaptor.getLowerBound(),
+ adaptor.getUpperBound(), adaptor.getStep());
Block *loweredBody = loweredFor.getBody();
@@ -257,7 +258,7 @@ IfLowering::matchAndRewrite(IfOp ifOp, OpAdaptor adaptor,
bool hasElseBlock = !elseRegion.empty();
auto loweredIf =
- rewriter.create<emitc::IfOp>(loc, adaptor.getCondition(), false, false);
+ emitc::IfOp::create(rewriter, loc, adaptor.getCondition(), false, false);
Region &loweredThenRegion = loweredIf.getThenRegion();
auto result = lowerRegion(thenRegion, loweredThenRegion);
@@ -304,8 +305,9 @@ LogicalResult IndexSwitchOpLowering::matchAndRewrite(
"create variables for results failed");
}
- auto loweredSwitch = rewriter.create<emitc::SwitchOp>(
- loc, adaptor.getArg(), adaptor.getCases(), indexSwitchOp.getNumCases());
+ auto loweredSwitch =
+ emitc::SwitchOp::create(rewriter, loc, adaptor.getArg(),
+ adaptor.getCases(), indexSwitchOp.getNumCases());
// Lowering all case regions.
for (auto pair :
diff --git a/mlir/lib/Conversion/SCFToGPU/SCFToGPU.cpp b/mlir/lib/Conversion/SCFToGPU/SCFToGPU.cpp
index 844e66e927c4d..f191f3502cf5a 100644
--- a/mlir/lib/Conversion/SCFToGPU/SCFToGPU.cpp
+++ b/mlir/lib/Conversion/SCFToGPU/SCFToGPU.cpp
@@ -84,8 +84,8 @@ static Operation::operand_range getUpperBoundOperands(AffineForOp forOp) {
// Get a Value that corresponds to the loop step. If the step is an attribute,
// materialize a corresponding constant using builder.
static Value getOrCreateStep(AffineForOp forOp, OpBuilder &builder) {
- return builder.create<arith::ConstantIndexOp>(forOp.getLoc(),
- forOp.getStepAsInt());
+ return arith::ConstantIndexOp::create(builder, forOp.getLoc(),
+ forOp.getStepAsInt());
}
// Get a Value for the loop lower bound. If the value requires computation,
@@ -190,12 +190,12 @@ AffineLoopToGpuConverter::collectBounds(AffineForOp forOp, unsigned numLoops) {
return std::nullopt;
}
- Value range = builder.create<arith::SubIOp>(currentLoop.getLoc(),
- upperBound, lowerBound);
+ Value range = arith::SubIOp::create(builder, currentLoop.getLoc(),
+ upperBound, lowerBound);
Value step = getOrCreateStep(currentLoop, builder);
if (getConstantIntValue(step) != static_cast<int64_t>(1))
- range =
- builder.create<arith::CeilDivSIOp>(currentLoop.getLoc(), range, step);
+ range = arith::CeilDivSIOp::create(builder, currentLoop.getLoc(), range,
+ step);
dims.push_back(range);
lbs.push_back(lowerBound);
@@ -221,7 +221,7 @@ void AffineLoopToGpuConverter::createLaunch(AffineForOp rootForOp,
// no loop mapped to a specific dimension, use constant "1" as its size.
Value constOne =
(numBlockDims < 3 || numThreadDims < 3)
- ? builder.create<arith::ConstantIndexOp>(rootForOp.getLoc(), 1)
+ ? arith::ConstantIndexOp::create(builder, rootForOp.getLoc(), 1)
: nullptr;
Value gridSizeX = numBlockDims > 0 ? dims[0] : constOne;
Value gridSizeY = numBlockDims > 1 ? dims[1] : constOne;
@@ -232,9 +232,9 @@ void AffineLoopToGpuConverter::createLaunch(AffineForOp rootForOp,
// Create a launch op and move the body region of the innermost loop to the
// launch op.
- auto launchOp = builder.create<gpu::LaunchOp>(
- rootForOp.getLoc(), gridSizeX, gridSizeY, gridSizeZ, blockSizeX,
- blockSizeY, blockSizeZ);
+ auto launchOp =
+ gpu::LaunchOp::create(builder, rootForOp.getLoc(), gridSizeX, gridSizeY,
+ gridSizeZ, blockSizeX, blockSizeY, blockSizeZ);
// Replace the loop terminator (loops contain only a single block) with the
// gpu terminator and move the operations from the loop body block to the gpu
@@ -244,7 +244,7 @@ void AffineLoopToGpuConverter::createLaunch(AffineForOp rootForOp,
Location terminatorLoc = terminator.getLoc();
terminator.erase();
builder.setInsertionPointToEnd(innermostForOp.getBody());
- builder.create<gpu::TerminatorOp>(terminatorLoc, TypeRange());
+ gpu::TerminatorOp::create(builder, terminatorLoc, TypeRange());
launchOp.getBody().front().getOperations().splice(
launchOp.getBody().front().begin(),
innermostForOp.getBody()->getOperations());
@@ -263,10 +263,10 @@ void AffineLoopToGpuConverter::createLaunch(AffineForOp rootForOp,
: getDim3Value(launchOp.getThreadIds(), en.index() - numBlockDims);
Value step = steps[en.index()];
if (getConstantIntValue(step) != static_cast<int64_t>(1))
- id = builder.create<arith::MulIOp>(rootForOp.getLoc(), step, id);
+ id = arith::MulIOp::create(builder, rootForOp.getLoc(), step, id);
Value ivReplacement =
- builder.create<arith::AddIOp>(rootForOp.getLoc(), *lbArgumentIt, id);
+ arith::AddIOp::create(builder, rootForOp.getLoc(), *lbArgumentIt, id);
en.value().replaceAllUsesWith(ivReplacement);
std::advance(lbArgumentIt, 1);
std::advance(stepArgumentIt, 1);
@@ -319,8 +319,8 @@ static Value deriveStaticUpperBound(Value upperBound,
if (auto minOp = upperBound.getDefiningOp<AffineMinOp>()) {
for (const AffineExpr &result : minOp.getMap().getResults()) {
if (auto constExpr = dyn_cast<AffineConstantExpr>(result)) {
- return rewriter.create<arith::ConstantIndexOp>(minOp.getLoc(),
- constExpr.getValue());
+ return arith::ConstantIndexOp::create(rewriter, minOp.getLoc(),
+ constExpr.getValue());
}
}
}
@@ -344,8 +344,8 @@ static Value deriveStaticUpperBound(Value upperBound,
if ((lhs.value() < 0) != (rhs.value() < 0))
return {};
- return rewriter.create<arith::ConstantIndexOp>(
- multiplyOp.getLoc(), lhs.value() * rhs.value());
+ return arith::ConstantIndexOp::create(rewriter, multiplyOp.getLoc(),
+ lhs.value() * rhs.value());
}
}
@@ -422,8 +422,8 @@ static LogicalResult processParallelLoop(
if (launchIndependent(val))
return val;
if (auto constOp = val.getDefiningOp<arith::ConstantOp>())
- return rewriter.create<arith::ConstantOp>(constOp.getLoc(),
- constOp.getValue());
+ return arith::ConstantOp::create(rewriter, constOp.getLoc(),
+ constOp.getValue());
return {};
};
@@ -453,8 +453,8 @@ static LogicalResult processParallelLoop(
1, 2,
rewriter.getAffineDimExpr(0) * rewriter.getAffineSymbolExpr(0) +
rewriter.getAffineSymbolExpr(1));
- newIndex = rewriter.create<AffineApplyOp>(
- loc, annotation.getMap().compose(lowerAndStep),
+ newIndex = AffineApplyOp::create(
+ rewriter, loc, annotation.getMap().compose(lowerAndStep),
ValueRange{operand, ensureLaunchIndependent(step),
ensureLaunchIndependent(lowerBound)});
// If there was also a bound, insert that, too.
@@ -498,8 +498,8 @@ static LogicalResult processParallelLoop(
1, 2,
((rewriter.getAffineDimExpr(0) - rewriter.getAffineSymbolExpr(0))
.ceilDiv(rewriter.getAffineSymbolExpr(1))));
- Value launchBound = rewriter.create<AffineApplyOp>(
- loc, annotation.getBound().compose(stepMap),
+ Value launchBound = AffineApplyOp::create(
+ rewriter, loc, annotation.getBound().compose(stepMap),
ValueRange{
ensureLaunchIndependent(
cloningMap.lookupOrDefault(upperBound)),
@@ -517,10 +517,10 @@ static LogicalResult processParallelLoop(
if (!boundIsPrecise) {
// We are using an approximation, create a surrounding conditional.
Value originalBound = std::get<3>(config);
- arith::CmpIOp pred = rewriter.create<arith::CmpIOp>(
- loc, arith::CmpIPredicate::slt, newIndex,
+ arith::CmpIOp pred = arith::CmpIOp::create(
+ rewriter, loc, arith::CmpIPredicate::slt, newIndex,
cloningMap.lookupOrDefault(originalBound));
- scf::IfOp ifOp = rewriter.create<scf::IfOp>(loc, pred, false);
+ scf::IfOp ifOp = scf::IfOp::create(rewriter, loc, pred, false);
rewriter.setInsertionPointToStart(&ifOp.getThenRegion().front());
// Put a sentinel into the worklist so we know when to pop out of the
// if body again. We use the launchOp here, as that cannot be part of
@@ -530,10 +530,10 @@ static LogicalResult processParallelLoop(
}
} else {
// Create a sequential for loop.
- auto loopOp = rewriter.create<scf::ForOp>(
- loc, cloningMap.lookupOrDefault(lowerBound),
- cloningMap.lookupOrDefault(upperBound),
- cloningMap.lookupOrDefault(step));
+ auto loopOp = scf::ForOp::create(rewriter, loc,
+ cloningMap.lookupOrDefault(lowerBound),
+ cloningMap.lookupOrDefault(upperBound),
+ cloningMap.lookupOrDefault(step));
newIndex = loopOp.getInductionVar();
rewriter.setInsertionPointToStart(loopOp.getBody());
// Put a sentinel into the worklist so we know when to pop out of the loop
@@ -608,12 +608,12 @@ ParallelToGpuLaunchLowering::matchAndRewrite(ParallelOp parallelOp,
// sizes. Those will be refined later as we discover them from mappings.
Location loc = parallelOp.getLoc();
Value constantOne =
- rewriter.create<arith::ConstantIndexOp>(parallelOp.getLoc(), 1);
- gpu::LaunchOp launchOp = rewriter.create<gpu::LaunchOp>(
- parallelOp.getLoc(), constantOne, constantOne, constantOne, constantOne,
- constantOne, constantOne);
+ arith::ConstantIndexOp::create(rewriter, parallelOp.getLoc(), 1);
+ gpu::LaunchOp launchOp = gpu::LaunchOp::create(
+ rewriter, parallelOp.getLoc(), constantOne, constantOne, constantOne,
+ constantOne, constantOne, constantOne);
rewriter.setInsertionPointToEnd(&launchOp.getBody().front());
- rewriter.create<gpu::TerminatorOp>(loc);
+ gpu::TerminatorOp::create(rewriter, loc);
rewriter.setInsertionPointToStart(&launchOp.getBody().front());
IRMapping cloningMap;
@@ -667,7 +667,7 @@ ParallelToGpuLaunchLowering::matchAndRewrite(ParallelOp parallelOp,
if (externalValues.size())
return failure();
// Replace by gpu.all_reduce.
- auto gpuRedOp = rewriter.create<gpu::AllReduceOp>(loc, newValue);
+ auto gpuRedOp = gpu::AllReduceOp::create(rewriter, loc, newValue);
cloningMap.map(parentLoop->getResult(0), gpuRedOp.getResult());
// Copy region.
rewriter.inlineRegionBefore(reduceOp.getRegion(0), gpuRedOp.getRegion(),
diff --git a/mlir/lib/Conversion/SCFToOpenMP/SCFToOpenMP.cpp b/mlir/lib/Conversion/SCFToOpenMP/SCFToOpenMP.cpp
index 584ac2f11b670..34f372af1e4b5 100644
--- a/mlir/lib/Conversion/SCFToOpenMP/SCFToOpenMP.cpp
+++ b/mlir/lib/Conversion/SCFToOpenMP/SCFToOpenMP.cpp
@@ -187,8 +187,8 @@ createDecl(PatternRewriter &builder, SymbolTable &symbolTable,
scf::ReduceOp reduce, int64_t reductionIndex, Attribute initValue) {
OpBuilder::InsertionGuard guard(builder);
Type type = reduce.getOperands()[reductionIndex].getType();
- auto decl = builder.create<omp::DeclareReductionOp>(reduce.getLoc(),
- "__scf_reduction", type);
+ auto decl = omp::DeclareReductionOp::create(builder, reduce.getLoc(),
+ "__scf_reduction", type);
symbolTable.insert(decl);
builder.createBlock(&decl.getInitializerRegion(),
@@ -196,8 +196,8 @@ createDecl(PatternRewriter &builder, SymbolTable &symbolTable,
{reduce.getOperands()[reductionIndex].getLoc()});
builder.setInsertionPointToEnd(&decl.getInitializerRegion().back());
Value init =
- builder.create<LLVM::ConstantOp>(reduce.getLoc(), type, initValue);
- builder.create<omp::YieldOp>(reduce.getLoc(), init);
+ LLVM::ConstantOp::create(builder, reduce.getLoc(), type, initValue);
+ omp::YieldOp::create(builder, reduce.getLoc(), init);
Operation *terminator =
&reduce.getReductions()[reductionIndex].front().back();
@@ -227,12 +227,12 @@ static omp::DeclareReductionOp addAtomicRMW(OpBuilder &builder,
{reduceOperandLoc, reduceOperandLoc});
Block *atomicBlock = &decl.getAtomicReductionRegion().back();
builder.setInsertionPointToEnd(atomicBlock);
- Value loaded = builder.create<LLVM::LoadOp>(reduce.getLoc(), decl.getType(),
- atomicBlock->getArgument(1));
- builder.create<LLVM::AtomicRMWOp>(reduce.getLoc(), atomicKind,
- atomicBlock->getArgument(0), loaded,
- LLVM::AtomicOrdering::monotonic);
- builder.create<omp::YieldOp>(reduce.getLoc(), ArrayRef<Value>());
+ Value loaded = LLVM::LoadOp::create(builder, reduce.getLoc(), decl.getType(),
+ atomicBlock->getArgument(1));
+ LLVM::AtomicRMWOp::create(builder, reduce.getLoc(), atomicKind,
+ atomicBlock->getArgument(0), loaded,
+ LLVM::AtomicOrdering::monotonic);
+ omp::YieldOp::create(builder, reduce.getLoc(), ArrayRef<Value>());
return decl;
}
@@ -380,8 +380,9 @@ struct ParallelOpLowering : public OpRewritePattern<scf::ParallelOp> {
// Allocate reduction variables. Make sure the we don't overflow the stack
// with local `alloca`s by saving and restoring the stack pointer.
Location loc = parallelOp.getLoc();
- Value one = rewriter.create<LLVM::ConstantOp>(
- loc, rewriter.getIntegerType(64), rewriter.getI64IntegerAttr(1));
+ Value one =
+ LLVM::ConstantOp::create(rewriter, loc, rewriter.getIntegerType(64),
+ rewriter.getI64IntegerAttr(1));
SmallVector<Value> reductionVariables;
reductionVariables.reserve(parallelOp.getNumReductions());
auto ptrType = LLVM::LLVMPointerType::get(parallelOp.getContext());
@@ -390,9 +391,9 @@ struct ParallelOpLowering : public OpRewritePattern<scf::ParallelOp> {
isa<LLVM::PointerElementTypeInterface>(init.getType())) &&
"cannot create a reduction variable if the type is not an LLVM "
"pointer element");
- Value storage =
- rewriter.create<LLVM::AllocaOp>(loc, ptrType, init.getType(), one, 0);
- rewriter.create<LLVM::StoreOp>(loc, init, storage);
+ Value storage = LLVM::AllocaOp::create(rewriter, loc, ptrType,
+ init.getType(), one, 0);
+ LLVM::StoreOp::create(rewriter, loc, init, storage);
reductionVariables.push_back(storage);
}
@@ -411,8 +412,8 @@ struct ParallelOpLowering : public OpRewritePattern<scf::ParallelOp> {
assert(redRegion.hasOneBlock() &&
"expect reduction region to have one block");
Value pvtRedVar = parallelOp.getRegion().addArgument(x.getType(), loc);
- Value pvtRedVal = rewriter.create<LLVM::LoadOp>(reduce.getLoc(),
- rD.getType(), pvtRedVar);
+ Value pvtRedVal = LLVM::LoadOp::create(rewriter, reduce.getLoc(),
+ rD.getType(), pvtRedVar);
// Make a copy of the reduction combiner region in the body
mlir::OpBuilder builder(rewriter.getContext());
builder.setInsertionPoint(reduce);
@@ -427,7 +428,7 @@ struct ParallelOpLowering : public OpRewritePattern<scf::ParallelOp> {
assert(yieldOp && yieldOp.getResults().size() == 1 &&
"expect YieldOp in reduction region to return one result");
Value redVal = yieldOp.getResults()[0];
- rewriter.create<LLVM::StoreOp>(loc, redVal, pvtRedVar);
+ LLVM::StoreOp::create(rewriter, loc, redVal, pvtRedVar);
rewriter.eraseOp(yieldOp);
break;
}
@@ -437,12 +438,12 @@ struct ParallelOpLowering : public OpRewritePattern<scf::ParallelOp> {
Value numThreadsVar;
if (numThreads > 0) {
- numThreadsVar = rewriter.create<LLVM::ConstantOp>(
- loc, rewriter.getI32IntegerAttr(numThreads));
+ numThreadsVar = LLVM::ConstantOp::create(
+ rewriter, loc, rewriter.getI32IntegerAttr(numThreads));
}
// Create the parallel wrapper.
- auto ompParallel = rewriter.create<omp::ParallelOp>(
- loc,
+ auto ompParallel = omp::ParallelOp::create(
+ rewriter, loc,
/* allocate_vars = */ llvm::SmallVector<Value>{},
/* allocator_vars = */ llvm::SmallVector<Value>{},
/* if_expr = */ Value{},
@@ -464,7 +465,7 @@ struct ParallelOpLowering : public OpRewritePattern<scf::ParallelOp> {
{
OpBuilder::InsertionGuard allocaGuard(rewriter);
// Create worksharing loop wrapper.
- auto wsloopOp = rewriter.create<omp::WsloopOp>(parallelOp.getLoc());
+ auto wsloopOp = omp::WsloopOp::create(rewriter, parallelOp.getLoc());
if (!reductionVariables.empty()) {
wsloopOp.setReductionSymsAttr(
ArrayAttr::get(rewriter.getContext(), reductionSyms));
@@ -476,7 +477,7 @@ struct ParallelOpLowering : public OpRewritePattern<scf::ParallelOp> {
wsloopOp.setReductionByref(
DenseBoolArrayAttr::get(rewriter.getContext(), reductionByRef));
}
- rewriter.create<omp::TerminatorOp>(loc); // omp.parallel terminator.
+ omp::TerminatorOp::create(rewriter, loc); // omp.parallel terminator.
// The wrapper's entry block arguments will define the reduction
// variables.
@@ -490,8 +491,8 @@ struct ParallelOpLowering : public OpRewritePattern<scf::ParallelOp> {
parallelOp.getLoc()));
// Create loop nest and populate region with contents of scf.parallel.
- auto loopOp = rewriter.create<omp::LoopNestOp>(
- parallelOp.getLoc(), parallelOp.getLowerBound(),
+ auto loopOp = omp::LoopNestOp::create(
+ rewriter, parallelOp.getLoc(), parallelOp.getLowerBound(),
parallelOp.getUpperBound(), parallelOp.getStep());
rewriter.inlineRegionBefore(parallelOp.getRegion(), loopOp.getRegion(),
@@ -511,13 +512,13 @@ struct ParallelOpLowering : public OpRewritePattern<scf::ParallelOp> {
rewriter.splitBlock(&loopOpEntryBlock, loopOpEntryBlock.begin());
rewriter.setInsertionPointToStart(&loopOpEntryBlock);
- auto scope = rewriter.create<memref::AllocaScopeOp>(parallelOp.getLoc(),
- TypeRange());
- rewriter.create<omp::YieldOp>(loc, ValueRange());
+ auto scope = memref::AllocaScopeOp::create(
+ rewriter, parallelOp.getLoc(), TypeRange());
+ omp::YieldOp::create(rewriter, loc, ValueRange());
Block *scopeBlock = rewriter.createBlock(&scope.getBodyRegion());
rewriter.mergeBlocks(ops, scopeBlock);
rewriter.setInsertionPointToEnd(&*scope.getBodyRegion().begin());
- rewriter.create<memref::AllocaScopeReturnOp>(loc, ValueRange());
+ memref::AllocaScopeReturnOp::create(rewriter, loc, ValueRange());
}
}
@@ -526,7 +527,7 @@ struct ParallelOpLowering : public OpRewritePattern<scf::ParallelOp> {
results.reserve(reductionVariables.size());
for (auto [variable, type] :
llvm::zip(reductionVariables, parallelOp.getResultTypes())) {
- Value res = rewriter.create<LLVM::LoadOp>(loc, type, variable);
+ Value res = LLVM::LoadOp::create(rewriter, loc, type, variable);
results.push_back(res);
}
rewriter.replaceOp(parallelOp, results);
diff --git a/mlir/lib/Conversion/SCFToSPIRV/SCFToSPIRV.cpp b/mlir/lib/Conversion/SCFToSPIRV/SCFToSPIRV.cpp
index 78d13278fef53..dc92367fc58cd 100644
--- a/mlir/lib/Conversion/SCFToSPIRV/SCFToSPIRV.cpp
+++ b/mlir/lib/Conversion/SCFToSPIRV/SCFToSPIRV.cpp
@@ -71,12 +71,12 @@ void replaceSCFOutputValue(ScfOp scfOp, OpTy newOp,
auto pointerType =
spirv::PointerType::get(convertedType, spirv::StorageClass::Function);
rewriter.setInsertionPoint(newOp);
- auto alloc = rewriter.create<spirv::VariableOp>(
- loc, pointerType, spirv::StorageClass::Function,
- /*initializer=*/nullptr);
+ auto alloc = spirv::VariableOp::create(rewriter, loc, pointerType,
+ spirv::StorageClass::Function,
+ /*initializer=*/nullptr);
allocas.push_back(alloc);
rewriter.setInsertionPointAfter(newOp);
- Value loadResult = rewriter.create<spirv::LoadOp>(loc, alloc);
+ Value loadResult = spirv::LoadOp::create(rewriter, loc, alloc);
resultValue.push_back(loadResult);
}
rewriter.replaceOp(scfOp, resultValue);
@@ -135,7 +135,8 @@ struct ForOpConversion final : SCFToSPIRVPattern<scf::ForOp> {
// a single back edge from the continue to header block, and a single exit
// from header to merge.
auto loc = forOp.getLoc();
- auto loopOp = rewriter.create<spirv::LoopOp>(loc, spirv::LoopControl::None);
+ auto loopOp =
+ spirv::LoopOp::create(rewriter, loc, spirv::LoopControl::None);
loopOp.addEntryAndMergeBlock(rewriter);
OpBuilder::InsertionGuard guard(rewriter);
@@ -172,16 +173,17 @@ struct ForOpConversion final : SCFToSPIRVPattern<scf::ForOp> {
args.append(adaptor.getInitArgs().begin(), adaptor.getInitArgs().end());
// Branch into it from the entry.
rewriter.setInsertionPointToEnd(&(loopOp.getBody().front()));
- rewriter.create<spirv::BranchOp>(loc, header, args);
+ spirv::BranchOp::create(rewriter, loc, header, args);
// Generate the rest of the loop header.
rewriter.setInsertionPointToEnd(header);
auto *mergeBlock = loopOp.getMergeBlock();
- auto cmpOp = rewriter.create<spirv::SLessThanOp>(
- loc, rewriter.getI1Type(), newIndVar, adaptor.getUpperBound());
+ auto cmpOp = spirv::SLessThanOp::create(rewriter, loc, rewriter.getI1Type(),
+ newIndVar, adaptor.getUpperBound());
- rewriter.create<spirv::BranchConditionalOp>(
- loc, cmpOp, body, ArrayRef<Value>(), mergeBlock, ArrayRef<Value>());
+ spirv::BranchConditionalOp::create(rewriter, loc, cmpOp, body,
+ ArrayRef<Value>(), mergeBlock,
+ ArrayRef<Value>());
// Generate instructions to increment the step of the induction variable and
// branch to the header.
@@ -189,9 +191,9 @@ struct ForOpConversion final : SCFToSPIRVPattern<scf::ForOp> {
rewriter.setInsertionPointToEnd(continueBlock);
// Add the step to the induction variable and branch to the header.
- Value updatedIndVar = rewriter.create<spirv::IAddOp>(
- loc, newIndVar.getType(), newIndVar, adaptor.getStep());
- rewriter.create<spirv::BranchOp>(loc, header, updatedIndVar);
+ Value updatedIndVar = spirv::IAddOp::create(
+ rewriter, loc, newIndVar.getType(), newIndVar, adaptor.getStep());
+ spirv::BranchOp::create(rewriter, loc, header, updatedIndVar);
// Infer the return types from the init operands. Vector type may get
// converted to CooperativeMatrix or to Vector type, to avoid having complex
@@ -237,11 +239,11 @@ struct IfOpConversion : SCFToSPIRVPattern<scf::IfOp> {
// Create `spirv.selection` operation, selection header block and merge
// block.
- auto selectionOp =
- rewriter.create<spirv::SelectionOp>(loc, spirv::SelectionControl::None);
+ auto selectionOp = spirv::SelectionOp::create(
+ rewriter, loc, spirv::SelectionControl::None);
auto *mergeBlock = rewriter.createBlock(&selectionOp.getBody(),
selectionOp.getBody().end());
- rewriter.create<spirv::MergeOp>(loc);
+ spirv::MergeOp::create(rewriter, loc);
OpBuilder::InsertionGuard guard(rewriter);
auto *selectionHeaderBlock =
@@ -251,7 +253,7 @@ struct IfOpConversion : SCFToSPIRVPattern<scf::IfOp> {
auto &thenRegion = ifOp.getThenRegion();
auto *thenBlock = &thenRegion.front();
rewriter.setInsertionPointToEnd(&thenRegion.back());
- rewriter.create<spirv::BranchOp>(loc, mergeBlock);
+ spirv::BranchOp::create(rewriter, loc, mergeBlock);
rewriter.inlineRegionBefore(thenRegion, mergeBlock);
auto *elseBlock = mergeBlock;
@@ -261,15 +263,15 @@ struct IfOpConversion : SCFToSPIRVPattern<scf::IfOp> {
auto &elseRegion = ifOp.getElseRegion();
elseBlock = &elseRegion.front();
rewriter.setInsertionPointToEnd(&elseRegion.back());
- rewriter.create<spirv::BranchOp>(loc, mergeBlock);
+ spirv::BranchOp::create(rewriter, loc, mergeBlock);
rewriter.inlineRegionBefore(elseRegion, mergeBlock);
}
// Create a `spirv.BranchConditional` operation for selection header block.
rewriter.setInsertionPointToEnd(selectionHeaderBlock);
- rewriter.create<spirv::BranchConditionalOp>(loc, adaptor.getCondition(),
- thenBlock, ArrayRef<Value>(),
- elseBlock, ArrayRef<Value>());
+ spirv::BranchConditionalOp::create(rewriter, loc, adaptor.getCondition(),
+ thenBlock, ArrayRef<Value>(), elseBlock,
+ ArrayRef<Value>());
replaceSCFOutputValue(ifOp, selectionOp, rewriter, scfToSPIRVContext,
returnTypes);
@@ -310,7 +312,7 @@ struct TerminatorOpConversion final : SCFToSPIRVPattern<scf::YieldOp> {
auto loc = terminatorOp.getLoc();
for (unsigned i = 0, e = operands.size(); i < e; i++)
- rewriter.create<spirv::StoreOp>(loc, allocas[i], operands[i]);
+ spirv::StoreOp::create(rewriter, loc, allocas[i], operands[i]);
if (isa<spirv::LoopOp>(parent)) {
// For loops we also need to update the branch jumping back to the
// header.
@@ -319,8 +321,8 @@ struct TerminatorOpConversion final : SCFToSPIRVPattern<scf::YieldOp> {
SmallVector<Value, 8> args(br.getBlockArguments());
args.append(operands.begin(), operands.end());
rewriter.setInsertionPoint(br);
- rewriter.create<spirv::BranchOp>(terminatorOp.getLoc(), br.getTarget(),
- args);
+ spirv::BranchOp::create(rewriter, terminatorOp.getLoc(), br.getTarget(),
+ args);
rewriter.eraseOp(br);
}
}
@@ -340,7 +342,8 @@ struct WhileOpConversion final : SCFToSPIRVPattern<scf::WhileOp> {
matchAndRewrite(scf::WhileOp whileOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto loc = whileOp.getLoc();
- auto loopOp = rewriter.create<spirv::LoopOp>(loc, spirv::LoopControl::None);
+ auto loopOp =
+ spirv::LoopOp::create(rewriter, loc, spirv::LoopControl::None);
loopOp.addEntryAndMergeBlock(rewriter);
Region &beforeRegion = whileOp.getBefore();
@@ -382,7 +385,7 @@ struct WhileOpConversion final : SCFToSPIRVPattern<scf::WhileOp> {
// Jump from the loop entry block to the loop header block.
rewriter.setInsertionPointToEnd(&entryBlock);
- rewriter.create<spirv::BranchOp>(loc, &beforeBlock, adaptor.getInits());
+ spirv::BranchOp::create(rewriter, loc, &beforeBlock, adaptor.getInits());
auto condLoc = cond.getLoc();
@@ -403,18 +406,18 @@ struct WhileOpConversion final : SCFToSPIRVPattern<scf::WhileOp> {
// Create local variables before the scf.while op.
rewriter.setInsertionPoint(loopOp);
- auto alloc = rewriter.create<spirv::VariableOp>(
- condLoc, pointerType, spirv::StorageClass::Function,
- /*initializer=*/nullptr);
+ auto alloc = spirv::VariableOp::create(rewriter, condLoc, pointerType,
+ spirv::StorageClass::Function,
+ /*initializer=*/nullptr);
// Load the final result values after the scf.while op.
rewriter.setInsertionPointAfter(loopOp);
- auto loadResult = rewriter.create<spirv::LoadOp>(condLoc, alloc);
+ auto loadResult = spirv::LoadOp::create(rewriter, condLoc, alloc);
resultValues[i] = loadResult;
// Store the current iteration's result value.
rewriter.setInsertionPointToEnd(&beforeBlock);
- rewriter.create<spirv::StoreOp>(condLoc, alloc, res);
+ spirv::StoreOp::create(rewriter, condLoc, alloc, res);
}
rewriter.setInsertionPointToEnd(&beforeBlock);
diff --git a/mlir/lib/Conversion/SPIRVToLLVM/ConvertLaunchFuncToLLVMCalls.cpp b/mlir/lib/Conversion/SPIRVToLLVM/ConvertLaunchFuncToLLVMCalls.cpp
index d7ae9f0e94fe8..035f197b1eac2 100644
--- a/mlir/lib/Conversion/SPIRVToLLVM/ConvertLaunchFuncToLLVMCalls.cpp
+++ b/mlir/lib/Conversion/SPIRVToLLVM/ConvertLaunchFuncToLLVMCalls.cpp
@@ -68,7 +68,7 @@ static unsigned calculateGlobalIndex(spirv::GlobalVariableOp op) {
/// Copies the given number of bytes from src to dst pointers.
static void copy(Location loc, Value dst, Value src, Value size,
OpBuilder &builder) {
- builder.create<LLVM::MemcpyOp>(loc, dst, src, size, /*isVolatile=*/false);
+ LLVM::MemcpyOp::create(builder, loc, dst, src, size, /*isVolatile=*/false);
}
/// Encodes the binding and descriptor set numbers into a new symbolic name.
@@ -194,8 +194,8 @@ class GPULaunchLowering : public ConvertOpToLLVMPattern<gpu::LaunchFuncOp> {
if (!kernelFunc) {
OpBuilder::InsertionGuard guard(rewriter);
rewriter.setInsertionPointToStart(module.getBody());
- kernelFunc = rewriter.create<LLVM::LLVMFuncOp>(
- rewriter.getUnknownLoc(), newKernelFuncName,
+ kernelFunc = LLVM::LLVMFuncOp::create(
+ rewriter, rewriter.getUnknownLoc(), newKernelFuncName,
LLVM::LLVMFunctionType::get(LLVM::LLVMVoidType::get(context),
ArrayRef<Type>()));
rewriter.setInsertionPoint(launchOp);
@@ -245,8 +245,8 @@ class GPULaunchLowering : public ConvertOpToLLVMPattern<gpu::LaunchFuncOp> {
if (!dstGlobal) {
OpBuilder::InsertionGuard guard(rewriter);
rewriter.setInsertionPointToStart(module.getBody());
- dstGlobal = rewriter.create<LLVM::GlobalOp>(
- loc, dstGlobalType,
+ dstGlobal = LLVM::GlobalOp::create(
+ rewriter, loc, dstGlobalType,
/*isConstant=*/false, LLVM::Linkage::Linkonce, name, Attribute(),
/*alignment=*/0);
rewriter.setInsertionPoint(launchOp);
@@ -255,8 +255,8 @@ class GPULaunchLowering : public ConvertOpToLLVMPattern<gpu::LaunchFuncOp> {
// Copy the data from src operand pointer to dst global variable. Save
// src, dst and size so that we can copy data back after emulating the
// kernel call.
- Value dst = rewriter.create<LLVM::AddressOfOp>(
- loc, typeConverter->convertType(spirvGlobal.getType()),
+ Value dst = LLVM::AddressOfOp::create(
+ rewriter, loc, typeConverter->convertType(spirvGlobal.getType()),
dstGlobal.getSymName());
copy(loc, dst, src, sizeBytes, rewriter);
diff --git a/mlir/lib/Conversion/SPIRVToLLVM/SPIRVToLLVM.cpp b/mlir/lib/Conversion/SPIRVToLLVM/SPIRVToLLVM.cpp
index 1d92b5d5562b5..aae3271371c1f 100644
--- a/mlir/lib/Conversion/SPIRVToLLVM/SPIRVToLLVM.cpp
+++ b/mlir/lib/Conversion/SPIRVToLLVM/SPIRVToLLVM.cpp
@@ -94,13 +94,13 @@ static IntegerAttr minusOneIntegerAttribute(Type type, Builder builder) {
static Value createConstantAllBitsSet(Location loc, Type srcType, Type dstType,
PatternRewriter &rewriter) {
if (isa<VectorType>(srcType)) {
- return rewriter.create<LLVM::ConstantOp>(
- loc, dstType,
+ return LLVM::ConstantOp::create(
+ rewriter, loc, dstType,
SplatElementsAttr::get(cast<ShapedType>(srcType),
minusOneIntegerAttribute(srcType, rewriter)));
}
- return rewriter.create<LLVM::ConstantOp>(
- loc, dstType, minusOneIntegerAttribute(srcType, rewriter));
+ return LLVM::ConstantOp::create(rewriter, loc, dstType,
+ minusOneIntegerAttribute(srcType, rewriter));
}
/// Creates `llvm.mlir.constant` with a floating-point scalar or vector value.
@@ -108,14 +108,14 @@ static Value createFPConstant(Location loc, Type srcType, Type dstType,
PatternRewriter &rewriter, double value) {
if (auto vecType = dyn_cast<VectorType>(srcType)) {
auto floatType = cast<FloatType>(vecType.getElementType());
- return rewriter.create<LLVM::ConstantOp>(
- loc, dstType,
+ return LLVM::ConstantOp::create(
+ rewriter, loc, dstType,
SplatElementsAttr::get(vecType,
rewriter.getFloatAttr(floatType, value)));
}
auto floatType = cast<FloatType>(srcType);
- return rewriter.create<LLVM::ConstantOp>(
- loc, dstType, rewriter.getFloatAttr(floatType, value));
+ return LLVM::ConstantOp::create(rewriter, loc, dstType,
+ rewriter.getFloatAttr(floatType, value));
}
/// Utility function for bitfield ops:
@@ -134,13 +134,13 @@ static Value optionallyTruncateOrExtend(Location loc, Value value,
: getBitWidth(srcType);
if (valueBitWidth < targetBitWidth)
- return rewriter.create<LLVM::ZExtOp>(loc, llvmType, value);
+ return LLVM::ZExtOp::create(rewriter, loc, llvmType, value);
// If the bit widths of `Count` and `Offset` are greater than the bit width
// of the target type, they are truncated. Truncation is safe since `Count`
// and `Offset` must be no more than 64 for op behaviour to be defined. Hence,
// both values can be expressed in 8 bits.
if (valueBitWidth > targetBitWidth)
- return rewriter.create<LLVM::TruncOp>(loc, llvmType, value);
+ return LLVM::TruncOp::create(rewriter, loc, llvmType, value);
return value;
}
@@ -151,12 +151,12 @@ static Value broadcast(Location loc, Value toBroadcast, unsigned numElements,
auto vectorType = VectorType::get(numElements, toBroadcast.getType());
auto llvmVectorType = typeConverter.convertType(vectorType);
auto llvmI32Type = typeConverter.convertType(rewriter.getIntegerType(32));
- Value broadcasted = rewriter.create<LLVM::PoisonOp>(loc, llvmVectorType);
+ Value broadcasted = LLVM::PoisonOp::create(rewriter, loc, llvmVectorType);
for (unsigned i = 0; i < numElements; ++i) {
- auto index = rewriter.create<LLVM::ConstantOp>(
- loc, llvmI32Type, rewriter.getI32IntegerAttr(i));
- broadcasted = rewriter.create<LLVM::InsertElementOp>(
- loc, llvmVectorType, broadcasted, toBroadcast, index);
+ auto index = LLVM::ConstantOp::create(rewriter, loc, llvmI32Type,
+ rewriter.getI32IntegerAttr(i));
+ broadcasted = LLVM::InsertElementOp::create(
+ rewriter, loc, llvmVectorType, broadcasted, toBroadcast, index);
}
return broadcasted;
}
@@ -217,8 +217,8 @@ static Type convertStructTypePacked(spirv::StructType type,
/// Creates LLVM dialect constant with the given value.
static Value createI32ConstantOf(Location loc, PatternRewriter &rewriter,
unsigned value) {
- return rewriter.create<LLVM::ConstantOp>(
- loc, IntegerType::get(rewriter.getContext(), 32),
+ return LLVM::ConstantOp::create(
+ rewriter, loc, IntegerType::get(rewriter.getContext(), 32),
rewriter.getIntegerAttr(rewriter.getI32Type(), value));
}
@@ -322,8 +322,9 @@ class AccessChainPattern : public SPIRVToLLVMConversion<spirv::AccessChainOp> {
auto llvmIndexType = getTypeConverter()->convertType(indexType);
if (!llvmIndexType)
return rewriter.notifyMatchFailure(op, "type conversion failed");
- Value zero = rewriter.create<LLVM::ConstantOp>(
- op.getLoc(), llvmIndexType, rewriter.getIntegerAttr(indexType, 0));
+ Value zero =
+ LLVM::ConstantOp::create(rewriter, op.getLoc(), llvmIndexType,
+ rewriter.getIntegerAttr(indexType, 0));
indices.insert(indices.begin(), zero);
auto elementType = getTypeConverter()->convertType(
@@ -375,20 +376,20 @@ class BitFieldInsertPattern
// Create a mask with bits set outside [Offset, Offset + Count - 1].
Value minusOne = createConstantAllBitsSet(loc, srcType, dstType, rewriter);
Value maskShiftedByCount =
- rewriter.create<LLVM::ShlOp>(loc, dstType, minusOne, count);
- Value negated = rewriter.create<LLVM::XOrOp>(loc, dstType,
- maskShiftedByCount, minusOne);
+ LLVM::ShlOp::create(rewriter, loc, dstType, minusOne, count);
+ Value negated = LLVM::XOrOp::create(rewriter, loc, dstType,
+ maskShiftedByCount, minusOne);
Value maskShiftedByCountAndOffset =
- rewriter.create<LLVM::ShlOp>(loc, dstType, negated, offset);
- Value mask = rewriter.create<LLVM::XOrOp>(
- loc, dstType, maskShiftedByCountAndOffset, minusOne);
+ LLVM::ShlOp::create(rewriter, loc, dstType, negated, offset);
+ Value mask = LLVM::XOrOp::create(rewriter, loc, dstType,
+ maskShiftedByCountAndOffset, minusOne);
// Extract unchanged bits from the `Base` that are outside of
// [Offset, Offset + Count - 1]. Then `or` with shifted `Insert`.
Value baseAndMask =
- rewriter.create<LLVM::AndOp>(loc, dstType, op.getBase(), mask);
+ LLVM::AndOp::create(rewriter, loc, dstType, op.getBase(), mask);
Value insertShiftedByOffset =
- rewriter.create<LLVM::ShlOp>(loc, dstType, op.getInsert(), offset);
+ LLVM::ShlOp::create(rewriter, loc, dstType, op.getInsert(), offset);
rewriter.replaceOpWithNewOp<LLVM::OrOp>(op, dstType, baseAndMask,
insertShiftedByOffset);
return success();
@@ -470,23 +471,23 @@ class BitFieldSExtractPattern
auto baseSize = rewriter.getIntegerAttr(integerType, getBitWidth(srcType));
Value size =
isa<VectorType>(srcType)
- ? rewriter.create<LLVM::ConstantOp>(
- loc, dstType,
+ ? LLVM::ConstantOp::create(
+ rewriter, loc, dstType,
SplatElementsAttr::get(cast<ShapedType>(srcType), baseSize))
- : rewriter.create<LLVM::ConstantOp>(loc, dstType, baseSize);
+ : LLVM::ConstantOp::create(rewriter, loc, dstType, baseSize);
// Shift `Base` left by [sizeof(Base) - (Count + Offset)], so that the bit
// at Offset + Count - 1 is the most significant bit now.
Value countPlusOffset =
- rewriter.create<LLVM::AddOp>(loc, dstType, count, offset);
+ LLVM::AddOp::create(rewriter, loc, dstType, count, offset);
Value amountToShiftLeft =
- rewriter.create<LLVM::SubOp>(loc, dstType, size, countPlusOffset);
- Value baseShiftedLeft = rewriter.create<LLVM::ShlOp>(
- loc, dstType, op.getBase(), amountToShiftLeft);
+ LLVM::SubOp::create(rewriter, loc, dstType, size, countPlusOffset);
+ Value baseShiftedLeft = LLVM::ShlOp::create(
+ rewriter, loc, dstType, op.getBase(), amountToShiftLeft);
// Shift the result right, filling the bits with the sign bit.
Value amountToShiftRight =
- rewriter.create<LLVM::AddOp>(loc, dstType, offset, amountToShiftLeft);
+ LLVM::AddOp::create(rewriter, loc, dstType, offset, amountToShiftLeft);
rewriter.replaceOpWithNewOp<LLVM::AShrOp>(op, dstType, baseShiftedLeft,
amountToShiftRight);
return success();
@@ -516,13 +517,13 @@ class BitFieldUExtractPattern
// Create a mask with bits set at [0, Count - 1].
Value minusOne = createConstantAllBitsSet(loc, srcType, dstType, rewriter);
Value maskShiftedByCount =
- rewriter.create<LLVM::ShlOp>(loc, dstType, minusOne, count);
- Value mask = rewriter.create<LLVM::XOrOp>(loc, dstType, maskShiftedByCount,
- minusOne);
+ LLVM::ShlOp::create(rewriter, loc, dstType, minusOne, count);
+ Value mask = LLVM::XOrOp::create(rewriter, loc, dstType, maskShiftedByCount,
+ minusOne);
// Shift `Base` by `Offset` and apply the mask on it.
Value shiftedBase =
- rewriter.create<LLVM::LShrOp>(loc, dstType, op.getBase(), offset);
+ LLVM::LShrOp::create(rewriter, loc, dstType, op.getBase(), offset);
rewriter.replaceOpWithNewOp<LLVM::AndOp>(op, dstType, shiftedBase, mask);
return success();
}
@@ -694,8 +695,8 @@ class ExecutionModePattern
auto structType = LLVM::LLVMStructType::getLiteral(context, fields);
// Create `llvm.mlir.global` with initializer region containing one block.
- auto global = rewriter.create<LLVM::GlobalOp>(
- UnknownLoc::get(context), structType, /*isConstant=*/true,
+ auto global = LLVM::GlobalOp::create(
+ rewriter, UnknownLoc::get(context), structType, /*isConstant=*/true,
LLVM::Linkage::External, executionModeInfoName, Attribute(),
/*alignment=*/0);
Location loc = global.getLoc();
@@ -704,22 +705,23 @@ class ExecutionModePattern
// Initialize the struct and set the execution mode value.
rewriter.setInsertionPointToStart(block);
- Value structValue = rewriter.create<LLVM::PoisonOp>(loc, structType);
- Value executionMode = rewriter.create<LLVM::ConstantOp>(
- loc, llvmI32Type,
+ Value structValue = LLVM::PoisonOp::create(rewriter, loc, structType);
+ Value executionMode = LLVM::ConstantOp::create(
+ rewriter, loc, llvmI32Type,
rewriter.getI32IntegerAttr(
static_cast<uint32_t>(executionModeAttr.getValue())));
- structValue = rewriter.create<LLVM::InsertValueOp>(loc, structValue,
- executionMode, 0);
+ SmallVector<int64_t> position{0};
+ structValue = LLVM::InsertValueOp::create(rewriter, loc, structValue,
+ executionMode, position);
// Insert extra operands if they exist into execution mode info struct.
for (unsigned i = 0, e = values.size(); i < e; ++i) {
auto attr = values.getValue()[i];
- Value entry = rewriter.create<LLVM::ConstantOp>(loc, llvmI32Type, attr);
- structValue = rewriter.create<LLVM::InsertValueOp>(
- loc, structValue, entry, ArrayRef<int64_t>({1, i}));
+ Value entry = LLVM::ConstantOp::create(rewriter, loc, llvmI32Type, attr);
+ structValue = LLVM::InsertValueOp::create(
+ rewriter, loc, structValue, entry, ArrayRef<int64_t>({1, i}));
}
- rewriter.create<LLVM::ReturnOp>(loc, ArrayRef<Value>({structValue}));
+ LLVM::ReturnOp::create(rewriter, loc, ArrayRef<Value>({structValue}));
rewriter.eraseOp(op);
return success();
}
@@ -913,7 +915,7 @@ class InverseSqrtPattern
Location loc = op.getLoc();
Value one = createFPConstant(loc, srcType, dstType, rewriter, 1.0);
- Value sqrt = rewriter.create<LLVM::SqrtOp>(loc, dstType, op.getOperand());
+ Value sqrt = LLVM::SqrtOp::create(rewriter, loc, dstType, op.getOperand());
rewriter.replaceOpWithNewOp<LLVM::FDivOp>(op, dstType, one, sqrt);
return success();
}
@@ -973,10 +975,10 @@ class NotPattern : public SPIRVToLLVMConversion<SPIRVOp> {
IntegerAttr minusOne = minusOneIntegerAttribute(srcType, rewriter);
auto mask =
isa<VectorType>(srcType)
- ? rewriter.create<LLVM::ConstantOp>(
- loc, dstType,
+ ? LLVM::ConstantOp::create(
+ rewriter, loc, dstType,
SplatElementsAttr::get(cast<VectorType>(srcType), minusOne))
- : rewriter.create<LLVM::ConstantOp>(loc, dstType, minusOne);
+ : LLVM::ConstantOp::create(rewriter, loc, dstType, minusOne);
rewriter.template replaceOpWithNewOp<LLVM::XOrOp>(notOp, dstType,
notOp.getOperand(), mask);
return success();
@@ -1034,8 +1036,8 @@ static LLVM::LLVMFuncOp lookupOrCreateSPIRVFn(Operation *symbolTable,
return func;
OpBuilder b(symbolTable->getRegion(0));
- func = b.create<LLVM::LLVMFuncOp>(
- symbolTable->getLoc(), name,
+ func = LLVM::LLVMFuncOp::create(
+ b, symbolTable->getLoc(), name,
LLVM::LLVMFunctionType::get(resultType, paramTypes));
func.setCConv(LLVM::cconv::CConv::SPIR_FUNC);
func.setConvergent(convergent);
@@ -1047,7 +1049,7 @@ static LLVM::LLVMFuncOp lookupOrCreateSPIRVFn(Operation *symbolTable,
static LLVM::CallOp createSPIRVBuiltinCall(Location loc, OpBuilder &builder,
LLVM::LLVMFuncOp func,
ValueRange args) {
- auto call = builder.create<LLVM::CallOp>(loc, func, args);
+ auto call = LLVM::CallOp::create(builder, loc, func, args);
call.setCConv(func.getCConv());
call.setConvergentAttr(func.getConvergentAttr());
call.setNoUnwindAttr(func.getNoUnwindAttr());
@@ -1078,12 +1080,12 @@ class ControlBarrierPattern : public SPIRVToLLVMConversion<BarrierOpTy> {
lookupOrCreateSPIRVFn(symbolTable, funcName, {i32, i32, i32}, voidTy);
Location loc = controlBarrierOp->getLoc();
- Value execution = rewriter.create<LLVM::ConstantOp>(
- loc, i32, static_cast<int32_t>(adaptor.getExecutionScope()));
- Value memory = rewriter.create<LLVM::ConstantOp>(
- loc, i32, static_cast<int32_t>(adaptor.getMemoryScope()));
- Value semantics = rewriter.create<LLVM::ConstantOp>(
- loc, i32, static_cast<int32_t>(adaptor.getMemorySemantics()));
+ Value execution = LLVM::ConstantOp::create(
+ rewriter, loc, i32, static_cast<int32_t>(adaptor.getExecutionScope()));
+ Value memory = LLVM::ConstantOp::create(
+ rewriter, loc, i32, static_cast<int32_t>(adaptor.getMemoryScope()));
+ Value semantics = LLVM::ConstantOp::create(
+ rewriter, loc, i32, static_cast<int32_t>(adaptor.getMemorySemantics()));
auto call = createSPIRVBuiltinCall(loc, rewriter, func,
{execution, memory, semantics});
@@ -1255,10 +1257,12 @@ class GroupReducePattern : public SPIRVToLLVMConversion<ReduceOp> {
lookupOrCreateSPIRVFn(symbolTable, funcName, paramTypes, retTy);
Location loc = op.getLoc();
- Value scope = rewriter.create<LLVM::ConstantOp>(
- loc, i32Ty, static_cast<int32_t>(adaptor.getExecutionScope()));
- Value groupOp = rewriter.create<LLVM::ConstantOp>(
- loc, i32Ty, static_cast<int32_t>(adaptor.getGroupOperation()));
+ Value scope = LLVM::ConstantOp::create(
+ rewriter, loc, i32Ty,
+ static_cast<int32_t>(adaptor.getExecutionScope()));
+ Value groupOp = LLVM::ConstantOp::create(
+ rewriter, loc, i32Ty,
+ static_cast<int32_t>(adaptor.getGroupOperation()));
SmallVector<Value> operands{scope, groupOp};
operands.append(adaptor.getOperands().begin(), adaptor.getOperands().end());
@@ -1368,7 +1372,7 @@ class LoopPattern : public SPIRVToLLVMConversion<spirv::LoopOp> {
return failure();
Block *headerBlock = loopOp.getHeaderBlock();
rewriter.setInsertionPointToEnd(currentBlock);
- rewriter.create<LLVM::BrOp>(loc, brOp.getBlockArguments(), headerBlock);
+ LLVM::BrOp::create(rewriter, loc, brOp.getBlockArguments(), headerBlock);
rewriter.eraseBlock(entryBlock);
// Branch from merge block to end block.
@@ -1376,7 +1380,7 @@ class LoopPattern : public SPIRVToLLVMConversion<spirv::LoopOp> {
Operation *terminator = mergeBlock->getTerminator();
ValueRange terminatorOperands = terminator->getOperands();
rewriter.setInsertionPointToEnd(mergeBlock);
- rewriter.create<LLVM::BrOp>(loc, terminatorOperands, endBlock);
+ LLVM::BrOp::create(rewriter, loc, terminatorOperands, endBlock);
rewriter.inlineRegionBefore(loopOp.getBody(), endBlock);
rewriter.replaceOp(loopOp, endBlock->getArguments());
@@ -1434,16 +1438,15 @@ class SelectionPattern : public SPIRVToLLVMConversion<spirv::SelectionOp> {
Operation *terminator = mergeBlock->getTerminator();
ValueRange terminatorOperands = terminator->getOperands();
rewriter.setInsertionPointToEnd(mergeBlock);
- rewriter.create<LLVM::BrOp>(loc, terminatorOperands, continueBlock);
+ LLVM::BrOp::create(rewriter, loc, terminatorOperands, continueBlock);
// Link current block to `true` and `false` blocks within the selection.
Block *trueBlock = condBrOp.getTrueBlock();
Block *falseBlock = condBrOp.getFalseBlock();
rewriter.setInsertionPointToEnd(currentBlock);
- rewriter.create<LLVM::CondBrOp>(loc, condBrOp.getCondition(), trueBlock,
- condBrOp.getTrueTargetOperands(),
- falseBlock,
- condBrOp.getFalseTargetOperands());
+ LLVM::CondBrOp::create(rewriter, loc, condBrOp.getCondition(), trueBlock,
+ condBrOp.getTrueTargetOperands(), falseBlock,
+ condBrOp.getFalseTargetOperands());
rewriter.eraseBlock(headerBlock);
rewriter.inlineRegionBefore(op.getBody(), continueBlock);
@@ -1521,8 +1524,8 @@ class TanPattern : public SPIRVToLLVMConversion<spirv::GLTanOp> {
return rewriter.notifyMatchFailure(tanOp, "type conversion failed");
Location loc = tanOp.getLoc();
- Value sin = rewriter.create<LLVM::SinOp>(loc, dstType, tanOp.getOperand());
- Value cos = rewriter.create<LLVM::CosOp>(loc, dstType, tanOp.getOperand());
+ Value sin = LLVM::SinOp::create(rewriter, loc, dstType, tanOp.getOperand());
+ Value cos = LLVM::CosOp::create(rewriter, loc, dstType, tanOp.getOperand());
rewriter.replaceOpWithNewOp<LLVM::FDivOp>(tanOp, dstType, sin, cos);
return success();
}
@@ -1549,13 +1552,13 @@ class TanhPattern : public SPIRVToLLVMConversion<spirv::GLTanhOp> {
Location loc = tanhOp.getLoc();
Value two = createFPConstant(loc, srcType, dstType, rewriter, 2.0);
Value multiplied =
- rewriter.create<LLVM::FMulOp>(loc, dstType, two, tanhOp.getOperand());
- Value exponential = rewriter.create<LLVM::ExpOp>(loc, dstType, multiplied);
+ LLVM::FMulOp::create(rewriter, loc, dstType, two, tanhOp.getOperand());
+ Value exponential = LLVM::ExpOp::create(rewriter, loc, dstType, multiplied);
Value one = createFPConstant(loc, srcType, dstType, rewriter, 1.0);
Value numerator =
- rewriter.create<LLVM::FSubOp>(loc, dstType, exponential, one);
+ LLVM::FSubOp::create(rewriter, loc, dstType, exponential, one);
Value denominator =
- rewriter.create<LLVM::FAddOp>(loc, dstType, exponential, one);
+ LLVM::FAddOp::create(rewriter, loc, dstType, exponential, one);
rewriter.replaceOpWithNewOp<LLVM::FDivOp>(tanhOp, dstType, numerator,
denominator);
return success();
@@ -1594,8 +1597,8 @@ class VariablePattern : public SPIRVToLLVMConversion<spirv::VariableOp> {
if (!elementType)
return rewriter.notifyMatchFailure(varOp, "type conversion failed");
Value allocated =
- rewriter.create<LLVM::AllocaOp>(loc, dstType, elementType, size);
- rewriter.create<LLVM::StoreOp>(loc, adaptor.getInitializer(), allocated);
+ LLVM::AllocaOp::create(rewriter, loc, dstType, elementType, size);
+ LLVM::StoreOp::create(rewriter, loc, adaptor.getInitializer(), allocated);
rewriter.replaceOp(varOp, allocated);
return success();
}
@@ -1656,7 +1659,7 @@ class FuncConversionPattern : public SPIRVToLLVMConversion<spirv::FuncOp> {
// Create a new `LLVMFuncOp`
Location loc = funcOp.getLoc();
StringRef name = funcOp.getName();
- auto newFuncOp = rewriter.create<LLVM::LLVMFuncOp>(loc, name, llvmType);
+ auto newFuncOp = LLVM::LLVMFuncOp::create(rewriter, loc, name, llvmType);
// Convert SPIR-V Function Control to equivalent LLVM function attribute
MLIRContext *context = funcOp.getContext();
@@ -1710,7 +1713,7 @@ class ModuleConversionPattern : public SPIRVToLLVMConversion<spirv::ModuleOp> {
ConversionPatternRewriter &rewriter) const override {
auto newModuleOp =
- rewriter.create<ModuleOp>(spvModuleOp.getLoc(), spvModuleOp.getName());
+ ModuleOp::create(rewriter, spvModuleOp.getLoc(), spvModuleOp.getName());
rewriter.inlineRegionBefore(spvModuleOp.getRegion(), newModuleOp.getBody());
// Remove the terminator block that was automatically added by builder
@@ -1751,7 +1754,7 @@ class VectorShufflePattern
auto componentsArray = components.getValue();
auto *context = rewriter.getContext();
auto llvmI32Type = IntegerType::get(context, 32);
- Value targetOp = rewriter.create<LLVM::PoisonOp>(loc, dstType);
+ Value targetOp = LLVM::PoisonOp::create(rewriter, loc, dstType);
for (unsigned i = 0; i < componentsArray.size(); i++) {
if (!isa<IntegerAttr>(componentsArray[i]))
return op.emitError("unable to support non-constant component");
@@ -1767,16 +1770,17 @@ class VectorShufflePattern
baseVector = vector2;
}
- Value dstIndex = rewriter.create<LLVM::ConstantOp>(
- loc, llvmI32Type, rewriter.getIntegerAttr(rewriter.getI32Type(), i));
- Value index = rewriter.create<LLVM::ConstantOp>(
- loc, llvmI32Type,
+ Value dstIndex = LLVM::ConstantOp::create(
+ rewriter, loc, llvmI32Type,
+ rewriter.getIntegerAttr(rewriter.getI32Type(), i));
+ Value index = LLVM::ConstantOp::create(
+ rewriter, loc, llvmI32Type,
rewriter.getIntegerAttr(rewriter.getI32Type(), indexVal - offsetVal));
- auto extractOp = rewriter.create<LLVM::ExtractElementOp>(
- loc, scalarType, baseVector, index);
- targetOp = rewriter.create<LLVM::InsertElementOp>(loc, dstType, targetOp,
- extractOp, dstIndex);
+ auto extractOp = LLVM::ExtractElementOp::create(rewriter, loc, scalarType,
+ baseVector, index);
+ targetOp = LLVM::InsertElementOp::create(rewriter, loc, dstType, targetOp,
+ extractOp, dstIndex);
}
rewriter.replaceOp(op, targetOp);
return success();
diff --git a/mlir/lib/Conversion/ShapeToStandard/ConvertShapeConstraints.cpp b/mlir/lib/Conversion/ShapeToStandard/ConvertShapeConstraints.cpp
index da9ad3dd67328..245e60b04ec31 100644
--- a/mlir/lib/Conversion/ShapeToStandard/ConvertShapeConstraints.cpp
+++ b/mlir/lib/Conversion/ShapeToStandard/ConvertShapeConstraints.cpp
@@ -32,7 +32,7 @@ class ConvertCstrRequireOp : public OpRewritePattern<shape::CstrRequireOp> {
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(shape::CstrRequireOp op,
PatternRewriter &rewriter) const override {
- rewriter.create<cf::AssertOp>(op.getLoc(), op.getPred(), op.getMsgAttr());
+ cf::AssertOp::create(rewriter, op.getLoc(), op.getPred(), op.getMsgAttr());
rewriter.replaceOpWithNewOp<shape::ConstWitnessOp>(op, true);
return success();
}
diff --git a/mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.cpp b/mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.cpp
index bbe1490137bf8..7025c5a7daf93 100644
--- a/mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.cpp
+++ b/mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.cpp
@@ -82,40 +82,40 @@ struct BroadcastOpConverter : public OpConversionPattern<BroadcastOp> {
// number of extent tensors and shifted offsets into them.
Value getBroadcastedDim(ImplicitLocOpBuilder lb, ValueRange extentTensors,
ValueRange rankDiffs, Value outputDimension) {
- Value one = lb.create<arith::ConstantIndexOp>(1);
+ Value one = arith::ConstantIndexOp::create(lb, 1);
Value broadcastedDim = one;
for (auto tup : llvm::zip(extentTensors, rankDiffs)) {
Value shape = std::get<0>(tup);
Value rankDiff = std::get<1>(tup);
- Value outOfBounds = lb.create<arith::CmpIOp>(arith::CmpIPredicate::ult,
- outputDimension, rankDiff);
+ Value outOfBounds = arith::CmpIOp::create(lb, arith::CmpIPredicate::ult,
+ outputDimension, rankDiff);
Type indexTy = lb.getIndexType();
broadcastedDim =
- lb.create<IfOp>(
- outOfBounds,
- [&](OpBuilder &b, Location loc) {
- b.create<scf::YieldOp>(loc, broadcastedDim);
- },
- [&](OpBuilder &b, Location loc) {
- // The broadcasting logic is:
- // - if one extent (here we arbitrarily choose the
- // extent from the greater-rank operand) is equal to 1,
- // then take the extent from the other operand
- // - otherwise, take the extent as-is.
- // Note that this logic remains correct in the presence
- // of dimensions of zero extent.
- Value lesserRankOperandDimension = b.create<arith::SubIOp>(
- loc, indexTy, outputDimension, rankDiff);
- Value lesserRankOperandExtent = b.create<tensor::ExtractOp>(
- loc, shape, ValueRange{lesserRankOperandDimension});
-
- Value dimIsOne =
- b.create<arith::CmpIOp>(loc, arith::CmpIPredicate::eq,
- lesserRankOperandExtent, one);
- Value dim = b.create<arith::SelectOp>(
- loc, dimIsOne, broadcastedDim, lesserRankOperandExtent);
- b.create<scf::YieldOp>(loc, dim);
- })
+ IfOp::create(
+ lb, outOfBounds,
+ [&](OpBuilder &b, Location loc) {
+ scf::YieldOp::create(b, loc, broadcastedDim);
+ },
+ [&](OpBuilder &b, Location loc) {
+ // The broadcasting logic is:
+ // - if one extent (here we arbitrarily choose the
+ // extent from the greater-rank operand) is equal to 1,
+ // then take the extent from the other operand
+ // - otherwise, take the extent as-is.
+ // Note that this logic remains correct in the presence
+ // of dimensions of zero extent.
+ Value lesserRankOperandDimension = arith::SubIOp::create(
+ b, loc, indexTy, outputDimension, rankDiff);
+ Value lesserRankOperandExtent = tensor::ExtractOp::create(
+ b, loc, shape, ValueRange{lesserRankOperandDimension});
+
+ Value dimIsOne =
+ arith::CmpIOp::create(b, loc, arith::CmpIPredicate::eq,
+ lesserRankOperandExtent, one);
+ Value dim = arith::SelectOp::create(
+ b, loc, dimIsOne, broadcastedDim, lesserRankOperandExtent);
+ scf::YieldOp::create(b, loc, dim);
+ })
.getResult(0);
}
return broadcastedDim;
@@ -133,7 +133,7 @@ LogicalResult BroadcastOpConverter::matchAndRewrite(
auto loc = op.getLoc();
ImplicitLocOpBuilder lb(loc, rewriter);
- Value zero = lb.create<arith::ConstantIndexOp>(0);
+ Value zero = arith::ConstantIndexOp::create(lb, 0);
Type indexTy = lb.getIndexType();
// Save all the ranks for bounds checking. Because this is a tensor
@@ -141,31 +141,31 @@ LogicalResult BroadcastOpConverter::matchAndRewrite(
// dimension in the tensor.
SmallVector<Value> ranks, rankDiffs;
llvm::append_range(ranks, llvm::map_range(adaptor.getShapes(), [&](Value v) {
- return lb.create<tensor::DimOp>(v, zero);
+ return tensor::DimOp::create(lb, v, zero);
}));
// Find the maximum rank
Value maxRank = ranks.front();
for (Value v : llvm::drop_begin(ranks, 1)) {
- maxRank = lb.create<arith::MaxUIOp>(v, maxRank);
+ maxRank = arith::MaxUIOp::create(lb, v, maxRank);
}
// Calculate the
diff erence of ranks and the maximum rank for later offsets.
llvm::append_range(rankDiffs, llvm::map_range(ranks, [&](Value v) {
- return lb.create<arith::SubIOp>(indexTy, maxRank, v);
+ return arith::SubIOp::create(lb, indexTy, maxRank, v);
}));
- Value replacement = lb.create<tensor::GenerateOp>(
- getExtentTensorType(lb.getContext()), ValueRange{maxRank},
+ Value replacement = tensor::GenerateOp::create(
+ lb, getExtentTensorType(lb.getContext()), ValueRange{maxRank},
[&](OpBuilder &b, Location loc, ValueRange args) {
Value broadcastedDim =
getBroadcastedDim(ImplicitLocOpBuilder(loc, b), adaptor.getShapes(),
rankDiffs, args[0]);
- b.create<tensor::YieldOp>(loc, broadcastedDim);
+ tensor::YieldOp::create(b, loc, broadcastedDim);
});
if (replacement.getType() != op.getType())
- replacement = lb.create<tensor::CastOp>(op.getType(), replacement);
+ replacement = tensor::CastOp::create(lb, op.getType(), replacement);
rewriter.replaceOp(op, replacement);
return success();
}
@@ -193,13 +193,13 @@ LogicalResult ConstShapeOpConverter::matchAndRewrite(
auto loc = op.getLoc();
SmallVector<Value, 4> extentOperands;
for (auto extent : op.getShape()) {
- extentOperands.push_back(
- rewriter.create<arith::ConstantIndexOp>(loc, extent.getLimitedValue()));
+ extentOperands.push_back(arith::ConstantIndexOp::create(
+ rewriter, loc, extent.getLimitedValue()));
}
Type resultTy =
RankedTensorType::get({op.getShape().size()}, rewriter.getIndexType());
Value tensor =
- rewriter.create<tensor::FromElementsOp>(loc, resultTy, extentOperands);
+ tensor::FromElementsOp::create(rewriter, loc, resultTy, extentOperands);
rewriter.replaceOpWithNewOp<tensor::CastOp>(op, resultTy, tensor);
return success();
}
@@ -245,8 +245,8 @@ LogicalResult IsBroadcastableOpConverter::matchAndRewrite(
auto loc = op.getLoc();
ImplicitLocOpBuilder lb(loc, rewriter);
- Value zero = lb.create<arith::ConstantIndexOp>(0);
- Value one = lb.create<arith::ConstantIndexOp>(1);
+ Value zero = arith::ConstantIndexOp::create(lb, 0);
+ Value one = arith::ConstantIndexOp::create(lb, 1);
Type indexTy = lb.getIndexType();
// Save all the ranks for bounds checking. Because this is a tensor
@@ -254,26 +254,26 @@ LogicalResult IsBroadcastableOpConverter::matchAndRewrite(
// dimension in the tensor.
SmallVector<Value> ranks, rankDiffs;
llvm::append_range(ranks, llvm::map_range(adaptor.getShapes(), [&](Value v) {
- return lb.create<tensor::DimOp>(v, zero);
+ return tensor::DimOp::create(lb, v, zero);
}));
// Find the maximum rank
Value maxRank = ranks.front();
for (Value v : llvm::drop_begin(ranks, 1)) {
- maxRank = lb.create<arith::MaxUIOp>(v, maxRank);
+ maxRank = arith::MaxUIOp::create(lb, v, maxRank);
}
// Calculate the
diff erence of ranks and the maximum rank for later offsets.
llvm::append_range(rankDiffs, llvm::map_range(ranks, [&](Value v) {
- return lb.create<arith::SubIOp>(indexTy, maxRank, v);
+ return arith::SubIOp::create(lb, indexTy, maxRank, v);
}));
Type i1Ty = rewriter.getI1Type();
- Value trueVal =
- rewriter.create<arith::ConstantOp>(loc, i1Ty, rewriter.getBoolAttr(true));
+ Value trueVal = arith::ConstantOp::create(rewriter, loc, i1Ty,
+ rewriter.getBoolAttr(true));
- auto reduceResult = lb.create<ForOp>(
- loc, zero, maxRank, one, ValueRange{trueVal},
+ auto reduceResult = ForOp::create(
+ lb, loc, zero, maxRank, one, ValueRange{trueVal},
[&](OpBuilder &b, Location loc, Value iv, ValueRange iterArgs) {
// Find a non-1 dim, if it exists. Note that the first part of this
// could reuse the Broadcast lowering entirely, but we redo the work
@@ -285,38 +285,38 @@ LogicalResult IsBroadcastableOpConverter::matchAndRewrite(
for (auto tup : llvm::zip(adaptor.getShapes(), rankDiffs)) {
Value shape, rankDiff;
std::tie(shape, rankDiff) = tup;
- Value outOfBounds = b.create<arith::CmpIOp>(
- loc, arith::CmpIPredicate::ult, iv, rankDiff);
+ Value outOfBounds = arith::CmpIOp::create(
+ b, loc, arith::CmpIPredicate::ult, iv, rankDiff);
broadcastable =
- b.create<IfOp>(
- loc, outOfBounds,
- [&](OpBuilder &b, Location loc) {
- // Non existent dimensions are always broadcastable
- b.create<scf::YieldOp>(loc, broadcastable);
- },
- [&](OpBuilder &b, Location loc) {
- // Every value needs to be either 1, or the same non-1
- // value to be broadcastable in this dim.
- Value operandDimension =
- b.create<arith::SubIOp>(loc, indexTy, iv, rankDiff);
- Value dimensionExtent = b.create<tensor::ExtractOp>(
- loc, shape, ValueRange{operandDimension});
-
- Value equalOne = b.create<arith::CmpIOp>(
- loc, arith::CmpIPredicate::eq, dimensionExtent, one);
- Value equalBroadcasted = b.create<arith::CmpIOp>(
- loc, arith::CmpIPredicate::eq, dimensionExtent,
- broadcastedDim);
- Value result = b.create<arith::AndIOp>(
- loc, broadcastable,
- b.create<arith::OrIOp>(loc, equalOne,
- equalBroadcasted));
- b.create<scf::YieldOp>(loc, result);
- })
+ IfOp::create(
+ b, loc, outOfBounds,
+ [&](OpBuilder &b, Location loc) {
+ // Non existent dimensions are always broadcastable
+ scf::YieldOp::create(b, loc, broadcastable);
+ },
+ [&](OpBuilder &b, Location loc) {
+ // Every value needs to be either 1, or the same non-1
+ // value to be broadcastable in this dim.
+ Value operandDimension =
+ arith::SubIOp::create(b, loc, indexTy, iv, rankDiff);
+ Value dimensionExtent = tensor::ExtractOp::create(
+ b, loc, shape, ValueRange{operandDimension});
+
+ Value equalOne = arith::CmpIOp::create(
+ b, loc, arith::CmpIPredicate::eq, dimensionExtent, one);
+ Value equalBroadcasted =
+ arith::CmpIOp::create(b, loc, arith::CmpIPredicate::eq,
+ dimensionExtent, broadcastedDim);
+ Value result = arith::AndIOp::create(
+ b, loc, broadcastable,
+ arith::OrIOp::create(b, loc, equalOne,
+ equalBroadcasted));
+ scf::YieldOp::create(b, loc, result);
+ })
.getResult(0);
}
- b.create<scf::YieldOp>(loc, broadcastable);
+ scf::YieldOp::create(b, loc, broadcastable);
});
rewriter.replaceOp(op, reduceResult.getResults().front());
@@ -339,7 +339,7 @@ DimOpConverter::matchAndRewrite(DimOp op, OpAdaptor adaptor,
// Lower to dim(X, i) to get_extent(shape_of(X), i) and rely on further
// lowerings. This can be further optimized if needed to avoid intermediate
// steps.
- auto shapeOf = rewriter.create<shape::ShapeOfOp>(op.getLoc(), op.getValue());
+ auto shapeOf = shape::ShapeOfOp::create(rewriter, op.getLoc(), op.getValue());
rewriter.replaceOpWithNewOp<shape::GetExtentOp>(op, op.getType(), shapeOf,
op.getIndex());
return success();
@@ -421,16 +421,17 @@ ReduceOpConverter::matchAndRewrite(shape::ReduceOp op, OpAdaptor adaptor,
auto loc = op.getLoc();
- Value zero = rewriter.create<arith::ConstantIndexOp>(loc, 0);
- Value one = rewriter.create<arith::ConstantIndexOp>(loc, 1);
+ Value zero = arith::ConstantIndexOp::create(rewriter, loc, 0);
+ Value one = arith::ConstantIndexOp::create(rewriter, loc, 1);
Type indexTy = rewriter.getIndexType();
Value rank =
- rewriter.create<tensor::DimOp>(loc, indexTy, adaptor.getShape(), zero);
+ tensor::DimOp::create(rewriter, loc, indexTy, adaptor.getShape(), zero);
- auto loop = rewriter.create<scf::ForOp>(
- loc, zero, rank, one, op.getInitVals(),
+ auto loop = scf::ForOp::create(
+ rewriter, loc, zero, rank, one, op.getInitVals(),
[&](OpBuilder &b, Location loc, Value iv, ValueRange args) {
- Value extent = b.create<tensor::ExtractOp>(loc, adaptor.getShape(), iv);
+ Value extent =
+ tensor::ExtractOp::create(b, loc, adaptor.getShape(), iv);
SmallVector<Value, 2> mappedValues{iv, extent};
mappedValues.append(args.begin(), args.end());
@@ -444,7 +445,7 @@ ReduceOpConverter::matchAndRewrite(shape::ReduceOp op, OpAdaptor adaptor,
SmallVector<Value, 2> mappedResults;
for (auto result : reduceBody->getTerminator()->getOperands())
mappedResults.push_back(mapping.lookup(result));
- b.create<scf::YieldOp>(loc, mappedResults);
+ scf::YieldOp::create(b, loc, mappedResults);
});
rewriter.replaceOp(op, loop.getResults());
@@ -507,44 +508,44 @@ ShapeEqOpConverter::matchAndRewrite(ShapeEqOp op, OpAdaptor adaptor,
auto loc = op.getLoc();
Type indexTy = rewriter.getIndexType();
- Value zero = rewriter.create<arith::ConstantIndexOp>(loc, 0);
+ Value zero = arith::ConstantIndexOp::create(rewriter, loc, 0);
Value firstShape = adaptor.getShapes().front();
Value firstRank =
- rewriter.create<tensor::DimOp>(loc, indexTy, firstShape, zero);
+ tensor::DimOp::create(rewriter, loc, indexTy, firstShape, zero);
Value result = nullptr;
// Generate a linear sequence of compares, all with firstShape as lhs.
for (Value shape : adaptor.getShapes().drop_front(1)) {
- Value rank = rewriter.create<tensor::DimOp>(loc, indexTy, shape, zero);
- Value eqRank = rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::eq,
- firstRank, rank);
- auto same = rewriter.create<IfOp>(
- loc, eqRank,
+ Value rank = tensor::DimOp::create(rewriter, loc, indexTy, shape, zero);
+ Value eqRank = arith::CmpIOp::create(
+ rewriter, loc, arith::CmpIPredicate::eq, firstRank, rank);
+ auto same = IfOp::create(
+ rewriter, loc, eqRank,
[&](OpBuilder &b, Location loc) {
- Value one = b.create<arith::ConstantIndexOp>(loc, 1);
+ Value one = arith::ConstantIndexOp::create(b, loc, 1);
Value init =
- b.create<arith::ConstantOp>(loc, i1Ty, b.getBoolAttr(true));
- auto loop = b.create<scf::ForOp>(
- loc, zero, firstRank, one, ValueRange{init},
+ arith::ConstantOp::create(b, loc, i1Ty, b.getBoolAttr(true));
+ auto loop = scf::ForOp::create(
+ b, loc, zero, firstRank, one, ValueRange{init},
[&](OpBuilder &b, Location nestedLoc, Value iv, ValueRange args) {
Value conj = args[0];
Value lhsExtent =
- b.create<tensor::ExtractOp>(loc, firstShape, iv);
- Value rhsExtent = b.create<tensor::ExtractOp>(loc, shape, iv);
- Value eqExtent = b.create<arith::CmpIOp>(
- loc, arith::CmpIPredicate::eq, lhsExtent, rhsExtent);
- Value conjNext = b.create<arith::AndIOp>(loc, conj, eqExtent);
- b.create<scf::YieldOp>(loc, ValueRange({conjNext}));
+ tensor::ExtractOp::create(b, loc, firstShape, iv);
+ Value rhsExtent = tensor::ExtractOp::create(b, loc, shape, iv);
+ Value eqExtent = arith::CmpIOp::create(
+ b, loc, arith::CmpIPredicate::eq, lhsExtent, rhsExtent);
+ Value conjNext = arith::AndIOp::create(b, loc, conj, eqExtent);
+ scf::YieldOp::create(b, loc, ValueRange({conjNext}));
});
- b.create<scf::YieldOp>(loc, loop.getResults());
+ scf::YieldOp::create(b, loc, loop.getResults());
},
[&](OpBuilder &b, Location loc) {
Value result =
- b.create<arith::ConstantOp>(loc, i1Ty, b.getBoolAttr(false));
- b.create<scf::YieldOp>(loc, result);
+ arith::ConstantOp::create(b, loc, i1Ty, b.getBoolAttr(false));
+ scf::YieldOp::create(b, loc, result);
});
result = !result ? same.getResult(0)
- : rewriter.create<arith::AndIOp>(loc, result,
- same.getResult(0));
+ : arith::AndIOp::create(rewriter, loc, result,
+ same.getResult(0));
}
rewriter.replaceOp(op, result);
return success();
@@ -581,18 +582,18 @@ LogicalResult ShapeOfOpConversion::matchAndRewrite(
int64_t rank = rankedTensorTy.getRank();
for (int64_t i = 0; i < rank; i++) {
if (rankedTensorTy.isDynamicDim(i)) {
- Value extent = rewriter.create<tensor::DimOp>(loc, tensor, i);
+ Value extent = tensor::DimOp::create(rewriter, loc, tensor, i);
extentValues.push_back(extent);
} else {
- Value extent = rewriter.create<arith::ConstantIndexOp>(
- loc, rankedTensorTy.getDimSize(i));
+ Value extent = arith::ConstantIndexOp::create(
+ rewriter, loc, rankedTensorTy.getDimSize(i));
extentValues.push_back(extent);
}
}
// Materialize extent tensor.
- Value staticExtentTensor = rewriter.create<tensor::FromElementsOp>(
- loc, RankedTensorType::get({rank}, rewriter.getIndexType()),
+ Value staticExtentTensor = tensor::FromElementsOp::create(
+ rewriter, loc, RankedTensorType::get({rank}, rewriter.getIndexType()),
extentValues);
rewriter.replaceOpWithNewOp<tensor::CastOp>(op, op.getType(),
staticExtentTensor);
@@ -601,13 +602,13 @@ LogicalResult ShapeOfOpConversion::matchAndRewrite(
// Lower to `tensor.generate` otherwise.
auto *ctx = rewriter.getContext();
- Value rank = rewriter.create<tensor::RankOp>(loc, tensor);
+ Value rank = tensor::RankOp::create(rewriter, loc, tensor);
rewriter.replaceOpWithNewOp<tensor::GenerateOp>(
op, getExtentTensorType(ctx), ValueRange{rank},
[&](OpBuilder &b, Location loc, ValueRange args) {
Value dim = args.front();
- Value extent = b.create<tensor::DimOp>(loc, tensor, dim);
- b.create<tensor::YieldOp>(loc, extent);
+ Value extent = tensor::DimOp::create(b, loc, tensor, dim);
+ tensor::YieldOp::create(b, loc, extent);
});
return success();
@@ -634,22 +635,22 @@ LogicalResult SplitAtOpConversion::matchAndRewrite(
return failure();
ImplicitLocOpBuilder b(op.getLoc(), rewriter);
- Value zero = b.create<arith::ConstantIndexOp>(0);
- Value rank = b.create<tensor::DimOp>(adaptor.getOperand(), zero);
+ Value zero = arith::ConstantIndexOp::create(b, 0);
+ Value rank = tensor::DimOp::create(b, adaptor.getOperand(), zero);
// index < 0 ? index + rank : index
Value originalIndex = adaptor.getIndex();
- Value add = b.create<arith::AddIOp>(originalIndex, rank);
+ Value add = arith::AddIOp::create(b, originalIndex, rank);
Value indexIsNegative =
- b.create<arith::CmpIOp>(arith::CmpIPredicate::slt, originalIndex, zero);
- Value index = b.create<arith::SelectOp>(indexIsNegative, add, originalIndex);
+ arith::CmpIOp::create(b, arith::CmpIPredicate::slt, originalIndex, zero);
+ Value index = arith::SelectOp::create(b, indexIsNegative, add, originalIndex);
- Value one = b.create<arith::ConstantIndexOp>(1);
+ Value one = arith::ConstantIndexOp::create(b, 1);
Value head =
- b.create<tensor::ExtractSliceOp>(adaptor.getOperand(), zero, index, one);
- Value tailSize = b.create<arith::SubIOp>(rank, index);
- Value tail = b.create<tensor::ExtractSliceOp>(adaptor.getOperand(), index,
- tailSize, one);
+ tensor::ExtractSliceOp::create(b, adaptor.getOperand(), zero, index, one);
+ Value tailSize = arith::SubIOp::create(b, rank, index);
+ Value tail = tensor::ExtractSliceOp::create(b, adaptor.getOperand(), index,
+ tailSize, one);
rewriter.replaceOp(op, {head, tail});
return success();
}
diff --git a/mlir/lib/Conversion/TensorToSPIRV/TensorToSPIRV.cpp b/mlir/lib/Conversion/TensorToSPIRV/TensorToSPIRV.cpp
index 2c4d27502a521..f24972f6b6ee1 100644
--- a/mlir/lib/Conversion/TensorToSPIRV/TensorToSPIRV.cpp
+++ b/mlir/lib/Conversion/TensorToSPIRV/TensorToSPIRV.cpp
@@ -68,10 +68,10 @@ class TensorExtractPattern final
// We could use the initializer directly; but certain driver compilers
// have bugs dealing with that. So for now, use spirv.Store for
// initialization.
- varOp = rewriter.create<spirv::VariableOp>(loc, varType,
- spirv::StorageClass::Function,
- /*initializer=*/nullptr);
- rewriter.create<spirv::StoreOp>(loc, varOp, adaptor.getTensor());
+ varOp = spirv::VariableOp::create(rewriter, loc, varType,
+ spirv::StorageClass::Function,
+ /*initializer=*/nullptr);
+ spirv::StoreOp::create(rewriter, loc, varOp, adaptor.getTensor());
} else {
// Need to store the value to the local variable. It's questionable
// whether we want to support such case though.
@@ -83,7 +83,7 @@ class TensorExtractPattern final
Value index = spirv::linearizeIndex(adaptor.getIndices(), strides,
/*offset=*/0, indexType, loc, rewriter);
- auto acOp = rewriter.create<spirv::AccessChainOp>(loc, varOp, index);
+ auto acOp = spirv::AccessChainOp::create(rewriter, loc, varOp, index);
rewriter.replaceOpWithNewOp<spirv::LoadOp>(extractOp, acOp);
diff --git a/mlir/lib/Conversion/TosaToArith/TosaToArith.cpp b/mlir/lib/Conversion/TosaToArith/TosaToArith.cpp
index 40ad63610e23f..044b725c7d805 100644
--- a/mlir/lib/Conversion/TosaToArith/TosaToArith.cpp
+++ b/mlir/lib/Conversion/TosaToArith/TosaToArith.cpp
@@ -51,8 +51,8 @@ TypedAttr getConstantAttr(Type type, int64_t value, PatternRewriter &rewriter) {
Value getConstantValue(Location loc, Type type, int64_t value,
PatternRewriter &rewriter) {
- return rewriter.create<arith::ConstantOp>(
- loc, getConstantAttr(type, value, rewriter));
+ return arith::ConstantOp::create(rewriter, loc,
+ getConstantAttr(type, value, rewriter));
}
// This converts the TOSA ApplyScale operator to a set of arithmetic ops,
@@ -82,41 +82,41 @@ class ApplyScaleGenericOpConverter
Value one64 = getConstantValue(loc, i64Ty, 1, rewriter);
Value thirtyOne32 = getConstantValue(loc, i32Ty, 31, rewriter);
- Value shift32 = rewriter.create<arith::ExtUIOp>(loc, i32Ty, op.getShift());
+ Value shift32 = arith::ExtUIOp::create(rewriter, loc, i32Ty, op.getShift());
// Compute the multiplication in 64-bits then select the high / low parts.
Value value64 = value;
if (getElementTypeOrSelf(valueTy) != rewriter.getI64Type())
- value64 = rewriter.create<arith::ExtSIOp>(loc, i64Ty, value);
+ value64 = arith::ExtSIOp::create(rewriter, loc, i64Ty, value);
Value multiplier64 =
- rewriter.create<arith::ExtSIOp>(loc, i64Ty, multiplier32);
+ arith::ExtSIOp::create(rewriter, loc, i64Ty, multiplier32);
Value multiply64 =
- rewriter.create<arith::MulIOp>(loc, value64, multiplier64);
+ arith::MulIOp::create(rewriter, loc, value64, multiplier64);
// Apply normal rounding.
- Value shift64 = rewriter.create<arith::ExtUIOp>(loc, i64Ty, shift32);
- Value round = rewriter.create<arith::ShLIOp>(loc, one64, shift64);
- round = rewriter.create<arith::ShRUIOp>(loc, round, one64);
- multiply64 = rewriter.create<arith::AddIOp>(loc, multiply64, round);
+ Value shift64 = arith::ExtUIOp::create(rewriter, loc, i64Ty, shift32);
+ Value round = arith::ShLIOp::create(rewriter, loc, one64, shift64);
+ round = arith::ShRUIOp::create(rewriter, loc, round, one64);
+ multiply64 = arith::AddIOp::create(rewriter, loc, multiply64, round);
// Apply double rounding if necessary.
if (op.getRoundingMode() == "DOUBLE_ROUND") {
int64_t roundInt = 1 << 30;
Value roundUp = getConstantValue(loc, i64Ty, roundInt, rewriter);
Value roundDown = getConstantValue(loc, i64Ty, -roundInt, rewriter);
- Value positive = rewriter.create<arith::CmpIOp>(
- loc, arith::CmpIPredicate::sge, value, zero);
+ Value positive = arith::CmpIOp::create(
+ rewriter, loc, arith::CmpIPredicate::sge, value, zero);
Value dir =
- rewriter.create<arith::SelectOp>(loc, positive, roundUp, roundDown);
- Value val = rewriter.create<arith::AddIOp>(loc, dir, multiply64);
- Value valid = rewriter.create<arith::CmpIOp>(
- loc, arith::CmpIPredicate::sgt, shift32, thirtyOne32);
+ arith::SelectOp::create(rewriter, loc, positive, roundUp, roundDown);
+ Value val = arith::AddIOp::create(rewriter, loc, dir, multiply64);
+ Value valid = arith::CmpIOp::create(
+ rewriter, loc, arith::CmpIPredicate::sgt, shift32, thirtyOne32);
multiply64 =
- rewriter.create<arith::SelectOp>(loc, valid, val, multiply64);
+ arith::SelectOp::create(rewriter, loc, valid, val, multiply64);
}
- Value result64 = rewriter.create<arith::ShRSIOp>(loc, multiply64, shift64);
- Value result32 = rewriter.create<arith::TruncIOp>(loc, i32Ty, result64);
+ Value result64 = arith::ShRSIOp::create(rewriter, loc, multiply64, shift64);
+ Value result32 = arith::TruncIOp::create(rewriter, loc, i32Ty, result64);
rewriter.replaceOp(op, result32);
return success();
@@ -146,7 +146,7 @@ class ApplyScale32BitOpConverter : public OpRewritePattern<tosa::ApplyScaleOp> {
Value value32 = op.getValue();
Value multiplier32 = op.getMultiplier();
- Value shift32 = rewriter.create<arith::ExtUIOp>(loc, i32Ty, op.getShift());
+ Value shift32 = arith::ExtUIOp::create(rewriter, loc, i32Ty, op.getShift());
// Constants used during the scaling operation.
Value zero32 = getConstantValue(loc, i32Ty, 0, rewriter);
@@ -158,86 +158,87 @@ class ApplyScale32BitOpConverter : public OpRewritePattern<tosa::ApplyScaleOp> {
// Compute the multiplication in 64-bits then select the high / low parts.
// Grab out the high/low of the computation
auto value64 =
- rewriter.create<arith::MulSIExtendedOp>(loc, value32, multiplier32);
+ arith::MulSIExtendedOp::create(rewriter, loc, value32, multiplier32);
Value low32 = value64.getLow();
Value high32 = value64.getHigh();
// Determine the direction and amount to shift the high bits.
- Value shiftOver32 = rewriter.create<arith::CmpIOp>(
- loc, arith::CmpIPredicate::sge, shift32, thirtyTwo32);
- Value roundHighBits = rewriter.create<arith::CmpIOp>(
- loc, arith::CmpIPredicate::sgt, shift32, thirtyTwo32);
+ Value shiftOver32 = arith::CmpIOp::create(
+ rewriter, loc, arith::CmpIPredicate::sge, shift32, thirtyTwo32);
+ Value roundHighBits = arith::CmpIOp::create(
+ rewriter, loc, arith::CmpIPredicate::sgt, shift32, thirtyTwo32);
Value shiftHighL =
- rewriter.create<arith::SubIOp>(loc, thirtyTwo32, shift32);
+ arith::SubIOp::create(rewriter, loc, thirtyTwo32, shift32);
Value shiftHighR =
- rewriter.create<arith::SubIOp>(loc, shift32, thirtyTwo32);
+ arith::SubIOp::create(rewriter, loc, shift32, thirtyTwo32);
shiftHighL =
- rewriter.create<arith::SelectOp>(loc, shiftOver32, zero32, shiftHighL);
+ arith::SelectOp::create(rewriter, loc, shiftOver32, zero32, shiftHighL);
shiftHighR =
- rewriter.create<arith::SelectOp>(loc, shiftOver32, shiftHighR, zero32);
+ arith::SelectOp::create(rewriter, loc, shiftOver32, shiftHighR, zero32);
// Conditionally perform our double round.
if (op.getRoundingMode() == "DOUBLE_ROUND") {
Value negOne32 = getConstantValue(loc, i32Ty, -1, rewriter);
- Value valuePositive = rewriter.create<arith::CmpIOp>(
- loc, arith::CmpIPredicate::sge, value32, zero32);
+ Value valuePositive = arith::CmpIOp::create(
+ rewriter, loc, arith::CmpIPredicate::sge, value32, zero32);
- Value roundDir =
- rewriter.create<arith::SelectOp>(loc, valuePositive, one32, negOne32);
+ Value roundDir = arith::SelectOp::create(rewriter, loc, valuePositive,
+ one32, negOne32);
roundDir =
- rewriter.create<arith::SelectOp>(loc, shiftOver32, roundDir, zero32);
+ arith::SelectOp::create(rewriter, loc, shiftOver32, roundDir, zero32);
- Value shiftLow = rewriter.create<arith::ShRUIOp>(loc, low32, thirty32);
- Value rounded = rewriter.create<arith::AddIOp>(loc, shiftLow, roundDir);
- Value carry = rewriter.create<arith::ShRSIOp>(loc, rounded, two32);
+ Value shiftLow = arith::ShRUIOp::create(rewriter, loc, low32, thirty32);
+ Value rounded = arith::AddIOp::create(rewriter, loc, shiftLow, roundDir);
+ Value carry = arith::ShRSIOp::create(rewriter, loc, rounded, two32);
Value shiftRound =
- rewriter.create<arith::ShLIOp>(loc, roundDir, thirty32);
+ arith::ShLIOp::create(rewriter, loc, roundDir, thirty32);
- low32 = rewriter.create<arith::AddIOp>(loc, low32, shiftRound);
- high32 = rewriter.create<arith::AddIOp>(loc, high32, carry);
+ low32 = arith::AddIOp::create(rewriter, loc, low32, shiftRound);
+ high32 = arith::AddIOp::create(rewriter, loc, high32, carry);
}
// Conditionally apply rounding in the low bits.
{
- Value shiftSubOne = rewriter.create<arith::SubIOp>(loc, shift32, one32);
- Value roundBit = rewriter.create<arith::ShLIOp>(loc, one32, shiftSubOne);
- roundBit = rewriter.create<arith::SelectOp>(loc, roundHighBits, zero32,
- roundBit);
-
- Value newLow32 = rewriter.create<arith::AddIOp>(loc, low32, roundBit);
- Value wasRounded = rewriter.create<arith::CmpIOp>(
- loc, arith::CmpIPredicate::ugt, low32, newLow32);
+ Value shiftSubOne = arith::SubIOp::create(rewriter, loc, shift32, one32);
+ Value roundBit = arith::ShLIOp::create(rewriter, loc, one32, shiftSubOne);
+ roundBit = arith::SelectOp::create(rewriter, loc, roundHighBits, zero32,
+ roundBit);
+
+ Value newLow32 = arith::AddIOp::create(rewriter, loc, low32, roundBit);
+ Value wasRounded = arith::CmpIOp::create(
+ rewriter, loc, arith::CmpIPredicate::ugt, low32, newLow32);
low32 = newLow32;
- Value rounded32 = rewriter.create<arith::ExtUIOp>(loc, i32Ty, wasRounded);
- high32 = rewriter.create<arith::AddIOp>(loc, high32, rounded32);
+ Value rounded32 =
+ arith::ExtUIOp::create(rewriter, loc, i32Ty, wasRounded);
+ high32 = arith::AddIOp::create(rewriter, loc, high32, rounded32);
}
// Conditionally apply rounding in the high bits.
{
Value shiftSubOne =
- rewriter.create<arith::SubIOp>(loc, shiftHighR, one32);
- Value roundBit = rewriter.create<arith::ShLIOp>(loc, one32, shiftSubOne);
- roundBit = rewriter.create<arith::SelectOp>(loc, roundHighBits, roundBit,
- zero32);
- high32 = rewriter.create<arith::AddIOp>(loc, high32, roundBit);
+ arith::SubIOp::create(rewriter, loc, shiftHighR, one32);
+ Value roundBit = arith::ShLIOp::create(rewriter, loc, one32, shiftSubOne);
+ roundBit = arith::SelectOp::create(rewriter, loc, roundHighBits, roundBit,
+ zero32);
+ high32 = arith::AddIOp::create(rewriter, loc, high32, roundBit);
}
// Combine the correct high/low bits into the final rescale result.
- high32 = rewriter.create<arith::ShLIOp>(loc, high32, shiftHighL);
- high32 = rewriter.create<arith::ShRSIOp>(loc, high32, shiftHighR);
- low32 = rewriter.create<arith::ShRUIOp>(loc, low32, shift32);
- low32 = rewriter.create<arith::SelectOp>(loc, shiftOver32, zero32, low32);
+ high32 = arith::ShLIOp::create(rewriter, loc, high32, shiftHighL);
+ high32 = arith::ShRSIOp::create(rewriter, loc, high32, shiftHighR);
+ low32 = arith::ShRUIOp::create(rewriter, loc, low32, shift32);
+ low32 = arith::SelectOp::create(rewriter, loc, shiftOver32, zero32, low32);
// Apply the rounding behavior and shift to the final alignment.
- Value result = rewriter.create<arith::AddIOp>(loc, low32, high32);
+ Value result = arith::AddIOp::create(rewriter, loc, low32, high32);
// Truncate if necessary.
if (!getElementTypeOrSelf(resultTy).isInteger(32)) {
- result = rewriter.create<arith::TruncIOp>(loc, resultTy, result);
+ result = arith::TruncIOp::create(rewriter, loc, resultTy, result);
}
rewriter.replaceOp(op, result);
diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
index 2f608bbd637b4..ec55091cd7eb8 100644
--- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
+++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
@@ -70,14 +70,14 @@ materializeBinaryNanCheckIfRequired(OpTy op, PatternRewriter &rewriter,
return result;
// Unordered comparison of NaN against itself will always return true.
- Value lhsIsNaN = rewriter.create<arith::CmpFOp>(
- op.getLoc(), arith::CmpFPredicate::UNO, lhs, lhs);
- Value rhsIsNaN = rewriter.create<arith::CmpFOp>(
- op.getLoc(), arith::CmpFPredicate::UNO, rhs, rhs);
+ Value lhsIsNaN = arith::CmpFOp::create(rewriter, op.getLoc(),
+ arith::CmpFPredicate::UNO, lhs, lhs);
+ Value rhsIsNaN = arith::CmpFOp::create(rewriter, op.getLoc(),
+ arith::CmpFPredicate::UNO, rhs, rhs);
Value rhsOrResult =
- rewriter.create<arith::SelectOp>(op.getLoc(), lhsIsNaN, rhs, result);
- return rewriter.create<arith::SelectOp>(op.getLoc(), rhsIsNaN, lhs,
- rhsOrResult);
+ arith::SelectOp::create(rewriter, op.getLoc(), lhsIsNaN, rhs, result);
+ return arith::SelectOp::create(rewriter, op.getLoc(), rhsIsNaN, lhs,
+ rhsOrResult);
}
static Value createLinalgBodyCalculationForElementwiseOp(
@@ -89,38 +89,38 @@ static Value createLinalgBodyCalculationForElementwiseOp(
// tosa::AbsOp
if (isa<tosa::AbsOp>(op) && isa<FloatType>(elementTy))
- return rewriter.create<math::AbsFOp>(loc, resultTypes, args);
+ return math::AbsFOp::create(rewriter, loc, resultTypes, args);
if (isa<tosa::AbsOp>(op) && isa<IntegerType>(elementTy)) {
- auto zero = rewriter.create<arith::ConstantOp>(
- loc, rewriter.getZeroAttr(elementTy));
- auto neg = rewriter.create<arith::SubIOp>(loc, zero, args[0]);
- return rewriter.create<arith::MaxSIOp>(loc, args[0], neg);
+ auto zero = arith::ConstantOp::create(rewriter, loc,
+ rewriter.getZeroAttr(elementTy));
+ auto neg = arith::SubIOp::create(rewriter, loc, zero, args[0]);
+ return arith::MaxSIOp::create(rewriter, loc, args[0], neg);
}
// tosa::AddOp
if (isa<tosa::AddOp>(op) && isa<FloatType>(elementTy))
- return rewriter.create<arith::AddFOp>(loc, resultTypes, args);
+ return arith::AddFOp::create(rewriter, loc, resultTypes, args);
if (isa<tosa::AddOp>(op) && isa<IntegerType>(elementTy))
- return rewriter.create<arith::AddIOp>(loc, resultTypes, args);
+ return arith::AddIOp::create(rewriter, loc, resultTypes, args);
// tosa::SubOp
if (isa<tosa::SubOp>(op) && isa<FloatType>(elementTy))
- return rewriter.create<arith::SubFOp>(loc, resultTypes, args);
+ return arith::SubFOp::create(rewriter, loc, resultTypes, args);
if (isa<tosa::SubOp>(op) && isa<IntegerType>(elementTy))
- return rewriter.create<arith::SubIOp>(loc, resultTypes, args);
+ return arith::SubIOp::create(rewriter, loc, resultTypes, args);
// tosa::IntDivOp
if (isa<tosa::IntDivOp>(op) && isa<IntegerType>(elementTy))
- return rewriter.create<arith::DivSIOp>(loc, resultTypes, args);
+ return arith::DivSIOp::create(rewriter, loc, resultTypes, args);
// tosa::ReciprocalOp
if (isa<tosa::ReciprocalOp>(op) && isa<FloatType>(elementTy)) {
auto one =
- rewriter.create<arith::ConstantOp>(loc, FloatAttr::get(elementTy, 1));
- return rewriter.create<arith::DivFOp>(loc, resultTypes, one, args[0]);
+ arith::ConstantOp::create(rewriter, loc, FloatAttr::get(elementTy, 1));
+ return arith::DivFOp::create(rewriter, loc, resultTypes, one, args[0]);
}
// tosa::MulOp
@@ -140,7 +140,8 @@ static Value createLinalgBodyCalculationForElementwiseOp(
"Cannot have shift value for float");
return nullptr;
}
- return rewriter.create<arith::MulFOp>(loc, resultTypes, args[0], args[1]);
+ return arith::MulFOp::create(rewriter, loc, resultTypes, args[0],
+ args[1]);
}
if (isa<IntegerType>(elementTy)) {
@@ -149,21 +150,21 @@ static Value createLinalgBodyCalculationForElementwiseOp(
if (shift > 0) {
auto shiftConst =
- rewriter.create<arith::ConstantIntOp>(loc, shift, /*bitwidth=*/8);
+ arith::ConstantIntOp::create(rewriter, loc, shift, /*bitwidth=*/8);
if (!a.getType().isInteger(32))
- a = rewriter.create<arith::ExtSIOp>(loc, rewriter.getI32Type(), a);
+ a = arith::ExtSIOp::create(rewriter, loc, rewriter.getI32Type(), a);
if (!b.getType().isInteger(32))
- b = rewriter.create<arith::ExtSIOp>(loc, rewriter.getI32Type(), b);
+ b = arith::ExtSIOp::create(rewriter, loc, rewriter.getI32Type(), b);
- auto result = rewriter.create<tosa::ApplyScaleOp>(
- loc, rewriter.getI32Type(), a, b, shiftConst,
+ auto result = tosa::ApplyScaleOp::create(
+ rewriter, loc, rewriter.getI32Type(), a, b, shiftConst,
rewriter.getStringAttr("SINGLE_ROUND"));
if (elementTy.isInteger(32))
return result;
- return rewriter.create<arith::TruncIOp>(loc, elementTy, result);
+ return arith::TruncIOp::create(rewriter, loc, elementTy, result);
}
int aWidth = a.getType().getIntOrFloatBitWidth();
@@ -171,11 +172,11 @@ static Value createLinalgBodyCalculationForElementwiseOp(
int cWidth = resultTypes[0].getIntOrFloatBitWidth();
if (aWidth < cWidth)
- a = rewriter.create<arith::ExtSIOp>(loc, resultTypes[0], a);
+ a = arith::ExtSIOp::create(rewriter, loc, resultTypes[0], a);
if (bWidth < cWidth)
- b = rewriter.create<arith::ExtSIOp>(loc, resultTypes[0], b);
+ b = arith::ExtSIOp::create(rewriter, loc, resultTypes[0], b);
- return rewriter.create<arith::MulIOp>(loc, resultTypes, a, b);
+ return arith::MulIOp::create(rewriter, loc, resultTypes, a, b);
}
}
@@ -201,14 +202,14 @@ static Value createLinalgBodyCalculationForElementwiseOp(
int64_t outZp = *maybeOutZp;
if (isa<FloatType>(elementTy))
- return rewriter.create<arith::NegFOp>(loc, resultTypes, args[0]);
+ return arith::NegFOp::create(rewriter, loc, resultTypes, args[0]);
if (isa<IntegerType>(elementTy)) {
if (!inZp && !outZp) {
- auto constant = rewriter.create<arith::ConstantOp>(
- loc, IntegerAttr::get(elementTy, 0));
- return rewriter.create<arith::SubIOp>(loc, resultTypes, constant,
- args[0]);
+ auto constant = arith::ConstantOp::create(
+ rewriter, loc, IntegerAttr::get(elementTy, 0));
+ return arith::SubIOp::create(rewriter, loc, resultTypes, constant,
+ args[0]);
}
// Compute the maximum value that can occur in the intermediate buffer.
@@ -231,214 +232,214 @@ static Value createLinalgBodyCalculationForElementwiseOp(
}
Type intermediateType = rewriter.getIntegerType(intermediateBitWidth);
- Value zpAddValue = rewriter.create<arith::ConstantOp>(
- loc, rewriter.getIntegerAttr(intermediateType, zpAdd));
+ Value zpAddValue = arith::ConstantOp::create(
+ rewriter, loc, rewriter.getIntegerAttr(intermediateType, zpAdd));
// The negation can be applied by doing:
// outputValue = inZp + outZp - inputValue
auto ext =
- rewriter.create<arith::ExtSIOp>(loc, intermediateType, args[0]);
- auto sub = rewriter.create<arith::SubIOp>(loc, zpAddValue, ext);
+ arith::ExtSIOp::create(rewriter, loc, intermediateType, args[0]);
+ auto sub = arith::SubIOp::create(rewriter, loc, zpAddValue, ext);
// Clamp to the negation range.
- Value min = rewriter.create<arith::ConstantIntOp>(
- loc, intermediateType,
+ Value min = arith::ConstantIntOp::create(
+ rewriter, loc, intermediateType,
APInt::getSignedMinValue(inputBitWidth).getSExtValue());
- Value max = rewriter.create<arith::ConstantIntOp>(
- loc, intermediateType,
+ Value max = arith::ConstantIntOp::create(
+ rewriter, loc, intermediateType,
APInt::getSignedMaxValue(inputBitWidth).getSExtValue());
auto clamp = clampIntHelper(loc, sub, min, max, rewriter, false);
// Truncate to the final value.
- return rewriter.create<arith::TruncIOp>(loc, elementTy, clamp);
+ return arith::TruncIOp::create(rewriter, loc, elementTy, clamp);
}
}
// tosa::BitwiseAndOp
if (isa<tosa::BitwiseAndOp>(op) && isa<IntegerType>(elementTy))
- return rewriter.create<arith::AndIOp>(loc, resultTypes, args);
+ return arith::AndIOp::create(rewriter, loc, resultTypes, args);
// tosa::BitwiseOrOp
if (isa<tosa::BitwiseOrOp>(op) && isa<IntegerType>(elementTy))
- return rewriter.create<arith::OrIOp>(loc, resultTypes, args);
+ return arith::OrIOp::create(rewriter, loc, resultTypes, args);
// tosa::BitwiseNotOp
if (isa<tosa::BitwiseNotOp>(op) && isa<IntegerType>(elementTy)) {
auto allOnesAttr = rewriter.getIntegerAttr(
elementTy, APInt::getAllOnes(elementTy.getIntOrFloatBitWidth()));
- auto allOnes = rewriter.create<arith::ConstantOp>(loc, allOnesAttr);
- return rewriter.create<arith::XOrIOp>(loc, resultTypes, args[0], allOnes);
+ auto allOnes = arith::ConstantOp::create(rewriter, loc, allOnesAttr);
+ return arith::XOrIOp::create(rewriter, loc, resultTypes, args[0], allOnes);
}
// tosa::BitwiseXOrOp
if (isa<tosa::BitwiseXorOp>(op) && isa<IntegerType>(elementTy))
- return rewriter.create<arith::XOrIOp>(loc, resultTypes, args);
+ return arith::XOrIOp::create(rewriter, loc, resultTypes, args);
// tosa::LogicalLeftShiftOp
if (isa<tosa::LogicalLeftShiftOp>(op) && isa<IntegerType>(elementTy))
- return rewriter.create<arith::ShLIOp>(loc, resultTypes, args);
+ return arith::ShLIOp::create(rewriter, loc, resultTypes, args);
// tosa::LogicalRightShiftOp
if (isa<tosa::LogicalRightShiftOp>(op) && isa<IntegerType>(elementTy))
- return rewriter.create<arith::ShRUIOp>(loc, resultTypes, args);
+ return arith::ShRUIOp::create(rewriter, loc, resultTypes, args);
// tosa::ArithmeticRightShiftOp
if (isa<tosa::ArithmeticRightShiftOp>(op) && isa<IntegerType>(elementTy)) {
- auto result = rewriter.create<arith::ShRSIOp>(loc, resultTypes, args);
+ auto result = arith::ShRSIOp::create(rewriter, loc, resultTypes, args);
auto round = cast<BoolAttr>(op->getAttr("round")).getValue();
if (!round) {
return result;
}
Type i1Ty = IntegerType::get(rewriter.getContext(), /*width=*/1);
- auto one =
- rewriter.create<arith::ConstantOp>(loc, IntegerAttr::get(elementTy, 1));
- auto zero =
- rewriter.create<arith::ConstantOp>(loc, IntegerAttr::get(elementTy, 0));
+ auto one = arith::ConstantOp::create(rewriter, loc,
+ IntegerAttr::get(elementTy, 1));
+ auto zero = arith::ConstantOp::create(rewriter, loc,
+ IntegerAttr::get(elementTy, 0));
auto i1one =
- rewriter.create<arith::ConstantOp>(loc, IntegerAttr::get(i1Ty, 1));
+ arith::ConstantOp::create(rewriter, loc, IntegerAttr::get(i1Ty, 1));
// Checking that input2 != 0
- auto shiftValueGreaterThanZero = rewriter.create<arith::CmpIOp>(
- loc, arith::CmpIPredicate::sgt, args[1], zero);
+ auto shiftValueGreaterThanZero = arith::CmpIOp::create(
+ rewriter, loc, arith::CmpIPredicate::sgt, args[1], zero);
// Checking for the last bit of input1 to be 1
auto subtract =
- rewriter.create<arith::SubIOp>(loc, resultTypes, args[1], one);
+ arith::SubIOp::create(rewriter, loc, resultTypes, args[1], one);
auto shifted =
- rewriter.create<arith::ShRSIOp>(loc, resultTypes, args[0], subtract)
+ arith::ShRSIOp::create(rewriter, loc, resultTypes, args[0], subtract)
->getResults();
- auto truncated = rewriter.create<arith::TruncIOp>(
- loc, i1Ty, shifted, ArrayRef<NamedAttribute>());
+ auto truncated = arith::TruncIOp::create(rewriter, loc, i1Ty, shifted,
+ ArrayRef<NamedAttribute>());
auto isInputOdd =
- rewriter.create<arith::AndIOp>(loc, i1Ty, truncated, i1one);
+ arith::AndIOp::create(rewriter, loc, i1Ty, truncated, i1one);
- auto shouldRound = rewriter.create<arith::AndIOp>(
- loc, i1Ty, shiftValueGreaterThanZero, isInputOdd);
+ auto shouldRound = arith::AndIOp::create(
+ rewriter, loc, i1Ty, shiftValueGreaterThanZero, isInputOdd);
auto extended =
- rewriter.create<arith::ExtUIOp>(loc, resultTypes, shouldRound);
- return rewriter.create<arith::AddIOp>(loc, resultTypes, result, extended);
+ arith::ExtUIOp::create(rewriter, loc, resultTypes, shouldRound);
+ return arith::AddIOp::create(rewriter, loc, resultTypes, result, extended);
}
// tosa::ClzOp
if (isa<tosa::ClzOp>(op) && isa<IntegerType>(elementTy)) {
- return rewriter.create<math::CountLeadingZerosOp>(loc, elementTy, args[0]);
+ return math::CountLeadingZerosOp::create(rewriter, loc, elementTy, args[0]);
}
// tosa::LogicalAnd
if (isa<tosa::LogicalAndOp>(op) && elementTy.isInteger(1))
- return rewriter.create<arith::AndIOp>(loc, resultTypes, args);
+ return arith::AndIOp::create(rewriter, loc, resultTypes, args);
// tosa::LogicalNot
if (isa<tosa::LogicalNotOp>(op) && elementTy.isInteger(1)) {
- auto one = rewriter.create<arith::ConstantOp>(
- loc, rewriter.getIntegerAttr(elementTy, 1));
- return rewriter.create<arith::XOrIOp>(loc, resultTypes, args[0], one);
+ auto one = arith::ConstantOp::create(rewriter, loc,
+ rewriter.getIntegerAttr(elementTy, 1));
+ return arith::XOrIOp::create(rewriter, loc, resultTypes, args[0], one);
}
// tosa::LogicalOr
if (isa<tosa::LogicalOrOp>(op) && elementTy.isInteger(1))
- return rewriter.create<arith::OrIOp>(loc, resultTypes, args);
+ return arith::OrIOp::create(rewriter, loc, resultTypes, args);
// tosa::LogicalXor
if (isa<tosa::LogicalXorOp>(op) && elementTy.isInteger(1))
- return rewriter.create<arith::XOrIOp>(loc, resultTypes, args);
+ return arith::XOrIOp::create(rewriter, loc, resultTypes, args);
// tosa::PowOp
if (isa<tosa::PowOp>(op) && isa<FloatType>(elementTy))
- return rewriter.create<mlir::math::PowFOp>(loc, resultTypes, args);
+ return mlir::math::PowFOp::create(rewriter, loc, resultTypes, args);
// tosa::RsqrtOp
if (isa<tosa::RsqrtOp>(op) && isa<FloatType>(elementTy))
- return rewriter.create<mlir::math::RsqrtOp>(loc, resultTypes, args);
+ return mlir::math::RsqrtOp::create(rewriter, loc, resultTypes, args);
// tosa::LogOp
if (isa<tosa::LogOp>(op) && isa<FloatType>(elementTy))
- return rewriter.create<mlir::math::LogOp>(loc, resultTypes, args);
+ return mlir::math::LogOp::create(rewriter, loc, resultTypes, args);
// tosa::ExpOp
if (isa<tosa::ExpOp>(op) && isa<FloatType>(elementTy))
- return rewriter.create<mlir::math::ExpOp>(loc, resultTypes, args);
+ return mlir::math::ExpOp::create(rewriter, loc, resultTypes, args);
// tosa::SinOp
if (isa<tosa::SinOp>(op) && isa<FloatType>(elementTy))
- return rewriter.create<mlir::math::SinOp>(loc, resultTypes, args);
+ return mlir::math::SinOp::create(rewriter, loc, resultTypes, args);
// tosa::CosOp
if (isa<tosa::CosOp>(op) && isa<FloatType>(elementTy))
- return rewriter.create<mlir::math::CosOp>(loc, resultTypes, args);
+ return mlir::math::CosOp::create(rewriter, loc, resultTypes, args);
// tosa::TanhOp
if (isa<tosa::TanhOp>(op) && isa<FloatType>(elementTy))
- return rewriter.create<mlir::math::TanhOp>(loc, resultTypes, args);
+ return mlir::math::TanhOp::create(rewriter, loc, resultTypes, args);
// tosa::ErfOp
if (isa<tosa::ErfOp>(op) && llvm::isa<FloatType>(elementTy))
- return rewriter.create<mlir::math::ErfOp>(loc, resultTypes, args);
+ return mlir::math::ErfOp::create(rewriter, loc, resultTypes, args);
// tosa::GreaterOp
if (isa<tosa::GreaterOp>(op) && isa<FloatType>(elementTy))
- return rewriter.create<arith::CmpFOp>(loc, arith::CmpFPredicate::OGT,
- args[0], args[1]);
+ return arith::CmpFOp::create(rewriter, loc, arith::CmpFPredicate::OGT,
+ args[0], args[1]);
if (isa<tosa::GreaterOp>(op) && elementTy.isSignlessInteger())
- return rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::sgt,
- args[0], args[1]);
+ return arith::CmpIOp::create(rewriter, loc, arith::CmpIPredicate::sgt,
+ args[0], args[1]);
// tosa::GreaterEqualOp
if (isa<tosa::GreaterEqualOp>(op) && isa<FloatType>(elementTy))
- return rewriter.create<arith::CmpFOp>(loc, arith::CmpFPredicate::OGE,
- args[0], args[1]);
+ return arith::CmpFOp::create(rewriter, loc, arith::CmpFPredicate::OGE,
+ args[0], args[1]);
if (isa<tosa::GreaterEqualOp>(op) && elementTy.isSignlessInteger())
- return rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::sge,
- args[0], args[1]);
+ return arith::CmpIOp::create(rewriter, loc, arith::CmpIPredicate::sge,
+ args[0], args[1]);
// tosa::EqualOp
if (isa<tosa::EqualOp>(op) && isa<FloatType>(elementTy))
- return rewriter.create<arith::CmpFOp>(loc, arith::CmpFPredicate::OEQ,
- args[0], args[1]);
+ return arith::CmpFOp::create(rewriter, loc, arith::CmpFPredicate::OEQ,
+ args[0], args[1]);
if (isa<tosa::EqualOp>(op) && elementTy.isSignlessInteger())
- return rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::eq,
- args[0], args[1]);
+ return arith::CmpIOp::create(rewriter, loc, arith::CmpIPredicate::eq,
+ args[0], args[1]);
// tosa::SelectOp
if (isa<tosa::SelectOp>(op)) {
elementTy = cast<ShapedType>(op->getOperand(1).getType()).getElementType();
if (isa<FloatType>(elementTy) || isa<IntegerType>(elementTy))
- return rewriter.create<arith::SelectOp>(loc, args[0], args[1], args[2]);
+ return arith::SelectOp::create(rewriter, loc, args[0], args[1], args[2]);
}
// tosa::MaximumOp
if (isa<tosa::MaximumOp>(op) && isa<FloatType>(elementTy)) {
- auto max = rewriter.create<arith::MaximumFOp>(loc, args[0], args[1]);
+ auto max = arith::MaximumFOp::create(rewriter, loc, args[0], args[1]);
return materializeBinaryNanCheckIfRequired(llvm::cast<tosa::MaximumOp>(op),
rewriter, args[0], args[1], max);
}
if (isa<tosa::MaximumOp>(op) && elementTy.isSignlessInteger()) {
- return rewriter.create<arith::MaxSIOp>(loc, args[0], args[1]);
+ return arith::MaxSIOp::create(rewriter, loc, args[0], args[1]);
}
// tosa::MinimumOp
if (isa<tosa::MinimumOp>(op) && isa<FloatType>(elementTy)) {
- auto min = rewriter.create<arith::MinimumFOp>(loc, args[0], args[1]);
+ auto min = arith::MinimumFOp::create(rewriter, loc, args[0], args[1]);
return materializeBinaryNanCheckIfRequired(llvm::cast<tosa::MinimumOp>(op),
rewriter, args[0], args[1], min);
}
if (isa<tosa::MinimumOp>(op) && elementTy.isSignlessInteger()) {
- return rewriter.create<arith::MinSIOp>(loc, args[0], args[1]);
+ return arith::MinSIOp::create(rewriter, loc, args[0], args[1]);
}
// tosa::CeilOp
if (isa<tosa::CeilOp>(op) && isa<FloatType>(elementTy))
- return rewriter.create<math::CeilOp>(loc, resultTypes, args);
+ return math::CeilOp::create(rewriter, loc, resultTypes, args);
// tosa::FloorOp
if (isa<tosa::FloorOp>(op) && isa<FloatType>(elementTy))
- return rewriter.create<math::FloorOp>(loc, resultTypes, args);
+ return math::FloorOp::create(rewriter, loc, resultTypes, args);
// tosa::ClampOp
if (isa<tosa::ClampOp>(op) && isa<FloatType>(elementTy)) {
@@ -449,10 +450,10 @@ static Value createLinalgBodyCalculationForElementwiseOp(
APFloat::rmNearestTiesToEven, &losesInfo);
maxApf.convert(cast<FloatType>(elementTy).getFloatSemantics(),
APFloat::rmNearestTiesToEven, &losesInfo);
- auto min = rewriter.create<arith::ConstantOp>(
- loc, elementTy, rewriter.getFloatAttr(elementTy, minApf));
- auto max = rewriter.create<arith::ConstantOp>(
- loc, elementTy, rewriter.getFloatAttr(elementTy, maxApf));
+ auto min = arith::ConstantOp::create(
+ rewriter, loc, elementTy, rewriter.getFloatAttr(elementTy, minApf));
+ auto max = arith::ConstantOp::create(
+ rewriter, loc, elementTy, rewriter.getFloatAttr(elementTy, maxApf));
auto result = clampFloatHelper(loc, args[0], min, max, rewriter);
auto clampOp = llvm::cast<tosa::ClampOp>(op);
@@ -478,11 +479,11 @@ static Value createLinalgBodyCalculationForElementwiseOp(
// return init if x == NaN else result
// Unordered comparison of NaN against itself will always return true.
- Value isNaN = rewriter.create<arith::CmpFOp>(
- op->getLoc(), arith::CmpFPredicate::UNO, args[0], args[0]);
+ Value isNaN = arith::CmpFOp::create(
+ rewriter, op->getLoc(), arith::CmpFPredicate::UNO, args[0], args[0]);
// TOSA specifies that in "ignore" NaN mode the result is "min" if the input
// is NaN.
- return rewriter.create<arith::SelectOp>(op->getLoc(), isNaN, min, result);
+ return arith::SelectOp::create(rewriter, op->getLoc(), isNaN, min, result);
}
if (isa<tosa::ClampOp>(op) && isa<IntegerType>(elementTy)) {
@@ -515,10 +516,10 @@ static Value createLinalgBodyCalculationForElementwiseOp(
min = std::min(min, maxRepresentable);
max = std::min(max, maxRepresentable);
- auto minVal = rewriter.create<arith::ConstantIntOp>(
- loc, min, intTy.getIntOrFloatBitWidth());
- auto maxVal = rewriter.create<arith::ConstantIntOp>(
- loc, max, intTy.getIntOrFloatBitWidth());
+ auto minVal = arith::ConstantIntOp::create(rewriter, loc, min,
+ intTy.getIntOrFloatBitWidth());
+ auto maxVal = arith::ConstantIntOp::create(rewriter, loc, max,
+ intTy.getIntOrFloatBitWidth());
return clampIntHelper(loc, args[0], minVal, maxVal, rewriter,
intTy.isUnsignedInteger());
}
@@ -526,11 +527,11 @@ static Value createLinalgBodyCalculationForElementwiseOp(
// tosa::SigmoidOp
if (isa<tosa::SigmoidOp>(op) && isa<FloatType>(elementTy)) {
auto one =
- rewriter.create<arith::ConstantOp>(loc, FloatAttr::get(elementTy, 1));
- auto negate = rewriter.create<arith::NegFOp>(loc, resultTypes, args[0]);
- auto exp = rewriter.create<mlir::math::ExpOp>(loc, resultTypes, negate);
- auto added = rewriter.create<arith::AddFOp>(loc, resultTypes, exp, one);
- return rewriter.create<arith::DivFOp>(loc, resultTypes, one, added);
+ arith::ConstantOp::create(rewriter, loc, FloatAttr::get(elementTy, 1));
+ auto negate = arith::NegFOp::create(rewriter, loc, resultTypes, args[0]);
+ auto exp = mlir::math::ExpOp::create(rewriter, loc, resultTypes, negate);
+ auto added = arith::AddFOp::create(rewriter, loc, resultTypes, exp, one);
+ return arith::DivFOp::create(rewriter, loc, resultTypes, one, added);
}
// tosa::CastOp
@@ -549,21 +550,21 @@ static Value createLinalgBodyCalculationForElementwiseOp(
return args.front();
if (isa<FloatType>(srcTy) && isa<FloatType>(dstTy) && bitExtend)
- return rewriter.create<arith::ExtFOp>(loc, resultTypes, args,
- ArrayRef<NamedAttribute>());
+ return arith::ExtFOp::create(rewriter, loc, resultTypes, args,
+ ArrayRef<NamedAttribute>());
if (isa<FloatType>(srcTy) && isa<FloatType>(dstTy) && !bitExtend)
- return rewriter.create<arith::TruncFOp>(loc, resultTypes, args,
- ArrayRef<NamedAttribute>());
+ return arith::TruncFOp::create(rewriter, loc, resultTypes, args,
+ ArrayRef<NamedAttribute>());
// 1-bit integers need to be treated as signless.
if (srcTy.isInteger(1) && arith::UIToFPOp::areCastCompatible(srcTy, dstTy))
- return rewriter.create<arith::UIToFPOp>(loc, resultTypes, args,
- ArrayRef<NamedAttribute>());
+ return arith::UIToFPOp::create(rewriter, loc, resultTypes, args,
+ ArrayRef<NamedAttribute>());
if (srcTy.isInteger(1) && isa<IntegerType>(dstTy) && bitExtend)
- return rewriter.create<arith::ExtUIOp>(loc, resultTypes, args,
- ArrayRef<NamedAttribute>());
+ return arith::ExtUIOp::create(rewriter, loc, resultTypes, args,
+ ArrayRef<NamedAttribute>());
// Unsigned integers need an unrealized cast so that they can be passed
// to UIToFP.
@@ -574,25 +575,25 @@ static Value createLinalgBodyCalculationForElementwiseOp(
loc, rewriter.getIntegerType(srcTy.getIntOrFloatBitWidth()),
args[0])
.getResult(0);
- return rewriter.create<arith::UIToFPOp>(loc, resultTypes[0],
- unrealizedCast);
+ return arith::UIToFPOp::create(rewriter, loc, resultTypes[0],
+ unrealizedCast);
}
// All other si-to-fp conversions should be handled by SIToFP.
if (arith::SIToFPOp::areCastCompatible(srcTy, dstTy))
- return rewriter.create<arith::SIToFPOp>(loc, resultTypes, args,
- ArrayRef<NamedAttribute>());
+ return arith::SIToFPOp::create(rewriter, loc, resultTypes, args,
+ ArrayRef<NamedAttribute>());
// Casting to boolean, floats need to only be checked as not-equal to zero.
if (isa<FloatType>(srcTy) && dstTy.isInteger(1)) {
- Value zero = rewriter.create<arith::ConstantOp>(
- loc, rewriter.getFloatAttr(srcTy, 0.0));
- return rewriter.create<arith::CmpFOp>(loc, arith::CmpFPredicate::UNE,
- args.front(), zero);
+ Value zero = arith::ConstantOp::create(rewriter, loc,
+ rewriter.getFloatAttr(srcTy, 0.0));
+ return arith::CmpFOp::create(rewriter, loc, arith::CmpFPredicate::UNE,
+ args.front(), zero);
}
if (arith::FPToSIOp::areCastCompatible(srcTy, dstTy)) {
- auto rounded = rewriter.create<math::RoundEvenOp>(loc, args[0]);
+ auto rounded = math::RoundEvenOp::create(rewriter, loc, args[0]);
const auto &fltSemantics = cast<FloatType>(srcTy).getFloatSemantics();
// Check whether neither int min nor int max can be represented in the
@@ -601,37 +602,42 @@ static Value createLinalgBodyCalculationForElementwiseOp(
APFloat::semanticsMaxExponent(fltSemantics)) {
// Use cmp + select to replace infinites by int min / int max. Other
// integral values can be represented in the integer space.
- auto conv = rewriter.create<arith::FPToSIOp>(loc, dstTy, rounded);
- auto posInf = rewriter.create<arith::ConstantOp>(
- loc, rewriter.getFloatAttr(getElementTypeOrSelf(srcTy),
- APFloat::getInf(fltSemantics)));
- auto negInf = rewriter.create<arith::ConstantOp>(
- loc, rewriter.getFloatAttr(
- getElementTypeOrSelf(srcTy),
- APFloat::getInf(fltSemantics, /*Negative=*/true)));
- auto overflow = rewriter.create<arith::CmpFOp>(
- loc, arith::CmpFPredicate::UEQ, rounded, posInf);
- auto underflow = rewriter.create<arith::CmpFOp>(
- loc, arith::CmpFPredicate::UEQ, rounded, negInf);
- auto intMin = rewriter.create<arith::ConstantOp>(
- loc, rewriter.getIntegerAttr(
- getElementTypeOrSelf(dstTy),
- APInt::getSignedMinValue(dstTy.getIntOrFloatBitWidth())));
- auto intMax = rewriter.create<arith::ConstantOp>(
- loc, rewriter.getIntegerAttr(
- getElementTypeOrSelf(dstTy),
- APInt::getSignedMaxValue(dstTy.getIntOrFloatBitWidth())));
+ auto conv = arith::FPToSIOp::create(rewriter, loc, dstTy, rounded);
+ auto posInf = arith::ConstantOp::create(
+ rewriter, loc,
+ rewriter.getFloatAttr(getElementTypeOrSelf(srcTy),
+ APFloat::getInf(fltSemantics)));
+ auto negInf = arith::ConstantOp::create(
+ rewriter, loc,
+ rewriter.getFloatAttr(
+ getElementTypeOrSelf(srcTy),
+ APFloat::getInf(fltSemantics, /*Negative=*/true)));
+ auto overflow = arith::CmpFOp::create(
+ rewriter, loc, arith::CmpFPredicate::UEQ, rounded, posInf);
+ auto underflow = arith::CmpFOp::create(
+ rewriter, loc, arith::CmpFPredicate::UEQ, rounded, negInf);
+ auto intMin = arith::ConstantOp::create(
+ rewriter, loc,
+ rewriter.getIntegerAttr(
+ getElementTypeOrSelf(dstTy),
+ APInt::getSignedMinValue(dstTy.getIntOrFloatBitWidth())));
+ auto intMax = arith::ConstantOp::create(
+ rewriter, loc,
+ rewriter.getIntegerAttr(
+ getElementTypeOrSelf(dstTy),
+ APInt::getSignedMaxValue(dstTy.getIntOrFloatBitWidth())));
auto maxClamped =
- rewriter.create<arith::SelectOp>(loc, overflow, intMax, conv);
- return rewriter.create<arith::SelectOp>(loc, underflow, intMin,
- maxClamped);
+ arith::SelectOp::create(rewriter, loc, overflow, intMax, conv);
+ return arith::SelectOp::create(rewriter, loc, underflow, intMin,
+ maxClamped);
}
- auto intMinFP = rewriter.create<arith::ConstantOp>(
- loc, rewriter.getFloatAttr(
- getElementTypeOrSelf(srcTy),
- APInt::getSignedMinValue(dstTy.getIntOrFloatBitWidth())
- .getSExtValue()));
+ auto intMinFP = arith::ConstantOp::create(
+ rewriter, loc,
+ rewriter.getFloatAttr(
+ getElementTypeOrSelf(srcTy),
+ APInt::getSignedMinValue(dstTy.getIntOrFloatBitWidth())
+ .getSExtValue()));
// Check whether the mantissa has enough bits to represent int max.
if (cast<FloatType>(srcTy).getFPMantissaWidth() >=
@@ -640,58 +646,61 @@ static Value createLinalgBodyCalculationForElementwiseOp(
// consists of a single leading bit. Therefore we can clamp the input
// in the floating-point domain.
- auto intMaxFP = rewriter.create<arith::ConstantOp>(
- loc, rewriter.getFloatAttr(
- getElementTypeOrSelf(srcTy),
- APInt::getSignedMaxValue(dstTy.getIntOrFloatBitWidth())
- .getSExtValue()));
+ auto intMaxFP = arith::ConstantOp::create(
+ rewriter, loc,
+ rewriter.getFloatAttr(
+ getElementTypeOrSelf(srcTy),
+ APInt::getSignedMaxValue(dstTy.getIntOrFloatBitWidth())
+ .getSExtValue()));
Value clamped =
clampFloatHelper(loc, rounded, intMinFP, intMaxFP, rewriter);
- return rewriter.create<arith::FPToSIOp>(loc, dstTy, clamped);
+ return arith::FPToSIOp::create(rewriter, loc, dstTy, clamped);
}
// Due to earlier check we know exponant range is big enough to represent
// int min. We can therefore rely on int max + 1 being representable as
// well because it's just int min with a positive sign. So clamp the min
// value and compare against that to select the max int value if needed.
- auto intMaxPlusOneFP = rewriter.create<arith::ConstantOp>(
- loc, rewriter.getFloatAttr(
- getElementTypeOrSelf(srcTy),
- static_cast<double>(
- APInt::getSignedMaxValue(dstTy.getIntOrFloatBitWidth())
- .getSExtValue()) +
- 1.0f));
-
- auto intMax = rewriter.create<arith::ConstantOp>(
- loc, rewriter.getIntegerAttr(
- getElementTypeOrSelf(dstTy),
- APInt::getSignedMaxValue(dstTy.getIntOrFloatBitWidth())));
+ auto intMaxPlusOneFP = arith::ConstantOp::create(
+ rewriter, loc,
+ rewriter.getFloatAttr(
+ getElementTypeOrSelf(srcTy),
+ static_cast<double>(
+ APInt::getSignedMaxValue(dstTy.getIntOrFloatBitWidth())
+ .getSExtValue()) +
+ 1.0f));
+
+ auto intMax = arith::ConstantOp::create(
+ rewriter, loc,
+ rewriter.getIntegerAttr(
+ getElementTypeOrSelf(dstTy),
+ APInt::getSignedMaxValue(dstTy.getIntOrFloatBitWidth())));
auto minClampedFP =
- rewriter.create<arith::MaximumFOp>(loc, rounded, intMinFP);
+ arith::MaximumFOp::create(rewriter, loc, rounded, intMinFP);
auto minClamped =
- rewriter.create<arith::FPToSIOp>(loc, dstTy, minClampedFP);
- auto overflow = rewriter.create<arith::CmpFOp>(
- loc, arith::CmpFPredicate::UGE, rounded, intMaxPlusOneFP);
- return rewriter.create<arith::SelectOp>(loc, overflow, intMax,
- minClamped);
+ arith::FPToSIOp::create(rewriter, loc, dstTy, minClampedFP);
+ auto overflow = arith::CmpFOp::create(
+ rewriter, loc, arith::CmpFPredicate::UGE, rounded, intMaxPlusOneFP);
+ return arith::SelectOp::create(rewriter, loc, overflow, intMax,
+ minClamped);
}
// Casting to boolean, integers need to only be checked as not-equal to
// zero.
if (isa<IntegerType>(srcTy) && dstTy.isInteger(1)) {
- Value zero = rewriter.create<arith::ConstantIntOp>(
- loc, 0, srcTy.getIntOrFloatBitWidth());
- return rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::ne,
- args.front(), zero);
+ Value zero = arith::ConstantIntOp::create(rewriter, loc, 0,
+ srcTy.getIntOrFloatBitWidth());
+ return arith::CmpIOp::create(rewriter, loc, arith::CmpIPredicate::ne,
+ args.front(), zero);
}
if (isa<IntegerType>(srcTy) && isa<IntegerType>(dstTy) && bitExtend)
- return rewriter.create<arith::ExtSIOp>(loc, resultTypes, args,
- ArrayRef<NamedAttribute>());
+ return arith::ExtSIOp::create(rewriter, loc, resultTypes, args,
+ ArrayRef<NamedAttribute>());
if (isa<IntegerType>(srcTy) && isa<IntegerType>(dstTy) && !bitExtend) {
- return rewriter.create<arith::TruncIOp>(loc, dstTy, args[0]);
+ return arith::TruncIOp::create(rewriter, loc, dstTy, args[0]);
}
}
@@ -710,14 +719,14 @@ static Value createIndex(PatternRewriter &rewriter, Location loc,
auto [it, inserted] = indexPool.try_emplace(index);
if (inserted)
it->second =
- rewriter.create<arith::ConstantOp>(loc, rewriter.getIndexAttr(index));
+ arith::ConstantOp::create(rewriter, loc, rewriter.getIndexAttr(index));
return it->second;
}
static Value getTensorDim(PatternRewriter &rewriter, Location loc,
IndexPool &indexPool, Value tensor, int64_t index) {
auto indexValue = createIndex(rewriter, loc, indexPool, index);
- return rewriter.create<tensor::DimOp>(loc, tensor, indexValue).getResult();
+ return tensor::DimOp::create(rewriter, loc, tensor, indexValue).getResult();
}
static OpFoldResult getOrFoldTensorDim(PatternRewriter &rewriter, Location loc,
@@ -783,7 +792,7 @@ computeTargetSize(PatternRewriter &rewriter, Location loc, IndexPool &indexPool,
for (size_t i = 1; i < operandsWithDynamicDim.size(); i++) {
auto nextSize =
getTensorDim(rewriter, loc, indexPool, operandsWithDynamicDim[i], dim);
- targetSize = rewriter.create<arith::MaxUIOp>(loc, targetSize, nextSize);
+ targetSize = arith::MaxUIOp::create(rewriter, loc, targetSize, nextSize);
}
return {targetSize, nullptr};
}
@@ -838,8 +847,8 @@ static Value broadcastDynamicDimension(PatternRewriter &rewriter, Location loc,
// Check if broadcast is necessary
auto one = createIndex(rewriter, loc, indexPool, 1);
auto runtimeSize = getTensorDim(rewriter, loc, indexPool, operand, dim);
- auto broadcastNecessary = rewriter.create<arith::CmpIOp>(
- loc, arith::CmpIPredicate::eq, runtimeSize, one);
+ auto broadcastNecessary = arith::CmpIOp::create(
+ rewriter, loc, arith::CmpIPredicate::eq, runtimeSize, one);
// Emit 'then' region of 'scf.if'
auto emitThenRegion = [&](OpBuilder &opBuilder, Location loc) {
@@ -855,8 +864,8 @@ static Value broadcastDynamicDimension(PatternRewriter &rewriter, Location loc,
operand, index);
outputTensorShape.push_back(size);
}
- Value outputTensor = opBuilder.create<tensor::EmptyOp>(
- loc, outputTensorShape, rankedTensorType.getElementType());
+ Value outputTensor = tensor::EmptyOp::create(
+ opBuilder, loc, outputTensorShape, rankedTensorType.getElementType());
// Emit 'linalg.generic' op
auto resultTensor =
@@ -866,7 +875,7 @@ static Value broadcastDynamicDimension(PatternRewriter &rewriter, Location loc,
getNParallelLoopsAttrs(rank),
[&](OpBuilder &opBuilder, Location loc, ValueRange blockArgs) {
// Emit 'linalg.yield' op
- opBuilder.create<linalg::YieldOp>(loc, blockArgs.front());
+ linalg::YieldOp::create(opBuilder, loc, blockArgs.front());
})
.getResult(0);
@@ -875,17 +884,17 @@ static Value broadcastDynamicDimension(PatternRewriter &rewriter, Location loc,
loc, operand.getType(), resultTensor);
// Emit 'scf.yield' op
- opBuilder.create<scf::YieldOp>(loc, castResultTensor);
+ scf::YieldOp::create(opBuilder, loc, castResultTensor);
};
// Emit 'else' region of 'scf.if'
auto emitElseRegion = [&](OpBuilder &opBuilder, Location loc) {
- opBuilder.create<scf::YieldOp>(loc, operand);
+ scf::YieldOp::create(opBuilder, loc, operand);
};
// Emit 'scf.if' op
- auto ifOp = rewriter.create<scf::IfOp>(loc, broadcastNecessary,
- emitThenRegion, emitElseRegion);
+ auto ifOp = scf::IfOp::create(rewriter, loc, broadcastNecessary,
+ emitThenRegion, emitElseRegion);
return ifOp.getResult(0);
}
@@ -930,8 +939,8 @@ emitElementwiseComputation(ConversionPatternRewriter &rewriter, Location loc,
if (!resultType) {
return rewriter.notifyMatchFailure(operation, "failed to convert type");
}
- Value outputTensor = rewriter.create<tensor::EmptyOp>(
- loc, targetShape, resultType.getElementType());
+ Value outputTensor = tensor::EmptyOp::create(rewriter, loc, targetShape,
+ resultType.getElementType());
// Create affine maps. Input affine maps broadcast static dimensions of size
// 1. The output affine map is an identity map.
@@ -957,8 +966,8 @@ emitElementwiseComputation(ConversionPatternRewriter &rewriter, Location loc,
// Emit 'linalg.generic' op
bool encounteredError = false;
- auto linalgOp = rewriter.create<linalg::GenericOp>(
- loc, outputTensor.getType(), operands, outputTensor, affineMaps,
+ auto linalgOp = linalg::GenericOp::create(
+ rewriter, loc, outputTensor.getType(), operands, outputTensor, affineMaps,
getNParallelLoopsAttrs(rank),
[&](OpBuilder &opBuilder, Location loc, ValueRange blockArgs) {
Value opResult = createLinalgBodyCalculationForElementwiseOp(
@@ -968,7 +977,7 @@ emitElementwiseComputation(ConversionPatternRewriter &rewriter, Location loc,
encounteredError = true;
return;
}
- opBuilder.create<linalg::YieldOp>(loc, opResult);
+ linalg::YieldOp::create(opBuilder, loc, opResult);
});
if (encounteredError)
return rewriter.notifyMatchFailure(
@@ -1078,42 +1087,42 @@ static Value createLinalgBodyCalculationForReduceOp(Operation *op,
PatternRewriter &rewriter) {
Location loc = op->getLoc();
if (isa<tosa::ReduceSumOp>(op) && isa<FloatType>(elementTy)) {
- return rewriter.create<arith::AddFOp>(loc, args);
+ return arith::AddFOp::create(rewriter, loc, args);
}
if (isa<tosa::ReduceSumOp>(op) && isa<IntegerType>(elementTy)) {
- return rewriter.create<arith::AddIOp>(loc, args);
+ return arith::AddIOp::create(rewriter, loc, args);
}
if (isa<tosa::ReduceProductOp>(op) && isa<FloatType>(elementTy)) {
- return rewriter.create<arith::MulFOp>(loc, args);
+ return arith::MulFOp::create(rewriter, loc, args);
}
if (isa<tosa::ReduceProductOp>(op) && isa<IntegerType>(elementTy)) {
- return rewriter.create<arith::MulIOp>(loc, args);
+ return arith::MulIOp::create(rewriter, loc, args);
}
if (isa<tosa::ReduceMinOp>(op) && isa<FloatType>(elementTy)) {
- return rewriter.create<arith::MinimumFOp>(loc, args[0], args[1]);
+ return arith::MinimumFOp::create(rewriter, loc, args[0], args[1]);
}
if (isa<tosa::ReduceMinOp>(op) && isa<IntegerType>(elementTy)) {
- return rewriter.create<arith::MinSIOp>(loc, args[0], args[1]);
+ return arith::MinSIOp::create(rewriter, loc, args[0], args[1]);
}
if (isa<tosa::ReduceMaxOp>(op) && isa<FloatType>(elementTy)) {
- return rewriter.create<arith::MaximumFOp>(loc, args[0], args[1]);
+ return arith::MaximumFOp::create(rewriter, loc, args[0], args[1]);
}
if (isa<tosa::ReduceMaxOp>(op) && isa<IntegerType>(elementTy)) {
- return rewriter.create<arith::MaxSIOp>(loc, args[0], args[1]);
+ return arith::MaxSIOp::create(rewriter, loc, args[0], args[1]);
}
if (isa<tosa::ReduceAllOp>(op) && elementTy.isInteger(1))
- return rewriter.create<arith::AndIOp>(loc, args);
+ return arith::AndIOp::create(rewriter, loc, args);
if (isa<tosa::ReduceAnyOp>(op) && elementTy.isInteger(1))
- return rewriter.create<arith::OrIOp>(loc, args);
+ return arith::OrIOp::create(rewriter, loc, args);
return {};
}
@@ -1139,7 +1148,7 @@ static LogicalResult reduceMatchAndRewriteHelper(OpTy op, uint64_t axis,
if (axis != i) {
reduceShape.push_back(inputTy.getDimSize(i));
if (inputTy.isDynamicDim(i))
- dynDims.push_back(rewriter.create<tensor::DimOp>(loc, input, i));
+ dynDims.push_back(tensor::DimOp::create(rewriter, loc, input, i));
}
}
@@ -1158,7 +1167,7 @@ static LogicalResult reduceMatchAndRewriteHelper(OpTy op, uint64_t axis,
return rewriter.notifyMatchFailure(
op, "No initial value found for reduction operation");
- auto fillValue = rewriter.create<arith::ConstantOp>(loc, fillValueAttr);
+ auto fillValue = arith::ConstantOp::create(rewriter, loc, fillValueAttr);
auto filledTensor = rewriter
.create<linalg::FillOp>(loc, ValueRange{fillValue},
ValueRange{emptyTensor})
@@ -1176,7 +1185,7 @@ static LogicalResult reduceMatchAndRewriteHelper(OpTy op, uint64_t axis,
// Additionally we have to keep track of whether we've seen any non-NaN
// values and then do a final select based on this predicate.
auto trueAttr = rewriter.getBoolAttr(true);
- auto trueValue = rewriter.create<arith::ConstantOp>(loc, trueAttr);
+ auto trueValue = arith::ConstantOp::create(rewriter, loc, trueAttr);
auto emptyBoolTensor =
rewriter
.create<tensor::EmptyOp>(loc, reduceShape, trueValue.getType(),
@@ -1202,8 +1211,8 @@ static LogicalResult reduceMatchAndRewriteHelper(OpTy op, uint64_t axis,
}
bool didEncounterError = false;
- linalg::LinalgOp linalgOp = rewriter.create<linalg::ReduceOp>(
- loc, inputs, outputs, axis,
+ linalg::LinalgOp linalgOp = linalg::ReduceOp::create(
+ rewriter, loc, inputs, outputs, axis,
[&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange blockArgs) {
std::array<Value, 2> binaryArgs{
blockArgs[0], isNanIgnoreMode ? blockArgs[2] : blockArgs[1]};
@@ -1219,21 +1228,22 @@ static LogicalResult reduceMatchAndRewriteHelper(OpTy op, uint64_t axis,
auto oldAllResultsNanFlagValue = blockArgs[3];
// Unordered comparison of NaN against itself will always return true.
- Value isNaN = nestedBuilder.create<arith::CmpFOp>(
- op->getLoc(), arith::CmpFPredicate::UNO, inputValue, inputValue);
+ Value isNaN = arith::CmpFOp::create(nestedBuilder, op->getLoc(),
+ arith::CmpFPredicate::UNO,
+ inputValue, inputValue);
// If we've encountered a NaN, take the non-NaN value.
- auto selectOp = nestedBuilder.create<arith::SelectOp>(
- op->getLoc(), isNaN, initialValue, result);
+ auto selectOp = arith::SelectOp::create(nestedBuilder, op->getLoc(),
+ isNaN, initialValue, result);
// Update the flag which keeps track of whether we have seen a non-NaN
// value.
- auto newAllResultsNanFlagValue = nestedBuilder.create<arith::AndIOp>(
- op->getLoc(), oldAllResultsNanFlagValue, isNaN);
+ auto newAllResultsNanFlagValue = arith::AndIOp::create(
+ nestedBuilder, op->getLoc(), oldAllResultsNanFlagValue, isNaN);
resultsToYield.push_back(selectOp);
resultsToYield.push_back(newAllResultsNanFlagValue);
} else {
resultsToYield.push_back(result);
}
- nestedBuilder.create<linalg::YieldOp>(loc, resultsToYield);
+ linalg::YieldOp::create(nestedBuilder, loc, resultsToYield);
});
if (!didEncounterError)
@@ -1250,7 +1260,7 @@ static LogicalResult reduceMatchAndRewriteHelper(OpTy op, uint64_t axis,
auto nanValueAttr = rewriter.getFloatAttr(
elementTy,
APFloat::getNaN(cast<FloatType>(elementTy).getFloatSemantics(), false));
- auto nanValue = rewriter.create<arith::ConstantOp>(loc, nanValueAttr);
+ auto nanValue = arith::ConstantOp::create(rewriter, loc, nanValueAttr);
auto emptyNanTensor =
rewriter
.create<tensor::EmptyOp>(loc, reduceShape,
@@ -1278,7 +1288,7 @@ static LogicalResult reduceMatchAndRewriteHelper(OpTy op, uint64_t axis,
ins.push_back(linalgOp->getResult(0));
outs.push_back(finalEmptyTensor);
auto linalgSelect =
- rewriter.create<linalg::SelectOp>(op->getLoc(), ins, outs);
+ linalg::SelectOp::create(rewriter, op->getLoc(), ins, outs);
linalgOp = linalgSelect;
}
@@ -1350,7 +1360,7 @@ class RescaleConverter : public OpRewritePattern<tosa::RescaleOp> {
SmallVector<Value> dynDims;
for (int i = 0; i < outputTy.getRank(); i++) {
if (outputTy.isDynamicDim(i)) {
- dynDims.push_back(rewriter.create<tensor::DimOp>(loc, input, i));
+ dynDims.push_back(tensor::DimOp::create(rewriter, loc, input, i));
}
}
@@ -1401,16 +1411,17 @@ class RescaleConverter : public OpRewritePattern<tosa::RescaleOp> {
Value multiplierConstant;
int64_t multiplierArg = 0;
if (multiplierValues.size() == 1) {
- multiplierConstant = rewriter.create<arith::ConstantOp>(
- loc, rewriter.getI32IntegerAttr(multiplierValues.front()));
+ multiplierConstant = arith::ConstantOp::create(
+ rewriter, loc, rewriter.getI32IntegerAttr(multiplierValues.front()));
} else {
SmallVector<AffineExpr, 2> multiplierExprs{
rewriter.getAffineDimExpr(rank - 1)};
auto multiplierType =
RankedTensorType::get({static_cast<int64_t>(multiplierValues.size())},
rewriter.getI32Type());
- genericInputs.push_back(rewriter.create<arith::ConstantOp>(
- loc, DenseIntElementsAttr::get(multiplierType, multiplierValues)));
+ genericInputs.push_back(arith::ConstantOp::create(
+ rewriter, loc,
+ DenseIntElementsAttr::get(multiplierType, multiplierValues)));
indexingMaps.push_back(AffineMap::get(/*dimCount=*/rank,
/*symbolCount=*/0, multiplierExprs,
@@ -1424,16 +1435,16 @@ class RescaleConverter : public OpRewritePattern<tosa::RescaleOp> {
Value shiftConstant;
int64_t shiftArg = 0;
if (shiftValues.size() == 1) {
- shiftConstant = rewriter.create<arith::ConstantOp>(
- loc, rewriter.getI8IntegerAttr(shiftValues.front()));
+ shiftConstant = arith::ConstantOp::create(
+ rewriter, loc, rewriter.getI8IntegerAttr(shiftValues.front()));
} else {
SmallVector<AffineExpr, 2> shiftExprs = {
rewriter.getAffineDimExpr(rank - 1)};
auto shiftType =
RankedTensorType::get({static_cast<int64_t>(shiftValues.size())},
rewriter.getIntegerType(8));
- genericInputs.push_back(rewriter.create<arith::ConstantOp>(
- loc, DenseIntElementsAttr::get(shiftType, shiftValues)));
+ genericInputs.push_back(arith::ConstantOp::create(
+ rewriter, loc, DenseIntElementsAttr::get(shiftType, shiftValues)));
indexingMaps.push_back(AffineMap::get(/*dimCount=*/rank,
/*symbolCount=*/0, shiftExprs,
rewriter.getContext()));
@@ -1444,13 +1455,13 @@ class RescaleConverter : public OpRewritePattern<tosa::RescaleOp> {
indexingMaps.push_back(rewriter.getMultiDimIdentityMap(rank));
// Construct the indexing maps needed for linalg.generic ops.
- Value emptyTensor = rewriter.create<tensor::EmptyOp>(
- loc, outputTy.getShape(), outputTy.getElementType(),
+ Value emptyTensor = tensor::EmptyOp::create(
+ rewriter, loc, outputTy.getShape(), outputTy.getElementType(),
ArrayRef<Value>({dynDims}));
- auto linalgOp = rewriter.create<linalg::GenericOp>(
- loc, outputTy, genericInputs, ValueRange{emptyTensor}, indexingMaps,
- getNParallelLoopsAttrs(rank),
+ auto linalgOp = linalg::GenericOp::create(
+ rewriter, loc, outputTy, genericInputs, ValueRange{emptyTensor},
+ indexingMaps, getNParallelLoopsAttrs(rank),
[&](OpBuilder &nestedBuilder, Location nestedLoc,
ValueRange blockArgs) {
Value value = blockArgs[0];
@@ -1466,9 +1477,10 @@ class RescaleConverter : public OpRewritePattern<tosa::RescaleOp> {
const int32_t inBitwidth = valueTy.getIntOrFloatBitWidth();
// Extend zeropoint for sub-32bits widths.
const int32_t inAttrBitwidth = inBitwidth > 32 ? inBitwidth : 32;
- auto inputZp = nestedBuilder.create<arith::ConstantOp>(
- loc, IntegerAttr::get(rewriter.getIntegerType(inAttrBitwidth),
- *maybeIZp));
+ auto inputZp = arith::ConstantOp::create(
+ nestedBuilder, loc,
+ IntegerAttr::get(rewriter.getIntegerType(inAttrBitwidth),
+ *maybeIZp));
FailureOr<int64_t> maybeOZp = op.getOutputZeroPoint();
if (failed(maybeOZp)) {
@@ -1482,9 +1494,10 @@ class RescaleConverter : public OpRewritePattern<tosa::RescaleOp> {
unsigned outBitWidth = outIntType.getWidth();
const int32_t outAttrBitwidth = 32;
assert(outBitWidth <= 32 && "Unexpected output zeropoint bitwidth");
- auto outputZp = nestedBuilder.create<arith::ConstantOp>(
- loc, IntegerAttr::get(rewriter.getIntegerType(outAttrBitwidth),
- *maybeOZp));
+ auto outputZp = arith::ConstantOp::create(
+ nestedBuilder, loc,
+ IntegerAttr::get(rewriter.getIntegerType(outAttrBitwidth),
+ *maybeOZp));
Value multiplier = multiplierConstant ? multiplierConstant
: blockArgs[multiplierArg];
@@ -1501,24 +1514,24 @@ class RescaleConverter : public OpRewritePattern<tosa::RescaleOp> {
}
if (valueTy.getIntOrFloatBitWidth() < 32) {
if (op.getInputUnsigned()) {
- value = nestedBuilder.create<arith::ExtUIOp>(
- nestedLoc, nestedBuilder.getI32Type(), value);
+ value = arith::ExtUIOp::create(nestedBuilder, nestedLoc,
+ nestedBuilder.getI32Type(), value);
} else {
- value = nestedBuilder.create<arith::ExtSIOp>(
- nestedLoc, nestedBuilder.getI32Type(), value);
+ value = arith::ExtSIOp::create(nestedBuilder, nestedLoc,
+ nestedBuilder.getI32Type(), value);
}
}
value =
- nestedBuilder.create<arith::SubIOp>(nestedLoc, value, inputZp);
+ arith::SubIOp::create(nestedBuilder, nestedLoc, value, inputZp);
- value = nestedBuilder.create<tosa::ApplyScaleOp>(
- loc, nestedBuilder.getI32Type(), value, multiplier, shift,
- roundingMode);
+ value = tosa::ApplyScaleOp::create(nestedBuilder, loc,
+ nestedBuilder.getI32Type(), value,
+ multiplier, shift, roundingMode);
// Move to the new zero-point.
value =
- nestedBuilder.create<arith::AddIOp>(nestedLoc, value, outputZp);
+ arith::AddIOp::create(nestedBuilder, nestedLoc, value, outputZp);
// Saturate to the output size.
int32_t intMin = APInt::getSignedMinValue(outBitWidth).getSExtValue();
@@ -1530,18 +1543,18 @@ class RescaleConverter : public OpRewritePattern<tosa::RescaleOp> {
intMax = APInt::getMaxValue(outBitWidth).getZExtValue();
}
- auto intMinVal = nestedBuilder.create<arith::ConstantOp>(
- loc, nestedBuilder.getI32IntegerAttr(intMin));
- auto intMaxVal = nestedBuilder.create<arith::ConstantOp>(
- loc, nestedBuilder.getI32IntegerAttr(intMax));
+ auto intMinVal = arith::ConstantOp::create(
+ nestedBuilder, loc, nestedBuilder.getI32IntegerAttr(intMin));
+ auto intMaxVal = arith::ConstantOp::create(
+ nestedBuilder, loc, nestedBuilder.getI32IntegerAttr(intMax));
value = clampIntHelper(nestedLoc, value, intMinVal, intMaxVal,
nestedBuilder, /*isUnsigned=*/false);
if (outIntType.getWidth() < 32) {
- value = nestedBuilder.create<arith::TruncIOp>(
- nestedLoc, rewriter.getIntegerType(outIntType.getWidth()),
- value);
+ value = arith::TruncIOp::create(
+ nestedBuilder, nestedLoc,
+ rewriter.getIntegerType(outIntType.getWidth()), value);
}
if (outIntType.isUnsignedInteger()) {
@@ -1550,7 +1563,7 @@ class RescaleConverter : public OpRewritePattern<tosa::RescaleOp> {
outIntType, value)
.getResult(0);
}
- nestedBuilder.create<linalg::YieldOp>(loc, value);
+ linalg::YieldOp::create(nestedBuilder, loc, value);
});
rewriter.replaceOp(op, linalgOp->getResults());
@@ -1608,48 +1621,49 @@ class ResizeUnaryConverter : public OpRewritePattern<tosa::ResizeOp> {
auto collapseTy =
RankedTensorType::get({inputTy.getDimSize(0), inputTy.getDimSize(3)},
inputTy.getElementType());
- Value collapse = builder.create<tensor::CollapseShapeOp>(collapseTy, input,
- reassociationMap);
+ Value collapse = tensor::CollapseShapeOp::create(builder, collapseTy, input,
+ reassociationMap);
// Get any dynamic shapes that appear in the input format.
llvm::SmallVector<Value> outputDynSize;
if (inputTy.isDynamicDim(0))
- outputDynSize.push_back(builder.create<tensor::DimOp>(input, 0));
+ outputDynSize.push_back(tensor::DimOp::create(builder, input, 0));
if (inputTy.isDynamicDim(3))
- outputDynSize.push_back(builder.create<tensor::DimOp>(input, 3));
+ outputDynSize.push_back(tensor::DimOp::create(builder, input, 3));
// Generate the elementwise operation for casting scaling the input value.
auto genericTy = collapseTy.clone(resultTy.getElementType());
- Value empty = builder.create<tensor::EmptyOp>(
- genericTy.getShape(), resultTy.getElementType(), outputDynSize);
+ Value empty =
+ tensor::EmptyOp::create(builder, genericTy.getShape(),
+ resultTy.getElementType(), outputDynSize);
auto genericMap = rewriter.getMultiDimIdentityMap(genericTy.getRank());
SmallVector<utils::IteratorType> iterators(genericTy.getRank(),
utils::IteratorType::parallel);
- auto generic = builder.create<linalg::GenericOp>(
- genericTy, ValueRange{collapse}, ValueRange{empty},
+ auto generic = linalg::GenericOp::create(
+ builder, genericTy, ValueRange{collapse}, ValueRange{empty},
ArrayRef<AffineMap>{genericMap, genericMap}, iterators,
[=](OpBuilder &b, Location loc, ValueRange args) {
Value value = args[0];
// This is the quantized case.
if (inputTy.getElementType() != resultTy.getElementType()) {
- value =
- b.create<arith::ExtSIOp>(loc, resultTy.getElementType(), value);
+ value = arith::ExtSIOp::create(b, loc, resultTy.getElementType(),
+ value);
if (isBilinear && scale[0] != 0) {
- Value scaleY = b.create<arith::ConstantOp>(
- loc, b.getI32IntegerAttr(scale[0]));
- value = b.create<arith::MulIOp>(loc, value, scaleY);
+ Value scaleY = arith::ConstantOp::create(
+ b, loc, b.getI32IntegerAttr(scale[0]));
+ value = arith::MulIOp::create(b, loc, value, scaleY);
}
if (isBilinear && scale[2] != 0) {
- Value scaleX = b.create<arith::ConstantOp>(
- loc, b.getI32IntegerAttr(scale[2]));
- value = b.create<arith::MulIOp>(loc, value, scaleX);
+ Value scaleX = arith::ConstantOp::create(
+ b, loc, b.getI32IntegerAttr(scale[2]));
+ value = arith::MulIOp::create(b, loc, value, scaleX);
}
}
- b.create<linalg::YieldOp>(loc, value);
+ linalg::YieldOp::create(b, loc, value);
});
rewriter.replaceOpWithNewOp<tensor::ExpandShapeOp>(
@@ -1697,9 +1711,9 @@ class MaterializeResizeBroadcast : public OpRewritePattern<tosa::ResizeOp> {
resizeShape.push_back(channels);
auto resizeTy = resultTy.clone(resizeShape);
- auto resize = builder.create<tosa::ResizeOp>(resizeTy, input, op.getScale(),
- op.getOffset(), op.getBorder(),
- op.getMode());
+ auto resize =
+ tosa::ResizeOp::create(builder, resizeTy, input, op.getScale(),
+ op.getOffset(), op.getBorder(), op.getMode());
// Collapse an unit result dims.
SmallVector<ReassociationExprs, 4> reassociationMap(2);
@@ -1720,20 +1734,20 @@ class MaterializeResizeBroadcast : public OpRewritePattern<tosa::ResizeOp> {
collapseShape.push_back(channels);
auto collapseTy = resultTy.clone(collapseShape);
- Value collapse = builder.create<tensor::CollapseShapeOp>(collapseTy, resize,
- reassociationMap);
+ Value collapse = tensor::CollapseShapeOp::create(builder, collapseTy,
+ resize, reassociationMap);
// Broadcast the collapsed shape to the output result.
llvm::SmallVector<Value> outputDynSize;
if (inputTy.isDynamicDim(0))
- outputDynSize.push_back(builder.create<tensor::DimOp>(input, 0));
+ outputDynSize.push_back(tensor::DimOp::create(builder, input, 0));
if (inputTy.isDynamicDim(3))
- outputDynSize.push_back(builder.create<tensor::DimOp>(input, 3));
+ outputDynSize.push_back(tensor::DimOp::create(builder, input, 3));
SmallVector<utils::IteratorType> iterators(resultTy.getRank(),
utils::IteratorType::parallel);
- Value empty = builder.create<tensor::EmptyOp>(
- resultTy.getShape(), resultTy.getElementType(), outputDynSize);
+ Value empty = tensor::EmptyOp::create(
+ builder, resultTy.getShape(), resultTy.getElementType(), outputDynSize);
SmallVector<AffineExpr, 4> inputExprs{rewriter.getAffineDimExpr(0)};
if (inputH != 1)
@@ -1751,7 +1765,7 @@ class MaterializeResizeBroadcast : public OpRewritePattern<tosa::ResizeOp> {
ArrayRef<AffineMap>{inputMap, outputMap}, iterators,
[=](OpBuilder &b, Location loc, ValueRange args) {
Value value = args[0];
- b.create<linalg::YieldOp>(loc, value);
+ linalg::YieldOp::create(b, loc, value);
});
return success();
@@ -1789,10 +1803,10 @@ class GenericResizeConverter : public OpRewritePattern<tosa::ResizeOp> {
SmallVector<AffineMap, 2> affineMaps = {
rewriter.getMultiDimIdentityMap(resultTy.getRank())};
- auto emptyTensor = b.create<tensor::EmptyOp>(resultTy.getShape(), resultETy,
- *dynamicDimsOr);
- auto genericOp = b.create<linalg::GenericOp>(
- resultTy, ValueRange({}), ValueRange{emptyTensor}, affineMaps,
+ auto emptyTensor = tensor::EmptyOp::create(b, resultTy.getShape(),
+ resultETy, *dynamicDimsOr);
+ auto genericOp = linalg::GenericOp::create(
+ b, resultTy, ValueRange({}), ValueRange{emptyTensor}, affineMaps,
getNParallelLoopsAttrs(resultTy.getRank()));
Value resize = genericOp.getResult(0);
@@ -1800,19 +1814,21 @@ class GenericResizeConverter : public OpRewritePattern<tosa::ResizeOp> {
OpBuilder::InsertionGuard regionGuard(b);
b.createBlock(&genericOp.getRegion(), genericOp.getRegion().end(),
TypeRange({resultETy}), loc);
- Value batch = b.create<linalg::IndexOp>(0);
- Value y = b.create<linalg::IndexOp>(1);
- Value x = b.create<linalg::IndexOp>(2);
- Value channel = b.create<linalg::IndexOp>(3);
+ Value batch = linalg::IndexOp::create(b, 0);
+ Value y = linalg::IndexOp::create(b, 1);
+ Value x = linalg::IndexOp::create(b, 2);
+ Value channel = linalg::IndexOp::create(b, 3);
Value zeroI32 =
- b.create<arith::ConstantOp>(b.getZeroAttr(b.getI32Type()));
- Value zeroFp = b.create<arith::ConstantOp>(b.getZeroAttr(floatTy));
- Value hMax = b.create<arith::ConstantOp>(b.getI32IntegerAttr(imageH - 1));
- Value wMax = b.create<arith::ConstantOp>(b.getI32IntegerAttr(imageW - 1));
+ arith::ConstantOp::create(b, b.getZeroAttr(b.getI32Type()));
+ Value zeroFp = arith::ConstantOp::create(b, b.getZeroAttr(floatTy));
+ Value hMax =
+ arith::ConstantOp::create(b, b.getI32IntegerAttr(imageH - 1));
+ Value wMax =
+ arith::ConstantOp::create(b, b.getI32IntegerAttr(imageW - 1));
- Value inY = b.create<arith::IndexCastOp>(b.getI32Type(), y);
- Value inX = b.create<arith::IndexCastOp>(b.getI32Type(), x);
+ Value inY = arith::IndexCastOp::create(b, b.getI32Type(), y);
+ Value inX = arith::IndexCastOp::create(b, b.getI32Type(), x);
SmallVector<int64_t> scale, offset, border;
if (!tosa::getConstShapeValues(op.getScale().getDefiningOp(), scale) ||
@@ -1824,16 +1840,16 @@ class GenericResizeConverter : public OpRewritePattern<tosa::ResizeOp> {
}
Value yScaleN, yScaleD, xScaleN, xScaleD;
- yScaleN = b.create<arith::ConstantOp>(b.getI32IntegerAttr(scale[0]));
- yScaleD = b.create<arith::ConstantOp>(b.getI32IntegerAttr(scale[1]));
- xScaleN = b.create<arith::ConstantOp>(b.getI32IntegerAttr(scale[2]));
- xScaleD = b.create<arith::ConstantOp>(b.getI32IntegerAttr(scale[3]));
+ yScaleN = arith::ConstantOp::create(b, b.getI32IntegerAttr(scale[0]));
+ yScaleD = arith::ConstantOp::create(b, b.getI32IntegerAttr(scale[1]));
+ xScaleN = arith::ConstantOp::create(b, b.getI32IntegerAttr(scale[2]));
+ xScaleD = arith::ConstantOp::create(b, b.getI32IntegerAttr(scale[3]));
Value yOffset, xOffset, yBorder, xBorder;
- yOffset = b.create<arith::ConstantOp>(b.getI32IntegerAttr(offset[0]));
- xOffset = b.create<arith::ConstantOp>(b.getI32IntegerAttr(offset[1]));
- yBorder = b.create<arith::ConstantOp>(b.getI32IntegerAttr(border[0]));
- xBorder = b.create<arith::ConstantOp>(b.getI32IntegerAttr(border[1]));
+ yOffset = arith::ConstantOp::create(b, b.getI32IntegerAttr(offset[0]));
+ xOffset = arith::ConstantOp::create(b, b.getI32IntegerAttr(offset[1]));
+ yBorder = arith::ConstantOp::create(b, b.getI32IntegerAttr(border[0]));
+ xBorder = arith::ConstantOp::create(b, b.getI32IntegerAttr(border[1]));
// Compute the ix and dx values for both the X and Y dimensions.
auto getIndexAndDeltaFp = [&](Value &index, Value &delta, Value in,
@@ -1846,16 +1862,16 @@ class GenericResizeConverter : public OpRewritePattern<tosa::ResizeOp> {
}
// x = x * scale_d + offset;
// ix = floor(x / scale_n)
- Value val = b.create<arith::MulIOp>(in, scaleD);
- val = b.create<arith::AddIOp>(val, offset);
- index = b.create<arith::FloorDivSIOp>(val, scaleN);
+ Value val = arith::MulIOp::create(b, in, scaleD);
+ val = arith::AddIOp::create(b, val, offset);
+ index = arith::FloorDivSIOp::create(b, val, scaleN);
// rx = x % scale_n
// dx = rx / scale_n
- Value r = b.create<arith::RemSIOp>(val, scaleN);
- Value rFp = b.create<arith::SIToFPOp>(floatTy, r);
- Value scaleNfp = b.create<arith::UIToFPOp>(floatTy, scaleN);
- delta = b.create<arith::DivFOp>(rFp, scaleNfp);
+ Value r = arith::RemSIOp::create(b, val, scaleN);
+ Value rFp = arith::SIToFPOp::create(b, floatTy, r);
+ Value scaleNfp = arith::UIToFPOp::create(b, floatTy, scaleN);
+ delta = arith::DivFOp::create(b, rFp, scaleNfp);
};
// Compute the ix and dx values for the X and Y dimensions - int case.
@@ -1870,11 +1886,11 @@ class GenericResizeConverter : public OpRewritePattern<tosa::ResizeOp> {
// x = x * scale_d + offset;
// ix = floor(x / scale_n)
// dx = x - ix * scale_n;
- Value val = b.create<arith::MulIOp>(in, scaleD);
- val = b.create<arith::AddIOp>(val, offset);
- index = b.create<arith::DivSIOp>(val, scaleN);
- delta = b.create<arith::MulIOp>(index, scaleN);
- delta = b.create<arith::SubIOp>(val, delta);
+ Value val = arith::MulIOp::create(b, in, scaleD);
+ val = arith::AddIOp::create(b, val, offset);
+ index = arith::DivSIOp::create(b, val, scaleN);
+ delta = arith::MulIOp::create(b, index, scaleN);
+ delta = arith::SubIOp::create(b, val, delta);
};
Value ix, iy, dx, dy;
@@ -1887,54 +1903,55 @@ class GenericResizeConverter : public OpRewritePattern<tosa::ResizeOp> {
}
if (op.getMode() == "NEAREST_NEIGHBOR") {
- auto one = b.create<arith::ConstantOp>(b.getI32IntegerAttr(1));
+ auto one = arith::ConstantOp::create(b, b.getI32IntegerAttr(1));
auto getNearestIndexAndClamp = [&](Value val, Value dval, Value scale,
Value max, int size,
ImplicitLocOpBuilder &b) -> Value {
if (size == 1) {
- return b.create<arith::ConstantIndexOp>(0);
+ return arith::ConstantIndexOp::create(b, 0);
}
Value pred;
if (floatingPointMode) {
- auto h = b.create<arith::ConstantOp>(b.getFloatAttr(floatTy, 0.5f));
- pred = b.create<arith::CmpFOp>(arith::CmpFPredicate::OGE, dval, h);
+ auto h =
+ arith::ConstantOp::create(b, b.getFloatAttr(floatTy, 0.5f));
+ pred = arith::CmpFOp::create(b, arith::CmpFPredicate::OGE, dval, h);
} else {
- Value dvalDouble = b.create<arith::ShLIOp>(dval, one);
- pred = b.create<arith::CmpIOp>(arith::CmpIPredicate::sge,
- dvalDouble, scale);
+ Value dvalDouble = arith::ShLIOp::create(b, dval, one);
+ pred = arith::CmpIOp::create(b, arith::CmpIPredicate::sge,
+ dvalDouble, scale);
}
- auto offset = b.create<arith::SelectOp>(pred, one, zeroI32);
- val = b.create<arith::AddIOp>(val, offset);
+ auto offset = arith::SelectOp::create(b, pred, one, zeroI32);
+ val = arith::AddIOp::create(b, val, offset);
val = clampIntHelper(loc, val, zeroI32, max, b, /*isUnsigned=*/false);
- return b.create<arith::IndexCastOp>(b.getIndexType(), val);
+ return arith::IndexCastOp::create(b, b.getIndexType(), val);
};
iy = getNearestIndexAndClamp(iy, dy, yScaleN, hMax, imageH, b);
ix = getNearestIndexAndClamp(ix, dx, xScaleN, wMax, imageW, b);
- Value result = b.create<tensor::ExtractOp>(
- input, ValueRange{batch, iy, ix, channel});
+ Value result = tensor::ExtractOp::create(
+ b, input, ValueRange{batch, iy, ix, channel});
- b.create<linalg::YieldOp>(result);
+ linalg::YieldOp::create(b, result);
} else {
// The mode here must be BILINEAR.
assert(op.getMode() == "BILINEAR");
- auto oneVal = b.create<arith::ConstantOp>(b.getI32IntegerAttr(1));
+ auto oneVal = arith::ConstantOp::create(b, b.getI32IntegerAttr(1));
auto getClampedIdxs = [&](Value &val0, Value &val1, int size, Value in,
Value max, ImplicitLocOpBuilder &b) {
val0 = in;
- val1 = b.create<arith::AddIOp>(val0, oneVal);
+ val1 = arith::AddIOp::create(b, val0, oneVal);
val0 =
clampIntHelper(loc, val0, zeroI32, max, b, /*isUnsigned=*/false);
val1 =
clampIntHelper(loc, val1, zeroI32, max, b, /*isUnsigned=*/false);
- val0 = b.create<arith::IndexCastOp>(b.getIndexType(), val0);
- val1 = b.create<arith::IndexCastOp>(b.getIndexType(), val1);
+ val0 = arith::IndexCastOp::create(b, b.getIndexType(), val0);
+ val1 = arith::IndexCastOp::create(b, b.getIndexType(), val1);
};
// Linalg equivalent to the section below:
@@ -1946,27 +1963,27 @@ class GenericResizeConverter : public OpRewritePattern<tosa::ResizeOp> {
getClampedIdxs(y0, y1, imageH, iy, hMax, b);
getClampedIdxs(x0, x1, imageW, ix, wMax, b);
- Value y0x0 = b.create<tensor::ExtractOp>(
- input, ValueRange{batch, y0, x0, channel});
- Value y0x1 = b.create<tensor::ExtractOp>(
- input, ValueRange{batch, y0, x1, channel});
- Value y1x0 = b.create<tensor::ExtractOp>(
- input, ValueRange{batch, y1, x0, channel});
- Value y1x1 = b.create<tensor::ExtractOp>(
- input, ValueRange{batch, y1, x1, channel});
+ Value y0x0 = tensor::ExtractOp::create(
+ b, input, ValueRange{batch, y0, x0, channel});
+ Value y0x1 = tensor::ExtractOp::create(
+ b, input, ValueRange{batch, y0, x1, channel});
+ Value y1x0 = tensor::ExtractOp::create(
+ b, input, ValueRange{batch, y1, x0, channel});
+ Value y1x1 = tensor::ExtractOp::create(
+ b, input, ValueRange{batch, y1, x1, channel});
if (floatingPointMode) {
auto oneVal =
- b.create<arith::ConstantOp>(b.getFloatAttr(floatTy, 1.0f));
+ arith::ConstantOp::create(b, b.getFloatAttr(floatTy, 1.0f));
auto interpolate = [&](Value val0, Value val1, Value delta,
int inputSize,
ImplicitLocOpBuilder &b) -> Value {
if (inputSize == 1)
return val0;
- Value oneMinusDelta = b.create<arith::SubFOp>(oneVal, delta);
- Value mul0 = b.create<arith::MulFOp>(val0, oneMinusDelta);
- Value mul1 = b.create<arith::MulFOp>(val1, delta);
- return b.create<arith::AddFOp>(mul0, mul1);
+ Value oneMinusDelta = arith::SubFOp::create(b, oneVal, delta);
+ Value mul0 = arith::MulFOp::create(b, val0, oneMinusDelta);
+ Value mul1 = arith::MulFOp::create(b, val1, delta);
+ return arith::AddFOp::create(b, mul0, mul1);
};
// Linalg equivalent to the section below:
@@ -1982,18 +1999,18 @@ class GenericResizeConverter : public OpRewritePattern<tosa::ResizeOp> {
// Linalg equivalent to the section below:
// result = topAcc * (unit_y - dy) + bottomAcc * dy
Value result = interpolate(topAcc, bottomAcc, dy, imageH, b);
- b.create<linalg::YieldOp>(result);
+ linalg::YieldOp::create(b, result);
} else {
// Perform in quantized space.
- y0x0 = b.create<arith::ExtSIOp>(resultETy, y0x0);
- y0x1 = b.create<arith::ExtSIOp>(resultETy, y0x1);
- y1x0 = b.create<arith::ExtSIOp>(resultETy, y1x0);
- y1x1 = b.create<arith::ExtSIOp>(resultETy, y1x1);
+ y0x0 = arith::ExtSIOp::create(b, resultETy, y0x0);
+ y0x1 = arith::ExtSIOp::create(b, resultETy, y0x1);
+ y1x0 = arith::ExtSIOp::create(b, resultETy, y1x0);
+ y1x1 = arith::ExtSIOp::create(b, resultETy, y1x1);
const int64_t deltaBitwidth = dx.getType().getIntOrFloatBitWidth();
if (resultETy.getIntOrFloatBitWidth() > deltaBitwidth) {
- dx = b.create<arith::ExtSIOp>(resultETy, dx);
- dy = b.create<arith::ExtSIOp>(resultETy, dy);
+ dx = arith::ExtSIOp::create(b, resultETy, dx);
+ dy = arith::ExtSIOp::create(b, resultETy, dy);
}
Value yScaleNExt = yScaleN;
@@ -2002,26 +2019,26 @@ class GenericResizeConverter : public OpRewritePattern<tosa::ResizeOp> {
const int64_t scaleBitwidth =
xScaleN.getType().getIntOrFloatBitWidth();
if (resultETy.getIntOrFloatBitWidth() > scaleBitwidth) {
- yScaleNExt = b.create<arith::ExtSIOp>(resultETy, yScaleN);
- xScaleNExt = b.create<arith::ExtSIOp>(resultETy, xScaleN);
+ yScaleNExt = arith::ExtSIOp::create(b, resultETy, yScaleN);
+ xScaleNExt = arith::ExtSIOp::create(b, resultETy, xScaleN);
}
auto interpolate = [](Value val0, Value val1, Value weight1,
Value scale, int inputSize,
ImplicitLocOpBuilder &b) -> Value {
if (inputSize == 1)
- return b.create<arith::MulIOp>(val0, scale);
- Value weight0 = b.create<arith::SubIOp>(scale, weight1);
- Value mul0 = b.create<arith::MulIOp>(val0, weight0);
- Value mul1 = b.create<arith::MulIOp>(val1, weight1);
- return b.create<arith::AddIOp>(mul0, mul1);
+ return arith::MulIOp::create(b, val0, scale);
+ Value weight0 = arith::SubIOp::create(b, scale, weight1);
+ Value mul0 = arith::MulIOp::create(b, val0, weight0);
+ Value mul1 = arith::MulIOp::create(b, val1, weight1);
+ return arith::AddIOp::create(b, mul0, mul1);
};
Value topAcc = interpolate(y0x0, y0x1, dx, xScaleNExt, imageW, b);
Value bottomAcc = interpolate(y1x0, y1x1, dx, xScaleNExt, imageW, b);
Value result =
interpolate(topAcc, bottomAcc, dy, yScaleNExt, imageH, b);
- b.create<linalg::YieldOp>(result);
+ linalg::YieldOp::create(b, result);
}
}
}
@@ -2072,11 +2089,11 @@ class ReverseConverter : public OpRewritePattern<tosa::ReverseOp> {
SmallVector<Value> dynDims;
for (int i = 0; i < inputTy.getRank(); i++) {
if (inputTy.isDynamicDim(i)) {
- dynDims.push_back(rewriter.create<tensor::DimOp>(loc, input, i));
+ dynDims.push_back(tensor::DimOp::create(rewriter, loc, input, i));
}
}
- Value axisDimSize = rewriter.create<tensor::DimOp>(loc, input, axis);
+ Value axisDimSize = tensor::DimOp::create(rewriter, loc, input, axis);
// First fill the output buffer with the init value.
auto emptyTensor = rewriter
@@ -2094,22 +2111,22 @@ class ReverseConverter : public OpRewritePattern<tosa::ReverseOp> {
llvm::SmallVector<Value> indices;
for (unsigned int i = 0; i < inputTy.getRank(); i++) {
Value index =
- rewriter.create<linalg::IndexOp>(nestedLoc, i).getResult();
+ linalg::IndexOp::create(rewriter, nestedLoc, i).getResult();
if (i == axis) {
- auto one = rewriter.create<arith::ConstantIndexOp>(nestedLoc, 1);
+ auto one = arith::ConstantIndexOp::create(rewriter, nestedLoc, 1);
auto sizeMinusOne =
- rewriter.create<arith::SubIOp>(nestedLoc, axisDimSize, one);
- index = rewriter.create<arith::SubIOp>(nestedLoc, sizeMinusOne,
- index);
+ arith::SubIOp::create(rewriter, nestedLoc, axisDimSize, one);
+ index = arith::SubIOp::create(rewriter, nestedLoc, sizeMinusOne,
+ index);
}
indices.push_back(index);
}
- auto extract = nestedBuilder.create<tensor::ExtractOp>(
- nestedLoc, input, indices);
- nestedBuilder.create<linalg::YieldOp>(op.getLoc(),
- extract.getResult());
+ auto extract = tensor::ExtractOp::create(nestedBuilder, nestedLoc,
+ input, indices);
+ linalg::YieldOp::create(nestedBuilder, op.getLoc(),
+ extract.getResult());
});
return success();
}
@@ -2148,12 +2165,12 @@ struct TileConverter : public OpConversionPattern<tosa::TileOp> {
SmallVector<Value> dynDims;
for (int i = 0; i < inputTy.getRank(); i++) {
if (inputTy.isDynamicDim(i) || multiples[i] == -1) {
- dynDims.push_back(rewriter.create<tensor::DimOp>(loc, input, i));
+ dynDims.push_back(tensor::DimOp::create(rewriter, loc, input, i));
}
}
- auto emptyTensor = rewriter.create<tensor::EmptyOp>(
- op.getLoc(), genericShape, elementTy, dynDims);
+ auto emptyTensor = tensor::EmptyOp::create(
+ rewriter, op.getLoc(), genericShape, elementTy, dynDims);
// We needs to map the input shape to the non-broadcasted dimensions.
SmallVector<AffineExpr, 4> dimExprs;
@@ -2168,12 +2185,12 @@ struct TileConverter : public OpConversionPattern<tosa::TileOp> {
SmallVector<AffineMap, 2> affineMaps = {
readAffineMap, rewriter.getMultiDimIdentityMap(genericShape.size())};
- auto genericOp = rewriter.create<linalg::GenericOp>(
- loc, RankedTensorType::get(genericShape, elementTy), input,
+ auto genericOp = linalg::GenericOp::create(
+ rewriter, loc, RankedTensorType::get(genericShape, elementTy), input,
ValueRange{emptyTensor}, affineMaps,
getNParallelLoopsAttrs(genericShape.size()),
[&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange args) {
- nestedBuilder.create<linalg::YieldOp>(op.getLoc(), *args.begin());
+ linalg::YieldOp::create(nestedBuilder, op.getLoc(), *args.begin());
});
auto shapeValue = getTosaConstShape(
@@ -2220,7 +2237,7 @@ class ArgMaxConverter : public OpRewritePattern<tosa::ArgMaxOp> {
SmallVector<Value> dynDims;
for (int i = 0; i < inputTy.getRank(); i++) {
if (inputTy.isDynamicDim(i) && i != axis) {
- dynDims.push_back(rewriter.create<tensor::DimOp>(loc, input, i));
+ dynDims.push_back(tensor::DimOp::create(rewriter, loc, input, i));
}
}
@@ -2229,8 +2246,8 @@ class ArgMaxConverter : public OpRewritePattern<tosa::ArgMaxOp> {
.create<tensor::EmptyOp>(loc, resultTy.getShape(),
outElementTy, dynDims)
.getResult();
- auto fillValueIdx = rewriter.create<arith::ConstantOp>(
- loc, rewriter.getIntegerAttr(outElementTy, 0));
+ auto fillValueIdx = arith::ConstantOp::create(
+ rewriter, loc, rewriter.getIntegerAttr(outElementTy, 0));
auto filledTensorIdx =
rewriter
.create<linalg::FillOp>(loc, ValueRange{fillValueIdx},
@@ -2250,7 +2267,7 @@ class ArgMaxConverter : public OpRewritePattern<tosa::ArgMaxOp> {
argmaxOp, "unsupported tosa.argmax element type");
auto fillValueMax =
- rewriter.create<arith::ConstantOp>(loc, fillValueMaxAttr);
+ arith::ConstantOp::create(rewriter, loc, fillValueMaxAttr);
auto filledTensorMax =
rewriter
.create<linalg::FillOp>(loc, ValueRange{fillValueMax},
@@ -2274,8 +2291,8 @@ class ArgMaxConverter : public OpRewritePattern<tosa::ArgMaxOp> {
bool didEncounterError = false;
auto maps = AffineMap::inferFromExprList({srcExprs, dstExprs, dstExprs},
rewriter.getContext());
- auto linalgOp = rewriter.create<linalg::GenericOp>(
- loc, ArrayRef<Type>({resultTy, resultMaxTy}), input,
+ auto linalgOp = linalg::GenericOp::create(
+ rewriter, loc, ArrayRef<Type>({resultTy, resultMaxTy}), input,
ValueRange({filledTensorIdx, filledTensorMax}), maps, iteratorTypes,
[&](OpBuilder &nestedBuilder, Location nestedLoc,
ValueRange blockArgs) {
@@ -2283,42 +2300,46 @@ class ArgMaxConverter : public OpRewritePattern<tosa::ArgMaxOp> {
auto oldIndex = blockArgs[1];
auto oldValue = blockArgs[2];
- Value newIndex = rewriter.create<arith::IndexCastOp>(
- nestedLoc, oldIndex.getType(),
- rewriter.create<linalg::IndexOp>(loc, axis));
+ Value newIndex = arith::IndexCastOp::create(
+ rewriter, nestedLoc, oldIndex.getType(),
+ linalg::IndexOp::create(rewriter, loc, axis));
Value predicate;
if (isa<FloatType>(inElementTy)) {
if (argmaxOp.getNanMode() == "IGNORE") {
// Only update index & max value for non NaN values. If all
// values are NaNs, the initial index will be return which is 0.
- predicate = rewriter.create<arith::CmpFOp>(
- nestedLoc, arith::CmpFPredicate::OGT, newValue, oldValue);
+ predicate = arith::CmpFOp::create(rewriter, nestedLoc,
+ arith::CmpFPredicate::OGT,
+ newValue, oldValue);
} else {
// Update max value if either of the following is true:
// - new value is bigger
// - cur max is not NaN and new value is NaN
- Value gt = rewriter.create<arith::CmpFOp>(
- nestedLoc, arith::CmpFPredicate::UGT, newValue, oldValue);
- Value oldNonNaN = rewriter.create<arith::CmpFOp>(
- nestedLoc, arith::CmpFPredicate::ORD, oldValue, oldValue);
- predicate = rewriter.create<arith::AndIOp>(
- nestedLoc, rewriter.getI1Type(), gt, oldNonNaN);
+ Value gt = arith::CmpFOp::create(rewriter, nestedLoc,
+ arith::CmpFPredicate::UGT,
+ newValue, oldValue);
+ Value oldNonNaN = arith::CmpFOp::create(rewriter, nestedLoc,
+ arith::CmpFPredicate::ORD,
+ oldValue, oldValue);
+ predicate = arith::AndIOp::create(
+ rewriter, nestedLoc, rewriter.getI1Type(), gt, oldNonNaN);
}
} else if (isa<IntegerType>(inElementTy)) {
- predicate = rewriter.create<arith::CmpIOp>(
- nestedLoc, arith::CmpIPredicate::sgt, newValue, oldValue);
+ predicate = arith::CmpIOp::create(rewriter, nestedLoc,
+ arith::CmpIPredicate::sgt,
+ newValue, oldValue);
} else {
didEncounterError = true;
return;
}
- auto resultMax = rewriter.create<arith::SelectOp>(
- nestedLoc, predicate, newValue, oldValue);
- auto resultIndex = rewriter.create<arith::SelectOp>(
- nestedLoc, predicate, newIndex, oldIndex);
- nestedBuilder.create<linalg::YieldOp>(
- nestedLoc, ValueRange({resultIndex, resultMax}));
+ auto resultMax = arith::SelectOp::create(
+ rewriter, nestedLoc, predicate, newValue, oldValue);
+ auto resultIndex = arith::SelectOp::create(
+ rewriter, nestedLoc, predicate, newIndex, oldIndex);
+ linalg::YieldOp::create(nestedBuilder, nestedLoc,
+ ValueRange({resultIndex, resultMax}));
});
if (didEncounterError)
@@ -2363,19 +2384,19 @@ class GatherConverter : public OpConversionPattern<tosa::GatherOp> {
rewriter.getContext()),
rewriter.getMultiDimIdentityMap(resultTy.getRank())};
- auto genericOp = rewriter.create<linalg::GenericOp>(
- loc, ArrayRef<Type>({resultTy}), ValueRange{indices},
+ auto genericOp = linalg::GenericOp::create(
+ rewriter, loc, ArrayRef<Type>({resultTy}), ValueRange{indices},
ValueRange{emptyTensor}, affineMaps,
getNParallelLoopsAttrs(resultTy.getRank()),
[&](OpBuilder &b, Location loc, ValueRange args) {
auto indexValue = args[0];
- auto index0 = rewriter.create<linalg::IndexOp>(loc, 0);
- Value index1 = rewriter.create<arith::IndexCastOp>(
- loc, rewriter.getIndexType(), indexValue);
- auto index2 = rewriter.create<linalg::IndexOp>(loc, 2);
- Value extract = rewriter.create<tensor::ExtractOp>(
- loc, input, ValueRange{index0, index1, index2});
- rewriter.create<linalg::YieldOp>(loc, extract);
+ auto index0 = linalg::IndexOp::create(rewriter, loc, 0);
+ Value index1 = arith::IndexCastOp::create(
+ rewriter, loc, rewriter.getIndexType(), indexValue);
+ auto index2 = linalg::IndexOp::create(rewriter, loc, 2);
+ Value extract = tensor::ExtractOp::create(
+ rewriter, loc, input, ValueRange{index0, index1, index2});
+ linalg::YieldOp::create(rewriter, loc, extract);
});
rewriter.replaceOp(op, genericOp.getResult(0));
return success();
@@ -2424,7 +2445,7 @@ class TableConverter : public OpRewritePattern<tosa::TableOp> {
for (int i = 0; i < resultTy.getRank(); ++i) {
if (inputTy.isDynamicDim(i)) {
dynDims.push_back(
- rewriter.create<tensor::DimOp>(loc, op.getOperand(0), i));
+ tensor::DimOp::create(rewriter, loc, op.getOperand(0), i));
}
}
@@ -2437,9 +2458,9 @@ class TableConverter : public OpRewritePattern<tosa::TableOp> {
rewriter.getMultiDimIdentityMap(resultTy.getRank()),
rewriter.getMultiDimIdentityMap(resultTy.getRank())};
- auto genericOp = rewriter.create<linalg::GenericOp>(
- loc, resultTy, ValueRange({input}), ValueRange{emptyTensor}, affineMaps,
- getNParallelLoopsAttrs(resultTy.getRank()));
+ auto genericOp = linalg::GenericOp::create(
+ rewriter, loc, resultTy, ValueRange({input}), ValueRange{emptyTensor},
+ affineMaps, getNParallelLoopsAttrs(resultTy.getRank()));
rewriter.replaceOp(op, genericOp.getResult(0));
{
@@ -2452,69 +2473,69 @@ class TableConverter : public OpRewritePattern<tosa::TableOp> {
rewriter.setInsertionPointToStart(block);
if (inputElementTy.isInteger(8) && tableElementTy.isInteger(8) &&
resultElementTy.isInteger(8)) {
- Value index = rewriter.create<arith::IndexCastOp>(
- loc, rewriter.getIndexType(), inputValue);
- Value offset = rewriter.create<arith::ConstantIndexOp>(loc, 128);
- index = rewriter.create<arith::AddIOp>(loc, rewriter.getIndexType(),
- index, offset);
+ Value index = arith::IndexCastOp::create(
+ rewriter, loc, rewriter.getIndexType(), inputValue);
+ Value offset = arith::ConstantIndexOp::create(rewriter, loc, 128);
+ index = arith::AddIOp::create(rewriter, loc, rewriter.getIndexType(),
+ index, offset);
Value extract =
- rewriter.create<tensor::ExtractOp>(loc, table, ValueRange{index});
- rewriter.create<linalg::YieldOp>(loc, extract);
+ tensor::ExtractOp::create(rewriter, loc, table, ValueRange{index});
+ linalg::YieldOp::create(rewriter, loc, extract);
return success();
}
if (inputElementTy.isInteger(16) && tableElementTy.isInteger(16) &&
resultElementTy.isInteger(32)) {
- Value extend = rewriter.create<arith::ExtSIOp>(
- loc, rewriter.getI32Type(), inputValue);
-
- auto offset = rewriter.create<arith::ConstantOp>(
- loc, rewriter.getI32IntegerAttr(32768));
- auto seven = rewriter.create<arith::ConstantOp>(
- loc, rewriter.getI32IntegerAttr(7));
- auto one = rewriter.create<arith::ConstantOp>(
- loc, rewriter.getI32IntegerAttr(1));
- auto b1111111 = rewriter.create<arith::ConstantOp>(
- loc, rewriter.getI32IntegerAttr(127));
+ Value extend = arith::ExtSIOp::create(
+ rewriter, loc, rewriter.getI32Type(), inputValue);
+
+ auto offset = arith::ConstantOp::create(
+ rewriter, loc, rewriter.getI32IntegerAttr(32768));
+ auto seven = arith::ConstantOp::create(rewriter, loc,
+ rewriter.getI32IntegerAttr(7));
+ auto one = arith::ConstantOp::create(rewriter, loc,
+ rewriter.getI32IntegerAttr(1));
+ auto b1111111 = arith::ConstantOp::create(
+ rewriter, loc, rewriter.getI32IntegerAttr(127));
// Compute the index and fractional part from the input value:
// value = value + 32768
// index = value >> 7;
// fraction = 0x01111111 & value
- auto extendAdd = rewriter.create<arith::AddIOp>(loc, extend, offset);
- Value index = rewriter.create<arith::ShRUIOp>(loc, extendAdd, seven);
+ auto extendAdd = arith::AddIOp::create(rewriter, loc, extend, offset);
+ Value index = arith::ShRUIOp::create(rewriter, loc, extendAdd, seven);
Value fraction =
- rewriter.create<arith::AndIOp>(loc, extendAdd, b1111111);
+ arith::AndIOp::create(rewriter, loc, extendAdd, b1111111);
// Extract the base and next values from the table.
// base = (int32_t) table[index];
// next = (int32_t) table[index + 1];
- Value indexPlusOne = rewriter.create<arith::AddIOp>(loc, index, one);
+ Value indexPlusOne = arith::AddIOp::create(rewriter, loc, index, one);
- index = rewriter.create<arith::IndexCastOp>(
- loc, rewriter.getIndexType(), index);
- indexPlusOne = rewriter.create<arith::IndexCastOp>(
- loc, rewriter.getIndexType(), indexPlusOne);
+ index = arith::IndexCastOp::create(rewriter, loc,
+ rewriter.getIndexType(), index);
+ indexPlusOne = arith::IndexCastOp::create(
+ rewriter, loc, rewriter.getIndexType(), indexPlusOne);
Value base =
- rewriter.create<tensor::ExtractOp>(loc, table, ValueRange{index});
- Value next = rewriter.create<tensor::ExtractOp>(
- loc, table, ValueRange{indexPlusOne});
+ tensor::ExtractOp::create(rewriter, loc, table, ValueRange{index});
+ Value next = tensor::ExtractOp::create(rewriter, loc, table,
+ ValueRange{indexPlusOne});
base =
- rewriter.create<arith::ExtSIOp>(loc, rewriter.getI32Type(), base);
+ arith::ExtSIOp::create(rewriter, loc, rewriter.getI32Type(), base);
next =
- rewriter.create<arith::ExtSIOp>(loc, rewriter.getI32Type(), next);
+ arith::ExtSIOp::create(rewriter, loc, rewriter.getI32Type(), next);
// Use the fractional part to interpolate between the input values:
// result = (base << 7) + (next - base) * fraction
- Value baseScaled = rewriter.create<arith::ShLIOp>(loc, base, seven);
- Value
diff = rewriter.create<arith::SubIOp>(loc, next, base);
- Value
diff Scaled = rewriter.create<arith::MulIOp>(loc,
diff , fraction);
+ Value baseScaled = arith::ShLIOp::create(rewriter, loc, base, seven);
+ Value
diff = arith::SubIOp::create(rewriter, loc, next, base);
+ Value
diff Scaled = arith::MulIOp::create(rewriter, loc,
diff , fraction);
Value result =
- rewriter.create<arith::AddIOp>(loc, baseScaled,
diff Scaled);
+ arith::AddIOp::create(rewriter, loc, baseScaled,
diff Scaled);
- rewriter.create<linalg::YieldOp>(loc, result);
+ linalg::YieldOp::create(rewriter, loc, result);
return success();
}
@@ -2532,8 +2553,8 @@ struct RFFT2dConverter final : public OpRewritePattern<RFFT2dOp> {
static OpFoldResult halfPlusOne(OpBuilder &builder, Location loc,
OpFoldResult ofr) {
- auto one = builder.create<arith::ConstantIndexOp>(loc, 1);
- auto two = builder.create<arith::ConstantIndexOp>(loc, 2);
+ auto one = arith::ConstantIndexOp::create(builder, loc, 1);
+ auto two = arith::ConstantIndexOp::create(builder, loc, 2);
auto value = getValueOrCreateConstantIndexOp(builder, loc, ofr);
auto divBy2 = builder.createOrFold<arith::DivUIOp>(loc, value, two);
@@ -2562,9 +2583,9 @@ struct RFFT2dConverter final : public OpRewritePattern<RFFT2dOp> {
RankedTensorType type,
llvm::ArrayRef<Value> dynamicSizes) {
auto emptyTensor =
- rewriter.create<tensor::EmptyOp>(loc, type, dynamicSizes);
+ tensor::EmptyOp::create(rewriter, loc, type, dynamicSizes);
auto fillValueAttr = rewriter.getZeroAttr(type.getElementType());
- auto fillValue = rewriter.create<arith::ConstantOp>(loc, fillValueAttr);
+ auto fillValue = arith::ConstantOp::create(rewriter, loc, fillValueAttr);
auto filledTensor = rewriter
.create<linalg::FillOp>(loc, ValueRange{fillValue},
ValueRange{emptyTensor})
@@ -2574,18 +2595,18 @@ struct RFFT2dConverter final : public OpRewritePattern<RFFT2dOp> {
static Value castIndexToFloat(OpBuilder &builder, Location loc,
FloatType type, Value value) {
- auto integerVal = builder.create<arith::IndexCastUIOp>(
- loc,
+ auto integerVal = arith::IndexCastUIOp::create(
+ builder, loc,
type.getIntOrFloatBitWidth() > 32 ? builder.getI64Type()
: builder.getI32Type(),
value);
- return builder.create<arith::UIToFPOp>(loc, type, integerVal);
+ return arith::UIToFPOp::create(builder, loc, type, integerVal);
}
static Value createLinalgIndex(OpBuilder &builder, Location loc,
FloatType type, int64_t index) {
- auto indexVal = builder.create<linalg::IndexOp>(loc, index);
+ auto indexVal = linalg::IndexOp::create(builder, loc, index);
return castIndexToFloat(builder, loc, type, indexVal);
}
@@ -2640,7 +2661,7 @@ struct RFFT2dConverter final : public OpRewritePattern<RFFT2dOp> {
// Constants and dimension sizes
auto twoPiAttr = rewriter.getFloatAttr(elementType, 6.283185307179586);
- auto twoPi = rewriter.create<arith::ConstantOp>(loc, twoPiAttr);
+ auto twoPi = arith::ConstantOp::create(rewriter, loc, twoPiAttr);
auto constH = castIndexToFloat(rewriter, loc, elementType, dimH);
auto constW = castIndexToFloat(rewriter, loc, elementType, dimW);
@@ -2650,43 +2671,45 @@ struct RFFT2dConverter final : public OpRewritePattern<RFFT2dOp> {
Value sumImag = args[2];
// Indices for angle computation
- Value oy = builder.create<linalg::IndexOp>(loc, 1);
- Value ox = builder.create<linalg::IndexOp>(loc, 2);
- Value iy = builder.create<linalg::IndexOp>(loc, 3);
- Value ix = builder.create<linalg::IndexOp>(loc, 4);
+ Value oy = linalg::IndexOp::create(builder, loc, 1);
+ Value ox = linalg::IndexOp::create(builder, loc, 2);
+ Value iy = linalg::IndexOp::create(builder, loc, 3);
+ Value ix = linalg::IndexOp::create(builder, loc, 4);
// Calculating angle without integer parts of components as sin/cos are
// periodic: angle = 2 * pi() * ( ( (iy * oy) % H) / H + ( (ix * ox) % W )
// / W);
- auto iyXoy = builder.create<index::MulOp>(loc, iy, oy);
- auto ixXox = builder.create<index::MulOp>(loc, ix, ox);
+ auto iyXoy = index::MulOp::create(builder, loc, iy, oy);
+ auto ixXox = index::MulOp::create(builder, loc, ix, ox);
- auto iyRem = builder.create<index::RemUOp>(loc, iyXoy, dimH);
- auto ixRem = builder.create<index::RemUOp>(loc, ixXox, dimW);
+ auto iyRem = index::RemUOp::create(builder, loc, iyXoy, dimH);
+ auto ixRem = index::RemUOp::create(builder, loc, ixXox, dimW);
auto iyRemFloat = castIndexToFloat(builder, loc, elementType, iyRem);
auto ixRemFloat = castIndexToFloat(builder, loc, elementType, ixRem);
- auto yComponent = builder.create<arith::DivFOp>(loc, iyRemFloat, constH);
- auto xComponent = builder.create<arith::DivFOp>(loc, ixRemFloat, constW);
- auto sumXY = builder.create<arith::AddFOp>(loc, yComponent, xComponent);
- auto angle = builder.create<arith::MulFOp>(loc, twoPi, sumXY);
+ auto yComponent = arith::DivFOp::create(builder, loc, iyRemFloat, constH);
+ auto xComponent = arith::DivFOp::create(builder, loc, ixRemFloat, constW);
+ auto sumXY = arith::AddFOp::create(builder, loc, yComponent, xComponent);
+ auto angle = arith::MulFOp::create(builder, loc, twoPi, sumXY);
// realComponent = valReal * cos(angle)
// imagComponent = valReal * sin(angle)
- auto cosAngle = builder.create<math::CosOp>(loc, angle);
- auto sinAngle = builder.create<math::SinOp>(loc, angle);
+ auto cosAngle = math::CosOp::create(builder, loc, angle);
+ auto sinAngle = math::SinOp::create(builder, loc, angle);
auto realComponent =
- builder.create<arith::MulFOp>(loc, valReal, cosAngle);
+ arith::MulFOp::create(builder, loc, valReal, cosAngle);
auto imagComponent =
- builder.create<arith::MulFOp>(loc, valReal, sinAngle);
+ arith::MulFOp::create(builder, loc, valReal, sinAngle);
// outReal = sumReal + realComponent
// outImag = sumImag - imagComponent
- auto outReal = builder.create<arith::AddFOp>(loc, sumReal, realComponent);
- auto outImag = builder.create<arith::SubFOp>(loc, sumImag, imagComponent);
+ auto outReal =
+ arith::AddFOp::create(builder, loc, sumReal, realComponent);
+ auto outImag =
+ arith::SubFOp::create(builder, loc, sumImag, imagComponent);
- builder.create<linalg::YieldOp>(loc, ValueRange{outReal, outImag});
+ linalg::YieldOp::create(builder, loc, ValueRange{outReal, outImag});
};
rewriter.replaceOpWithNewOp<linalg::GenericOp>(
@@ -2760,7 +2783,7 @@ struct FFT2dConverter final : OpRewritePattern<FFT2dOp> {
// Constants and dimension sizes
auto twoPiAttr = rewriter.getFloatAttr(real_el_ty, 6.283185307179586);
- auto twoPi = rewriter.create<arith::ConstantOp>(loc, twoPiAttr);
+ auto twoPi = arith::ConstantOp::create(rewriter, loc, twoPiAttr);
Value constH =
RFFT2dConverter::castIndexToFloat(rewriter, loc, real_el_ty, dimH);
Value constW =
@@ -2773,57 +2796,59 @@ struct FFT2dConverter final : OpRewritePattern<FFT2dOp> {
Value sumImag = args[3];
// Indices for angle computation
- Value oy = builder.create<linalg::IndexOp>(loc, 1);
- Value ox = builder.create<linalg::IndexOp>(loc, 2);
- Value iy = builder.create<linalg::IndexOp>(loc, 3);
- Value ix = builder.create<linalg::IndexOp>(loc, 4);
+ Value oy = linalg::IndexOp::create(builder, loc, 1);
+ Value ox = linalg::IndexOp::create(builder, loc, 2);
+ Value iy = linalg::IndexOp::create(builder, loc, 3);
+ Value ix = linalg::IndexOp::create(builder, loc, 4);
// float_t angle = sign_val * 2 * pi() * ( ( (iy * oy) % H) / H + ( (ix *
// ox) % W ) / W);
- auto iyXoy = builder.create<index::MulOp>(loc, iy, oy);
- auto ixXox = builder.create<index::MulOp>(loc, ix, ox);
+ auto iyXoy = index::MulOp::create(builder, loc, iy, oy);
+ auto ixXox = index::MulOp::create(builder, loc, ix, ox);
- auto iyRem = builder.create<index::RemUOp>(loc, iyXoy, dimH);
- auto ixRem = builder.create<index::RemUOp>(loc, ixXox, dimW);
+ auto iyRem = index::RemUOp::create(builder, loc, iyXoy, dimH);
+ auto ixRem = index::RemUOp::create(builder, loc, ixXox, dimW);
auto iyRemFloat =
RFFT2dConverter::castIndexToFloat(builder, loc, real_el_ty, iyRem);
auto ixRemFloat =
RFFT2dConverter::castIndexToFloat(builder, loc, real_el_ty, ixRem);
- auto yComponent = builder.create<arith::DivFOp>(loc, iyRemFloat, constH);
- auto xComponent = builder.create<arith::DivFOp>(loc, ixRemFloat, constW);
+ auto yComponent = arith::DivFOp::create(builder, loc, iyRemFloat, constH);
+ auto xComponent = arith::DivFOp::create(builder, loc, ixRemFloat, constW);
- auto sumXY = builder.create<arith::AddFOp>(loc, yComponent, xComponent);
- auto angle = builder.create<arith::MulFOp>(loc, twoPi, sumXY);
+ auto sumXY = arith::AddFOp::create(builder, loc, yComponent, xComponent);
+ auto angle = arith::MulFOp::create(builder, loc, twoPi, sumXY);
if (inverse.getValue()) {
- angle = builder.create<arith::MulFOp>(
- loc, angle,
- rewriter.create<arith::ConstantOp>(
- loc, rewriter.getFloatAttr(real_el_ty, -1.0)));
+ angle = arith::MulFOp::create(
+ builder, loc, angle,
+ arith::ConstantOp::create(rewriter, loc,
+ rewriter.getFloatAttr(real_el_ty, -1.0)));
}
// realComponent = val_real * cos(a) + val_imag * sin(a);
// imagComponent = -val_real * sin(a) + val_imag * cos(a);
- auto cosAngle = builder.create<math::CosOp>(loc, angle);
- auto sinAngle = builder.create<math::SinOp>(loc, angle);
+ auto cosAngle = math::CosOp::create(builder, loc, angle);
+ auto sinAngle = math::SinOp::create(builder, loc, angle);
- auto rcos = builder.create<arith::MulFOp>(loc, valReal, cosAngle);
- auto rsin = builder.create<arith::MulFOp>(loc, valImag, sinAngle);
- auto realComponent = builder.create<arith::AddFOp>(loc, rcos, rsin);
+ auto rcos = arith::MulFOp::create(builder, loc, valReal, cosAngle);
+ auto rsin = arith::MulFOp::create(builder, loc, valImag, sinAngle);
+ auto realComponent = arith::AddFOp::create(builder, loc, rcos, rsin);
- auto icos = builder.create<arith::MulFOp>(loc, valImag, cosAngle);
- auto isin = builder.create<arith::MulFOp>(loc, valReal, sinAngle);
+ auto icos = arith::MulFOp::create(builder, loc, valImag, cosAngle);
+ auto isin = arith::MulFOp::create(builder, loc, valReal, sinAngle);
- auto imagComponent = builder.create<arith::SubFOp>(loc, icos, isin);
+ auto imagComponent = arith::SubFOp::create(builder, loc, icos, isin);
// outReal = sumReal + realComponent
// outImag = sumImag - imagComponent
- auto outReal = builder.create<arith::AddFOp>(loc, sumReal, realComponent);
- auto outImag = builder.create<arith::AddFOp>(loc, sumImag, imagComponent);
+ auto outReal =
+ arith::AddFOp::create(builder, loc, sumReal, realComponent);
+ auto outImag =
+ arith::AddFOp::create(builder, loc, sumImag, imagComponent);
- builder.create<linalg::YieldOp>(loc, ValueRange{outReal, outImag});
+ linalg::YieldOp::create(builder, loc, ValueRange{outReal, outImag});
};
rewriter.replaceOpWithNewOp<linalg::GenericOp>(
diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp
index 00b9a065dfb3d..3a205246ddd9e 100644
--- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp
+++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp
@@ -52,11 +52,11 @@ static mlir::Value applyPad(Location loc, Value input, ArrayRef<int64_t> pad,
highIndices.push_back(rewriter.getIndexAttr(highPad));
}
- Value padValue = rewriter.create<arith::ConstantOp>(loc, padAttr);
+ Value padValue = arith::ConstantOp::create(rewriter, loc, padAttr);
- return rewriter.create<tensor::PadOp>(
- loc, RankedTensorType::get(paddedShape, inputETy), input, lowIndices,
- highIndices, padValue);
+ return tensor::PadOp::create(rewriter, loc,
+ RankedTensorType::get(paddedShape, inputETy),
+ input, lowIndices, highIndices, padValue);
}
static mlir::Value
@@ -72,10 +72,10 @@ linalgIntBroadcastExtSIAdd(PatternRewriter &rewriter, Location loc, Value bias,
Value biasVal = args[0];
Type resType = args[1].getType();
if (resType != biasVal.getType()) {
- biasVal = builder.create<arith::ExtSIOp>(loc, resType, biasVal);
+ biasVal = arith::ExtSIOp::create(builder, loc, resType, biasVal);
}
- Value added = builder.create<arith::AddIOp>(loc, biasVal, args[1]);
- builder.create<linalg::YieldOp>(loc, added);
+ Value added = arith::AddIOp::create(builder, loc, biasVal, args[1]);
+ linalg::YieldOp::create(builder, loc, added);
})
.getResult(0);
}
@@ -134,19 +134,19 @@ static mlir::Value linalgBroadcastAndMaybeExt(PatternRewriter &rewriter,
if (resType != biasVal.getType()) {
biasVal =
resultTy.getElementType().isFloat()
- ? builder.create<arith::ExtFOp>(loc, resType, biasVal)
+ ? arith::ExtFOp::create(builder, loc, resType, biasVal)
.getResult()
- : builder.create<arith::ExtSIOp>(loc, resType, biasVal)
+ : arith::ExtSIOp::create(builder, loc, resType, biasVal)
.getResult();
}
- builder.create<linalg::YieldOp>(loc, biasVal);
+ linalg::YieldOp::create(builder, loc, biasVal);
})
.getResult(0);
}
static mlir::Value reifyConstantDim(int64_t attr,
ImplicitLocOpBuilder &builder) {
- return builder.create<arith::ConstantIndexOp>(attr);
+ return arith::ConstantIndexOp::create(builder, attr);
}
// Calculating the output width/height using the formula:
@@ -160,22 +160,22 @@ static mlir::Value getConvOrPoolOutputDim(Location loc, Value inputDim,
int64_t dilationAttr,
OpBuilder &rewriter) {
ImplicitLocOpBuilder builder(loc, rewriter);
- auto one = rewriter.create<arith::ConstantOp>(
- loc, IntegerAttr::get(inputDim.getType(), 1));
+ auto one = arith::ConstantOp::create(rewriter, loc,
+ IntegerAttr::get(inputDim.getType(), 1));
Value padBefore = reifyConstantDim(padBeforeAttr, builder);
- Value paddedBefore = builder.create<arith::AddIOp>(inputDim, padBefore);
+ Value paddedBefore = arith::AddIOp::create(builder, inputDim, padBefore);
Value padAfter = reifyConstantDim(padAfterAttr, builder);
- Value paddedAfter = builder.create<arith::AddIOp>(paddedBefore, padAfter);
+ Value paddedAfter = arith::AddIOp::create(builder, paddedBefore, padAfter);
- Value subOne = builder.create<arith::SubIOp>(kernelDim, one);
+ Value subOne = arith::SubIOp::create(builder, kernelDim, one);
Value dilation = reifyConstantDim(dilationAttr, builder);
- Value dilated = builder.create<arith::MulIOp>(dilation, subOne);
- Value addOne = builder.create<arith::AddIOp>(dilated, one);
+ Value dilated = arith::MulIOp::create(builder, dilation, subOne);
+ Value addOne = arith::AddIOp::create(builder, dilated, one);
- Value subtract = builder.create<arith::SubIOp>(paddedAfter, addOne);
+ Value subtract = arith::SubIOp::create(builder, paddedAfter, addOne);
Value stride = reifyConstantDim(strideAttr, builder);
- Value divide = builder.create<arith::DivUIOp>(subtract, stride);
- return builder.create<arith::AddIOp>(divide, one);
+ Value divide = arith::DivUIOp::create(builder, subtract, stride);
+ return arith::AddIOp::create(builder, divide, one);
}
// Creates a vector of the dynamic output dims for Conv2D and Depthwise_Conv2D
@@ -198,9 +198,9 @@ static SmallVector<Value> inferDynamicDimsForConv(
auto padBottom = padAttr[i * 2 + 1];
auto stride = strideAttr[i];
auto dilation = dilationAttr[i];
- Value initDynDim = rewriter.create<tensor::DimOp>(loc, input, inputDim);
+ Value initDynDim = tensor::DimOp::create(rewriter, loc, input, inputDim);
Value kernelDynDim =
- rewriter.create<tensor::DimOp>(loc, weight, kernelDim);
+ tensor::DimOp::create(rewriter, loc, weight, kernelDim);
// H = F(IH, pad_top, pad_bottom, dilation_y, KH, stride_y)
dynDims[inputDim] =
getConvOrPoolOutputDim(loc, initDynDim, padTop, padBottom,
@@ -211,7 +211,7 @@ static SmallVector<Value> inferDynamicDimsForConv(
// Get the batch/channels dimensions.
for (int i = 0; i < inputRank; i++) {
if (resultTy.isDynamicDim(i) && !dynDims[i])
- dynDims[i] = rewriter.create<tensor::DimOp>(loc, input, i);
+ dynDims[i] = tensor::DimOp::create(rewriter, loc, input, i);
}
SmallVector<Value> filteredDims = condenseValues(dynDims);
@@ -350,8 +350,8 @@ class ConvConverter : public OpConversionPattern<TosaConvOp> {
auto weightPermAttr = rewriter.getDenseI32ArrayAttr(weightPerm);
Type newWeightTy =
RankedTensorType::get(newWeightShape, weightTy.getElementType());
- weight = rewriter.create<tosa::TransposeOp>(loc, newWeightTy, weight,
- weightPermAttr);
+ weight = tosa::TransposeOp::create(rewriter, loc, newWeightTy, weight,
+ weightPermAttr);
}
}
@@ -372,8 +372,8 @@ class ConvConverter : public OpConversionPattern<TosaConvOp> {
auto weightPermAttr = rewriter.getDenseI32ArrayAttr(weightPerm);
Type newWeightTy =
RankedTensorType::get(newWeightShape, weightTy.getElementType());
- weight = rewriter.create<tosa::TransposeOp>(loc, newWeightTy, weight,
- weightPermAttr);
+ weight = tosa::TransposeOp::create(rewriter, loc, newWeightTy, weight,
+ weightPermAttr);
}
// Extract the attributes for convolution.
@@ -384,8 +384,8 @@ class ConvConverter : public OpConversionPattern<TosaConvOp> {
auto strideAttr = rewriter.getI64TensorAttr(stride);
auto dilationAttr = rewriter.getI64TensorAttr(dilation);
- Value biasEmptyTensor = rewriter.create<tensor::EmptyOp>(
- loc, resultTy.getShape(), accETy, filteredDims);
+ Value biasEmptyTensor = tensor::EmptyOp::create(
+ rewriter, loc, resultTy.getShape(), accETy, filteredDims);
Value broadcastBias =
linalgBroadcastAndMaybeExt(rewriter, loc, bias, biasEmptyTensor);
@@ -394,8 +394,8 @@ class ConvConverter : public OpConversionPattern<TosaConvOp> {
auto iZp = rewriter.getI32IntegerAttr(inputZpVal);
auto kZp = rewriter.getI32IntegerAttr(weightZpVal);
- auto iZpVal = rewriter.create<arith::ConstantOp>(loc, iZp);
- auto kZpVal = rewriter.create<arith::ConstantOp>(loc, kZp);
+ auto iZpVal = arith::ConstantOp::create(rewriter, loc, iZp);
+ auto kZpVal = arith::ConstantOp::create(rewriter, loc, kZp);
Value conv =
rewriter
@@ -417,7 +417,7 @@ class ConvConverter : public OpConversionPattern<TosaConvOp> {
// We may need to truncate back to the result type if the accumulator was
// wider than the result.
if (resultTy != accTy)
- conv = rewriter.create<tosa::CastOp>(loc, resultTy, conv);
+ conv = tosa::CastOp::create(rewriter, loc, resultTy, conv);
rewriter.replaceOp(op, conv);
return success();
@@ -526,16 +526,16 @@ class DepthwiseConvConverter
accETy);
auto resultZeroAttr = rewriter.getZeroAttr(accETy);
- Value emptyTensor = rewriter.create<tensor::EmptyOp>(
- loc, linalgConvTy.getShape(), accETy, filteredDims);
- Value zero = rewriter.create<arith::ConstantOp>(loc, resultZeroAttr);
+ Value emptyTensor = tensor::EmptyOp::create(
+ rewriter, loc, linalgConvTy.getShape(), accETy, filteredDims);
+ Value zero = arith::ConstantOp::create(rewriter, loc, resultZeroAttr);
Value zeroTensor = rewriter
.create<linalg::FillOp>(loc, ValueRange{zero},
ValueRange{emptyTensor})
.result();
- Value biasEmptyTensor = rewriter.create<tensor::EmptyOp>(
- loc, resultTy.getShape(), resultETy, filteredDims);
+ Value biasEmptyTensor = tensor::EmptyOp::create(
+ rewriter, loc, resultTy.getShape(), resultETy, filteredDims);
// Broadcast the initial value to the output tensor before convolving.
SmallVector<AffineMap, 4> indexingMaps;
@@ -553,16 +553,16 @@ class DepthwiseConvConverter
// We may need to truncate back to the result type if the accumulator was
// wider than the result.
if (accETy != resultETy)
- conv = rewriter.create<tosa::CastOp>(
- loc,
+ conv = tosa::CastOp::create(
+ rewriter, loc,
RankedTensorType::get(cast<ShapedType>(conv.getType()).getShape(),
resultETy),
conv);
SmallVector<ReassociationExprs, 4> reassociationMap;
createDepthwiseConvCollapseMap(resultRank, reassociationMap, rewriter);
- Value convReshape = rewriter.create<tensor::CollapseShapeOp>(
- loc, resultTy, conv, reassociationMap);
+ Value convReshape = tensor::CollapseShapeOp::create(
+ rewriter, loc, resultTy, conv, reassociationMap);
Value result =
rewriter
@@ -574,20 +574,20 @@ class DepthwiseConvConverter
ValueRange args) {
Value added;
if (llvm::isa<FloatType>(inputETy))
- added = nestedBuilder.create<arith::AddFOp>(loc, args[0],
- args[1]);
+ added = arith::AddFOp::create(nestedBuilder, loc, args[0],
+ args[1]);
else
- added = nestedBuilder.create<arith::AddIOp>(loc, args[0],
- args[1]);
- nestedBuilder.create<linalg::YieldOp>(nestedLoc, added);
+ added = arith::AddIOp::create(nestedBuilder, loc, args[0],
+ args[1]);
+ linalg::YieldOp::create(nestedBuilder, nestedLoc, added);
})
.getResult(0);
rewriter.replaceOp(op, result);
} else {
IntegerAttr iZp = rewriter.getI32IntegerAttr(inputZpVal);
IntegerAttr wZp = rewriter.getI32IntegerAttr(weightZpVal);
- auto iZpVal = rewriter.create<arith::ConstantOp>(loc, iZp);
- auto kZpVal = rewriter.create<arith::ConstantOp>(loc, wZp);
+ auto iZpVal = arith::ConstantOp::create(rewriter, loc, iZp);
+ auto kZpVal = arith::ConstantOp::create(rewriter, loc, wZp);
Value conv =
rewriter
.create<linalg::DepthwiseConv2DNhwcHwcmQOp>(
@@ -596,8 +596,8 @@ class DepthwiseConvConverter
.getResult(0);
SmallVector<ReassociationExprs, 4> reassociationMap;
createDepthwiseConvCollapseMap(resultRank, reassociationMap, rewriter);
- Value convReshape = rewriter.create<tensor::CollapseShapeOp>(
- loc, resultTy, conv, reassociationMap);
+ Value convReshape = tensor::CollapseShapeOp::create(
+ rewriter, loc, resultTy, conv, reassociationMap);
Value result = linalgIntBroadcastExtSIAdd(
rewriter, loc, bias, convReshape, biasEmptyTensor, indexingMaps);
rewriter.replaceOp(op, result);
@@ -621,23 +621,24 @@ class MatMulConverter : public OpConversionPattern<tosa::MatMulOp> {
dynDims.resize(cast<ShapedType>(op->getResult(0).getType()).getRank());
if (!outputTy.hasRank() || outputTy.isDynamicDim(0)) {
- dynDims[0] = rewriter.create<tensor::DimOp>(loc, op->getOperand(0), 0);
+ dynDims[0] = tensor::DimOp::create(rewriter, loc, op->getOperand(0), 0);
}
if (!outputTy.hasRank() || outputTy.isDynamicDim(1)) {
- dynDims[1] = rewriter.create<tensor::DimOp>(loc, op->getOperand(0), 1);
+ dynDims[1] = tensor::DimOp::create(rewriter, loc, op->getOperand(0), 1);
}
if (!outputTy.hasRank() || outputTy.isDynamicDim(2)) {
- dynDims[2] = rewriter.create<tensor::DimOp>(loc, op->getOperand(1), 2);
+ dynDims[2] = tensor::DimOp::create(rewriter, loc, op->getOperand(1), 2);
}
SmallVector<Value> filteredDims = condenseValues(dynDims);
auto zeroAttr = rewriter.getZeroAttr(outputElementTy);
- Value zero = rewriter.create<arith::ConstantOp>(loc, zeroAttr);
- auto emptyTensor = rewriter.create<tensor::EmptyOp>(
- loc, outputTy.getShape(), outputTy.getElementType(), filteredDims);
+ Value zero = arith::ConstantOp::create(rewriter, loc, zeroAttr);
+ auto emptyTensor =
+ tensor::EmptyOp::create(rewriter, loc, outputTy.getShape(),
+ outputTy.getElementType(), filteredDims);
Value zeroTensor = rewriter
.create<linalg::FillOp>(loc, ValueRange{zero},
ValueRange{emptyTensor})
@@ -670,10 +671,10 @@ class MatMulConverter : public OpConversionPattern<tosa::MatMulOp> {
return success();
}
- auto aZp = rewriter.create<arith::ConstantOp>(
- loc, rewriter.getI32IntegerAttr(aZpVal));
- auto bZp = rewriter.create<arith::ConstantOp>(
- loc, rewriter.getI32IntegerAttr(bZpVal));
+ auto aZp = arith::ConstantOp::create(rewriter, loc,
+ rewriter.getI32IntegerAttr(aZpVal));
+ auto bZp = arith::ConstantOp::create(rewriter, loc,
+ rewriter.getI32IntegerAttr(bZpVal));
rewriter.replaceOpWithNewOp<linalg::QuantizedBatchMatmulOp>(
op, TypeRange{op.getType()},
ValueRange{adaptor.getA(), adaptor.getB(), aZp, bZp}, zeroTensor);
@@ -702,7 +703,7 @@ class MaxPool2dConverter : public OpConversionPattern<tosa::MaxPool2dOp> {
// Batch dimension
if (resultTy.isDynamicDim(0))
- dynamicDims.push_back(rewriter.create<tensor::DimOp>(loc, input, 0));
+ dynamicDims.push_back(tensor::DimOp::create(rewriter, loc, input, 0));
// Height/width dimensions
for (int64_t dim : {1, 2}) {
@@ -713,10 +714,10 @@ class MaxPool2dConverter : public OpConversionPattern<tosa::MaxPool2dOp> {
int64_t index = dim - 1;
// Input height/width
- Value ihw = rewriter.create<tensor::DimOp>(loc, input, dim);
+ Value ihw = tensor::DimOp::create(rewriter, loc, input, dim);
// Kernel height/width
- Value khw = rewriter.create<arith::ConstantIndexOp>(loc, kernel[index]);
+ Value khw = arith::ConstantIndexOp::create(rewriter, loc, kernel[index]);
// Output height/width
Value ohw = getConvOrPoolOutputDim(loc, ihw, pad[index * 2],
@@ -727,7 +728,7 @@ class MaxPool2dConverter : public OpConversionPattern<tosa::MaxPool2dOp> {
// Channel dimension
if (resultTy.isDynamicDim(3))
- dynamicDims.push_back(rewriter.create<tensor::DimOp>(loc, input, 3));
+ dynamicDims.push_back(tensor::DimOp::create(rewriter, loc, input, 3));
return dynamicDims;
}
@@ -776,7 +777,7 @@ class MaxPool2dConverter : public OpConversionPattern<tosa::MaxPool2dOp> {
Value paddedInput = applyPad(loc, input, pad, initialAttr, rewriter);
- Value initialValue = rewriter.create<arith::ConstantOp>(loc, initialAttr);
+ Value initialValue = arith::ConstantOp::create(rewriter, loc, initialAttr);
ArrayRef<int64_t> kernel = op.getKernel();
ArrayRef<int64_t> stride = op.getStride();
@@ -785,15 +786,16 @@ class MaxPool2dConverter : public OpConversionPattern<tosa::MaxPool2dOp> {
Attribute dilationAttr = rewriter.getI64VectorAttr({1, 1});
// Create the linalg op that performs pooling.
- Value emptyTensor = rewriter.create<tensor::EmptyOp>(
- loc, resultTy.getShape(), resultTy.getElementType(), dynamicDims);
+ Value emptyTensor =
+ tensor::EmptyOp::create(rewriter, loc, resultTy.getShape(),
+ resultTy.getElementType(), dynamicDims);
Value filledEmptyTensor =
- rewriter.create<linalg::FillOp>(loc, initialValue, emptyTensor)
+ linalg::FillOp::create(rewriter, loc, initialValue, emptyTensor)
.result();
Value fakeWindowDims =
- rewriter.create<tensor::EmptyOp>(loc, kernel, resultETy);
+ tensor::EmptyOp::create(rewriter, loc, kernel, resultETy);
if (isUnsigned) {
rewriter.replaceOpWithNewOp<linalg::PoolingNhwcMaxUnsignedOp>(
@@ -802,8 +804,8 @@ class MaxPool2dConverter : public OpConversionPattern<tosa::MaxPool2dOp> {
return llvm::success();
}
- auto resultOp = rewriter.create<linalg::PoolingNhwcMaxOp>(
- op->getLoc(), ArrayRef<Type>{resultTy},
+ auto resultOp = linalg::PoolingNhwcMaxOp::create(
+ rewriter, op->getLoc(), ArrayRef<Type>{resultTy},
ValueRange{paddedInput, fakeWindowDims}, filledEmptyTensor, strideAttr,
dilationAttr);
@@ -823,9 +825,10 @@ class MaxPool2dConverter : public OpConversionPattern<tosa::MaxPool2dOp> {
// it to include the appropriate checks. If the current value is NaN the
// old value of pool will be taken otherwise we use the result.
if (nanMode == "IGNORE") {
- auto genericOp = rewriter.create<linalg::GenericOp>(
- loc, resultOp.getType(0), resultOp.getInputs(), resultOp.getOutputs(),
- resultOp.getIndexingMapsArray(), resultOp.getIteratorTypesArray(),
+ auto genericOp = linalg::GenericOp::create(
+ rewriter, loc, resultOp.getType(0), resultOp.getInputs(),
+ resultOp.getOutputs(), resultOp.getIndexingMapsArray(),
+ resultOp.getIteratorTypesArray(),
[&](OpBuilder &opBuilder, Location loc, ValueRange blockArgs) {
IRMapping map;
auto oldBlock = resultOp.getRegion().begin();
@@ -833,12 +836,12 @@ class MaxPool2dConverter : public OpConversionPattern<tosa::MaxPool2dOp> {
auto &oldMaxOp = *resultOp.getBlock()->begin();
map.map(oldArgs, blockArgs);
auto *newOp = opBuilder.clone(oldMaxOp, map);
- Value isNaN = opBuilder.create<arith::CmpFOp>(
- loc, arith::CmpFPredicate::UNO, blockArgs.front(),
- blockArgs.front());
- auto selectOp = opBuilder.create<arith::SelectOp>(
- loc, isNaN, blockArgs.back(), newOp->getResult(0));
- opBuilder.create<linalg::YieldOp>(loc, selectOp.getResult());
+ Value isNaN =
+ arith::CmpFOp::create(opBuilder, loc, arith::CmpFPredicate::UNO,
+ blockArgs.front(), blockArgs.front());
+ auto selectOp = arith::SelectOp::create(
+ opBuilder, loc, isNaN, blockArgs.back(), newOp->getResult(0));
+ linalg::YieldOp::create(opBuilder, loc, selectOp.getResult());
});
rewriter.replaceOp(resultOp, genericOp);
}
@@ -894,7 +897,7 @@ class AvgPool2dConverter : public OpRewritePattern<tosa::AvgPool2dOp> {
Value paddedInput = applyPad(loc, input, pad, padAttr, rewriter);
auto initialAttr = rewriter.getZeroAttr(accETy);
- Value initialValue = rewriter.create<arith::ConstantOp>(loc, initialAttr);
+ Value initialValue = arith::ConstantOp::create(rewriter, loc, initialAttr);
ArrayRef<int64_t> kernel = op.getKernel();
ArrayRef<int64_t> stride = op.getStride();
@@ -903,8 +906,8 @@ class AvgPool2dConverter : public OpRewritePattern<tosa::AvgPool2dOp> {
Attribute dilationAttr = rewriter.getI64VectorAttr({1, 1});
// Create the linalg op that performs pooling.
- Value poolEmptyTensor = rewriter.create<tensor::EmptyOp>(
- loc, accTy.getShape(), accETy, dynamicDims);
+ Value poolEmptyTensor = tensor::EmptyOp::create(
+ rewriter, loc, accTy.getShape(), accETy, dynamicDims);
Value filledEmptyTensor =
rewriter
@@ -913,7 +916,7 @@ class AvgPool2dConverter : public OpRewritePattern<tosa::AvgPool2dOp> {
.result();
Value fakeWindowDims =
- rewriter.create<tensor::EmptyOp>(loc, kernel, accETy);
+ tensor::EmptyOp::create(rewriter, loc, kernel, accETy);
// Sum across the pooled region.
Value poolingOp = rewriter
@@ -925,24 +928,24 @@ class AvgPool2dConverter : public OpRewritePattern<tosa::AvgPool2dOp> {
// Normalize the summed value by the number of elements grouped in each
// pool.
- Value iH = rewriter.create<tensor::DimOp>(loc, poolingOp, 1);
- Value iW = rewriter.create<tensor::DimOp>(loc, poolingOp, 2);
+ Value iH = tensor::DimOp::create(rewriter, loc, poolingOp, 1);
+ Value iW = tensor::DimOp::create(rewriter, loc, poolingOp, 2);
- auto one = rewriter.create<arith::ConstantIndexOp>(loc, 1);
- iH = rewriter.create<arith::SubIOp>(loc, iH, one);
- iW = rewriter.create<arith::SubIOp>(loc, iW, one);
+ auto one = arith::ConstantIndexOp::create(rewriter, loc, 1);
+ iH = arith::SubIOp::create(rewriter, loc, iH, one);
+ iW = arith::SubIOp::create(rewriter, loc, iW, one);
- Value genericEmptyTensor = rewriter.create<tensor::EmptyOp>(
- loc, resultTy.getShape(), resultETy, dynamicDims);
+ Value genericEmptyTensor = tensor::EmptyOp::create(
+ rewriter, loc, resultTy.getShape(), resultETy, dynamicDims);
auto affineMap = rewriter.getMultiDimIdentityMap(resultTy.getRank());
- auto genericOp = rewriter.create<linalg::GenericOp>(
- loc, ArrayRef<Type>({resultTy}), ValueRange{poolingOp},
+ auto genericOp = linalg::GenericOp::create(
+ rewriter, loc, ArrayRef<Type>({resultTy}), ValueRange{poolingOp},
ValueRange{genericEmptyTensor},
ArrayRef<AffineMap>({affineMap, affineMap}),
getNParallelLoopsAttrs(resultTy.getRank()),
[&](OpBuilder &b, Location loc, ValueRange args) {
- auto zero = rewriter.create<arith::ConstantIndexOp>(loc, 0);
+ auto zero = arith::ConstantIndexOp::create(rewriter, loc, 0);
// Determines what the portion of valid input is covered by the
// kernel.
@@ -950,30 +953,30 @@ class AvgPool2dConverter : public OpRewritePattern<tosa::AvgPool2dOp> {
if (pad == 0)
return valid;
- auto padVal = rewriter.create<arith::ConstantIndexOp>(loc, pad);
- Value dpos = rewriter.create<arith::SubIOp>(loc, pos, padVal);
+ auto padVal = arith::ConstantIndexOp::create(rewriter, loc, pad);
+ Value dpos = arith::SubIOp::create(rewriter, loc, pos, padVal);
- Value offset = rewriter.create<arith::MinSIOp>(loc, dpos, zero);
- return rewriter.create<arith::AddIOp>(loc, valid, offset)
+ Value offset = arith::MinSIOp::create(rewriter, loc, dpos, zero);
+ return arith::AddIOp::create(rewriter, loc, valid, offset)
->getResult(0);
};
auto coverageFn = [&](int64_t i, Value isize) -> Value {
Value strideVal =
- rewriter.create<arith::ConstantIndexOp>(loc, stride[i - 1]);
+ arith::ConstantIndexOp::create(rewriter, loc, stride[i - 1]);
Value val =
- rewriter.create<arith::ConstantIndexOp>(loc, kernel[i - 1]);
+ arith::ConstantIndexOp::create(rewriter, loc, kernel[i - 1]);
// Find the position relative to the input tensor's ends.
- Value left = rewriter.create<linalg::IndexOp>(loc, i);
- Value right = rewriter.create<arith::SubIOp>(loc, isize, left);
- left = rewriter.create<arith::MulIOp>(loc, left, strideVal);
- right = rewriter.create<arith::MulIOp>(loc, right, strideVal);
+ Value left = linalg::IndexOp::create(rewriter, loc, i);
+ Value right = arith::SubIOp::create(rewriter, loc, isize, left);
+ left = arith::MulIOp::create(rewriter, loc, left, strideVal);
+ right = arith::MulIOp::create(rewriter, loc, right, strideVal);
// Determine how much padding was included.
val = padFn(val, left, pad[i * 2]);
val = padFn(val, right, pad[i * 2 + 1]);
- return rewriter.create<arith::MaxSIOp>(loc, one, val);
+ return arith::MaxSIOp::create(rewriter, loc, one, val);
};
// Compute the indices from either end.
@@ -981,70 +984,70 @@ class AvgPool2dConverter : public OpRewritePattern<tosa::AvgPool2dOp> {
Value kW3 = coverageFn(2, iW);
// Compute the total number of elements and normalize.
- auto count = rewriter.create<arith::IndexCastOp>(
- loc, rewriter.getI32Type(),
- rewriter.create<arith::MulIOp>(loc, kH3, kW3));
+ auto count = arith::IndexCastOp::create(
+ rewriter, loc, rewriter.getI32Type(),
+ arith::MulIOp::create(rewriter, loc, kH3, kW3));
// Divide by the number of summed values. For floats this is just
// a div however for quantized values input normalization had
// to be applied.
Value poolVal = args[0];
if (isa<FloatType>(accETy)) {
- auto countF = rewriter.create<arith::SIToFPOp>(loc, accETy, count);
- poolVal = rewriter.create<arith::DivFOp>(loc, poolVal, countF)
+ auto countF = arith::SIToFPOp::create(rewriter, loc, accETy, count);
+ poolVal = arith::DivFOp::create(rewriter, loc, poolVal, countF)
->getResult(0);
if (accETy.getIntOrFloatBitWidth() >
resultETy.getIntOrFloatBitWidth())
poolVal =
- rewriter.create<arith::TruncFOp>(loc, resultETy, poolVal);
+ arith::TruncFOp::create(rewriter, loc, resultETy, poolVal);
} else {
// If we have quantization information we need to apply an offset
// for the input zp value.
if (inputZpVal != 0) {
- auto inputZp = rewriter.create<arith::ConstantOp>(
- loc, b.getIntegerAttr(accETy, inputZpVal));
+ auto inputZp = arith::ConstantOp::create(
+ rewriter, loc, b.getIntegerAttr(accETy, inputZpVal));
Value offset =
- rewriter.create<arith::MulIOp>(loc, accETy, count, inputZp);
+ arith::MulIOp::create(rewriter, loc, accETy, count, inputZp);
poolVal =
- rewriter.create<arith::SubIOp>(loc, accETy, poolVal, offset);
+ arith::SubIOp::create(rewriter, loc, accETy, poolVal, offset);
}
// Compute: k = 32 - count_leading_zeros(value - 1)
- Value one32 = rewriter.create<arith::ConstantOp>(
- loc, rewriter.getI32IntegerAttr(1));
- Value thirtyTwo32 = rewriter.create<arith::ConstantOp>(
- loc, rewriter.getI32IntegerAttr(32));
+ Value one32 = arith::ConstantOp::create(
+ rewriter, loc, rewriter.getI32IntegerAttr(1));
+ Value thirtyTwo32 = arith::ConstantOp::create(
+ rewriter, loc, rewriter.getI32IntegerAttr(32));
Value countSubOne =
- rewriter.create<arith::SubIOp>(loc, count, one32);
+ arith::SubIOp::create(rewriter, loc, count, one32);
Value leadingZeros =
- rewriter.create<math::CountLeadingZerosOp>(loc, countSubOne);
+ math::CountLeadingZerosOp::create(rewriter, loc, countSubOne);
Value k =
- rewriter.create<arith::SubIOp>(loc, thirtyTwo32, leadingZeros);
+ arith::SubIOp::create(rewriter, loc, thirtyTwo32, leadingZeros);
// Compute: numerator = ((1 << 30) + 1) << k
Value k64 =
- rewriter.create<arith::ExtUIOp>(loc, rewriter.getI64Type(), k);
- Value thirtyShiftPlusOne = rewriter.create<arith::ConstantOp>(
- loc, rewriter.getI64IntegerAttr((1 << 30) + 1));
+ arith::ExtUIOp::create(rewriter, loc, rewriter.getI64Type(), k);
+ Value thirtyShiftPlusOne = arith::ConstantOp::create(
+ rewriter, loc, rewriter.getI64IntegerAttr((1 << 30) + 1));
Value numerator =
- rewriter.create<arith::ShLIOp>(loc, thirtyShiftPlusOne, k64);
+ arith::ShLIOp::create(rewriter, loc, thirtyShiftPlusOne, k64);
// Compute: scale.multiplier = numerator / value;
- Value count64 = rewriter.create<arith::ExtUIOp>(
- loc, rewriter.getI64Type(), count);
+ Value count64 = arith::ExtUIOp::create(
+ rewriter, loc, rewriter.getI64Type(), count);
Value multiplier =
- rewriter.create<arith::DivUIOp>(loc, numerator, count64);
- multiplier = rewriter.create<arith::TruncIOp>(
- loc, rewriter.getI32Type(), multiplier);
+ arith::DivUIOp::create(rewriter, loc, numerator, count64);
+ multiplier = arith::TruncIOp::create(
+ rewriter, loc, rewriter.getI32Type(), multiplier);
// Compute: scale.shift = 30 + k
Value k8 =
- rewriter.create<arith::TruncIOp>(loc, rewriter.getI8Type(), k);
- Value thirty8 = rewriter.create<arith::ConstantOp>(
- loc, rewriter.getI8IntegerAttr(30));
- Value shift = rewriter.create<arith::AddIOp>(loc, k8, thirty8);
+ arith::TruncIOp::create(rewriter, loc, rewriter.getI8Type(), k);
+ Value thirty8 = arith::ConstantOp::create(
+ rewriter, loc, rewriter.getI8IntegerAttr(30));
+ Value shift = arith::AddIOp::create(rewriter, loc, k8, thirty8);
auto scaled =
rewriter
@@ -1056,20 +1059,21 @@ class AvgPool2dConverter : public OpRewritePattern<tosa::AvgPool2dOp> {
// If we have quantization information we need to apply output
// zeropoint.
if (outputZpVal != 0) {
- auto outputZp = rewriter.create<arith::ConstantOp>(
- loc, b.getIntegerAttr(scaled.getType(), outputZpVal));
- scaled = rewriter.create<arith::AddIOp>(loc, scaled, outputZp)
+ auto outputZp = arith::ConstantOp::create(
+ rewriter, loc,
+ b.getIntegerAttr(scaled.getType(), outputZpVal));
+ scaled = arith::AddIOp::create(rewriter, loc, scaled, outputZp)
.getResult();
}
// Apply Clip.
int64_t outBitwidth = resultETy.getIntOrFloatBitWidth();
- auto min = rewriter.create<arith::ConstantIntOp>(
- loc, accETy,
+ auto min = arith::ConstantIntOp::create(
+ rewriter, loc, accETy,
APInt::getSignedMinValue(outBitwidth).getSExtValue());
- auto max = rewriter.create<arith::ConstantIntOp>(
- loc, accETy,
+ auto max = arith::ConstantIntOp::create(
+ rewriter, loc, accETy,
APInt::getSignedMaxValue(outBitwidth).getSExtValue());
auto clamp = clampIntHelper(loc, scaled, min, max, rewriter,
/*isUnsigned=*/false);
@@ -1078,11 +1082,11 @@ class AvgPool2dConverter : public OpRewritePattern<tosa::AvgPool2dOp> {
// Convert type.
if (resultETy != clamp.getType()) {
poolVal =
- rewriter.create<arith::TruncIOp>(loc, resultETy, poolVal);
+ arith::TruncIOp::create(rewriter, loc, resultETy, poolVal);
}
}
- rewriter.create<linalg::YieldOp>(loc, poolVal);
+ linalg::YieldOp::create(rewriter, loc, poolVal);
});
rewriter.replaceOp(op, genericOp.getResult(0));
@@ -1107,8 +1111,9 @@ class TransposeConverter : public OpRewritePattern<tosa::TransposeOp> {
auto permutedSizes =
applyTOSAPermutation<OpFoldResult>(inputSizes, constantPerms);
- auto permutedInit = rewriter.create<tensor::EmptyOp>(
- loc, permutedSizes, op.getInput1().getType().getElementType());
+ auto permutedInit =
+ tensor::EmptyOp::create(rewriter, loc, permutedSizes,
+ op.getInput1().getType().getElementType());
rewriter.replaceOpWithNewOp<linalg::TransposeOp>(
op, op.getInput1(), permutedInit,
llvm::to_vector(llvm::map_range(
diff --git a/mlir/lib/Conversion/TosaToMLProgram/TosaToMLProgram.cpp b/mlir/lib/Conversion/TosaToMLProgram/TosaToMLProgram.cpp
index 7dbccd19a0518..b83f5ec9b0283 100644
--- a/mlir/lib/Conversion/TosaToMLProgram/TosaToMLProgram.cpp
+++ b/mlir/lib/Conversion/TosaToMLProgram/TosaToMLProgram.cpp
@@ -27,8 +27,8 @@ class VariableOpConverter : public OpRewritePattern<tosa::VariableOp> {
LogicalResult matchAndRewrite(tosa::VariableOp op,
PatternRewriter &rewriter) const final {
auto variableType = tosa::getVariableType(op);
- auto newVariable = rewriter.create<mlir::ml_program::GlobalOp>(
- op.getLoc(), op.getName(), variableType, /*is_mutable=*/true,
+ auto newVariable = mlir::ml_program::GlobalOp::create(
+ rewriter, op.getLoc(), op.getName(), variableType, /*is_mutable=*/true,
op.getInitialValueAttr(), /*sym_visibility=*/nullptr);
newVariable.setPrivate();
rewriter.replaceOp(op, newVariable);
@@ -45,8 +45,8 @@ class VariableWriteOpConverter
PatternRewriter &rewriter) const final {
auto globalSymbolRef =
SymbolRefAttr::get(rewriter.getContext(), op.getName());
- auto newVariableWrite = rewriter.create<ml_program::GlobalStoreOp>(
- op.getLoc(), globalSymbolRef, op.getInput1());
+ auto newVariableWrite = ml_program::GlobalStoreOp::create(
+ rewriter, op.getLoc(), globalSymbolRef, op.getInput1());
rewriter.replaceOp(op, newVariableWrite);
return success();
}
@@ -60,8 +60,8 @@ class VariableReadOpConverter : public OpRewritePattern<tosa::VariableReadOp> {
PatternRewriter &rewriter) const final {
auto globalSymbolRef =
SymbolRefAttr::get(rewriter.getContext(), op.getName());
- auto newVariableRead = rewriter.create<ml_program::GlobalLoadOp>(
- op.getLoc(), op.getType(), globalSymbolRef);
+ auto newVariableRead = ml_program::GlobalLoadOp::create(
+ rewriter, op.getLoc(), op.getType(), globalSymbolRef);
rewriter.replaceOp(op, newVariableRead);
return success();
diff --git a/mlir/lib/Conversion/TosaToSCF/TosaToSCF.cpp b/mlir/lib/Conversion/TosaToSCF/TosaToSCF.cpp
index 03f9d20ad69de..aa6b4164e9876 100644
--- a/mlir/lib/Conversion/TosaToSCF/TosaToSCF.cpp
+++ b/mlir/lib/Conversion/TosaToSCF/TosaToSCF.cpp
@@ -30,7 +30,7 @@ static void inlineIfCase(Region &srcRegion, Region &dstRegion,
auto yield = cast<YieldOp>(headBlock->getTerminator());
rewriter.setInsertionPoint(yield);
- rewriter.create<scf::YieldOp>(yield.getLoc(), yield.getInputs());
+ scf::YieldOp::create(rewriter, yield.getLoc(), yield.getInputs());
rewriter.eraseOp(yield);
headBlock->eraseArguments(0, headBlock->getNumArguments());
@@ -46,13 +46,13 @@ static void inlineWhileCase(Region &srcRegion, Region &dstRegion,
auto yield = cast<YieldOp>(headBlock->getTerminator());
rewriter.setInsertionPoint(yield);
if (isCond) {
- auto condition =
- rewriter.create<tensor::ExtractOp>(yield.getLoc(), yield.getOperand(0));
- rewriter.create<scf::ConditionOp>(yield.getLoc(), condition,
- headBlock->getArguments());
+ auto condition = tensor::ExtractOp::create(rewriter, yield.getLoc(),
+ yield.getOperand(0));
+ scf::ConditionOp::create(rewriter, yield.getLoc(), condition,
+ headBlock->getArguments());
} else {
rewriter.setInsertionPoint(yield);
- rewriter.create<scf::YieldOp>(yield.getLoc(), yield.getInputs());
+ scf::YieldOp::create(rewriter, yield.getLoc(), yield.getInputs());
}
rewriter.eraseOp(yield);
}
@@ -66,9 +66,9 @@ class IfOpConverter : public OpRewritePattern<tosa::IfOp> {
LogicalResult matchAndRewrite(tosa::IfOp op,
PatternRewriter &rewriter) const final {
auto condition =
- rewriter.create<tensor::ExtractOp>(op.getLoc(), op.getCondition());
- auto newIf = rewriter.create<scf::IfOp>(op.getLoc(), op.getResultTypes(),
- condition, true);
+ tensor::ExtractOp::create(rewriter, op.getLoc(), op.getCondition());
+ auto newIf = scf::IfOp::create(rewriter, op.getLoc(), op.getResultTypes(),
+ condition, true);
inlineIfCase(op.getThenGraph(), newIf.getThenRegion(), op.getInputList(),
rewriter);
@@ -88,7 +88,7 @@ class ScatterOpConverter : public OpRewritePattern<tosa::ScatterOp> {
static Value createIndexConst(OpBuilder &builder, Location loc,
int64_t value) {
- return builder.create<arith::ConstantIndexOp>(loc, value);
+ return arith::ConstantIndexOp::create(builder, loc, value);
}
public:
@@ -119,9 +119,9 @@ class ScatterOpConverter : public OpRewritePattern<tosa::ScatterOp> {
auto n = ivs[0];
// Read the index and cast it to index type
- auto index = builder.create<tensor::ExtractOp>(loc, indices, ivs);
- auto castIndex = builder.create<arith::IndexCastOp>(
- loc, builder.getIndexType(), index);
+ auto index = tensor::ExtractOp::create(builder, loc, indices, ivs);
+ auto castIndex = arith::IndexCastOp::create(
+ builder, loc, builder.getIndexType(), index);
// Offset, sizes, and strides for the input tensor
auto inputOffset = llvm::to_vector(ivs);
@@ -130,13 +130,13 @@ class ScatterOpConverter : public OpRewritePattern<tosa::ScatterOp> {
llvm::SmallVector<Value> sizes = {one, one, dimC};
llvm::SmallVector<Value> strides = {one, one, one};
- auto slice = builder.create<tensor::ExtractSliceOp>(
- loc, input, inputOffset, sizes, strides);
+ auto slice = tensor::ExtractSliceOp::create(builder, loc, input,
+ inputOffset, sizes, strides);
// Insert the slice into the output accumulator tensor.
llvm::SmallVector<Value> outputOffset = {n, castIndex, zero};
- auto updated = builder.create<tensor::InsertSliceOp>(
- loc, slice, args[0], outputOffset, sizes, strides);
+ auto updated = tensor::InsertSliceOp::create(
+ builder, loc, slice, args[0], outputOffset, sizes, strides);
return {updated};
};
@@ -155,8 +155,8 @@ class WhileOpConverter : public OpRewritePattern<tosa::WhileOp> {
LogicalResult matchAndRewrite(tosa::WhileOp op,
PatternRewriter &rewriter) const final {
- auto newWhile = rewriter.create<scf::WhileOp>(
- op.getLoc(), op.getResultTypes(), op.getInputList());
+ auto newWhile = scf::WhileOp::create(
+ rewriter, op.getLoc(), op.getResultTypes(), op.getInputList());
rewriter.createBlock(&newWhile.getBefore());
rewriter.createBlock(&newWhile.getAfter());
diff --git a/mlir/lib/Conversion/TosaToTensor/TosaToTensor.cpp b/mlir/lib/Conversion/TosaToTensor/TosaToTensor.cpp
index c6cbcb0f8ab2b..2945ae3b49f1f 100644
--- a/mlir/lib/Conversion/TosaToTensor/TosaToTensor.cpp
+++ b/mlir/lib/Conversion/TosaToTensor/TosaToTensor.cpp
@@ -308,15 +308,15 @@ class SliceConverter : public OpConversionPattern<tosa::SliceOp> {
if (ShapedType::isStatic(sizes.back()))
continue;
- auto dim = rewriter.create<tensor::DimOp>(loc, input, index);
- auto offset = rewriter.create<arith::ConstantOp>(
- loc, rewriter.getIndexAttr(sliceStarts[index]));
- dynSizes.push_back(rewriter.create<arith::SubIOp>(loc, dim, offset));
+ auto dim = tensor::DimOp::create(rewriter, loc, input, index);
+ auto offset = arith::ConstantOp::create(
+ rewriter, loc, rewriter.getIndexAttr(sliceStarts[index]));
+ dynSizes.push_back(arith::SubIOp::create(rewriter, loc, dim, offset));
}
- auto newSliceOp = rewriter.create<tensor::ExtractSliceOp>(
- sliceOp.getLoc(), sliceOp.getType(), input, ValueRange({}), dynSizes,
- ValueRange({}), rewriter.getDenseI64ArrayAttr(sliceStarts),
+ auto newSliceOp = tensor::ExtractSliceOp::create(
+ rewriter, sliceOp.getLoc(), sliceOp.getType(), input, ValueRange({}),
+ dynSizes, ValueRange({}), rewriter.getDenseI64ArrayAttr(sliceStarts),
rewriter.getDenseI64ArrayAttr(sizes),
rewriter.getDenseI64ArrayAttr(strides));
@@ -361,7 +361,7 @@ class PadConverter : public OpConversionPattern<tosa::PadOp> {
Value padConstant = rewriter.createOrFold<tensor::ExtractOp>(
loc, padOp.getPadConst(),
- ValueRange({rewriter.create<arith::ConstantIndexOp>(loc, 0)}));
+ ValueRange({arith::ConstantIndexOp::create(rewriter, loc, 0)}));
if (!padConstant) {
return rewriter.notifyMatchFailure(
@@ -375,16 +375,16 @@ class PadConverter : public OpConversionPattern<tosa::PadOp> {
highValues.reserve(rank);
for (int i = 0; i < rank; i++) {
- Value lowVal = rewriter.create<arith::ConstantOp>(
- loc, rewriter.getIndexAttr(paddingVals[2 * i]));
- Value highVal = rewriter.create<arith::ConstantOp>(
- loc, rewriter.getIndexAttr(paddingVals[2 * i + 1]));
+ Value lowVal = arith::ConstantOp::create(
+ rewriter, loc, rewriter.getIndexAttr(paddingVals[2 * i]));
+ Value highVal = arith::ConstantOp::create(
+ rewriter, loc, rewriter.getIndexAttr(paddingVals[2 * i + 1]));
lowValues.push_back(lowVal);
highValues.push_back(highVal);
}
- auto newPadOp = rewriter.create<tensor::PadOp>(
- loc, padOp.getType(), input, lowValues, highValues, padConstant);
+ auto newPadOp = tensor::PadOp::create(rewriter, loc, padOp.getType(), input,
+ lowValues, highValues, padConstant);
rewriter.replaceOp(padOp, newPadOp.getResult());
return success();
@@ -402,7 +402,7 @@ struct ConcatConverter : public OpConversionPattern<tosa::ConcatOp> {
Location loc = op.getLoc();
int axis = op.getAxis();
Value axisValue =
- rewriter.create<arith::ConstantOp>(loc, rewriter.getIndexAttr(axis));
+ arith::ConstantOp::create(rewriter, loc, rewriter.getIndexAttr(axis));
int64_t rank = resultType.getRank();
SmallVector<OpFoldResult> strides(rank, rewriter.getIndexAttr(1));
@@ -439,8 +439,9 @@ struct ConcatConverter : public OpConversionPattern<tosa::ConcatOp> {
}
}
- Value result = rewriter.create<tensor::EmptyOp>(
- loc, resultType.getShape(), resultType.getElementType(), dynDims);
+ Value result =
+ tensor::EmptyOp::create(rewriter, loc, resultType.getShape(),
+ resultType.getElementType(), dynDims);
for (auto [arg, offset] : llvm::zip(adaptor.getOperands(), axisOffsets)) {
auto sizes = tensor::getMixedSizes(rewriter, op.getLoc(), arg);
diff --git a/mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp b/mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp
index d6f9495b2567c..125ea1eb60ed6 100644
--- a/mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp
+++ b/mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp
@@ -226,22 +226,22 @@ struct BroadcastOpToArmSMELowering
(srcVectorType && (srcVectorType.getRank() == 0))) {
// Broadcast scalar or 0-d vector to 1-d vector.
VectorType tileSliceType = VectorType::Builder(tileType).dropDim(0);
- broadcastOp1D = rewriter.create<vector::BroadcastOp>(
- loc, tileSliceType, broadcastOp.getSource());
+ broadcastOp1D = vector::BroadcastOp::create(rewriter, loc, tileSliceType,
+ broadcastOp.getSource());
} else if (srcVectorType && (srcVectorType.getRank() == 1))
// Value to broadcast is already a 1-d vector, nothing to do.
broadcastOp1D = broadcastOp.getSource();
else
return failure();
- auto initTile = rewriter.create<arm_sme::GetTileOp>(loc, tileType);
+ auto initTile = arm_sme::GetTileOp::create(rewriter, loc, tileType);
auto makeLoopBody = [&](OpBuilder &b, Location loc, Value tileSliceIndex,
Value currentTile) {
// Create 'arm_sme.insert_tile_slice' to broadcast the value
// to each tile slice.
- auto nextTile = b.create<arm_sme::InsertTileSliceOp>(
- loc, tileType, broadcastOp1D, currentTile, tileSliceIndex);
+ auto nextTile = arm_sme::InsertTileSliceOp::create(
+ b, loc, tileType, broadcastOp1D, currentTile, tileSliceIndex);
return nextTile.getResult();
};
@@ -292,15 +292,15 @@ struct SplatOpToArmSMELowering : public OpRewritePattern<vector::SplatOp> {
// First, broadcast the scalar to a 1-d vector.
VectorType tileSliceType = VectorType::Builder(tileType).dropDim(0);
- Value broadcastOp1D = rewriter.create<vector::BroadcastOp>(
- loc, tileSliceType, splatOp.getInput());
+ Value broadcastOp1D = vector::BroadcastOp::create(
+ rewriter, loc, tileSliceType, splatOp.getInput());
- auto initTile = rewriter.create<arm_sme::GetTileOp>(loc, tileType);
+ auto initTile = arm_sme::GetTileOp::create(rewriter, loc, tileType);
auto makeLoopBody = [&](OpBuilder &b, Location loc, Value tileSliceIndex,
Value currentTile) {
- auto nextTile = b.create<arm_sme::InsertTileSliceOp>(
- loc, tileType, broadcastOp1D, currentTile, tileSliceIndex);
+ auto nextTile = arm_sme::InsertTileSliceOp::create(
+ b, loc, tileType, broadcastOp1D, currentTile, tileSliceIndex);
return nextTile.getResult();
};
@@ -370,22 +370,22 @@ struct TransposeOpToArmSMELowering
// Allocate buffer to store input tile to.
Value vscale =
- rewriter.create<vector::VectorScaleOp>(loc, rewriter.getIndexType());
- Value minTileSlices = rewriter.create<arith::ConstantOp>(
- loc, rewriter.getIndexAttr(tileType.getDimSize(0)));
+ vector::VectorScaleOp::create(rewriter, loc, rewriter.getIndexType());
+ Value minTileSlices = arith::ConstantOp::create(
+ rewriter, loc, rewriter.getIndexAttr(tileType.getDimSize(0)));
Value c0 =
- rewriter.create<arith::ConstantOp>(loc, rewriter.getIndexAttr(0));
+ arith::ConstantOp::create(rewriter, loc, rewriter.getIndexAttr(0));
Value numTileSlices =
- rewriter.create<arith::MulIOp>(loc, vscale, minTileSlices);
+ arith::MulIOp::create(rewriter, loc, vscale, minTileSlices);
auto bufferType =
MemRefType::get({ShapedType::kDynamic, ShapedType::kDynamic},
tileType.getElementType());
- auto buffer = rewriter.create<memref::AllocaOp>(
- loc, bufferType, ValueRange{numTileSlices, numTileSlices});
+ auto buffer = memref::AllocaOp::create(
+ rewriter, loc, bufferType, ValueRange{numTileSlices, numTileSlices});
// Store input tile.
- auto tileStoreOp = rewriter.create<arm_sme::TileStoreOp>(
- loc, input, buffer, ValueRange{c0, c0});
+ auto tileStoreOp = arm_sme::TileStoreOp::create(rewriter, loc, input,
+ buffer, ValueRange{c0, c0});
// Reload input tile vertically.
rewriter.replaceOpWithNewOp<arm_sme::TileLoadOp>(
@@ -488,10 +488,10 @@ struct VectorOuterProductToArmSMELowering
Value rhsMaskDim = createMaskOp.getOperand(1);
VectorType operandMaskType = VectorType::Builder(maskType).dropDim(0);
- Value lhsMask =
- rewriter.create<vector::CreateMaskOp>(loc, operandMaskType, lhsMaskDim);
- Value rhsMask =
- rewriter.create<vector::CreateMaskOp>(loc, operandMaskType, rhsMaskDim);
+ Value lhsMask = vector::CreateMaskOp::create(rewriter, loc, operandMaskType,
+ lhsMaskDim);
+ Value rhsMask = vector::CreateMaskOp::create(rewriter, loc, operandMaskType,
+ rhsMaskDim);
return std::make_pair(lhsMask, rhsMask);
}
@@ -531,8 +531,8 @@ struct VectorExtractToArmSMELowering
}
Value sliceIndex = vector::getAsValues(rewriter, loc, position[0]).front();
- auto extractTileSlice = rewriter.create<arm_sme::ExtractTileSliceOp>(
- loc, sourceVector, sliceIndex);
+ auto extractTileSlice = arm_sme::ExtractTileSliceOp::create(
+ rewriter, loc, sourceVector, sliceIndex);
if (position.size() == 1) {
// Single index case: Extracts a 1D slice.
@@ -593,10 +593,10 @@ struct VectorInsertToArmSMELowering
if (position.size() == 2) {
// Two indices case: Insert single element into tile.
// We need to first extract the existing slice and update the element.
- tileSlice = rewriter.create<arm_sme::ExtractTileSliceOp>(
- loc, insertOp.getDest(), sliceIndex);
- tileSlice = rewriter.create<vector::InsertOp>(loc, source, tileSlice,
- position[1]);
+ tileSlice = arm_sme::ExtractTileSliceOp::create(
+ rewriter, loc, insertOp.getDest(), sliceIndex);
+ tileSlice = vector::InsertOp::create(rewriter, loc, source, tileSlice,
+ position[1]);
}
// Insert the slice into the destination tile.
@@ -642,23 +642,24 @@ struct VectorPrintToArmSMELowering : public OpRewritePattern<vector::PrintOp> {
auto loc = printOp.getLoc();
// Create a loop over the rows of the tile.
- auto vscale = rewriter.create<vector::VectorScaleOp>(loc);
+ auto vscale = vector::VectorScaleOp::create(rewriter, loc);
auto minTileRows =
- rewriter.create<arith::ConstantIndexOp>(loc, vectorType.getDimSize(0));
- auto lowerBound = rewriter.create<arith::ConstantIndexOp>(loc, 0);
- auto upperBound = rewriter.create<arith::MulIOp>(loc, minTileRows, vscale);
- auto step = rewriter.create<arith::ConstantIndexOp>(loc, 1);
- auto forOp = rewriter.create<scf::ForOp>(loc, lowerBound, upperBound, step);
+ arith::ConstantIndexOp::create(rewriter, loc, vectorType.getDimSize(0));
+ auto lowerBound = arith::ConstantIndexOp::create(rewriter, loc, 0);
+ auto upperBound = arith::MulIOp::create(rewriter, loc, minTileRows, vscale);
+ auto step = arith::ConstantIndexOp::create(rewriter, loc, 1);
+ auto forOp =
+ scf::ForOp::create(rewriter, loc, lowerBound, upperBound, step);
{
// Loop body.
rewriter.setInsertionPointToStart(forOp.getBody());
// Extract the current row from the tile.
Value rowIndex = forOp.getInductionVar();
- auto tileSlice = rewriter.create<arm_sme::ExtractTileSliceOp>(
- loc, printOp.getSource(), rowIndex);
+ auto tileSlice = arm_sme::ExtractTileSliceOp::create(
+ rewriter, loc, printOp.getSource(), rowIndex);
// Print the row with a 1D vector.print.
- rewriter.create<vector::PrintOp>(loc, tileSlice,
- printOp.getPunctuation());
+ vector::PrintOp::create(rewriter, loc, tileSlice,
+ printOp.getPunctuation());
}
rewriter.eraseOp(printOp);
@@ -707,8 +708,8 @@ struct FoldTransferWriteOfExtractTileSlice
Value mask = writeOp.getMask();
if (!mask) {
auto maskType = writeOp.getVectorType().clone(rewriter.getI1Type());
- mask = rewriter.create<arith::ConstantOp>(
- writeOp.getLoc(), maskType, DenseElementsAttr::get(maskType, true));
+ mask = arith::ConstantOp::create(rewriter, writeOp.getLoc(), maskType,
+ DenseElementsAttr::get(maskType, true));
}
rewriter.replaceOpWithNewOp<arm_sme::StoreTileSliceOp>(
@@ -776,10 +777,10 @@ struct ExtractFromCreateMaskToPselLowering
// Create the two 1-D masks at the location of the 2-D create_mask (which is
// usually outside a loop). This prevents the need for later hoisting.
rewriter.setInsertionPoint(createMaskOp);
- auto rowMask = rewriter.create<vector::CreateMaskOp>(
- loc, rowMaskType, createMaskOp.getOperand(0));
- auto colMask = rewriter.create<vector::CreateMaskOp>(
- loc, colMaskType, createMaskOp.getOperand(1));
+ auto rowMask = vector::CreateMaskOp::create(rewriter, loc, rowMaskType,
+ createMaskOp.getOperand(0));
+ auto colMask = vector::CreateMaskOp::create(rewriter, loc, colMaskType,
+ createMaskOp.getOperand(1));
rewriter.setInsertionPoint(extractOp);
auto position =
diff --git a/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp b/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp
index 9a8eb72d72925..77aab85483a8b 100644
--- a/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp
+++ b/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp
@@ -412,22 +412,22 @@ struct PrepareContractToGPUMMA
if (maps == infer({{m, k}, {k, n}, {m, n}}))
return rewriter.notifyMatchFailure(op, "contraction already prepared");
if (maps == infer({{m, k}, {n, k}, {m, n}})) {
- rhs = rewriter.create<vector::TransposeOp>(loc, rhs, perm);
+ rhs = vector::TransposeOp::create(rewriter, loc, rhs, perm);
} else if (maps == infer({{k, m}, {k, n}, {m, n}})) {
- lhs = rewriter.create<vector::TransposeOp>(loc, lhs, perm);
+ lhs = vector::TransposeOp::create(rewriter, loc, lhs, perm);
} else if (maps == infer({{k, m}, {n, k}, {m, n}})) {
- rhs = rewriter.create<vector::TransposeOp>(loc, rhs, perm);
- lhs = rewriter.create<vector::TransposeOp>(loc, lhs, perm);
+ rhs = vector::TransposeOp::create(rewriter, loc, rhs, perm);
+ lhs = vector::TransposeOp::create(rewriter, loc, lhs, perm);
} else if (maps == infer({{m, k}, {k, n}, {n, m}})) {
std::swap(rhs, lhs);
- rhs = rewriter.create<vector::TransposeOp>(loc, rhs, perm);
- lhs = rewriter.create<vector::TransposeOp>(loc, lhs, perm);
+ rhs = vector::TransposeOp::create(rewriter, loc, rhs, perm);
+ lhs = vector::TransposeOp::create(rewriter, loc, lhs, perm);
} else if (maps == infer({{m, k}, {n, k}, {n, m}})) {
std::swap(rhs, lhs);
- rhs = rewriter.create<vector::TransposeOp>(loc, rhs, perm);
+ rhs = vector::TransposeOp::create(rewriter, loc, rhs, perm);
} else if (maps == infer({{k, m}, {k, n}, {n, m}})) {
std::swap(lhs, rhs);
- lhs = rewriter.create<vector::TransposeOp>(loc, lhs, perm);
+ lhs = vector::TransposeOp::create(rewriter, loc, lhs, perm);
} else if (maps == infer({{k, m}, {n, k}, {n, m}})) {
std::swap(lhs, rhs);
} else {
@@ -494,13 +494,13 @@ struct CombineTransferReadOpTranspose final
// Fuse through the integer extend op.
if (extOp) {
if (isa<arith::ExtSIOp>(extOp))
- result = rewriter.create<arith::ExtSIOp>(loc, op.getType(), result)
+ result = arith::ExtSIOp::create(rewriter, loc, op.getType(), result)
.getResult();
else if (isa<arith::ExtUIOp>(extOp))
- result = rewriter.create<arith::ExtUIOp>(loc, op.getType(), result)
+ result = arith::ExtUIOp::create(rewriter, loc, op.getType(), result)
.getResult();
else
- result = rewriter.create<arith::ExtFOp>(loc, op.getType(), result)
+ result = arith::ExtFOp::create(rewriter, loc, op.getType(), result)
.getResult();
}
@@ -579,8 +579,8 @@ convertTransferReadOp(RewriterBase &rewriter, vector::TransferReadOp op,
}
gpu::MMAMatrixType type =
gpu::MMAMatrixType::get(op.getVectorType().getShape(), elType, fragType);
- Value load = rewriter.create<gpu::SubgroupMmaLoadMatrixOp>(
- op.getLoc(), type, op.getBase(), op.getIndices(),
+ Value load = gpu::SubgroupMmaLoadMatrixOp::create(
+ rewriter, op.getLoc(), type, op.getBase(), op.getIndices(),
rewriter.getIndexAttr(*stride),
isTranspose ? rewriter.getUnitAttr() : UnitAttr());
valueMapping[mappingResult] = load;
@@ -610,8 +610,8 @@ convertTransferWriteOp(RewriterBase &rewriter, vector::TransferWriteOp op,
}
Value matrix = it->second;
- auto store = rewriter.create<gpu::SubgroupMmaStoreMatrixOp>(
- op.getLoc(), matrix, op.getBase(), op.getIndices(),
+ auto store = gpu::SubgroupMmaStoreMatrixOp::create(
+ rewriter, op.getLoc(), matrix, op.getBase(), op.getIndices(),
rewriter.getIndexAttr(*stride), /*transpose=*/UnitAttr());
(void)store;
@@ -661,8 +661,8 @@ convertConstantOpMmaSync(RewriterBase &rewriter, arith::ConstantOp op,
return rewriter.notifyMatchFailure(op, "not a splat");
}
- Value result = rewriter.create<arith::ConstantOp>(
- op.getLoc(), vectorType,
+ Value result = arith::ConstantOp::create(
+ rewriter, op.getLoc(), vectorType,
DenseElementsAttr::get(vectorType, dense.getSplatValue<Attribute>()));
valueMapping[op.getResult()] = result;
return success();
@@ -743,7 +743,7 @@ creatLdMatrixCompatibleLoads(RewriterBase &rewriter, vector::TransferReadOp op,
}
// Adjust the load offset.
- auto laneId = rewriter.create<gpu::LaneIdOp>(loc, /*upperBound=*/nullptr);
+ auto laneId = gpu::LaneIdOp::create(rewriter, loc, /*upperBound=*/nullptr);
FailureOr<AffineMap> offsets =
nvgpu::getLaneIdToLdMatrixMatrixCoord(rewriter, loc, *params);
if (failed(offsets)) {
@@ -757,8 +757,9 @@ creatLdMatrixCompatibleLoads(RewriterBase &rewriter, vector::TransferReadOp op,
getXferIndices<vector::TransferReadOp>(rewriter, op, *offsets, {laneId},
indices);
- nvgpu::LdMatrixOp newOp = rewriter.create<nvgpu::LdMatrixOp>(
- loc, vectorType, op.getBase(), indices, *transpose, params->numTiles);
+ nvgpu::LdMatrixOp newOp =
+ nvgpu::LdMatrixOp::create(rewriter, loc, vectorType, op.getBase(),
+ indices, *transpose, params->numTiles);
valueMapping[op] = newOp->getResult(0);
return success();
}
@@ -782,17 +783,17 @@ createNonLdMatrixLoads(RewriterBase &rewriter, vector::TransferReadOp op,
"conversion to distributed non-ldmatrix compatible load");
}
- Value laneId = rewriter.create<gpu::LaneIdOp>(loc, /*upperBound=*/nullptr);
+ Value laneId = gpu::LaneIdOp::create(rewriter, loc, /*upperBound=*/nullptr);
// This is the individual element type.
Type loadedElType = regInfo->registerLLVMType;
VectorType vectorType = getMmaSyncVectorOperandType(*regInfo);
- Value fill = rewriter.create<arith::ConstantOp>(
- op.getLoc(), vectorType.getElementType(),
+ Value fill = arith::ConstantOp::create(
+ rewriter, op.getLoc(), vectorType.getElementType(),
rewriter.getZeroAttr(vectorType.getElementType()));
Value result =
- rewriter.create<vector::BroadcastOp>(op.getLoc(), vectorType, fill);
+ vector::BroadcastOp::create(rewriter, op.getLoc(), vectorType, fill);
bool isTransposeLoad = !op.getPermutationMap().isMinorIdentity();
@@ -809,16 +810,16 @@ createNonLdMatrixLoads(RewriterBase &rewriter, vector::TransferReadOp op,
if (failed(coords))
return rewriter.notifyMatchFailure(op, "no coords");
- Value logicalValueId = rewriter.create<arith::ConstantOp>(
- loc, rewriter.getIndexType(),
+ Value logicalValueId = arith::ConstantOp::create(
+ rewriter, loc, rewriter.getIndexType(),
rewriter.getIndexAttr(i * regInfo->elementsPerRegister));
SmallVector<Value, 4> newIndices;
getXferIndices<vector::TransferReadOp>(
rewriter, op, *coords, {laneId, logicalValueId}, newIndices);
- Value el = rewriter.create<vector::LoadOp>(loc, loadedElType,
- op.getBase(), newIndices);
- result = rewriter.create<vector::InsertOp>(loc, el, result, i);
+ Value el = vector::LoadOp::create(rewriter, loc, loadedElType,
+ op.getBase(), newIndices);
+ result = vector::InsertOp::create(rewriter, loc, el, result, i);
}
} else {
if (auto vecType = dyn_cast<VectorType>(loadedElType)) {
@@ -828,8 +829,8 @@ createNonLdMatrixLoads(RewriterBase &rewriter, vector::TransferReadOp op,
for (unsigned innerIdx = 0; innerIdx < vectorType.getShape()[1];
innerIdx++) {
- Value logicalValueId = rewriter.create<arith::ConstantOp>(
- loc, rewriter.getIndexType(),
+ Value logicalValueId = arith::ConstantOp::create(
+ rewriter, loc, rewriter.getIndexType(),
rewriter.getIndexAttr(i * regInfo->elementsPerRegister + innerIdx));
FailureOr<AffineMap> coords = nvgpu::getLaneIdAndValueIdToOperandCoord(
rewriter, op.getLoc(), *warpMatrixInfo);
@@ -839,10 +840,10 @@ createNonLdMatrixLoads(RewriterBase &rewriter, vector::TransferReadOp op,
SmallVector<Value, 4> newIndices;
getXferIndices<vector::TransferReadOp>(
rewriter, op, *coords, {laneId, logicalValueId}, newIndices);
- Value el = rewriter.create<memref::LoadOp>(op.getLoc(), loadedElType,
- op.getBase(), newIndices);
- result = rewriter.create<vector::InsertOp>(
- op.getLoc(), el, result, ArrayRef<int64_t>{i, innerIdx});
+ Value el = memref::LoadOp::create(rewriter, op.getLoc(), loadedElType,
+ op.getBase(), newIndices);
+ result = vector::InsertOp::create(rewriter, op.getLoc(), el, result,
+ ArrayRef<int64_t>{i, innerIdx});
}
}
}
@@ -916,11 +917,11 @@ convertTransferWriteToStores(RewriterBase &rewriter, vector::TransferWriteOp op,
return rewriter.notifyMatchFailure(op, "not mma sync reg info");
VectorType vectorType = getMmaSyncVectorOperandType(*regInfo);
- Value laneId = rewriter.create<gpu::LaneIdOp>(loc, /*upperBound=*/nullptr);
+ Value laneId = gpu::LaneIdOp::create(rewriter, loc, /*upperBound=*/nullptr);
for (unsigned i = 0; i < vectorType.getShape()[0]; i++) {
- Value logicalValueId = rewriter.create<arith::ConstantOp>(
- loc, rewriter.getIndexType(),
+ Value logicalValueId = arith::ConstantOp::create(
+ rewriter, loc, rewriter.getIndexType(),
rewriter.getIndexAttr(i * regInfo->elementsPerRegister));
FailureOr<AffineMap> coords = nvgpu::getLaneIdAndValueIdToOperandCoord(
rewriter, op.getLoc(), *warpMatrixInfo);
@@ -928,11 +929,11 @@ convertTransferWriteToStores(RewriterBase &rewriter, vector::TransferWriteOp op,
return rewriter.notifyMatchFailure(op, "no coords");
Value el =
- rewriter.create<vector::ExtractOp>(loc, matrix, ArrayRef<int64_t>{i});
+ vector::ExtractOp::create(rewriter, loc, matrix, ArrayRef<int64_t>{i});
SmallVector<Value, 4> newIndices;
getXferIndices<vector::TransferWriteOp>(
rewriter, op, *coords, {laneId, logicalValueId}, newIndices);
- rewriter.create<vector::StoreOp>(loc, el, op.getBase(), newIndices);
+ vector::StoreOp::create(rewriter, loc, el, op.getBase(), newIndices);
}
LLVM_DEBUG(DBGS() << "erase: " << op << "\n");
@@ -1015,8 +1016,8 @@ convertExtractStridedSlice(RewriterBase &rewriter,
else if (offsets[1])
sliceOffset[0] = (warpVectorShape[1] / offsets[1]);
- Value newOp = rewriter.create<vector::ExtractStridedSliceOp>(
- loc, sourceVector, sliceOffset, sliceShape, strides);
+ Value newOp = vector::ExtractStridedSliceOp::create(
+ rewriter, loc, sourceVector, sliceOffset, sliceShape, strides);
valueMapping[op] = newOp;
return success();
@@ -1035,9 +1036,10 @@ convertContractOp(RewriterBase &rewriter, vector::ContractionOp op,
itC == valueMapping.end())
return rewriter.notifyMatchFailure(op, "no mapping");
Value opA = itA->second, opB = itB->second, opC = itC->second;
- Value matmul = rewriter.create<gpu::SubgroupMmaComputeOp>(
- op.getLoc(), opC.getType(), opA, opB, opC, /*a_transpose=*/UnitAttr(),
- /*b_transpose=*/UnitAttr());
+ Value matmul = gpu::SubgroupMmaComputeOp::create(rewriter, op.getLoc(),
+ opC.getType(), opA, opB, opC,
+ /*a_transpose=*/UnitAttr(),
+ /*b_transpose=*/UnitAttr());
valueMapping[op.getResult()] = matmul;
return success();
}
@@ -1058,8 +1060,8 @@ convertContractOpToMmaSync(RewriterBase &rewriter, vector::ContractionOp op,
int64_t m = cast<VectorType>(op.getLhs().getType()).getShape()[0];
int64_t n = cast<VectorType>(op.getRhs().getType()).getShape()[0];
int64_t k = cast<VectorType>(op.getLhs().getType()).getShape()[1];
- Value matmul = rewriter.create<nvgpu::MmaSyncOp>(
- op.getLoc(), opA, opB, opC, rewriter.getI64ArrayAttr({m, n, k}));
+ Value matmul = nvgpu::MmaSyncOp::create(rewriter, op.getLoc(), opA, opB, opC,
+ rewriter.getI64ArrayAttr({m, n, k}));
valueMapping[op.getResult()] = matmul;
return success();
}
@@ -1076,13 +1078,13 @@ convertConstantOp(RewriterBase &rewriter, arith::ConstantOp op,
auto splat =
cast<SplatElementsAttr>(op.getValue()).getSplatValue<TypedAttr>();
auto scalarConstant =
- rewriter.create<arith::ConstantOp>(op.getLoc(), splat.getType(), splat);
+ arith::ConstantOp::create(rewriter, op.getLoc(), splat.getType(), splat);
const char *fragType = inferFragType(op);
auto vecType = cast<VectorType>(op.getType());
gpu::MMAMatrixType type = gpu::MMAMatrixType::get(
vecType.getShape(), vecType.getElementType(), llvm::StringRef(fragType));
- auto matrix = rewriter.create<gpu::SubgroupMmaConstantMatrixOp>(
- op.getLoc(), type, scalarConstant);
+ auto matrix = gpu::SubgroupMmaConstantMatrixOp::create(rewriter, op.getLoc(),
+ type, scalarConstant);
valueMapping[op.getResult()] = matrix;
return success();
}
@@ -1100,8 +1102,8 @@ convertBroadcastOp(RewriterBase &rewriter, vector::BroadcastOp op,
auto vecType = op.getResultVectorType();
gpu::MMAMatrixType type = gpu::MMAMatrixType::get(
vecType.getShape(), vecType.getElementType(), llvm::StringRef(fragType));
- auto matrix = rewriter.create<gpu::SubgroupMmaConstantMatrixOp>(
- op.getLoc(), type, op.getSource());
+ auto matrix = gpu::SubgroupMmaConstantMatrixOp::create(rewriter, op.getLoc(),
+ type, op.getSource());
valueMapping[op.getResult()] = matrix;
return success();
}
@@ -1118,9 +1120,9 @@ static scf::ForOp replaceForOpWithNewSignature(RewriterBase &rewriter,
rewriter.setInsertionPoint(loop);
auto operands = llvm::to_vector<4>(loop.getInitArgs());
llvm::append_range(operands, newInitArgs);
- scf::ForOp newLoop = rewriter.create<scf::ForOp>(
- loop.getLoc(), loop.getLowerBound(), loop.getUpperBound(), loop.getStep(),
- operands);
+ scf::ForOp newLoop =
+ scf::ForOp::create(rewriter, loop.getLoc(), loop.getLowerBound(),
+ loop.getUpperBound(), loop.getStep(), operands);
rewriter.eraseBlock(newLoop.getBody());
newLoop.getRegion().getBlocks().splice(
@@ -1189,7 +1191,7 @@ convertYieldOp(RewriterBase &rewriter, scf::YieldOp op,
yieldOperands[operand.index()] = loop.getInitArgs()[operand.index()];
yieldOperands.push_back(it->second);
}
- rewriter.create<scf::YieldOp>(op.getLoc(), yieldOperands);
+ scf::YieldOp::create(rewriter, op.getLoc(), yieldOperands);
LLVM_DEBUG(DBGS() << "erase: " << op << "\n");
rewriter.eraseOp(op);
@@ -1220,8 +1222,8 @@ convertElementwiseOp(RewriterBase &rewriter, Operation *op,
resultType.getOperand());
}
- Value newOp = rewriter.create<gpu::SubgroupMmaElementwiseOp>(
- op->getLoc(), resultType, matrixOperands, opType);
+ Value newOp = gpu::SubgroupMmaElementwiseOp::create(
+ rewriter, op->getLoc(), resultType, matrixOperands, opType);
valueMapping[op->getResult(0)] = newOp;
return success();
}
diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
index e4ff770a807c6..9cd491caa9421 100644
--- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
+++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
@@ -43,13 +43,13 @@ static Value insertOne(ConversionPatternRewriter &rewriter,
assert(rank > 0 && "0-D vector corner case should have been handled already");
if (rank == 1) {
auto idxType = rewriter.getIndexType();
- auto constant = rewriter.create<LLVM::ConstantOp>(
- loc, typeConverter.convertType(idxType),
+ auto constant = LLVM::ConstantOp::create(
+ rewriter, loc, typeConverter.convertType(idxType),
rewriter.getIntegerAttr(idxType, pos));
- return rewriter.create<LLVM::InsertElementOp>(loc, llvmType, val1, val2,
- constant);
+ return LLVM::InsertElementOp::create(rewriter, loc, llvmType, val1, val2,
+ constant);
}
- return rewriter.create<LLVM::InsertValueOp>(loc, val1, val2, pos);
+ return LLVM::InsertValueOp::create(rewriter, loc, val1, val2, pos);
}
// Helper that picks the proper sequence for extracting.
@@ -58,13 +58,13 @@ static Value extractOne(ConversionPatternRewriter &rewriter,
Value val, Type llvmType, int64_t rank, int64_t pos) {
if (rank <= 1) {
auto idxType = rewriter.getIndexType();
- auto constant = rewriter.create<LLVM::ConstantOp>(
- loc, typeConverter.convertType(idxType),
+ auto constant = LLVM::ConstantOp::create(
+ rewriter, loc, typeConverter.convertType(idxType),
rewriter.getIntegerAttr(idxType, pos));
- return rewriter.create<LLVM::ExtractElementOp>(loc, llvmType, val,
- constant);
+ return LLVM::ExtractElementOp::create(rewriter, loc, llvmType, val,
+ constant);
}
- return rewriter.create<LLVM::ExtractValueOp>(loc, val, pos);
+ return LLVM::ExtractValueOp::create(rewriter, loc, val, pos);
}
// Helper that returns data layout alignment of a vector.
@@ -141,9 +141,9 @@ static Value getIndexedPtrs(ConversionPatternRewriter &rewriter, Location loc,
auto ptrsType =
LLVM::getVectorType(pType, vectorType.getDimSize(0),
/*isScalable=*/vectorType.getScalableDims()[0]);
- return rewriter.create<LLVM::GEPOp>(
- loc, ptrsType, typeConverter.convertType(memRefType.getElementType()),
- base, index);
+ return LLVM::GEPOp::create(
+ rewriter, loc, ptrsType,
+ typeConverter.convertType(memRefType.getElementType()), base, index);
}
/// Convert `foldResult` into a Value. Integer attribute is converted to
@@ -152,7 +152,7 @@ static Value getAsLLVMValue(OpBuilder &builder, Location loc,
OpFoldResult foldResult) {
if (auto attr = dyn_cast<Attribute>(foldResult)) {
auto intAttr = cast<IntegerAttr>(attr);
- return builder.create<LLVM::ConstantOp>(loc, intAttr).getResult();
+ return LLVM::ConstantOp::create(builder, loc, intAttr).getResult();
}
return cast<Value>(foldResult);
@@ -440,32 +440,32 @@ class ReductionNeutralFPMax {};
static Value createReductionNeutralValue(ReductionNeutralZero neutral,
ConversionPatternRewriter &rewriter,
Location loc, Type llvmType) {
- return rewriter.create<LLVM::ConstantOp>(loc, llvmType,
- rewriter.getZeroAttr(llvmType));
+ return LLVM::ConstantOp::create(rewriter, loc, llvmType,
+ rewriter.getZeroAttr(llvmType));
}
/// Create the reduction neutral integer one value.
static Value createReductionNeutralValue(ReductionNeutralIntOne neutral,
ConversionPatternRewriter &rewriter,
Location loc, Type llvmType) {
- return rewriter.create<LLVM::ConstantOp>(
- loc, llvmType, rewriter.getIntegerAttr(llvmType, 1));
+ return LLVM::ConstantOp::create(rewriter, loc, llvmType,
+ rewriter.getIntegerAttr(llvmType, 1));
}
/// Create the reduction neutral fp one value.
static Value createReductionNeutralValue(ReductionNeutralFPOne neutral,
ConversionPatternRewriter &rewriter,
Location loc, Type llvmType) {
- return rewriter.create<LLVM::ConstantOp>(
- loc, llvmType, rewriter.getFloatAttr(llvmType, 1.0));
+ return LLVM::ConstantOp::create(rewriter, loc, llvmType,
+ rewriter.getFloatAttr(llvmType, 1.0));
}
/// Create the reduction neutral all-ones value.
static Value createReductionNeutralValue(ReductionNeutralAllOnes neutral,
ConversionPatternRewriter &rewriter,
Location loc, Type llvmType) {
- return rewriter.create<LLVM::ConstantOp>(
- loc, llvmType,
+ return LLVM::ConstantOp::create(
+ rewriter, loc, llvmType,
rewriter.getIntegerAttr(
llvmType, llvm::APInt::getAllOnes(llvmType.getIntOrFloatBitWidth())));
}
@@ -474,8 +474,8 @@ static Value createReductionNeutralValue(ReductionNeutralAllOnes neutral,
static Value createReductionNeutralValue(ReductionNeutralSIntMin neutral,
ConversionPatternRewriter &rewriter,
Location loc, Type llvmType) {
- return rewriter.create<LLVM::ConstantOp>(
- loc, llvmType,
+ return LLVM::ConstantOp::create(
+ rewriter, loc, llvmType,
rewriter.getIntegerAttr(llvmType, llvm::APInt::getSignedMinValue(
llvmType.getIntOrFloatBitWidth())));
}
@@ -484,8 +484,8 @@ static Value createReductionNeutralValue(ReductionNeutralSIntMin neutral,
static Value createReductionNeutralValue(ReductionNeutralUIntMin neutral,
ConversionPatternRewriter &rewriter,
Location loc, Type llvmType) {
- return rewriter.create<LLVM::ConstantOp>(
- loc, llvmType,
+ return LLVM::ConstantOp::create(
+ rewriter, loc, llvmType,
rewriter.getIntegerAttr(llvmType, llvm::APInt::getMinValue(
llvmType.getIntOrFloatBitWidth())));
}
@@ -494,8 +494,8 @@ static Value createReductionNeutralValue(ReductionNeutralUIntMin neutral,
static Value createReductionNeutralValue(ReductionNeutralSIntMax neutral,
ConversionPatternRewriter &rewriter,
Location loc, Type llvmType) {
- return rewriter.create<LLVM::ConstantOp>(
- loc, llvmType,
+ return LLVM::ConstantOp::create(
+ rewriter, loc, llvmType,
rewriter.getIntegerAttr(llvmType, llvm::APInt::getSignedMaxValue(
llvmType.getIntOrFloatBitWidth())));
}
@@ -504,8 +504,8 @@ static Value createReductionNeutralValue(ReductionNeutralSIntMax neutral,
static Value createReductionNeutralValue(ReductionNeutralUIntMax neutral,
ConversionPatternRewriter &rewriter,
Location loc, Type llvmType) {
- return rewriter.create<LLVM::ConstantOp>(
- loc, llvmType,
+ return LLVM::ConstantOp::create(
+ rewriter, loc, llvmType,
rewriter.getIntegerAttr(llvmType, llvm::APInt::getMaxValue(
llvmType.getIntOrFloatBitWidth())));
}
@@ -515,8 +515,8 @@ static Value createReductionNeutralValue(ReductionNeutralFPMin neutral,
ConversionPatternRewriter &rewriter,
Location loc, Type llvmType) {
auto floatType = cast<FloatType>(llvmType);
- return rewriter.create<LLVM::ConstantOp>(
- loc, llvmType,
+ return LLVM::ConstantOp::create(
+ rewriter, loc, llvmType,
rewriter.getFloatAttr(
llvmType, llvm::APFloat::getQNaN(floatType.getFloatSemantics(),
/*Negative=*/false)));
@@ -527,8 +527,8 @@ static Value createReductionNeutralValue(ReductionNeutralFPMax neutral,
ConversionPatternRewriter &rewriter,
Location loc, Type llvmType) {
auto floatType = cast<FloatType>(llvmType);
- return rewriter.create<LLVM::ConstantOp>(
- loc, llvmType,
+ return LLVM::ConstantOp::create(
+ rewriter, loc, llvmType,
rewriter.getFloatAttr(
llvmType, llvm::APFloat::getQNaN(floatType.getFloatSemantics(),
/*Negative=*/true)));
@@ -556,19 +556,19 @@ static Value createVectorLengthValue(ConversionPatternRewriter &rewriter,
auto vShape = vType.getShape();
assert(vShape.size() == 1 && "Unexpected multi-dim vector type");
- Value baseVecLength = rewriter.create<LLVM::ConstantOp>(
- loc, rewriter.getI32Type(),
+ Value baseVecLength = LLVM::ConstantOp::create(
+ rewriter, loc, rewriter.getI32Type(),
rewriter.getIntegerAttr(rewriter.getI32Type(), vShape[0]));
if (!vType.getScalableDims()[0])
return baseVecLength;
// For a scalable vector type, create and return `vScale * baseVecLength`.
- Value vScale = rewriter.create<vector::VectorScaleOp>(loc);
+ Value vScale = vector::VectorScaleOp::create(rewriter, loc);
vScale =
- rewriter.create<arith::IndexCastOp>(loc, rewriter.getI32Type(), vScale);
+ arith::IndexCastOp::create(rewriter, loc, rewriter.getI32Type(), vScale);
Value scalableVecLength =
- rewriter.create<arith::MulIOp>(loc, baseVecLength, vScale);
+ arith::MulIOp::create(rewriter, loc, baseVecLength, vScale);
return scalableVecLength;
}
@@ -581,10 +581,11 @@ static Value createIntegerReductionArithmeticOpLowering(
ConversionPatternRewriter &rewriter, Location loc, Type llvmType,
Value vectorOperand, Value accumulator) {
- Value result = rewriter.create<LLVMRedIntrinOp>(loc, llvmType, vectorOperand);
+ Value result =
+ LLVMRedIntrinOp::create(rewriter, loc, llvmType, vectorOperand);
if (accumulator)
- result = rewriter.create<ScalarOp>(loc, accumulator, result);
+ result = ScalarOp::create(rewriter, loc, accumulator, result);
return result;
}
@@ -596,11 +597,12 @@ template <class LLVMRedIntrinOp>
static Value createIntegerReductionComparisonOpLowering(
ConversionPatternRewriter &rewriter, Location loc, Type llvmType,
Value vectorOperand, Value accumulator, LLVM::ICmpPredicate predicate) {
- Value result = rewriter.create<LLVMRedIntrinOp>(loc, llvmType, vectorOperand);
+ Value result =
+ LLVMRedIntrinOp::create(rewriter, loc, llvmType, vectorOperand);
if (accumulator) {
Value cmp =
- rewriter.create<LLVM::ICmpOp>(loc, predicate, accumulator, result);
- result = rewriter.create<LLVM::SelectOp>(loc, cmp, accumulator, result);
+ LLVM::ICmpOp::create(rewriter, loc, predicate, accumulator, result);
+ result = LLVM::SelectOp::create(rewriter, loc, cmp, accumulator, result);
}
return result;
}
@@ -631,12 +633,11 @@ static Value createFPReductionComparisonOpLowering(
ConversionPatternRewriter &rewriter, Location loc, Type llvmType,
Value vectorOperand, Value accumulator, LLVM::FastmathFlagsAttr fmf) {
Value result =
- rewriter.create<LLVMRedIntrinOp>(loc, llvmType, vectorOperand, fmf);
+ LLVMRedIntrinOp::create(rewriter, loc, llvmType, vectorOperand, fmf);
if (accumulator) {
- result =
- rewriter.create<typename VectorToScalarMapper<LLVMRedIntrinOp>::Type>(
- loc, result, accumulator);
+ result = VectorToScalarMapper<LLVMRedIntrinOp>::Type::create(
+ rewriter, loc, result, accumulator);
}
return result;
@@ -667,7 +668,7 @@ static Value createMaskNeutralValue(ConversionPatternRewriter &rewriter,
const auto &floatSemantics = cast<FloatType>(llvmType).getFloatSemantics();
auto value = getMaskNeutralValue(MaskNeutral{}, floatSemantics);
auto denseValue = DenseElementsAttr::get(cast<ShapedType>(vectorType), value);
- return rewriter.create<LLVM::ConstantOp>(loc, vectorType, denseValue);
+ return LLVM::ConstantOp::create(rewriter, loc, vectorType, denseValue);
}
/// Lowers masked `fmaximum` and `fminimum` reductions using the non-masked
@@ -682,8 +683,8 @@ lowerMaskedReductionWithRegular(ConversionPatternRewriter &rewriter,
Value mask, LLVM::FastmathFlagsAttr fmf) {
const Value vectorMaskNeutral = createMaskNeutralValue<MaskNeutral>(
rewriter, loc, llvmType, vectorOperand.getType());
- const Value selectedVectorByMask = rewriter.create<LLVM::SelectOp>(
- loc, mask, vectorOperand, vectorMaskNeutral);
+ const Value selectedVectorByMask = LLVM::SelectOp::create(
+ rewriter, loc, mask, vectorOperand, vectorMaskNeutral);
return createFPReductionComparisonOpLowering<LLVMRedIntrinOp>(
rewriter, loc, llvmType, selectedVectorByMask, accumulator, fmf);
}
@@ -695,9 +696,9 @@ lowerReductionWithStartValue(ConversionPatternRewriter &rewriter, Location loc,
Value accumulator, LLVM::FastmathFlagsAttr fmf) {
accumulator = getOrCreateAccumulator<ReductionNeutral>(rewriter, loc,
llvmType, accumulator);
- return rewriter.create<LLVMRedIntrinOp>(loc, llvmType,
- /*startValue=*/accumulator,
- vectorOperand, fmf);
+ return LLVMRedIntrinOp::create(rewriter, loc, llvmType,
+ /*startValue=*/accumulator, vectorOperand,
+ fmf);
}
/// Overloaded methods to lower a *predicated* reduction to an llvm intrinsic
@@ -710,9 +711,8 @@ lowerPredicatedReductionWithStartValue(ConversionPatternRewriter &rewriter,
Value vectorOperand, Value accumulator) {
accumulator = getOrCreateAccumulator<ReductionNeutral>(rewriter, loc,
llvmType, accumulator);
- return rewriter.create<LLVMVPRedIntrinOp>(loc, llvmType,
- /*startValue=*/accumulator,
- vectorOperand);
+ return LLVMVPRedIntrinOp::create(rewriter, loc, llvmType,
+ /*startValue=*/accumulator, vectorOperand);
}
template <class LLVMVPRedIntrinOp, class ReductionNeutral>
@@ -723,9 +723,9 @@ static Value lowerPredicatedReductionWithStartValue(
llvmType, accumulator);
Value vectorLength =
createVectorLengthValue(rewriter, loc, vectorOperand.getType());
- return rewriter.create<LLVMVPRedIntrinOp>(loc, llvmType,
- /*startValue=*/accumulator,
- vectorOperand, mask, vectorLength);
+ return LLVMVPRedIntrinOp::create(rewriter, loc, llvmType,
+ /*startValue=*/accumulator, vectorOperand,
+ mask, vectorLength);
}
template <class LLVMIntVPRedIntrinOp, class IntReductionNeutral,
@@ -1036,8 +1036,8 @@ class VectorShuffleOpConversion
// For rank 0 and 1, where both operands have *exactly* the same vector
// type, there is direct shuffle support in LLVM. Use it!
if (rank <= 1 && v1Type == v2Type) {
- Value llvmShuffleOp = rewriter.create<LLVM::ShuffleVectorOp>(
- loc, adaptor.getV1(), adaptor.getV2(),
+ Value llvmShuffleOp = LLVM::ShuffleVectorOp::create(
+ rewriter, loc, adaptor.getV1(), adaptor.getV2(),
llvm::to_vector_of<int32_t>(mask));
rewriter.replaceOp(shuffleOp, llvmShuffleOp);
return success();
@@ -1050,7 +1050,7 @@ class VectorShuffleOpConversion
eltType = arrayType.getElementType();
else
eltType = cast<VectorType>(llvmType).getElementType();
- Value insert = rewriter.create<LLVM::PoisonOp>(loc, llvmType);
+ Value insert = LLVM::PoisonOp::create(rewriter, loc, llvmType);
int64_t insPos = 0;
for (int64_t extPos : mask) {
Value value = adaptor.getV1();
@@ -1087,9 +1087,9 @@ class VectorExtractElementOpConversion
if (vectorType.getRank() == 0) {
Location loc = extractEltOp.getLoc();
auto idxType = rewriter.getIndexType();
- auto zero = rewriter.create<LLVM::ConstantOp>(
- loc, typeConverter->convertType(idxType),
- rewriter.getIntegerAttr(idxType, 0));
+ auto zero = LLVM::ConstantOp::create(rewriter, loc,
+ typeConverter->convertType(idxType),
+ rewriter.getIntegerAttr(idxType, 0));
rewriter.replaceOpWithNewOp<LLVM::ExtractElementOp>(
extractEltOp, llvmType, adaptor.getVector(), zero);
return success();
@@ -1158,13 +1158,14 @@ class VectorExtractOpConversion
if (!llvm::all_of(position, llvm::IsaPred<Attribute>)) {
return failure();
}
- extracted = rewriter.create<LLVM::ExtractValueOp>(
- loc, extracted, getAsIntegers(position));
+ extracted = LLVM::ExtractValueOp::create(rewriter, loc, extracted,
+ getAsIntegers(position));
}
if (extractsScalar) {
- extracted = rewriter.create<LLVM::ExtractElementOp>(
- loc, extracted, getAsLLVMValue(rewriter, loc, positionVec.back()));
+ extracted = LLVM::ExtractElementOp::create(
+ rewriter, loc, extracted,
+ getAsLLVMValue(rewriter, loc, positionVec.back()));
}
rewriter.replaceOp(extractOp, extracted);
@@ -1221,9 +1222,9 @@ class VectorInsertElementOpConversion
if (vectorType.getRank() == 0) {
Location loc = insertEltOp.getLoc();
auto idxType = rewriter.getIndexType();
- auto zero = rewriter.create<LLVM::ConstantOp>(
- loc, typeConverter->convertType(idxType),
- rewriter.getIntegerAttr(idxType, 0));
+ auto zero = LLVM::ConstantOp::create(rewriter, loc,
+ typeConverter->convertType(idxType),
+ rewriter.getIntegerAttr(idxType, 0));
rewriter.replaceOpWithNewOp<LLVM::InsertElementOp>(
insertEltOp, llvmType, adaptor.getDest(), adaptor.getSource(), zero);
return success();
@@ -1307,8 +1308,8 @@ class VectorInsertOpConversion
// llvm.extractvalue does not support dynamic dimensions.
return failure();
}
- sourceAggregate = rewriter.create<LLVM::ExtractValueOp>(
- loc, adaptor.getDest(),
+ sourceAggregate = LLVM::ExtractValueOp::create(
+ rewriter, loc, adaptor.getDest(),
getAsIntegers(positionOf1DVectorWithinAggregate));
} else {
// No-aggregate case. The destination for the InsertElementOp is just
@@ -1316,16 +1317,16 @@ class VectorInsertOpConversion
sourceAggregate = adaptor.getDest();
}
// Insert the scalar into the 1D vector.
- sourceAggregate = rewriter.create<LLVM::InsertElementOp>(
- loc, sourceAggregate.getType(), sourceAggregate,
+ sourceAggregate = LLVM::InsertElementOp::create(
+ rewriter, loc, sourceAggregate.getType(), sourceAggregate,
adaptor.getValueToStore(),
getAsLLVMValue(rewriter, loc, positionOfScalarWithin1DVector));
}
Value result = sourceAggregate;
if (isNestedAggregate) {
- result = rewriter.create<LLVM::InsertValueOp>(
- loc, adaptor.getDest(), sourceAggregate,
+ result = LLVM::InsertValueOp::create(
+ rewriter, loc, adaptor.getDest(), sourceAggregate,
getAsIntegers(positionOf1DVectorWithinAggregate));
}
@@ -1404,15 +1405,15 @@ class VectorFMAOpNDRewritePattern : public OpRewritePattern<FMAOp> {
auto loc = op.getLoc();
auto elemType = vType.getElementType();
- Value zero = rewriter.create<arith::ConstantOp>(
- loc, elemType, rewriter.getZeroAttr(elemType));
- Value desc = rewriter.create<vector::BroadcastOp>(loc, vType, zero);
+ Value zero = arith::ConstantOp::create(rewriter, loc, elemType,
+ rewriter.getZeroAttr(elemType));
+ Value desc = vector::BroadcastOp::create(rewriter, loc, vType, zero);
for (int64_t i = 0, e = vType.getShape().front(); i != e; ++i) {
- Value extrLHS = rewriter.create<ExtractOp>(loc, op.getLhs(), i);
- Value extrRHS = rewriter.create<ExtractOp>(loc, op.getRhs(), i);
- Value extrACC = rewriter.create<ExtractOp>(loc, op.getAcc(), i);
- Value fma = rewriter.create<FMAOp>(loc, extrLHS, extrRHS, extrACC);
- desc = rewriter.create<InsertOp>(loc, fma, desc, i);
+ Value extrLHS = ExtractOp::create(rewriter, loc, op.getLhs(), i);
+ Value extrRHS = ExtractOp::create(rewriter, loc, op.getRhs(), i);
+ Value extrACC = ExtractOp::create(rewriter, loc, op.getAcc(), i);
+ Value fma = FMAOp::create(rewriter, loc, extrLHS, extrRHS, extrACC);
+ desc = InsertOp::create(rewriter, loc, fma, desc, i);
}
rewriter.replaceOp(op, desc);
return success();
@@ -1502,7 +1503,7 @@ class VectorTypeCastOpConversion
desc.setAlignedPtr(rewriter, loc, ptr);
// Fill offset 0.
auto attr = rewriter.getIntegerAttr(rewriter.getIndexType(), 0);
- auto zero = rewriter.create<LLVM::ConstantOp>(loc, int64Ty, attr);
+ auto zero = LLVM::ConstantOp::create(rewriter, loc, int64Ty, attr);
desc.setOffset(rewriter, loc, zero);
// Fill size and stride descriptors in memref.
@@ -1511,11 +1512,12 @@ class VectorTypeCastOpConversion
int64_t index = indexedSize.index();
auto sizeAttr =
rewriter.getIntegerAttr(rewriter.getIndexType(), indexedSize.value());
- auto size = rewriter.create<LLVM::ConstantOp>(loc, int64Ty, sizeAttr);
+ auto size = LLVM::ConstantOp::create(rewriter, loc, int64Ty, sizeAttr);
desc.setSize(rewriter, loc, index, size);
auto strideAttr = rewriter.getIntegerAttr(rewriter.getIndexType(),
(*targetStrides)[index]);
- auto stride = rewriter.create<LLVM::ConstantOp>(loc, int64Ty, strideAttr);
+ auto stride =
+ LLVM::ConstantOp::create(rewriter, loc, int64Ty, strideAttr);
desc.setStride(rewriter, loc, index, stride);
}
@@ -1543,14 +1545,15 @@ class VectorCreateMaskOpConversion
IntegerType idxType =
force32BitVectorIndices ? rewriter.getI32Type() : rewriter.getI64Type();
auto loc = op->getLoc();
- Value indices = rewriter.create<LLVM::StepVectorOp>(
- loc, LLVM::getVectorType(idxType, dstType.getShape()[0],
- /*isScalable=*/true));
+ Value indices = LLVM::StepVectorOp::create(
+ rewriter, loc,
+ LLVM::getVectorType(idxType, dstType.getShape()[0],
+ /*isScalable=*/true));
auto bound = getValueOrCreateCastToIndexLike(rewriter, loc, idxType,
adaptor.getOperands()[0]);
- Value bounds = rewriter.create<BroadcastOp>(loc, indices.getType(), bound);
- Value comp = rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::slt,
- indices, bounds);
+ Value bounds = BroadcastOp::create(rewriter, loc, indices.getType(), bound);
+ Value comp = arith::CmpIOp::create(rewriter, loc, arith::CmpIPredicate::slt,
+ indices, bounds);
rewriter.replaceOp(op, comp);
return success();
}
@@ -1706,16 +1709,16 @@ class VectorPrintOpConversion : public ConvertOpToLLVMPattern<vector::PrintOp> {
switch (conversion) {
case PrintConversion::ZeroExt64:
- value = rewriter.create<arith::ExtUIOp>(
- loc, IntegerType::get(rewriter.getContext(), 64), value);
+ value = arith::ExtUIOp::create(
+ rewriter, loc, IntegerType::get(rewriter.getContext(), 64), value);
break;
case PrintConversion::SignExt64:
- value = rewriter.create<arith::ExtSIOp>(
- loc, IntegerType::get(rewriter.getContext(), 64), value);
+ value = arith::ExtSIOp::create(
+ rewriter, loc, IntegerType::get(rewriter.getContext(), 64), value);
break;
case PrintConversion::Bitcast16:
- value = rewriter.create<LLVM::BitcastOp>(
- loc, IntegerType::get(rewriter.getContext(), 16), value);
+ value = LLVM::BitcastOp::create(
+ rewriter, loc, IntegerType::get(rewriter.getContext(), 16), value);
break;
case PrintConversion::None:
break;
@@ -1727,8 +1730,8 @@ class VectorPrintOpConversion : public ConvertOpToLLVMPattern<vector::PrintOp> {
// Helper to emit a call.
static void emitCall(ConversionPatternRewriter &rewriter, Location loc,
Operation *ref, ValueRange params = ValueRange()) {
- rewriter.create<LLVM::CallOp>(loc, TypeRange(), SymbolRefAttr::get(ref),
- params);
+ LLVM::CallOp::create(rewriter, loc, TypeRange(), SymbolRefAttr::get(ref),
+ params);
}
};
@@ -1754,9 +1757,9 @@ struct VectorBroadcastScalarToLowRankLowering
// First insert it into a poison vector so we can shuffle it.
auto vectorType = typeConverter->convertType(broadcast.getType());
Value poison =
- rewriter.create<LLVM::PoisonOp>(broadcast.getLoc(), vectorType);
- auto zero = rewriter.create<LLVM::ConstantOp>(
- broadcast.getLoc(),
+ LLVM::PoisonOp::create(rewriter, broadcast.getLoc(), vectorType);
+ auto zero = LLVM::ConstantOp::create(
+ rewriter, broadcast.getLoc(),
typeConverter->convertType(rewriter.getIntegerType(32)),
rewriter.getZeroAttr(rewriter.getIntegerType(32)));
@@ -1768,8 +1771,9 @@ struct VectorBroadcastScalarToLowRankLowering
}
// For 1-d vector, we additionally do a `vectorshuffle`.
- auto v = rewriter.create<LLVM::InsertElementOp>(
- broadcast.getLoc(), vectorType, poison, adaptor.getSource(), zero);
+ auto v =
+ LLVM::InsertElementOp::create(rewriter, broadcast.getLoc(), vectorType,
+ poison, adaptor.getSource(), zero);
int64_t width = cast<VectorType>(broadcast.getType()).getDimSize(0);
SmallVector<int32_t> zeroValues(width, 0);
@@ -1811,26 +1815,26 @@ struct VectorBroadcastScalarToNdLowering
return failure();
// Construct returned value.
- Value desc = rewriter.create<LLVM::PoisonOp>(loc, llvmNDVectorTy);
+ Value desc = LLVM::PoisonOp::create(rewriter, loc, llvmNDVectorTy);
// Construct a 1-D vector with the broadcasted value that we insert in all
// the places within the returned descriptor.
- Value vdesc = rewriter.create<LLVM::PoisonOp>(loc, llvm1DVectorTy);
- auto zero = rewriter.create<LLVM::ConstantOp>(
- loc, typeConverter->convertType(rewriter.getIntegerType(32)),
+ Value vdesc = LLVM::PoisonOp::create(rewriter, loc, llvm1DVectorTy);
+ auto zero = LLVM::ConstantOp::create(
+ rewriter, loc, typeConverter->convertType(rewriter.getIntegerType(32)),
rewriter.getZeroAttr(rewriter.getIntegerType(32)));
- Value v = rewriter.create<LLVM::InsertElementOp>(loc, llvm1DVectorTy, vdesc,
- adaptor.getSource(), zero);
+ Value v = LLVM::InsertElementOp::create(rewriter, loc, llvm1DVectorTy,
+ vdesc, adaptor.getSource(), zero);
// Shuffle the value across the desired number of elements.
int64_t width = resultType.getDimSize(resultType.getRank() - 1);
SmallVector<int32_t> zeroValues(width, 0);
- v = rewriter.create<LLVM::ShuffleVectorOp>(loc, v, v, zeroValues);
+ v = LLVM::ShuffleVectorOp::create(rewriter, loc, v, v, zeroValues);
// Iterate of linear index, convert to coords space and insert broadcasted
// 1-D vector in each position.
nDVectorIterate(vectorTypeInfo, rewriter, [&](ArrayRef<int64_t> position) {
- desc = rewriter.create<LLVM::InsertValueOp>(loc, desc, v, position);
+ desc = LLVM::InsertValueOp::create(rewriter, loc, desc, v, position);
});
rewriter.replaceOp(broadcast, desc);
return success();
@@ -1900,13 +1904,13 @@ struct VectorDeinterleaveOpLowering
auto deinterleaveResults = deinterleaveOp.getResultTypes();
auto packedOpResults =
llvmTypeConverter->packOperationResults(deinterleaveResults);
- auto intrinsic = rewriter.create<LLVM::vector_deinterleave2>(
- loc, packedOpResults, adaptor.getSource());
+ auto intrinsic = LLVM::vector_deinterleave2::create(
+ rewriter, loc, packedOpResults, adaptor.getSource());
- auto evenResult = rewriter.create<LLVM::ExtractValueOp>(
- loc, intrinsic->getResult(0), 0);
- auto oddResult = rewriter.create<LLVM::ExtractValueOp>(
- loc, intrinsic->getResult(0), 1);
+ auto evenResult = LLVM::ExtractValueOp::create(
+ rewriter, loc, intrinsic->getResult(0), 0);
+ auto oddResult = LLVM::ExtractValueOp::create(rewriter, loc,
+ intrinsic->getResult(0), 1);
rewriter.replaceOp(deinterleaveOp, ValueRange{evenResult, oddResult});
return success();
@@ -1929,11 +1933,11 @@ struct VectorDeinterleaveOpLowering
oddShuffleMask.push_back(i);
}
- auto poison = rewriter.create<LLVM::PoisonOp>(loc, sourceType);
- auto evenShuffle = rewriter.create<LLVM::ShuffleVectorOp>(
- loc, adaptor.getSource(), poison, evenShuffleMask);
- auto oddShuffle = rewriter.create<LLVM::ShuffleVectorOp>(
- loc, adaptor.getSource(), poison, oddShuffleMask);
+ auto poison = LLVM::PoisonOp::create(rewriter, loc, sourceType);
+ auto evenShuffle = LLVM::ShuffleVectorOp::create(
+ rewriter, loc, adaptor.getSource(), poison, evenShuffleMask);
+ auto oddShuffle = LLVM::ShuffleVectorOp::create(
+ rewriter, loc, adaptor.getSource(), poison, oddShuffleMask);
rewriter.replaceOp(deinterleaveOp, ValueRange{evenShuffle, oddShuffle});
return success();
@@ -1956,9 +1960,9 @@ struct VectorFromElementsLowering
return rewriter.notifyMatchFailure(fromElementsOp,
"rank > 1 vectors are not supported");
Type llvmType = typeConverter->convertType(vectorType);
- Value result = rewriter.create<LLVM::PoisonOp>(loc, llvmType);
+ Value result = LLVM::PoisonOp::create(rewriter, loc, llvmType);
for (auto [idx, val] : llvm::enumerate(adaptor.getElements()))
- result = rewriter.create<vector::InsertOp>(loc, val, result, idx);
+ result = vector::InsertOp::create(rewriter, loc, val, result, idx);
rewriter.replaceOp(fromElementsOp, result);
return success();
}
@@ -1982,12 +1986,12 @@ struct VectorToElementsLowering
if (element.use_empty())
continue;
- auto constIdx = rewriter.create<LLVM::ConstantOp>(
- loc, idxType, rewriter.getIntegerAttr(idxType, idx));
+ auto constIdx = LLVM::ConstantOp::create(
+ rewriter, loc, idxType, rewriter.getIntegerAttr(idxType, idx));
auto llvmType = typeConverter->convertType(element.getType());
- Value result = rewriter.create<LLVM::ExtractElementOp>(loc, llvmType,
- source, constIdx);
+ Value result = LLVM::ExtractElementOp::create(rewriter, loc, llvmType,
+ source, constIdx);
results[idx] = result;
}
@@ -2098,7 +2102,7 @@ FailureOr<Value> ContractionOpToMatmulOpLowering::matchAndRewriteMaskableOp(
Value lhs = op.getLhs();
auto lhsMap = op.getIndexingMapsArray()[0];
if (lhsMap == AffineMap::get(3, 0, {k, m}, ctx))
- lhs = rew.create<vector::TransposeOp>(loc, lhs, ArrayRef<int64_t>{1, 0});
+ lhs = vector::TransposeOp::create(rew, loc, lhs, ArrayRef<int64_t>{1, 0});
else if (lhsMap != AffineMap::get(3, 0, {m, k}, ctx))
return failure();
@@ -2106,7 +2110,7 @@ FailureOr<Value> ContractionOpToMatmulOpLowering::matchAndRewriteMaskableOp(
Value rhs = op.getRhs();
auto rhsMap = op.getIndexingMapsArray()[1];
if (rhsMap == AffineMap::get(3, 0, {n, k}, ctx))
- rhs = rew.create<vector::TransposeOp>(loc, rhs, ArrayRef<int64_t>{1, 0});
+ rhs = vector::TransposeOp::create(rew, loc, rhs, ArrayRef<int64_t>{1, 0});
else if (rhsMap != AffineMap::get(3, 0, {k, n}, ctx))
return failure();
@@ -2119,20 +2123,20 @@ FailureOr<Value> ContractionOpToMatmulOpLowering::matchAndRewriteMaskableOp(
Type flattenedLHSType =
VectorType::get(lhsType.getNumElements(), lhsType.getElementType());
- lhs = rew.create<vector::ShapeCastOp>(loc, flattenedLHSType, lhs);
+ lhs = vector::ShapeCastOp::create(rew, loc, flattenedLHSType, lhs);
Type flattenedRHSType =
VectorType::get(rhsType.getNumElements(), rhsType.getElementType());
- rhs = rew.create<vector::ShapeCastOp>(loc, flattenedRHSType, rhs);
+ rhs = vector::ShapeCastOp::create(rew, loc, flattenedRHSType, rhs);
- Value mul = rew.create<LLVM::MatrixMultiplyOp>(
- loc,
+ Value mul = LLVM::MatrixMultiplyOp::create(
+ rew, loc,
VectorType::get(lhsRows * rhsColumns,
cast<VectorType>(lhs.getType()).getElementType()),
lhs, rhs, lhsRows, lhsColumns, rhsColumns);
- mul = rew.create<vector::ShapeCastOp>(
- loc,
+ mul = vector::ShapeCastOp::create(
+ rew, loc,
VectorType::get({lhsRows, rhsColumns},
getElementTypeOrSelf(op.getAcc().getType())),
mul);
@@ -2140,15 +2144,15 @@ FailureOr<Value> ContractionOpToMatmulOpLowering::matchAndRewriteMaskableOp(
// ACC must be C(m, n) or C(n, m).
auto accMap = op.getIndexingMapsArray()[2];
if (accMap == AffineMap::get(3, 0, {n, m}, ctx))
- mul = rew.create<vector::TransposeOp>(loc, mul, ArrayRef<int64_t>{1, 0});
+ mul = vector::TransposeOp::create(rew, loc, mul, ArrayRef<int64_t>{1, 0});
else if (accMap != AffineMap::get(3, 0, {m, n}, ctx))
llvm_unreachable("invalid contraction semantics");
- Value res =
- isa<IntegerType>(elementType)
- ? static_cast<Value>(rew.create<arith::AddIOp>(loc, op.getAcc(), mul))
- : static_cast<Value>(
- rew.create<arith::AddFOp>(loc, op.getAcc(), mul));
+ Value res = isa<IntegerType>(elementType)
+ ? static_cast<Value>(
+ arith::AddIOp::create(rew, loc, op.getAcc(), mul))
+ : static_cast<Value>(
+ arith::AddFOp::create(rew, loc, op.getAcc(), mul));
return res;
}
@@ -2181,11 +2185,11 @@ class TransposeOpToMatrixTransposeOpLowering
Type flattenedType =
VectorType::get(resType.getNumElements(), resType.getElementType());
auto matrix =
- rewriter.create<vector::ShapeCastOp>(loc, flattenedType, input);
+ vector::ShapeCastOp::create(rewriter, loc, flattenedType, input);
auto rows = rewriter.getI32IntegerAttr(resType.getShape()[0]);
auto columns = rewriter.getI32IntegerAttr(resType.getShape()[1]);
- Value trans = rewriter.create<LLVM::MatrixTransposeOp>(
- loc, flattenedType, matrix, rows, columns);
+ Value trans = LLVM::MatrixTransposeOp::create(rewriter, loc, flattenedType,
+ matrix, rows, columns);
rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(op, resType, trans);
return success();
}
diff --git a/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp b/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp
index 43732f58a4e0a..4c1047a8871a5 100644
--- a/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp
+++ b/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp
@@ -132,9 +132,9 @@ static void maybeYieldValue(OpBuilder &b, Location loc, bool hasRetVal,
Value value) {
if (hasRetVal) {
assert(value && "Expected non-empty value");
- b.create<scf::YieldOp>(loc, value);
+ scf::YieldOp::create(b, loc, value);
} else {
- b.create<scf::YieldOp>(loc);
+ scf::YieldOp::create(b, loc);
}
}
@@ -154,7 +154,7 @@ static Value generateMaskCheck(OpBuilder &b, OpTy xferOp, Value iv) {
return Value();
Location loc = xferOp.getLoc();
- return b.create<vector::ExtractOp>(loc, xferOp.getMask(), iv);
+ return vector::ExtractOp::create(b, loc, xferOp.getMask(), iv);
}
/// Helper function TransferOpConversion and TransferOp1dConversion.
@@ -201,22 +201,22 @@ static Value generateInBoundsCheck(
Value base = xferOp.getIndices()[*dim];
Value memrefIdx =
affine::makeComposedAffineApply(b, loc, d0 + d1, {base, iv});
- cond = lb.create<arith::CmpIOp>(arith::CmpIPredicate::sgt, memrefDim,
- memrefIdx);
+ cond = arith::CmpIOp::create(lb, arith::CmpIPredicate::sgt, memrefDim,
+ memrefIdx);
}
// Condition check 2: Masked in?
if (auto maskCond = generateMaskCheck(b, xferOp, iv)) {
if (cond)
- cond = lb.create<arith::AndIOp>(cond, maskCond);
+ cond = arith::AndIOp::create(lb, cond, maskCond);
else
cond = maskCond;
}
// If the condition is non-empty, generate an SCF::IfOp.
if (cond) {
- auto check = lb.create<scf::IfOp>(
- cond,
+ auto check = scf::IfOp::create(
+ lb, cond,
/*thenBuilder=*/
[&](OpBuilder &b, Location loc) {
maybeYieldValue(b, loc, hasRetVal, inBoundsCase(b, loc));
@@ -226,7 +226,7 @@ static Value generateInBoundsCheck(
if (outOfBoundsCase) {
maybeYieldValue(b, loc, hasRetVal, outOfBoundsCase(b, loc));
} else {
- b.create<scf::YieldOp>(loc);
+ scf::YieldOp::create(b, loc);
}
});
@@ -303,14 +303,15 @@ static BufferAllocs allocBuffers(OpBuilder &b, OpTy xferOp) {
BufferAllocs result;
auto bufferType = MemRefType::get({}, xferOp.getVectorType());
- result.dataBuffer = b.create<memref::AllocaOp>(loc, bufferType);
+ result.dataBuffer = memref::AllocaOp::create(b, loc, bufferType);
if (xferOp.getMask()) {
auto maskType = MemRefType::get({}, xferOp.getMask().getType());
- auto maskBuffer = b.create<memref::AllocaOp>(loc, maskType);
+ auto maskBuffer = memref::AllocaOp::create(b, loc, maskType);
b.setInsertionPoint(xferOp);
- b.create<memref::StoreOp>(loc, xferOp.getMask(), maskBuffer);
- result.maskBuffer = b.create<memref::LoadOp>(loc, maskBuffer, ValueRange());
+ memref::StoreOp::create(b, loc, xferOp.getMask(), maskBuffer);
+ result.maskBuffer =
+ memref::LoadOp::create(b, loc, maskBuffer, ValueRange());
}
return result;
@@ -421,14 +422,15 @@ struct Strategy<TransferReadOp> {
auto bufferType = dyn_cast<ShapedType>(buffer.getType());
auto vecType = dyn_cast<VectorType>(bufferType.getElementType());
auto inBoundsAttr = dropFirstElem(b, xferOp.getInBoundsAttr());
- auto newXferOp = b.create<vector::TransferReadOp>(
- loc, vecType, xferOp.getBase(), xferIndices,
+ auto newXferOp = vector::TransferReadOp::create(
+ b, loc, vecType, xferOp.getBase(), xferIndices,
AffineMapAttr::get(unpackedPermutationMap(b, xferOp)),
xferOp.getPadding(), Value(), inBoundsAttr);
maybeApplyPassLabel(b, newXferOp, options.targetRank);
- b.create<memref::StoreOp>(loc, newXferOp.getVector(), buffer, storeIndices);
+ memref::StoreOp::create(b, loc, newXferOp.getVector(), buffer,
+ storeIndices);
return newXferOp;
}
@@ -444,8 +446,9 @@ struct Strategy<TransferReadOp> {
Location loc = xferOp.getLoc();
auto bufferType = dyn_cast<ShapedType>(buffer.getType());
auto vecType = dyn_cast<VectorType>(bufferType.getElementType());
- auto vec = b.create<vector::BroadcastOp>(loc, vecType, xferOp.getPadding());
- b.create<memref::StoreOp>(loc, vec, buffer, storeIndices);
+ auto vec =
+ vector::BroadcastOp::create(b, loc, vecType, xferOp.getPadding());
+ memref::StoreOp::create(b, loc, vec, buffer, storeIndices);
return Value();
}
@@ -506,12 +509,12 @@ struct Strategy<TransferWriteOp> {
getXferIndices(b, xferOp, iv, xferIndices);
Location loc = xferOp.getLoc();
- auto vec = b.create<memref::LoadOp>(loc, buffer, loadIndices);
+ auto vec = memref::LoadOp::create(b, loc, buffer, loadIndices);
auto inBoundsAttr = dropFirstElem(b, xferOp.getInBoundsAttr());
auto source = loopState.empty() ? xferOp.getBase() : loopState[0];
Type type = isTensorOp(xferOp) ? xferOp.getShapedType() : Type();
- auto newXferOp = b.create<vector::TransferWriteOp>(
- loc, type, vec, source, xferIndices,
+ auto newXferOp = vector::TransferWriteOp::create(
+ b, loc, type, vec, source, xferIndices,
AffineMapAttr::get(unpackedPermutationMap(b, xferOp)), Value(),
inBoundsAttr);
@@ -610,8 +613,8 @@ struct PrepareTransferReadConversion
}
Location loc = xferOp.getLoc();
- rewriter.create<memref::StoreOp>(loc, newXfer->getResult(0),
- buffers.dataBuffer);
+ memref::StoreOp::create(rewriter, loc, newXfer->getResult(0),
+ buffers.dataBuffer);
rewriter.replaceOpWithNewOp<memref::LoadOp>(xferOp, buffers.dataBuffer);
return success();
@@ -653,9 +656,9 @@ struct PrepareTransferWriteConversion
Location loc = xferOp.getLoc();
auto buffers = allocBuffers(rewriter, xferOp);
- rewriter.create<memref::StoreOp>(loc, xferOp.getVector(),
- buffers.dataBuffer);
- auto loadedVec = rewriter.create<memref::LoadOp>(loc, buffers.dataBuffer);
+ memref::StoreOp::create(rewriter, loc, xferOp.getVector(),
+ buffers.dataBuffer);
+ auto loadedVec = memref::LoadOp::create(rewriter, loc, buffers.dataBuffer);
rewriter.modifyOpInPlace(xferOp, [&]() {
xferOp.getValueToStoreMutable().assign(loadedVec);
xferOp->setAttr(kPassLabel, rewriter.getUnitAttr());
@@ -735,17 +738,17 @@ struct DecomposePrintOpConversion : public VectorToSCFPattern<vector::PrintOp> {
auto signlessTargetVectorType =
vectorType.cloneWith({}, getIntTypeWithSignlessSemantics(legalIntTy));
auto targetVectorType = vectorType.cloneWith({}, legalIntTy);
- value = rewriter.create<vector::BitCastOp>(loc, signlessSourceVectorType,
- value);
+ value = vector::BitCastOp::create(rewriter, loc, signlessSourceVectorType,
+ value);
if (value.getType() != signlessTargetVectorType) {
if (width == 1 || intTy.isUnsigned())
- value = rewriter.create<arith::ExtUIOp>(loc, signlessTargetVectorType,
- value);
+ value = arith::ExtUIOp::create(rewriter, loc,
+ signlessTargetVectorType, value);
else
- value = rewriter.create<arith::ExtSIOp>(loc, signlessTargetVectorType,
- value);
+ value = arith::ExtSIOp::create(rewriter, loc,
+ signlessTargetVectorType, value);
}
- value = rewriter.create<vector::BitCastOp>(loc, targetVectorType, value);
+ value = vector::BitCastOp::create(rewriter, loc, targetVectorType, value);
vectorType = targetVectorType;
}
@@ -762,29 +765,30 @@ struct DecomposePrintOpConversion : public VectorToSCFPattern<vector::PrintOp> {
std::multiplies<int64_t>());
auto flatVectorType =
VectorType::get({flatLength}, vectorType.getElementType());
- value = rewriter.create<vector::ShapeCastOp>(loc, flatVectorType, value);
+ value = vector::ShapeCastOp::create(rewriter, loc, flatVectorType, value);
}
vector::PrintOp firstClose;
SmallVector<Value, 8> loopIndices;
for (unsigned d = 0; d < shape.size(); d++) {
// Setup loop bounds and step.
- Value lowerBound = rewriter.create<arith::ConstantIndexOp>(loc, 0);
- Value upperBound = rewriter.create<arith::ConstantIndexOp>(loc, shape[d]);
- Value step = rewriter.create<arith::ConstantIndexOp>(loc, 1);
+ Value lowerBound = arith::ConstantIndexOp::create(rewriter, loc, 0);
+ Value upperBound =
+ arith::ConstantIndexOp::create(rewriter, loc, shape[d]);
+ Value step = arith::ConstantIndexOp::create(rewriter, loc, 1);
if (!scalableDimensions.empty() && scalableDimensions[d]) {
- auto vscale = rewriter.create<vector::VectorScaleOp>(
- loc, rewriter.getIndexType());
- upperBound = rewriter.create<arith::MulIOp>(loc, upperBound, vscale);
+ auto vscale = vector::VectorScaleOp::create(rewriter, loc,
+ rewriter.getIndexType());
+ upperBound = arith::MulIOp::create(rewriter, loc, upperBound, vscale);
}
- auto lastIndex = rewriter.create<arith::SubIOp>(loc, upperBound, step);
+ auto lastIndex = arith::SubIOp::create(rewriter, loc, upperBound, step);
// Create a loop to print the elements surrounded by parentheses.
- rewriter.create<vector::PrintOp>(loc, vector::PrintPunctuation::Open);
+ vector::PrintOp::create(rewriter, loc, vector::PrintPunctuation::Open);
auto loop =
- rewriter.create<scf::ForOp>(loc, lowerBound, upperBound, step);
- auto printClose = rewriter.create<vector::PrintOp>(
- loc, vector::PrintPunctuation::Close);
+ scf::ForOp::create(rewriter, loc, lowerBound, upperBound, step);
+ auto printClose = vector::PrintOp::create(
+ rewriter, loc, vector::PrintPunctuation::Close);
if (!firstClose)
firstClose = printClose;
@@ -793,14 +797,14 @@ struct DecomposePrintOpConversion : public VectorToSCFPattern<vector::PrintOp> {
// Print a comma after all but the last element.
rewriter.setInsertionPointToStart(loop.getBody());
- auto notLastIndex = rewriter.create<arith::CmpIOp>(
- loc, arith::CmpIPredicate::ult, loopIdx, lastIndex);
- rewriter.create<scf::IfOp>(loc, notLastIndex,
- [&](OpBuilder &builder, Location loc) {
- builder.create<vector::PrintOp>(
- loc, vector::PrintPunctuation::Comma);
- builder.create<scf::YieldOp>(loc);
- });
+ auto notLastIndex = arith::CmpIOp::create(
+ rewriter, loc, arith::CmpIPredicate::ult, loopIdx, lastIndex);
+ scf::IfOp::create(rewriter, loc, notLastIndex,
+ [&](OpBuilder &builder, Location loc) {
+ vector::PrintOp::create(
+ builder, loc, vector::PrintPunctuation::Comma);
+ scf::YieldOp::create(builder, loc);
+ });
rewriter.setInsertionPointToStart(loop.getBody());
}
@@ -810,22 +814,23 @@ struct DecomposePrintOpConversion : public VectorToSCFPattern<vector::PrintOp> {
Value flatIndex;
auto currentStride = 1;
for (int d = shape.size() - 1; d >= 0; d--) {
- auto stride = rewriter.create<arith::ConstantIndexOp>(loc, currentStride);
- auto index = rewriter.create<arith::MulIOp>(loc, stride, loopIndices[d]);
+ auto stride =
+ arith::ConstantIndexOp::create(rewriter, loc, currentStride);
+ auto index = arith::MulIOp::create(rewriter, loc, stride, loopIndices[d]);
if (flatIndex)
- flatIndex = rewriter.create<arith::AddIOp>(loc, flatIndex, index);
+ flatIndex = arith::AddIOp::create(rewriter, loc, flatIndex, index);
else
flatIndex = index;
currentStride *= shape[d];
}
// Print the scalar elements in the inner most loop.
- auto element = rewriter.create<vector::ExtractOp>(loc, value, flatIndex);
- rewriter.create<vector::PrintOp>(loc, element,
- vector::PrintPunctuation::NoPunctuation);
+ auto element = vector::ExtractOp::create(rewriter, loc, value, flatIndex);
+ vector::PrintOp::create(rewriter, loc, element,
+ vector::PrintPunctuation::NoPunctuation);
rewriter.setInsertionPointAfter(firstClose);
- rewriter.create<vector::PrintOp>(loc, printOp.getPunctuation());
+ vector::PrintOp::create(rewriter, loc, printOp.getPunctuation());
rewriter.eraseOp(printOp);
return success();
}
@@ -916,7 +921,7 @@ struct TransferOpConversion : public VectorToSCFPattern<OpTy> {
"Failed to unpack one vector dim.");
auto castedDataBuffer =
- locB.create<vector::TypeCastOp>(*castedDataType, dataBuffer);
+ vector::TypeCastOp::create(locB, *castedDataType, dataBuffer);
// If the xferOp has a mask: Find and cast mask buffer.
Value castedMaskBuffer;
@@ -935,22 +940,22 @@ struct TransferOpConversion : public VectorToSCFPattern<OpTy> {
auto maskBufferType = cast<MemRefType>(maskBuffer.getType());
MemRefType castedMaskType = *unpackOneDim(maskBufferType);
castedMaskBuffer =
- locB.create<vector::TypeCastOp>(castedMaskType, maskBuffer);
+ vector::TypeCastOp::create(locB, castedMaskType, maskBuffer);
}
}
// Loop bounds and step.
- auto lb = locB.create<arith::ConstantIndexOp>(0);
- auto ub = locB.create<arith::ConstantIndexOp>(
- castedDataType->getDimSize(castedDataType->getRank() - 1));
- auto step = locB.create<arith::ConstantIndexOp>(1);
+ auto lb = arith::ConstantIndexOp::create(locB, 0);
+ auto ub = arith::ConstantIndexOp::create(
+ locB, castedDataType->getDimSize(castedDataType->getRank() - 1));
+ auto step = arith::ConstantIndexOp::create(locB, 1);
// TransferWriteOps that operate on tensors return the modified tensor and
// require a loop state.
auto loopState = Strategy<OpTy>::initialLoopState(xferOp);
// Generate for loop.
- auto result = locB.create<scf::ForOp>(
- lb, ub, step, loopState ? ValueRange(loopState) : ValueRange(),
+ auto result = scf::ForOp::create(
+ locB, lb, ub, step, loopState ? ValueRange(loopState) : ValueRange(),
[&](OpBuilder &b, Location loc, Value iv, ValueRange loopState) {
Type stateType = loopState.empty() ? Type() : loopState[0].getType();
@@ -975,8 +980,8 @@ struct TransferOpConversion : public VectorToSCFPattern<OpTy> {
SmallVector<Value, 8> loadIndices;
getMaskBufferLoadIndices(xferOp, castedMaskBuffer,
loadIndices, iv);
- auto mask = b.create<memref::LoadOp>(loc, castedMaskBuffer,
- loadIndices);
+ auto mask = memref::LoadOp::create(b, loc, castedMaskBuffer,
+ loadIndices);
rewriter.modifyOpInPlace(newXfer, [&]() {
newXfer.getMaskMutable().assign(mask);
});
@@ -1119,30 +1124,30 @@ struct ScalableTransposeTransferWriteConversion
auto transposeSource = transposeOp.getVector();
SmallVector<Value> transposeSourceSlices =
llvm::map_to_vector(fixedDimOffsets, [&](int64_t idx) -> Value {
- return rewriter.create<vector::ExtractOp>(loc, transposeSource, idx);
+ return vector::ExtractOp::create(rewriter, loc, transposeSource, idx);
});
// Loop bounds and step.
- auto lb = rewriter.create<arith::ConstantIndexOp>(loc, 0);
+ auto lb = arith::ConstantIndexOp::create(rewriter, loc, 0);
auto ub =
maskDims->empty()
? Value(createVscaleMultiple(vectorType.getDimSize(0)))
: vector::getAsValues(rewriter, loc, maskDims->front()).front();
- auto step = rewriter.create<arith::ConstantIndexOp>(loc, 1);
+ auto step = arith::ConstantIndexOp::create(rewriter, loc, 1);
// Generate a new mask for the slice.
VectorType sliceType = VectorType::Builder(vectorType).dropDim(0);
Value sliceMask = nullptr;
if (!maskDims->empty()) {
- sliceMask = rewriter.create<vector::CreateMaskOp>(
- loc, sliceType.clone(rewriter.getI1Type()),
+ sliceMask = vector::CreateMaskOp::create(
+ rewriter, loc, sliceType.clone(rewriter.getI1Type()),
ArrayRef<OpFoldResult>(*maskDims).drop_front());
}
Value initDest = isTensorOp(writeOp) ? writeOp.getBase() : Value{};
ValueRange initLoopArgs = initDest ? initDest : ValueRange{};
- auto result = rewriter.create<scf::ForOp>(
- loc, lb, ub, step, initLoopArgs,
+ auto result = scf::ForOp::create(
+ rewriter, loc, lb, ub, step, initLoopArgs,
[&](OpBuilder &b, Location loc, Value iv, ValueRange loopIterArgs) {
// Indices for the new transfer op.
SmallVector<Value, 8> xferIndices;
@@ -1151,25 +1156,25 @@ struct ScalableTransposeTransferWriteConversion
// Extract a transposed slice from the source vector.
SmallVector<Value> transposeElements =
llvm::map_to_vector(fixedDimOffsets, [&](int64_t idx) -> Value {
- return b.create<vector::ExtractOp>(
- loc, transposeSourceSlices[idx], iv);
+ return vector::ExtractOp::create(
+ b, loc, transposeSourceSlices[idx], iv);
});
- auto sliceVec = b.create<vector::FromElementsOp>(loc, sliceType,
- transposeElements);
+ auto sliceVec = vector::FromElementsOp::create(b, loc, sliceType,
+ transposeElements);
// Create the transfer_write for the slice.
Value dest =
loopIterArgs.empty() ? writeOp.getBase() : loopIterArgs.front();
- auto newWriteOp = b.create<vector::TransferWriteOp>(
- loc, sliceVec, dest, xferIndices,
+ auto newWriteOp = vector::TransferWriteOp::create(
+ b, loc, sliceVec, dest, xferIndices,
ArrayRef<bool>(writeOp.getInBoundsValues()).drop_front());
if (sliceMask)
newWriteOp.getMaskMutable().assign(sliceMask);
// Yield from the loop.
- b.create<scf::YieldOp>(loc, loopIterArgs.empty()
- ? ValueRange{}
- : newWriteOp.getResult());
+ scf::YieldOp::create(b, loc,
+ loopIterArgs.empty() ? ValueRange{}
+ : newWriteOp.getResult());
});
if (isTensorOp(writeOp))
@@ -1207,7 +1212,7 @@ static void maybeAssignMask(OpBuilder &b, OpTy xferOp, OpTy newXferOp,
llvm::SmallVector<int64_t, 1> indices({i});
Location loc = xferOp.getLoc();
- auto newMask = b.create<vector::ExtractOp>(loc, xferOp.getMask(), indices);
+ auto newMask = vector::ExtractOp::create(b, loc, xferOp.getMask(), indices);
newXferOp.getMaskMutable().assign(newMask);
}
@@ -1261,8 +1266,8 @@ struct UnrollTransferReadConversion
if (auto insertOp = getInsertOp(xferOp))
return insertOp.getDest();
Location loc = xferOp.getLoc();
- return rewriter.create<vector::BroadcastOp>(loc, xferOp.getVectorType(),
- xferOp.getPadding());
+ return vector::BroadcastOp::create(rewriter, loc, xferOp.getVectorType(),
+ xferOp.getPadding());
}
/// If the result of the TransferReadOp has exactly one user, which is a
@@ -1317,7 +1322,7 @@ struct UnrollTransferReadConversion
// Generate fully unrolled loop of transfer ops.
Location loc = xferOp.getLoc();
for (int64_t i = 0; i < dimSize; ++i) {
- Value iv = rewriter.create<arith::ConstantIndexOp>(loc, i);
+ Value iv = arith::ConstantIndexOp::create(rewriter, loc, i);
// FIXME: Rename this lambda - it does much more than just
// in-bounds-check generation.
@@ -1336,8 +1341,8 @@ struct UnrollTransferReadConversion
auto inBoundsAttr = dropFirstElem(b, xferOp.getInBoundsAttr());
- auto newXferOp = b.create<vector::TransferReadOp>(
- loc, newXferVecType, xferOp.getBase(), xferIndices,
+ auto newXferOp = vector::TransferReadOp::create(
+ b, loc, newXferVecType, xferOp.getBase(), xferIndices,
AffineMapAttr::get(unpackedPermutationMap(b, xferOp)),
xferOp.getPadding(), Value(), inBoundsAttr);
maybeAssignMask(b, xferOp, newXferOp, i);
@@ -1346,11 +1351,11 @@ struct UnrollTransferReadConversion
if (newXferVecType.getRank() == 0) {
// vector.insert does not accept rank-0 as the non-indexed
// argument. Extract the scalar before inserting.
- valToInser = b.create<vector::ExtractOp>(loc, valToInser,
- SmallVector<int64_t>());
+ valToInser = vector::ExtractOp::create(b, loc, valToInser,
+ SmallVector<int64_t>());
}
- return b.create<vector::InsertOp>(loc, valToInser, vec,
- insertionIndices);
+ return vector::InsertOp::create(b, loc, valToInser, vec,
+ insertionIndices);
},
/*outOfBoundsCase=*/
[&](OpBuilder &b, Location loc) {
@@ -1460,7 +1465,7 @@ struct UnrollTransferWriteConversion
// Generate fully unrolled loop of transfer ops.
Location loc = xferOp.getLoc();
for (int64_t i = 0; i < dimSize; ++i) {
- Value iv = rewriter.create<arith::ConstantIndexOp>(loc, i);
+ Value iv = arith::ConstantIndexOp::create(rewriter, loc, i);
auto updatedSource = generateInBoundsCheck(
rewriter, xferOp, iv, unpackedDim(xferOp),
@@ -1477,20 +1482,20 @@ struct UnrollTransferWriteConversion
extractionIndices.push_back(b.getI64IntegerAttr(i));
auto extracted =
- b.create<vector::ExtractOp>(loc, vec, extractionIndices);
+ vector::ExtractOp::create(b, loc, vec, extractionIndices);
auto inBoundsAttr = dropFirstElem(b, xferOp.getInBoundsAttr());
Value xferVec;
if (inputVectorTy.getRank() == 1) {
// When target-rank=0, unrolling would causes the vector input
// argument into `transfer_write` to become a scalar. We solve
// this by broadcasting the scalar to a 0D vector.
- xferVec = b.create<vector::BroadcastOp>(
- loc, VectorType::get({}, extracted.getType()), extracted);
+ xferVec = vector::BroadcastOp::create(
+ b, loc, VectorType::get({}, extracted.getType()), extracted);
} else {
xferVec = extracted;
}
- auto newXferOp = b.create<vector::TransferWriteOp>(
- loc, sourceType, xferVec, source, xferIndices,
+ auto newXferOp = vector::TransferWriteOp::create(
+ b, loc, sourceType, xferVec, source, xferIndices,
AffineMapAttr::get(unpackedPermutationMap(b, xferOp)), Value(),
inBoundsAttr);
@@ -1572,19 +1577,19 @@ struct Strategy1d<TransferReadOp> {
b, xferOp, iv, dim, TypeRange(xferOp.getVectorType()),
/*inBoundsCase=*/
[&](OpBuilder &b, Location loc) {
- Value val = b.create<memref::LoadOp>(loc, xferOp.getBase(), indices);
- return b.create<vector::InsertOp>(loc, val, vec, iv);
+ Value val = memref::LoadOp::create(b, loc, xferOp.getBase(), indices);
+ return vector::InsertOp::create(b, loc, val, vec, iv);
},
/*outOfBoundsCase=*/
[&](OpBuilder & /*b*/, Location loc) { return vec; });
- b.create<scf::YieldOp>(loc, nextVec);
+ scf::YieldOp::create(b, loc, nextVec);
}
static Value initialLoopState(OpBuilder &b, TransferReadOp xferOp) {
// Inititalize vector with padding value.
Location loc = xferOp.getLoc();
- return b.create<vector::BroadcastOp>(loc, xferOp.getVectorType(),
- xferOp.getPadding());
+ return vector::BroadcastOp::create(b, loc, xferOp.getVectorType(),
+ xferOp.getPadding());
}
};
@@ -1601,10 +1606,10 @@ struct Strategy1d<TransferWriteOp> {
generateInBoundsCheck(
b, xferOp, iv, dim,
/*inBoundsCase=*/[&](OpBuilder &b, Location loc) {
- auto val = b.create<vector::ExtractOp>(loc, xferOp.getVector(), iv);
- b.create<memref::StoreOp>(loc, val, xferOp.getBase(), indices);
+ auto val = vector::ExtractOp::create(b, loc, xferOp.getVector(), iv);
+ memref::StoreOp::create(b, loc, val, xferOp.getBase(), indices);
});
- b.create<scf::YieldOp>(loc);
+ scf::YieldOp::create(b, loc);
}
static Value initialLoopState(OpBuilder &b, TransferWriteOp xferOp) {
@@ -1665,15 +1670,15 @@ struct TransferOp1dConversion : public VectorToSCFPattern<OpTy> {
// Loop bounds, step, state...
Location loc = xferOp.getLoc();
auto vecType = xferOp.getVectorType();
- auto lb = rewriter.create<arith::ConstantIndexOp>(loc, 0);
+ auto lb = arith::ConstantIndexOp::create(rewriter, loc, 0);
Value ub =
- rewriter.create<arith::ConstantIndexOp>(loc, vecType.getDimSize(0));
+ arith::ConstantIndexOp::create(rewriter, loc, vecType.getDimSize(0));
if (vecType.isScalable()) {
Value vscale =
- rewriter.create<vector::VectorScaleOp>(loc, rewriter.getIndexType());
- ub = rewriter.create<arith::MulIOp>(loc, ub, vscale);
+ vector::VectorScaleOp::create(rewriter, loc, rewriter.getIndexType());
+ ub = arith::MulIOp::create(rewriter, loc, ub, vscale);
}
- auto step = rewriter.create<arith::ConstantIndexOp>(loc, 1);
+ auto step = arith::ConstantIndexOp::create(rewriter, loc, 1);
auto loopState = Strategy1d<OpTy>::initialLoopState(rewriter, xferOp);
// Generate for loop.
diff --git a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
index 06b19335f2aed..986eae33503d1 100644
--- a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
+++ b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
@@ -161,19 +161,19 @@ static Value sanitizeDynamicIndex(ConversionPatternRewriter &rewriter,
Location loc, Value dynamicIndex,
int64_t kPoisonIndex, unsigned vectorSize) {
if (llvm::isPowerOf2_32(vectorSize)) {
- Value inBoundsMask = rewriter.create<spirv::ConstantOp>(
- loc, dynamicIndex.getType(),
+ Value inBoundsMask = spirv::ConstantOp::create(
+ rewriter, loc, dynamicIndex.getType(),
rewriter.getIntegerAttr(dynamicIndex.getType(), vectorSize - 1));
- return rewriter.create<spirv::BitwiseAndOp>(loc, dynamicIndex,
- inBoundsMask);
+ return spirv::BitwiseAndOp::create(rewriter, loc, dynamicIndex,
+ inBoundsMask);
}
- Value poisonIndex = rewriter.create<spirv::ConstantOp>(
- loc, dynamicIndex.getType(),
+ Value poisonIndex = spirv::ConstantOp::create(
+ rewriter, loc, dynamicIndex.getType(),
rewriter.getIntegerAttr(dynamicIndex.getType(), kPoisonIndex));
Value cmpResult =
- rewriter.create<spirv::IEqualOp>(loc, dynamicIndex, poisonIndex);
- return rewriter.create<spirv::SelectOp>(
- loc, cmpResult,
+ spirv::IEqualOp::create(rewriter, loc, dynamicIndex, poisonIndex);
+ return spirv::SelectOp::create(
+ rewriter, loc, cmpResult,
spirv::ConstantOp::getZero(dynamicIndex.getType(), loc, rewriter),
dynamicIndex);
}
@@ -441,8 +441,8 @@ static SmallVector<Value> extractAllElements(
Location loc = reduceOp.getLoc();
for (int i = 0; i < numElements; ++i) {
- values.push_back(rewriter.create<spirv::CompositeExtractOp>(
- loc, srcVectorType.getElementType(), adaptor.getVector(),
+ values.push_back(spirv::CompositeExtractOp::create(
+ rewriter, loc, srcVectorType.getElementType(), adaptor.getVector(),
rewriter.getI32ArrayAttr({i})));
}
if (Value acc = adaptor.getAcc())
@@ -495,16 +495,16 @@ struct VectorReductionPattern final : OpConversionPattern<vector::ReductionOp> {
#define INT_AND_FLOAT_CASE(kind, iop, fop) \
case vector::CombiningKind::kind: \
if (llvm::isa<IntegerType>(resultType)) { \
- result = rewriter.create<spirv::iop>(loc, resultType, result, next); \
+ result = spirv::iop::create(rewriter, loc, resultType, result, next); \
} else { \
assert(llvm::isa<FloatType>(resultType)); \
- result = rewriter.create<spirv::fop>(loc, resultType, result, next); \
+ result = spirv::fop::create(rewriter, loc, resultType, result, next); \
} \
break
#define INT_OR_FLOAT_CASE(kind, fop) \
case vector::CombiningKind::kind: \
- result = rewriter.create<fop>(loc, resultType, result, next); \
+ result = fop::create(rewriter, loc, resultType, result, next); \
break
INT_AND_FLOAT_CASE(ADD, IAddOp, FAddOp);
@@ -551,7 +551,7 @@ struct VectorReductionFloatMinMax final
#define INT_OR_FLOAT_CASE(kind, fop) \
case vector::CombiningKind::kind: \
- result = rewriter.create<fop>(loc, resultType, result, next); \
+ result = fop::create(rewriter, loc, resultType, result, next); \
break
INT_OR_FLOAT_CASE(MAXIMUMF, SPIRVFMaxOp);
@@ -632,8 +632,8 @@ struct VectorShuffleOpConvert final
auto getElementAtIdx = [&rewriter, loc = shuffleOp.getLoc()](
Value scalarOrVec, int32_t idx) -> Value {
if (auto vecTy = dyn_cast<VectorType>(scalarOrVec.getType()))
- return rewriter.create<spirv::CompositeExtractOp>(loc, scalarOrVec,
- idx);
+ return spirv::CompositeExtractOp::create(rewriter, loc, scalarOrVec,
+ idx);
assert(idx == 0 && "Invalid scalar element index");
return scalarOrVec;
@@ -731,11 +731,13 @@ struct VectorDeinterleaveOpConvert final
// We cannot use `spirv::VectorShuffleOp` directly in this case, and need to
// use `spirv::CompositeExtractOp`.
if (n == 2) {
- auto elem0 = rewriter.create<spirv::CompositeExtractOp>(
- loc, newResultType, sourceVector, rewriter.getI32ArrayAttr({0}));
+ auto elem0 = spirv::CompositeExtractOp::create(
+ rewriter, loc, newResultType, sourceVector,
+ rewriter.getI32ArrayAttr({0}));
- auto elem1 = rewriter.create<spirv::CompositeExtractOp>(
- loc, newResultType, sourceVector, rewriter.getI32ArrayAttr({1}));
+ auto elem1 = spirv::CompositeExtractOp::create(
+ rewriter, loc, newResultType, sourceVector,
+ rewriter.getI32ArrayAttr({1}));
rewriter.replaceOp(deinterleaveOp, {elem0, elem1});
return success();
@@ -752,12 +754,12 @@ struct VectorDeinterleaveOpConvert final
llvm::map_to_vector(seqOdd, [](int i) { return i * 2 + 1; });
// Create two SPIR-V shuffles.
- auto shuffleEven = rewriter.create<spirv::VectorShuffleOp>(
- loc, newResultType, sourceVector, sourceVector,
+ auto shuffleEven = spirv::VectorShuffleOp::create(
+ rewriter, loc, newResultType, sourceVector, sourceVector,
rewriter.getI32ArrayAttr(indicesEven));
- auto shuffleOdd = rewriter.create<spirv::VectorShuffleOp>(
- loc, newResultType, sourceVector, sourceVector,
+ auto shuffleOdd = spirv::VectorShuffleOp::create(
+ rewriter, loc, newResultType, sourceVector, sourceVector,
rewriter.getI32ArrayAttr(indicesOdd));
rewriter.replaceOp(deinterleaveOp, {shuffleEven, shuffleOdd});
@@ -801,10 +803,11 @@ struct VectorLoadOpConverter final
// For single element vectors, we don't need to bitcast the access chain to
// the original vector type. Both is going to be the same, a pointer
// to a scalar.
- Value castedAccessChain = (vectorType.getNumElements() == 1)
- ? accessChain
- : rewriter.create<spirv::BitcastOp>(
- loc, vectorPtrType, accessChain);
+ Value castedAccessChain =
+ (vectorType.getNumElements() == 1)
+ ? accessChain
+ : spirv::BitcastOp::create(rewriter, loc, vectorPtrType,
+ accessChain);
rewriter.replaceOpWithNewOp<spirv::LoadOp>(loadOp, spirvVectorType,
castedAccessChain);
@@ -843,10 +846,11 @@ struct VectorStoreOpConverter final
// For single element vectors, we don't need to bitcast the access chain to
// the original vector type. Both is going to be the same, a pointer
// to a scalar.
- Value castedAccessChain = (vectorType.getNumElements() == 1)
- ? accessChain
- : rewriter.create<spirv::BitcastOp>(
- loc, vectorPtrType, accessChain);
+ Value castedAccessChain =
+ (vectorType.getNumElements() == 1)
+ ? accessChain
+ : spirv::BitcastOp::create(rewriter, loc, vectorPtrType,
+ accessChain);
rewriter.replaceOpWithNewOp<spirv::StoreOp>(storeOp, castedAccessChain,
adaptor.getValueToStore());
@@ -927,10 +931,10 @@ struct VectorReductionToIntDotProd final
auto v4i8Type = VectorType::get({4}, i8Type);
Location loc = op.getLoc();
Value zero = spirv::ConstantOp::getZero(i8Type, loc, rewriter);
- lhsIn = rewriter.create<spirv::CompositeConstructOp>(
- loc, v4i8Type, ValueRange{lhsIn, zero});
- rhsIn = rewriter.create<spirv::CompositeConstructOp>(
- loc, v4i8Type, ValueRange{rhsIn, zero});
+ lhsIn = spirv::CompositeConstructOp::create(rewriter, loc, v4i8Type,
+ ValueRange{lhsIn, zero});
+ rhsIn = spirv::CompositeConstructOp::create(rewriter, loc, v4i8Type,
+ ValueRange{rhsIn, zero});
}
// There's no variant of dot prod ops for unsigned LHS and signed RHS, so
@@ -993,14 +997,14 @@ struct VectorReductionToFPDotProd final
Attribute oneAttr =
rewriter.getFloatAttr(vectorType.getElementType(), 1.0);
oneAttr = SplatElementsAttr::get(vectorType, oneAttr);
- rhs = rewriter.create<spirv::ConstantOp>(loc, vectorType, oneAttr);
+ rhs = spirv::ConstantOp::create(rewriter, loc, vectorType, oneAttr);
}
assert(lhs);
assert(rhs);
- Value res = rewriter.create<spirv::DotOp>(loc, resultType, lhs, rhs);
+ Value res = spirv::DotOp::create(rewriter, loc, resultType, lhs, rhs);
if (acc)
- res = rewriter.create<spirv::FAddOp>(loc, acc, res);
+ res = spirv::FAddOp::create(rewriter, loc, acc, res);
rewriter.replaceOp(op, res);
return success();
@@ -1035,7 +1039,8 @@ struct VectorStepOpConvert final : OpConversionPattern<vector::StepOp> {
source.reserve(numElements);
for (int64_t i = 0; i < numElements; ++i) {
Attribute intAttr = rewriter.getIntegerAttr(intType, i);
- Value constOp = rewriter.create<spirv::ConstantOp>(loc, intType, intAttr);
+ Value constOp =
+ spirv::ConstantOp::create(rewriter, loc, intType, intAttr);
source.push_back(constOp);
}
rewriter.replaceOpWithNewOp<spirv::CompositeConstructOp>(stepOp, dstType,
@@ -1078,8 +1083,8 @@ struct VectorToElementOpConvert final
if (element.use_empty())
continue;
- Value result = rewriter.create<spirv::CompositeExtractOp>(
- loc, elementType, adaptor.getSource(),
+ Value result = spirv::CompositeExtractOp::create(
+ rewriter, loc, elementType, adaptor.getSource(),
rewriter.getI32ArrayAttr({static_cast<int32_t>(idx)}));
results[idx] = result;
}
diff --git a/mlir/lib/Conversion/VectorToXeGPU/VectorToXeGPU.cpp b/mlir/lib/Conversion/VectorToXeGPU/VectorToXeGPU.cpp
index 2e6a16ddbfdaa..80107554144cf 100644
--- a/mlir/lib/Conversion/VectorToXeGPU/VectorToXeGPU.cpp
+++ b/mlir/lib/Conversion/VectorToXeGPU/VectorToXeGPU.cpp
@@ -108,15 +108,15 @@ createNdDescriptor(PatternRewriter &rewriter, Location loc,
xegpu::CreateNdDescOp ndDesc;
if (srcTy.hasStaticShape()) {
- ndDesc = rewriter.create<xegpu::CreateNdDescOp>(loc, descType, src,
- getAsOpFoldResult(offsets));
+ ndDesc = xegpu::CreateNdDescOp::create(rewriter, loc, descType, src,
+ getAsOpFoldResult(offsets));
} else {
// In case of any dynamic shapes, source's shape and strides have to be
// explicitly provided.
SmallVector<Value> sourceDims;
unsigned srcRank = srcTy.getRank();
for (unsigned i = 0; i < srcRank; ++i)
- sourceDims.push_back(rewriter.create<memref::DimOp>(loc, src, i));
+ sourceDims.push_back(memref::DimOp::create(rewriter, loc, src, i));
SmallVector<int64_t> constOffsets;
SmallVector<Value> dynOffsets;
@@ -135,18 +135,18 @@ createNdDescriptor(PatternRewriter &rewriter, Location loc,
// Compute strides in reverse order.
SmallVector<Value> dynStrides;
- Value accStride = rewriter.create<arith::ConstantIndexOp>(loc, 1);
+ Value accStride = arith::ConstantIndexOp::create(rewriter, loc, 1);
// Last stride is guaranteed to be static and unit.
for (int i = static_cast<int>(strides.size()) - 2; i >= 0; --i) {
accStride =
- rewriter.create<arith::MulIOp>(loc, accStride, sourceDims[i + 1]);
+ arith::MulIOp::create(rewriter, loc, accStride, sourceDims[i + 1]);
if (strides[i] == ShapedType::kDynamic)
dynStrides.push_back(accStride);
}
std::reverse(dynStrides.begin(), dynStrides.end());
- ndDesc = rewriter.create<xegpu::CreateNdDescOp>(
- loc, descType, src, dynOffsets, dynShapes, dynStrides,
+ ndDesc = xegpu::CreateNdDescOp::create(
+ rewriter, loc, descType, src, dynOffsets, dynShapes, dynStrides,
DenseI64ArrayAttr::get(rewriter.getContext(), constOffsets),
DenseI64ArrayAttr::get(rewriter.getContext(), srcTy.getShape()),
DenseI64ArrayAttr::get(rewriter.getContext(), strides));
@@ -200,10 +200,10 @@ struct TransferReadLowering : public OpRewritePattern<vector::TransferReadOp> {
ArrayRef<int64_t>{1, 0});
// By default, no specific caching policy is assigned.
xegpu::CachePolicyAttr hint = nullptr;
- auto loadOp = rewriter.create<xegpu::LoadNdOp>(
- loc, vecTy, ndDesc, /*packed=*/nullptr, transposeAttr,
- /*l1_hint=*/hint,
- /*l2_hint=*/hint, /*l3_hint=*/hint);
+ auto loadOp = xegpu::LoadNdOp::create(rewriter, loc, vecTy, ndDesc,
+ /*packed=*/nullptr, transposeAttr,
+ /*l1_hint=*/hint,
+ /*l2_hint=*/hint, /*l3_hint=*/hint);
rewriter.replaceOp(readOp, loadOp);
return success();
@@ -238,9 +238,9 @@ struct TransferWriteLowering
// By default, no specific caching policy is assigned.
xegpu::CachePolicyAttr hint = nullptr;
auto storeOp =
- rewriter.create<xegpu::StoreNdOp>(loc, writeOp.getVector(), ndDesc,
- /*l1_hint=*/hint,
- /*l2_hint=*/hint, /*l3_hint=*/hint);
+ xegpu::StoreNdOp::create(rewriter, loc, writeOp.getVector(), ndDesc,
+ /*l1_hint=*/hint,
+ /*l2_hint=*/hint, /*l3_hint=*/hint);
rewriter.replaceOp(writeOp, storeOp);
return success();
@@ -269,8 +269,8 @@ struct LoadLowering : public OpRewritePattern<vector::LoadOp> {
// By default, no specific caching policy is assigned.
xegpu::CachePolicyAttr hint = nullptr;
- auto loadNdOp = rewriter.create<xegpu::LoadNdOp>(
- loc, vecTy, ndDesc, /*packed=*/nullptr, /*transpose=*/nullptr,
+ auto loadNdOp = xegpu::LoadNdOp::create(
+ rewriter, loc, vecTy, ndDesc, /*packed=*/nullptr, /*transpose=*/nullptr,
/*l1_hint=*/hint,
/*l2_hint=*/hint, /*l3_hint=*/hint);
rewriter.replaceOp(loadOp, loadNdOp);
@@ -303,9 +303,9 @@ struct StoreLowering : public OpRewritePattern<vector::StoreOp> {
// By default, no specific caching policy is assigned.
xegpu::CachePolicyAttr hint = nullptr;
auto storeNdOp =
- rewriter.create<xegpu::StoreNdOp>(loc, vector, ndDesc,
- /*l1_hint=*/hint,
- /*l2_hint=*/hint, /*l3_hint=*/hint);
+ xegpu::StoreNdOp::create(rewriter, loc, vector, ndDesc,
+ /*l1_hint=*/hint,
+ /*l2_hint=*/hint, /*l3_hint=*/hint);
rewriter.replaceOp(storeOp, storeNdOp);
return success();
@@ -339,8 +339,9 @@ struct ContractionLowering : public OpRewritePattern<vector::ContractionOp> {
if (!isRowMajorMatmul(contractOp.getIndexingMapsAttr()))
return rewriter.notifyMatchFailure(contractOp, "Invalid indexing maps");
- auto dpasOp = rewriter.create<xegpu::DpasOp>(
- loc, TypeRange{contractOp.getResultType()}, ValueRange{lhs, rhs, acc});
+ auto dpasOp = xegpu::DpasOp::create(rewriter, loc,
+ TypeRange{contractOp.getResultType()},
+ ValueRange{lhs, rhs, acc});
rewriter.replaceOp(contractOp, dpasOp);
return success();
diff --git a/mlir/lib/Conversion/XeVMToLLVM/XeVMToLLVM.cpp b/mlir/lib/Conversion/XeVMToLLVM/XeVMToLLVM.cpp
index a8380b9669f0f..2411af043f3f7 100644
--- a/mlir/lib/Conversion/XeVMToLLVM/XeVMToLLVM.cpp
+++ b/mlir/lib/Conversion/XeVMToLLVM/XeVMToLLVM.cpp
@@ -251,7 +251,7 @@ static LLVM::CallOp createDeviceFunctionCall(
for (auto [idx, attrName] : paramAttrs)
funcOp.setArgAttr(idx, attrName, rewriter.getUnitAttr());
- auto callOp = rewriter.create<LLVM::CallOp>(loc, funcOp, args);
+ auto callOp = LLVM::CallOp::create(rewriter, loc, funcOp, args);
callOp->setAttrs(funcOp->getAttrs());
return callOp;
@@ -299,7 +299,7 @@ class MMAToOCLPattern : public OpConversionPattern<xevm::MMAOp> {
VectorType newTy = VectorType::get(
vecBitSize / packedType.getIntOrFloatBitWidth(), packedType);
if (origTy != newTy)
- val = rewriter.create<LLVM::BitcastOp>(loc, newTy, val);
+ val = LLVM::BitcastOp::create(rewriter, loc, newTy, val);
return val;
};
@@ -326,7 +326,7 @@ class MMAToOCLPattern : public OpConversionPattern<xevm::MMAOp> {
: cOrigTy;
VectorType resTy = cTy;
if (cOrigTy != cTy)
- c = rewriter.create<LLVM::BitcastOp>(loc, cTy, c);
+ c = LLVM::BitcastOp::create(rewriter, loc, cTy, c);
constexpr int32_t systolicDepth{8};
std::string fnName =
@@ -352,7 +352,7 @@ class MMAToOCLPattern : public OpConversionPattern<xevm::MMAOp> {
->getResult(0);
if (resOrigTy != resTy)
- result = rewriter.create<LLVM::BitcastOp>(loc, resOrigTy, result);
+ result = LLVM::BitcastOp::create(rewriter, loc, resOrigTy, result);
rewriter.replaceOp(op, result);
return success();
@@ -383,7 +383,7 @@ class PrefetchToOCLPattern : public OpConversionPattern<PrefetchOp> {
auto loc = op.getLoc();
const std::string fnName{"_Z8prefetchPU3AS1Kcm"};
Value one =
- rewriter.create<LLVM::ConstantOp>(loc, rewriter.getI64Type(), 1);
+ LLVM::ConstantOp::create(rewriter, loc, rewriter.getI64Type(), 1);
SmallVector<Value> args{op.getPtr(), one};
SmallVector<Type> argTypes;
for (auto arg : args)
@@ -439,11 +439,11 @@ class MemfenceToOCLPattern : public OpConversionPattern<MemfenceOp> {
op, "Fence only supports workgroup and device memory scopes.");
}
Type i32Type = rewriter.getI32Type();
- Value acqRel = rewriter.create<LLVM::ConstantOp>(loc, i32Type, 4);
+ Value acqRel = LLVM::ConstantOp::create(rewriter, loc, i32Type, 4);
Value memScopeConst =
- rewriter.create<LLVM::ConstantOp>(loc, i32Type, memScope);
+ LLVM::ConstantOp::create(rewriter, loc, i32Type, memScope);
Value addrSpaceConst =
- rewriter.create<LLVM::ConstantOp>(loc, i32Type, addrSpace);
+ LLVM::ConstantOp::create(rewriter, loc, i32Type, addrSpace);
SmallVector<Value> args{addrSpaceConst, acqRel, memScopeConst};
SmallVector<Type> argTypes{3, i32Type};
createDeviceFunctionCall(rewriter, mangle(fnName, argTypes),
@@ -477,13 +477,13 @@ class LoadStorePrefetchToOCLPattern : public OpConversionPattern<OpType> {
auto i32Type = rewriter.getI32Type();
Value byteCoord =
- rewriter.create<LLVM::UndefOp>(loc, VectorType::get(2, i32Type));
- Value zero = rewriter.create<LLVM::ConstantOp>(loc, i32Type, 0);
- Value one = rewriter.create<LLVM::ConstantOp>(loc, i32Type, 1);
- byteCoord = rewriter.create<LLVM::InsertElementOp>(
- loc, VectorType::get(2, i32Type), byteCoord, op.getX(), zero);
- byteCoord = rewriter.create<LLVM::InsertElementOp>(
- loc, VectorType::get(2, i32Type), byteCoord, op.getY(), one);
+ LLVM::UndefOp::create(rewriter, loc, VectorType::get(2, i32Type));
+ Value zero = LLVM::ConstantOp::create(rewriter, loc, i32Type, 0);
+ Value one = LLVM::ConstantOp::create(rewriter, loc, i32Type, 1);
+ byteCoord = LLVM::InsertElementOp::create(
+ rewriter, loc, VectorType::get(2, i32Type), byteCoord, op.getX(), zero);
+ byteCoord = LLVM::InsertElementOp::create(
+ rewriter, loc, VectorType::get(2, i32Type), byteCoord, op.getY(), one);
SmallVector<Value> args{op.getPtr(), op.getBaseWidth(), op.getBaseHeight(),
op.getBasePitch(), byteCoord};
SmallVector<Type> retTypes;
@@ -504,11 +504,11 @@ class LoadStorePrefetchToOCLPattern : public OpConversionPattern<OpType> {
} else {
auto vecElemType = vecType.getElementType();
auto vecElemBitWidth = vecElemType.getIntOrFloatBitWidth();
- Value numElems = rewriter.create<LLVM::ConstantOp>(
- loc, i32Type, vecType.getNumElements());
- auto dstOrSrcPtr = rewriter.create<LLVM::AllocaOp>(
- loc, LLVM::LLVMPointerType::get(rewriter.getContext()), vecElemType,
- numElems);
+ Value numElems = LLVM::ConstantOp::create(rewriter, loc, i32Type,
+ vecType.getNumElements());
+ auto dstOrSrcPtr = LLVM::AllocaOp::create(
+ rewriter, loc, LLVM::LLVMPointerType::get(rewriter.getContext()),
+ vecElemType, numElems);
args.push_back(dstOrSrcPtr);
if constexpr (isLoad) { // Load
funcName += "read";
@@ -530,7 +530,7 @@ class LoadStorePrefetchToOCLPattern : public OpConversionPattern<OpType> {
bitWidthId = (vecElemBitWidth == 32)
? "j"
: ((vecElemBitWidth == 16) ? "t" : "h");
- rewriter.create<LLVM::StoreOp>(loc, op.getStoredVal(), dstOrSrcPtr);
+ LLVM::StoreOp::create(rewriter, loc, op.getStoredVal(), dstOrSrcPtr);
paramAttrs = {
std::make_pair(0, LLVM::LLVMDialect::getNonNullAttrName()),
std::make_pair(0, LLVM::LLVMDialect::getWriteOnlyAttrName()),
@@ -563,7 +563,7 @@ class LoadStorePrefetchToOCLPattern : public OpConversionPattern<OpType> {
}
if constexpr (isLoad)
rewriter.replaceOp(
- op, rewriter.create<LLVM::LoadOp>(loc, vecType, spvLoadDstPtr));
+ op, LLVM::LoadOp::create(rewriter, loc, vecType, spvLoadDstPtr));
else
rewriter.eraseOp(op);
return success();
More information about the Mlir-commits
mailing list