[Mlir-commits] [mlir] c3823af - [mlir][NFC] update `mlir/Dialect` create APIs (22/n) (#149929)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Mon Jul 21 16:57:47 PDT 2025
Author: Maksim Levental
Date: 2025-07-21T19:57:44-04:00
New Revision: c3823af156b517d926a56e3d0d585e2a15720e96
URL: https://github.com/llvm/llvm-project/commit/c3823af156b517d926a56e3d0d585e2a15720e96
DIFF: https://github.com/llvm/llvm-project/commit/c3823af156b517d926a56e3d0d585e2a15720e96.diff
LOG: [mlir][NFC] update `mlir/Dialect` create APIs (22/n) (#149929)
See https://github.com/llvm/llvm-project/pull/147168 for more info.
Added:
Modified:
mlir/lib/Dialect/SPIRV/IR/ControlFlowOps.cpp
mlir/lib/Dialect/SPIRV/IR/SPIRVCanonicalization.cpp
mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp
mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
mlir/lib/Dialect/SPIRV/Linking/ModuleCombiner/ModuleCombiner.cpp
mlir/lib/Dialect/SPIRV/Transforms/LowerABIAttributesPass.cpp
mlir/lib/Dialect/SPIRV/Transforms/RewriteInsertsPass.cpp
mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp
mlir/lib/Dialect/SPIRV/Transforms/SPIRVWebGPUTransforms.cpp
mlir/lib/Dialect/SPIRV/Transforms/UnifyAliasedResourcePass.cpp
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/SPIRV/IR/ControlFlowOps.cpp b/mlir/lib/Dialect/SPIRV/IR/ControlFlowOps.cpp
index 371456552b5b5..890406df74e72 100644
--- a/mlir/lib/Dialect/SPIRV/IR/ControlFlowOps.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/ControlFlowOps.cpp
@@ -391,7 +391,7 @@ void LoopOp::addEntryAndMergeBlock(OpBuilder &builder) {
builder.createBlock(&getBody());
// Add a spirv.mlir.merge op into the merge block.
- builder.create<spirv::MergeOp>(getLoc());
+ spirv::MergeOp::create(builder, getLoc());
}
//===----------------------------------------------------------------------===//
@@ -543,7 +543,7 @@ void SelectionOp::addMergeBlock(OpBuilder &builder) {
builder.createBlock(&getBody());
// Add a spirv.mlir.merge op into the merge block.
- builder.create<spirv::MergeOp>(getLoc());
+ spirv::MergeOp::create(builder, getLoc());
}
SelectionOp
@@ -551,7 +551,7 @@ SelectionOp::createIfThen(Location loc, Value condition,
function_ref<void(OpBuilder &builder)> thenBody,
OpBuilder &builder) {
auto selectionOp =
- builder.create<spirv::SelectionOp>(loc, spirv::SelectionControl::None);
+ spirv::SelectionOp::create(builder, loc, spirv::SelectionControl::None);
selectionOp.addMergeBlock(builder);
Block *mergeBlock = selectionOp.getMergeBlock();
@@ -562,17 +562,17 @@ SelectionOp::createIfThen(Location loc, Value condition,
OpBuilder::InsertionGuard guard(builder);
thenBlock = builder.createBlock(mergeBlock);
thenBody(builder);
- builder.create<spirv::BranchOp>(loc, mergeBlock);
+ spirv::BranchOp::create(builder, loc, mergeBlock);
}
// Build the header block.
{
OpBuilder::InsertionGuard guard(builder);
builder.createBlock(thenBlock);
- builder.create<spirv::BranchConditionalOp>(
- loc, condition, thenBlock,
- /*trueArguments=*/ArrayRef<Value>(), mergeBlock,
- /*falseArguments=*/ArrayRef<Value>());
+ spirv::BranchConditionalOp::create(builder, loc, condition, thenBlock,
+ /*trueArguments=*/ArrayRef<Value>(),
+ mergeBlock,
+ /*falseArguments=*/ArrayRef<Value>());
}
return selectionOp;
diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVCanonicalization.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVCanonicalization.cpp
index 047f8da0cc003..2bde44baf961e 100644
--- a/mlir/lib/Dialect/SPIRV/IR/SPIRVCanonicalization.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVCanonicalization.cpp
@@ -178,16 +178,16 @@ struct IAddCarryFold final : OpRewritePattern<spirv::IAddCarryOp> {
return failure();
Value addsVal =
- rewriter.create<spirv::ConstantOp>(loc, constituentType, adds);
+ spirv::ConstantOp::create(rewriter, loc, constituentType, adds);
Value carrysVal =
- rewriter.create<spirv::ConstantOp>(loc, constituentType, carrys);
+ spirv::ConstantOp::create(rewriter, loc, constituentType, carrys);
// Create empty struct
- Value undef = rewriter.create<spirv::UndefOp>(loc, op.getType());
+ Value undef = spirv::UndefOp::create(rewriter, loc, op.getType());
// Fill in adds at id 0
Value intermediate =
- rewriter.create<spirv::CompositeInsertOp>(loc, addsVal, undef, 0);
+ spirv::CompositeInsertOp::create(rewriter, loc, addsVal, undef, 0);
// Fill in carrys at id 1
rewriter.replaceOpWithNewOp<spirv::CompositeInsertOp>(op, carrysVal,
intermediate, 1);
@@ -260,16 +260,16 @@ struct MulExtendedFold final : OpRewritePattern<MulOp> {
return failure();
Value lowBitsVal =
- rewriter.create<spirv::ConstantOp>(loc, constituentType, lowBits);
+ spirv::ConstantOp::create(rewriter, loc, constituentType, lowBits);
Value highBitsVal =
- rewriter.create<spirv::ConstantOp>(loc, constituentType, highBits);
+ spirv::ConstantOp::create(rewriter, loc, constituentType, highBits);
// Create empty struct
- Value undef = rewriter.create<spirv::UndefOp>(loc, op.getType());
+ Value undef = spirv::UndefOp::create(rewriter, loc, op.getType());
// Fill in lowBits at id 0
Value intermediate =
- rewriter.create<spirv::CompositeInsertOp>(loc, lowBitsVal, undef, 0);
+ spirv::CompositeInsertOp::create(rewriter, loc, lowBitsVal, undef, 0);
// Fill in highBits at id 1
rewriter.replaceOpWithNewOp<spirv::CompositeInsertOp>(op, highBitsVal,
intermediate, 1);
@@ -1309,11 +1309,11 @@ struct ConvertSelectionOpToSelect final : OpRewritePattern<spirv::SelectionOp> {
auto storeOpAttributes =
cast<spirv::StoreOp>(trueBlock->front())->getAttrs();
- auto selectOp = rewriter.create<spirv::SelectOp>(
- selectionOp.getLoc(), trueValue.getType(),
+ auto selectOp = spirv::SelectOp::create(
+ rewriter, selectionOp.getLoc(), trueValue.getType(),
brConditionalOp.getCondition(), trueValue, falseValue);
- rewriter.create<spirv::StoreOp>(selectOp.getLoc(), ptrValue,
- selectOp.getResult(), storeOpAttributes);
+ spirv::StoreOp::create(rewriter, selectOp.getLoc(), ptrValue,
+ selectOp.getResult(), storeOpAttributes);
// `spirv.mlir.selection` is not needed anymore.
rewriter.eraseOp(op);
diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp
index f32c53b8f0b9e..c9a8e97bd3296 100644
--- a/mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp
@@ -940,12 +940,12 @@ Operation *SPIRVDialect::materializeConstant(OpBuilder &builder,
Attribute value, Type type,
Location loc) {
if (auto poison = dyn_cast<ub::PoisonAttr>(value))
- return builder.create<ub::PoisonOp>(loc, type, poison);
+ return ub::PoisonOp::create(builder, loc, type, poison);
if (!spirv::ConstantOp::isBuildableWith(type))
return nullptr;
- return builder.create<spirv::ConstantOp>(loc, type, value);
+ return spirv::ConstantOp::create(builder, loc, type, value);
}
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
index 656236246b1ad..52c672a05fa43 100644
--- a/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
@@ -651,26 +651,26 @@ spirv::ConstantOp spirv::ConstantOp::getZero(Type type, Location loc,
if (auto intType = llvm::dyn_cast<IntegerType>(type)) {
unsigned width = intType.getWidth();
if (width == 1)
- return builder.create<spirv::ConstantOp>(loc, type,
- builder.getBoolAttr(false));
- return builder.create<spirv::ConstantOp>(
- loc, type, builder.getIntegerAttr(type, APInt(width, 0)));
+ return spirv::ConstantOp::create(builder, loc, type,
+ builder.getBoolAttr(false));
+ return spirv::ConstantOp::create(
+ builder, loc, type, builder.getIntegerAttr(type, APInt(width, 0)));
}
if (auto floatType = llvm::dyn_cast<FloatType>(type)) {
- return builder.create<spirv::ConstantOp>(
- loc, type, builder.getFloatAttr(floatType, 0.0));
+ return spirv::ConstantOp::create(builder, loc, type,
+ builder.getFloatAttr(floatType, 0.0));
}
if (auto vectorType = llvm::dyn_cast<VectorType>(type)) {
Type elemType = vectorType.getElementType();
if (llvm::isa<IntegerType>(elemType)) {
- return builder.create<spirv::ConstantOp>(
- loc, type,
+ return spirv::ConstantOp::create(
+ builder, loc, type,
DenseElementsAttr::get(vectorType,
IntegerAttr::get(elemType, 0).getValue()));
}
if (llvm::isa<FloatType>(elemType)) {
- return builder.create<spirv::ConstantOp>(
- loc, type,
+ return spirv::ConstantOp::create(
+ builder, loc, type,
DenseFPElementsAttr::get(vectorType,
FloatAttr::get(elemType, 0.0).getValue()));
}
@@ -684,26 +684,26 @@ spirv::ConstantOp spirv::ConstantOp::getOne(Type type, Location loc,
if (auto intType = llvm::dyn_cast<IntegerType>(type)) {
unsigned width = intType.getWidth();
if (width == 1)
- return builder.create<spirv::ConstantOp>(loc, type,
- builder.getBoolAttr(true));
- return builder.create<spirv::ConstantOp>(
- loc, type, builder.getIntegerAttr(type, APInt(width, 1)));
+ return spirv::ConstantOp::create(builder, loc, type,
+ builder.getBoolAttr(true));
+ return spirv::ConstantOp::create(
+ builder, loc, type, builder.getIntegerAttr(type, APInt(width, 1)));
}
if (auto floatType = llvm::dyn_cast<FloatType>(type)) {
- return builder.create<spirv::ConstantOp>(
- loc, type, builder.getFloatAttr(floatType, 1.0));
+ return spirv::ConstantOp::create(builder, loc, type,
+ builder.getFloatAttr(floatType, 1.0));
}
if (auto vectorType = llvm::dyn_cast<VectorType>(type)) {
Type elemType = vectorType.getElementType();
if (llvm::isa<IntegerType>(elemType)) {
- return builder.create<spirv::ConstantOp>(
- loc, type,
+ return spirv::ConstantOp::create(
+ builder, loc, type,
DenseElementsAttr::get(vectorType,
IntegerAttr::get(elemType, 1).getValue()));
}
if (llvm::isa<FloatType>(elemType)) {
- return builder.create<spirv::ConstantOp>(
- loc, type,
+ return spirv::ConstantOp::create(
+ builder, loc, type,
DenseFPElementsAttr::get(vectorType,
FloatAttr::get(elemType, 1.0).getValue()));
}
@@ -1985,7 +1985,7 @@ ParseResult spirv::SpecConstantOperationOp::parse(OpAsmParser &parser,
OpBuilder builder(parser.getContext());
builder.setInsertionPointToEnd(&block);
- builder.create<spirv::YieldOp>(wrappedOp->getLoc(), wrappedOp->getResult(0));
+ spirv::YieldOp::create(builder, wrappedOp->getLoc(), wrappedOp->getResult(0));
result.location = wrappedOp->getLoc();
result.addTypes(wrappedOp->getResult(0).getType());
diff --git a/mlir/lib/Dialect/SPIRV/Linking/ModuleCombiner/ModuleCombiner.cpp b/mlir/lib/Dialect/SPIRV/Linking/ModuleCombiner/ModuleCombiner.cpp
index 8da688806bade..2b9c7296830dc 100644
--- a/mlir/lib/Dialect/SPIRV/Linking/ModuleCombiner/ModuleCombiner.cpp
+++ b/mlir/lib/Dialect/SPIRV/Linking/ModuleCombiner/ModuleCombiner.cpp
@@ -105,8 +105,9 @@ OwningOpRef<spirv::ModuleOp> combine(ArrayRef<spirv::ModuleOp> inputModules,
}
}
- auto combinedModule = combinedModuleBuilder.create<spirv::ModuleOp>(
- firstModule.getLoc(), addressingModel, memoryModel, vceTriple);
+ auto combinedModule =
+ spirv::ModuleOp::create(combinedModuleBuilder, firstModule.getLoc(),
+ addressingModel, memoryModel, vceTriple);
combinedModuleBuilder.setInsertionPointToStart(combinedModule.getBody());
// In some cases, a symbol in the (current state of the) combined module is
diff --git a/mlir/lib/Dialect/SPIRV/Transforms/LowerABIAttributesPass.cpp b/mlir/lib/Dialect/SPIRV/Transforms/LowerABIAttributesPass.cpp
index 85525a5a02fa2..81365b44a3aad 100644
--- a/mlir/lib/Dialect/SPIRV/Transforms/LowerABIAttributesPass.cpp
+++ b/mlir/lib/Dialect/SPIRV/Transforms/LowerABIAttributesPass.cpp
@@ -70,9 +70,9 @@ createGlobalVarForEntryPointArgument(OpBuilder &builder, spirv::FuncOp funcOp,
varType =
spirv::PointerType::get(varPointeeType, varPtrType.getStorageClass());
- return builder.create<spirv::GlobalVariableOp>(
- funcOp.getLoc(), varType, varName, abiInfo.getDescriptorSet(),
- abiInfo.getBinding());
+ return spirv::GlobalVariableOp::create(builder, funcOp.getLoc(), varType,
+ varName, abiInfo.getDescriptorSet(),
+ abiInfo.getBinding());
}
/// Gets the global variables that need to be specified as interface variable
@@ -146,17 +146,17 @@ static LogicalResult lowerEntryPointABIAttr(spirv::FuncOp funcOp,
return funcOp.emitRemark("lower entry point failure: could not select "
"execution model based on 'spirv.target_env'");
- builder.create<spirv::EntryPointOp>(funcOp.getLoc(), *executionModel, funcOp,
- interfaceVars);
+ spirv::EntryPointOp::create(builder, funcOp.getLoc(), *executionModel, funcOp,
+ interfaceVars);
// Specifies the spirv.ExecutionModeOp.
if (DenseI32ArrayAttr workgroupSizeAttr = entryPointAttr.getWorkgroupSize()) {
std::optional<ArrayRef<spirv::Capability>> caps =
spirv::getCapabilities(spirv::ExecutionMode::LocalSize);
if (!caps || targetEnv.allows(*caps)) {
- builder.create<spirv::ExecutionModeOp>(funcOp.getLoc(), funcOp,
- spirv::ExecutionMode::LocalSize,
- workgroupSizeAttr.asArrayRef());
+ spirv::ExecutionModeOp::create(builder, funcOp.getLoc(), funcOp,
+ spirv::ExecutionMode::LocalSize,
+ workgroupSizeAttr.asArrayRef());
// Erase workgroup size.
entryPointAttr = spirv::EntryPointABIAttr::get(
entryPointAttr.getContext(), DenseI32ArrayAttr(),
@@ -167,9 +167,9 @@ static LogicalResult lowerEntryPointABIAttr(spirv::FuncOp funcOp,
std::optional<ArrayRef<spirv::Capability>> caps =
spirv::getCapabilities(spirv::ExecutionMode::SubgroupSize);
if (!caps || targetEnv.allows(*caps)) {
- builder.create<spirv::ExecutionModeOp>(funcOp.getLoc(), funcOp,
- spirv::ExecutionMode::SubgroupSize,
- *subgroupSize);
+ spirv::ExecutionModeOp::create(builder, funcOp.getLoc(), funcOp,
+ spirv::ExecutionMode::SubgroupSize,
+ *subgroupSize);
// Erase subgroup size.
entryPointAttr = spirv::EntryPointABIAttr::get(
entryPointAttr.getContext(), entryPointAttr.getWorkgroupSize(),
@@ -180,8 +180,8 @@ static LogicalResult lowerEntryPointABIAttr(spirv::FuncOp funcOp,
std::optional<ArrayRef<spirv::Capability>> caps =
spirv::getCapabilities(spirv::ExecutionMode::SignedZeroInfNanPreserve);
if (!caps || targetEnv.allows(*caps)) {
- builder.create<spirv::ExecutionModeOp>(
- funcOp.getLoc(), funcOp,
+ spirv::ExecutionModeOp::create(
+ builder, funcOp.getLoc(), funcOp,
spirv::ExecutionMode::SignedZeroInfNanPreserve, *targetWidth);
// Erase target width.
entryPointAttr = spirv::EntryPointABIAttr::get(
@@ -259,7 +259,7 @@ LogicalResult ProcessInterfaceVarABI::matchAndRewrite(
// Insert spirv::AddressOf and spirv::AccessChain operations.
Value replacement =
- rewriter.create<spirv::AddressOfOp>(funcOp.getLoc(), var);
+ spirv::AddressOfOp::create(rewriter, funcOp.getLoc(), var);
// Check if the arg is a scalar or vector type. In that case, the value
// needs to be loaded into registers.
// TODO: This is loading value of the scalar into registers
@@ -269,9 +269,9 @@ LogicalResult ProcessInterfaceVarABI::matchAndRewrite(
if (cast<spirv::SPIRVType>(argType.value()).isScalarOrVector()) {
auto zero =
spirv::ConstantOp::getZero(indexType, funcOp.getLoc(), rewriter);
- auto loadPtr = rewriter.create<spirv::AccessChainOp>(
- funcOp.getLoc(), replacement, zero.getConstant());
- replacement = rewriter.create<spirv::LoadOp>(funcOp.getLoc(), loadPtr);
+ auto loadPtr = spirv::AccessChainOp::create(
+ rewriter, funcOp.getLoc(), replacement, zero.getConstant());
+ replacement = spirv::LoadOp::create(rewriter, funcOp.getLoc(), loadPtr);
}
signatureConverter.remapInput(argType.index(), replacement);
}
@@ -308,7 +308,7 @@ void LowerABIAttributesPass::runOnOperation() {
ValueRange inputs, Location loc) {
if (inputs.size() != 1 || !isa<spirv::PointerType>(inputs[0].getType()))
return Value();
- return builder.create<spirv::BitcastOp>(loc, type, inputs[0]).getResult();
+ return spirv::BitcastOp::create(builder, loc, type, inputs[0]).getResult();
});
RewritePatternSet patterns(context);
diff --git a/mlir/lib/Dialect/SPIRV/Transforms/RewriteInsertsPass.cpp b/mlir/lib/Dialect/SPIRV/Transforms/RewriteInsertsPass.cpp
index ab5898d0e3925..38ef547f0769f 100644
--- a/mlir/lib/Dialect/SPIRV/Transforms/RewriteInsertsPass.cpp
+++ b/mlir/lib/Dialect/SPIRV/Transforms/RewriteInsertsPass.cpp
@@ -65,8 +65,8 @@ void RewriteInsertsPass::runOnOperation() {
operands.push_back(insertionOp.getObject());
OpBuilder builder(lastCompositeInsertOp);
- auto compositeConstructOp = builder.create<spirv::CompositeConstructOp>(
- location, compositeType, operands);
+ auto compositeConstructOp = spirv::CompositeConstructOp::create(
+ builder, location, compositeType, operands);
lastCompositeInsertOp.replaceAllUsesWith(
compositeConstructOp->getResult(0));
diff --git a/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp b/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp
index f70b3325f8725..35ec0190b5a61 100644
--- a/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp
+++ b/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp
@@ -669,21 +669,24 @@ static Value castToSourceType(const spirv::TargetEnv &targetEnv,
Location loc) {
// We can only cast one value in SPIR-V.
if (inputs.size() != 1) {
- auto castOp = builder.create<UnrealizedConversionCastOp>(loc, type, inputs);
+ auto castOp =
+ UnrealizedConversionCastOp::create(builder, loc, type, inputs);
return castOp.getResult(0);
}
Value input = inputs.front();
// Only support integer types for now. Floating point types to be implemented.
if (!isa<IntegerType>(type)) {
- auto castOp = builder.create<UnrealizedConversionCastOp>(loc, type, inputs);
+ auto castOp =
+ UnrealizedConversionCastOp::create(builder, loc, type, inputs);
return castOp.getResult(0);
}
auto inputType = cast<IntegerType>(input.getType());
auto scalarType = dyn_cast<spirv::ScalarType>(type);
if (!scalarType) {
- auto castOp = builder.create<UnrealizedConversionCastOp>(loc, type, inputs);
+ auto castOp =
+ UnrealizedConversionCastOp::create(builder, loc, type, inputs);
return castOp.getResult(0);
}
@@ -691,14 +694,15 @@ static Value castToSourceType(const spirv::TargetEnv &targetEnv,
// truncating to go back so we don't need to worry about the signedness.
// For extension, we cannot have enough signal here to decide which op to use.
if (inputType.getIntOrFloatBitWidth() < scalarType.getIntOrFloatBitWidth()) {
- auto castOp = builder.create<UnrealizedConversionCastOp>(loc, type, inputs);
+ auto castOp =
+ UnrealizedConversionCastOp::create(builder, loc, type, inputs);
return castOp.getResult(0);
}
// Boolean values would need to use
diff erent ops than normal integer values.
if (type.isInteger(1)) {
Value one = spirv::ConstantOp::getOne(inputType, loc, builder);
- return builder.create<spirv::IEqualOp>(loc, input, one);
+ return spirv::IEqualOp::create(builder, loc, input, one);
}
// Check that the source integer type is supported by the environment.
@@ -708,7 +712,8 @@ static Value castToSourceType(const spirv::TargetEnv &targetEnv,
scalarType.getCapabilities(caps);
if (failed(checkCapabilityRequirements(type, targetEnv, caps)) ||
failed(checkExtensionRequirements(type, targetEnv, exts))) {
- auto castOp = builder.create<UnrealizedConversionCastOp>(loc, type, inputs);
+ auto castOp =
+ UnrealizedConversionCastOp::create(builder, loc, type, inputs);
return castOp.getResult(0);
}
@@ -716,9 +721,9 @@ static Value castToSourceType(const spirv::TargetEnv &targetEnv,
// care about signedness here. Still try to use a corresponding op for better
// consistency though.
if (type.isSignedInteger()) {
- return builder.create<spirv::SConvertOp>(loc, type, input);
+ return spirv::SConvertOp::create(builder, loc, type, input);
}
- return builder.create<spirv::UConvertOp>(loc, type, input);
+ return spirv::UConvertOp::create(builder, loc, type, input);
}
//===----------------------------------------------------------------------===//
@@ -770,7 +775,7 @@ getOrInsertBuiltinVariable(Block &body, Location loc, spirv::BuiltIn builtin,
spirv::StorageClass::Input);
std::string name = getBuiltinVarName(builtin, prefix, suffix);
newVarOp =
- builder.create<spirv::GlobalVariableOp>(loc, ptrType, name, builtin);
+ spirv::GlobalVariableOp::create(builder, loc, ptrType, name, builtin);
break;
}
case spirv::BuiltIn::SubgroupId:
@@ -781,7 +786,7 @@ getOrInsertBuiltinVariable(Block &body, Location loc, spirv::BuiltIn builtin,
spirv::PointerType::get(integerType, spirv::StorageClass::Input);
std::string name = getBuiltinVarName(builtin, prefix, suffix);
newVarOp =
- builder.create<spirv::GlobalVariableOp>(loc, ptrType, name, builtin);
+ spirv::GlobalVariableOp::create(builder, loc, ptrType, name, builtin);
break;
}
default:
@@ -842,8 +847,8 @@ getOrInsertPushConstantVariable(Location loc, Block &block,
auto builder = OpBuilder::atBlockBegin(&block, b.getListener());
auto type = getPushConstantStorageType(elementCount, builder, indexType);
const char *name = "__push_constant_var__";
- return builder.create<spirv::GlobalVariableOp>(loc, type, name,
- /*initializer=*/nullptr);
+ return spirv::GlobalVariableOp::create(builder, loc, type, name,
+ /*initializer=*/nullptr);
}
//===----------------------------------------------------------------------===//
@@ -879,8 +884,8 @@ struct FuncOpConversion final : OpConversionPattern<func::FuncOp> {
}
// Create the converted spirv.func op.
- auto newFuncOp = rewriter.create<spirv::FuncOp>(
- funcOp.getLoc(), funcOp.getName(),
+ auto newFuncOp = spirv::FuncOp::create(
+ rewriter, funcOp.getLoc(), funcOp.getName(),
rewriter.getFunctionType(signatureConverter.getConvertedTypes(),
resultType ? TypeRange(resultType)
: TypeRange()));
@@ -919,8 +924,8 @@ struct FuncOpVectorUnroll final : OpRewritePattern<func::FuncOp> {
}
// Create a new func op with the original type and copy the function body.
- auto newFuncOp = rewriter.create<func::FuncOp>(funcOp.getLoc(),
- funcOp.getName(), fnType);
+ auto newFuncOp = func::FuncOp::create(rewriter, funcOp.getLoc(),
+ funcOp.getName(), fnType);
rewriter.inlineRegionBefore(funcOp.getBody(), newFuncOp.getBody(),
newFuncOp.end());
@@ -954,8 +959,8 @@ struct FuncOpVectorUnroll final : OpRewritePattern<func::FuncOp> {
auto origVecType = dyn_cast<VectorType>(origType);
if (!origVecType) {
// We need a placeholder for the old argument that will be erased later.
- Value result = rewriter.create<arith::ConstantOp>(
- loc, origType, rewriter.getZeroAttr(origType));
+ Value result = arith::ConstantOp::create(
+ rewriter, loc, origType, rewriter.getZeroAttr(origType));
rewriter.replaceAllUsesWith(newFuncOp.getArgument(origInputNo), result);
tmpOps.insert({result.getDefiningOp(), newInputNo});
oneToNTypeMapping.addInputs(origInputNo, origType);
@@ -967,8 +972,8 @@ struct FuncOpVectorUnroll final : OpRewritePattern<func::FuncOp> {
auto targetShape = getTargetShape(origVecType);
if (!targetShape) {
// We need a placeholder for the old argument that will be erased later.
- Value result = rewriter.create<arith::ConstantOp>(
- loc, origType, rewriter.getZeroAttr(origType));
+ Value result = arith::ConstantOp::create(
+ rewriter, loc, origType, rewriter.getZeroAttr(origType));
rewriter.replaceAllUsesWith(newFuncOp.getArgument(origInputNo), result);
tmpOps.insert({result.getDefiningOp(), newInputNo});
oneToNTypeMapping.addInputs(origInputNo, origType);
@@ -982,12 +987,12 @@ struct FuncOpVectorUnroll final : OpRewritePattern<func::FuncOp> {
llvm::to_vector_of<int64_t, 4>(origVecType.getShape());
// Prepare the result vector.
- Value result = rewriter.create<arith::ConstantOp>(
- loc, origVecType, rewriter.getZeroAttr(origVecType));
+ Value result = arith::ConstantOp::create(
+ rewriter, loc, origVecType, rewriter.getZeroAttr(origVecType));
++newOpCount;
// Prepare the placeholder for the new arguments that will be added later.
- Value dummy = rewriter.create<arith::ConstantOp>(
- loc, unrolledType, rewriter.getZeroAttr(unrolledType));
+ Value dummy = arith::ConstantOp::create(
+ rewriter, loc, unrolledType, rewriter.getZeroAttr(unrolledType));
++newOpCount;
// Create the `vector.insert_strided_slice` ops.
@@ -995,8 +1000,8 @@ struct FuncOpVectorUnroll final : OpRewritePattern<func::FuncOp> {
SmallVector<Type> newTypes;
for (SmallVector<int64_t> offsets :
StaticTileOffsetRange(originalShape, *targetShape)) {
- result = rewriter.create<vector::InsertStridedSliceOp>(
- loc, dummy, result, offsets, strides);
+ result = vector::InsertStridedSliceOp::create(rewriter, loc, dummy,
+ result, offsets, strides);
newTypes.push_back(unrolledType);
unrolledInputNums.push_back(newInputNo);
++newInputNo;
@@ -1109,12 +1114,12 @@ struct ReturnOpVectorUnroll final : OpRewritePattern<func::ReturnOp> {
Value returnValue = returnOp.getOperand(origResultNo);
for (SmallVector<int64_t> offsets :
StaticTileOffsetRange(originalShape, *targetShape)) {
- Value result = rewriter.create<vector::ExtractStridedSliceOp>(
- loc, returnValue, offsets, extractShape, strides);
+ Value result = vector::ExtractStridedSliceOp::create(
+ rewriter, loc, returnValue, offsets, extractShape, strides);
if (originalShape.size() > 1) {
SmallVector<int64_t> extractIndices(originalShape.size() - 1, 0);
result =
- rewriter.create<vector::ExtractOp>(loc, result, extractIndices);
+ vector::ExtractOp::create(rewriter, loc, result, extractIndices);
}
newOperands.push_back(result);
newTypes.push_back(unrolledType);
@@ -1132,7 +1137,7 @@ struct ReturnOpVectorUnroll final : OpRewritePattern<func::ReturnOp> {
// Replace the return op using the new operands. This will automatically
// update the entry block as well.
rewriter.replaceOp(returnOp,
- rewriter.create<func::ReturnOp>(loc, newOperands));
+ func::ReturnOp::create(rewriter, loc, newOperands));
return success();
}
@@ -1157,8 +1162,8 @@ Value mlir::spirv::getBuiltinVariableValue(Operation *op,
spirv::GlobalVariableOp varOp =
getOrInsertBuiltinVariable(*parent->getRegion(0).begin(), op->getLoc(),
builtin, integerType, builder, prefix, suffix);
- Value ptr = builder.create<spirv::AddressOfOp>(op->getLoc(), varOp);
- return builder.create<spirv::LoadOp>(op->getLoc(), ptr);
+ Value ptr = spirv::AddressOfOp::create(builder, op->getLoc(), varOp);
+ return spirv::LoadOp::create(builder, op->getLoc(), ptr);
}
//===----------------------------------------------------------------------===//
@@ -1179,12 +1184,12 @@ Value spirv::getPushConstantValue(Operation *op, unsigned elementCount,
loc, parent->getRegion(0).front(), elementCount, builder, integerType);
Value zeroOp = spirv::ConstantOp::getZero(integerType, loc, builder);
- Value offsetOp = builder.create<spirv::ConstantOp>(
- loc, integerType, builder.getI32IntegerAttr(offset));
- auto addrOp = builder.create<spirv::AddressOfOp>(loc, varOp);
- auto acOp = builder.create<spirv::AccessChainOp>(
- loc, addrOp, llvm::ArrayRef({zeroOp, offsetOp}));
- return builder.create<spirv::LoadOp>(loc, acOp);
+ Value offsetOp = spirv::ConstantOp::create(builder, loc, integerType,
+ builder.getI32IntegerAttr(offset));
+ auto addrOp = spirv::AddressOfOp::create(builder, loc, varOp);
+ auto acOp = spirv::AccessChainOp::create(builder, loc, addrOp,
+ llvm::ArrayRef({zeroOp, offsetOp}));
+ return spirv::LoadOp::create(builder, loc, acOp);
}
//===----------------------------------------------------------------------===//
@@ -1244,7 +1249,7 @@ Value mlir::spirv::getVulkanElementPtr(const SPIRVTypeConverter &typeConverter,
linearizedIndices.push_back(
linearizeIndex(indices, strides, offset, indexType, loc, builder));
}
- return builder.create<spirv::AccessChainOp>(loc, basePtr, linearizedIndices);
+ return spirv::AccessChainOp::create(builder, loc, basePtr, linearizedIndices);
}
Value mlir::spirv::getOpenCLElementPtr(const SPIRVTypeConverter &typeConverter,
@@ -1275,11 +1280,11 @@ Value mlir::spirv::getOpenCLElementPtr(const SPIRVTypeConverter &typeConverter,
cast<spirv::PointerType>(basePtr.getType()).getPointeeType();
if (isa<spirv::ArrayType>(pointeeType)) {
linearizedIndices.push_back(linearIndex);
- return builder.create<spirv::AccessChainOp>(loc, basePtr,
- linearizedIndices);
+ return spirv::AccessChainOp::create(builder, loc, basePtr,
+ linearizedIndices);
}
- return builder.create<spirv::PtrAccessChainOp>(loc, basePtr, linearIndex,
- linearizedIndices);
+ return spirv::PtrAccessChainOp::create(builder, loc, basePtr, linearIndex,
+ linearizedIndices);
}
Value mlir::spirv::getElementPtr(const SPIRVTypeConverter &typeConverter,
@@ -1465,7 +1470,7 @@ SPIRVTypeConverter::SPIRVTypeConverter(spirv::TargetEnvAttr targetAttr,
});
addTargetMaterialization([](OpBuilder &builder, Type type, ValueRange inputs,
Location loc) {
- auto cast = builder.create<UnrealizedConversionCastOp>(loc, type, inputs);
+ auto cast = UnrealizedConversionCastOp::create(builder, loc, type, inputs);
return cast.getResult(0);
});
}
diff --git a/mlir/lib/Dialect/SPIRV/Transforms/SPIRVWebGPUTransforms.cpp b/mlir/lib/Dialect/SPIRV/Transforms/SPIRVWebGPUTransforms.cpp
index af1cf2a1373e3..e0900005ea1bb 100644
--- a/mlir/lib/Dialect/SPIRV/Transforms/SPIRVWebGPUTransforms.cpp
+++ b/mlir/lib/Dialect/SPIRV/Transforms/SPIRVWebGPUTransforms.cpp
@@ -64,16 +64,16 @@ static Value lowerExtendedMultiplication(Operation *mulOp,
// and 4 additions after constant folding.
// - With sign-extended arguments, we end up emitting 8 multiplications and
// and 12 additions after CSE.
- Value cstLowMask = rewriter.create<ConstantOp>(
- loc, lhs.getType(), getScalarOrSplatAttr(argTy, (1 << 16) - 1));
+ Value cstLowMask = ConstantOp::create(
+ rewriter, loc, lhs.getType(), getScalarOrSplatAttr(argTy, (1 << 16) - 1));
auto getLowDigit = [&rewriter, loc, cstLowMask](Value val) {
- return rewriter.create<BitwiseAndOp>(loc, val, cstLowMask);
+ return BitwiseAndOp::create(rewriter, loc, val, cstLowMask);
};
- Value cst16 = rewriter.create<ConstantOp>(loc, lhs.getType(),
- getScalarOrSplatAttr(argTy, 16));
+ Value cst16 = ConstantOp::create(rewriter, loc, lhs.getType(),
+ getScalarOrSplatAttr(argTy, 16));
auto getHighDigit = [&rewriter, loc, cst16](Value val) {
- return rewriter.create<ShiftRightLogicalOp>(loc, val, cst16);
+ return ShiftRightLogicalOp::create(rewriter, loc, val, cst16);
};
auto getSignDigit = [&rewriter, loc, cst16, &getHighDigit](Value val) {
@@ -82,11 +82,11 @@ static Value lowerExtendedMultiplication(Operation *mulOp,
// fine. We do not have to introduce an extra constant since any
// value in [15, 32) would do.
return getHighDigit(
- rewriter.create<ShiftRightArithmeticOp>(loc, val, cst16));
+ ShiftRightArithmeticOp::create(rewriter, loc, val, cst16));
};
- Value cst0 = rewriter.create<ConstantOp>(loc, lhs.getType(),
- getScalarOrSplatAttr(argTy, 0));
+ Value cst0 = ConstantOp::create(rewriter, loc, lhs.getType(),
+ getScalarOrSplatAttr(argTy, 0));
Value lhsLow = getLowDigit(lhs);
Value lhsHigh = getHighDigit(lhs);
@@ -108,7 +108,7 @@ static Value lowerExtendedMultiplication(Operation *mulOp,
continue;
Value &thisResDigit = resultDigits[i + j];
- Value mul = rewriter.create<IMulOp>(loc, lhsDigit, rhsDigit);
+ Value mul = IMulOp::create(rewriter, loc, lhsDigit, rhsDigit);
Value current = rewriter.createOrFold<IAddOp>(loc, thisResDigit, mul);
thisResDigit = getLowDigit(current);
@@ -122,14 +122,15 @@ static Value lowerExtendedMultiplication(Operation *mulOp,
}
auto combineDigits = [loc, cst16, &rewriter](Value low, Value high) {
- Value highBits = rewriter.create<ShiftLeftLogicalOp>(loc, high, cst16);
- return rewriter.create<BitwiseOrOp>(loc, low, highBits);
+ Value highBits = ShiftLeftLogicalOp::create(rewriter, loc, high, cst16);
+ return BitwiseOrOp::create(rewriter, loc, low, highBits);
};
Value low = combineDigits(resultDigits[0], resultDigits[1]);
Value high = combineDigits(resultDigits[2], resultDigits[3]);
- return rewriter.create<CompositeConstructOp>(
- loc, mulOp->getResultTypes().front(), llvm::ArrayRef({low, high}));
+ return CompositeConstructOp::create(rewriter, loc,
+ mulOp->getResultTypes().front(),
+ llvm::ArrayRef({low, high}));
}
//===----------------------------------------------------------------------===//
@@ -184,18 +185,19 @@ struct ExpandAddCarryPattern final : OpRewritePattern<IAddCarryOp> {
loc,
llvm::formatv("Unexpected integer type for WebGPU: '{0}'", elemTy));
- Value one =
- rewriter.create<ConstantOp>(loc, argTy, getScalarOrSplatAttr(argTy, 1));
- Value zero =
- rewriter.create<ConstantOp>(loc, argTy, getScalarOrSplatAttr(argTy, 0));
+ Value one = ConstantOp::create(rewriter, loc, argTy,
+ getScalarOrSplatAttr(argTy, 1));
+ Value zero = ConstantOp::create(rewriter, loc, argTy,
+ getScalarOrSplatAttr(argTy, 0));
// Calculate the carry by checking if the addition resulted in an overflow.
- Value out = rewriter.create<IAddOp>(loc, lhs, rhs);
- Value cmp = rewriter.create<ULessThanOp>(loc, out, lhs);
- Value carry = rewriter.create<SelectOp>(loc, cmp, one, zero);
+ Value out = IAddOp::create(rewriter, loc, lhs, rhs);
+ Value cmp = ULessThanOp::create(rewriter, loc, out, lhs);
+ Value carry = SelectOp::create(rewriter, loc, cmp, one, zero);
- Value add = rewriter.create<CompositeConstructOp>(
- loc, op->getResultTypes().front(), llvm::ArrayRef({out, carry}));
+ Value add = CompositeConstructOp::create(rewriter, loc,
+ op->getResultTypes().front(),
+ llvm::ArrayRef({out, carry}));
rewriter.replaceOp(op, add);
return success();
diff --git a/mlir/lib/Dialect/SPIRV/Transforms/UnifyAliasedResourcePass.cpp b/mlir/lib/Dialect/SPIRV/Transforms/UnifyAliasedResourcePass.cpp
index 527d92634c196..692f2e7616e5a 100644
--- a/mlir/lib/Dialect/SPIRV/Transforms/UnifyAliasedResourcePass.cpp
+++ b/mlir/lib/Dialect/SPIRV/Transforms/UnifyAliasedResourcePass.cpp
@@ -380,13 +380,13 @@ struct ConvertAccessChain : public ConvertAliasResource<spirv::AccessChainOp> {
Type indexType = oldIndex.getType();
int ratio = dstNumBytes / srcNumBytes;
- auto ratioValue = rewriter.create<spirv::ConstantOp>(
- loc, indexType, rewriter.getIntegerAttr(indexType, ratio));
+ auto ratioValue = spirv::ConstantOp::create(
+ rewriter, loc, indexType, rewriter.getIntegerAttr(indexType, ratio));
indices.back() =
- rewriter.create<spirv::SDivOp>(loc, indexType, oldIndex, ratioValue);
- indices.push_back(
- rewriter.create<spirv::SModOp>(loc, indexType, oldIndex, ratioValue));
+ spirv::SDivOp::create(rewriter, loc, indexType, oldIndex, ratioValue);
+ indices.push_back(spirv::SModOp::create(rewriter, loc, indexType,
+ oldIndex, ratioValue));
rewriter.replaceOpWithNewOp<spirv::AccessChainOp>(
acOp, adaptor.getBasePtr(), indices);
@@ -407,11 +407,11 @@ struct ConvertAccessChain : public ConvertAliasResource<spirv::AccessChainOp> {
Type indexType = oldIndex.getType();
int ratio = srcNumBytes / dstNumBytes;
- auto ratioValue = rewriter.create<spirv::ConstantOp>(
- loc, indexType, rewriter.getIntegerAttr(indexType, ratio));
+ auto ratioValue = spirv::ConstantOp::create(
+ rewriter, loc, indexType, rewriter.getIntegerAttr(indexType, ratio));
indices.back() =
- rewriter.create<spirv::IMulOp>(loc, indexType, oldIndex, ratioValue);
+ spirv::IMulOp::create(rewriter, loc, indexType, oldIndex, ratioValue);
rewriter.replaceOpWithNewOp<spirv::AccessChainOp>(
acOp, adaptor.getBasePtr(), indices);
@@ -435,15 +435,15 @@ struct ConvertLoad : public ConvertAliasResource<spirv::LoadOp> {
auto dstElemType = cast<spirv::SPIRVType>(dstPtrType.getPointeeType());
Location loc = loadOp.getLoc();
- auto newLoadOp = rewriter.create<spirv::LoadOp>(loc, adaptor.getPtr());
+ auto newLoadOp = spirv::LoadOp::create(rewriter, loc, adaptor.getPtr());
if (srcElemType == dstElemType) {
rewriter.replaceOp(loadOp, newLoadOp->getResults());
return success();
}
if (areSameBitwidthScalarType(srcElemType, dstElemType)) {
- auto castOp = rewriter.create<spirv::BitcastOp>(loc, srcElemType,
- newLoadOp.getValue());
+ auto castOp = spirv::BitcastOp::create(rewriter, loc, srcElemType,
+ newLoadOp.getValue());
rewriter.replaceOp(loadOp, castOp->getResults());
return success();
@@ -475,14 +475,14 @@ struct ConvertLoad : public ConvertAliasResource<spirv::LoadOp> {
auto indices = llvm::to_vector<4>(acOp.getIndices());
for (int i = 1; i < ratio; ++i) {
// Load all subsequent components belonging to this element.
- indices.back() = rewriter.create<spirv::IAddOp>(
- loc, i32Type, indices.back(), oneValue);
- auto componentAcOp = rewriter.create<spirv::AccessChainOp>(
- loc, acOp.getBasePtr(), indices);
+ indices.back() = spirv::IAddOp::create(rewriter, loc, i32Type,
+ indices.back(), oneValue);
+ auto componentAcOp = spirv::AccessChainOp::create(
+ rewriter, loc, acOp.getBasePtr(), indices);
// Assuming little endian, this reads lower-ordered bits of the number
// to lower-numbered components of the vector.
components.push_back(
- rewriter.create<spirv::LoadOp>(loc, componentAcOp));
+ spirv::LoadOp::create(rewriter, loc, componentAcOp));
}
// Create a vector of the components and then cast back to the larger
@@ -510,15 +510,15 @@ struct ConvertLoad : public ConvertAliasResource<spirv::LoadOp> {
castType = VectorType::get({count}, castType);
for (Value &c : components)
- c = rewriter.create<spirv::BitcastOp>(loc, castType, c);
+ c = spirv::BitcastOp::create(rewriter, loc, castType, c);
}
}
- Value vectorValue = rewriter.create<spirv::CompositeConstructOp>(
- loc, vectorType, components);
+ Value vectorValue = spirv::CompositeConstructOp::create(
+ rewriter, loc, vectorType, components);
if (!isa<VectorType>(srcElemType))
vectorValue =
- rewriter.create<spirv::BitcastOp>(loc, srcElemType, vectorValue);
+ spirv::BitcastOp::create(rewriter, loc, srcElemType, vectorValue);
rewriter.replaceOp(loadOp, vectorValue);
return success();
}
@@ -546,7 +546,7 @@ struct ConvertStore : public ConvertAliasResource<spirv::StoreOp> {
Location loc = storeOp.getLoc();
Value value = adaptor.getValue();
if (srcElemType != dstElemType)
- value = rewriter.create<spirv::BitcastOp>(loc, dstElemType, value);
+ value = spirv::BitcastOp::create(rewriter, loc, dstElemType, value);
rewriter.replaceOpWithNewOp<spirv::StoreOp>(storeOp, adaptor.getPtr(),
value, storeOp->getAttrs());
return success();
More information about the Mlir-commits
mailing list