[Mlir-commits] [mlir] [mlir][NFC] update `mlir/Dialect` create APIs (22/n) (PR #149929)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Mon Jul 21 15:47:26 PDT 2025
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir
Author: Maksim Levental (makslevental)
<details>
<summary>Changes</summary>
See https://github.com/llvm/llvm-project/pull/147168 for more info.
---
Patch is 41.41 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/149929.diff
10 Files Affected:
- (modified) mlir/lib/Dialect/SPIRV/IR/ControlFlowOps.cpp (+8-8)
- (modified) mlir/lib/Dialect/SPIRV/IR/SPIRVCanonicalization.cpp (+12-12)
- (modified) mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp (+2-2)
- (modified) mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp (+21-21)
- (modified) mlir/lib/Dialect/SPIRV/Linking/ModuleCombiner/ModuleCombiner.cpp (+3-2)
- (modified) mlir/lib/Dialect/SPIRV/Transforms/LowerABIAttributesPass.cpp (+18-18)
- (modified) mlir/lib/Dialect/SPIRV/Transforms/RewriteInsertsPass.cpp (+2-2)
- (modified) mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp (+49-44)
- (modified) mlir/lib/Dialect/SPIRV/Transforms/SPIRVWebGPUTransforms.cpp (+25-23)
- (modified) mlir/lib/Dialect/SPIRV/Transforms/UnifyAliasedResourcePass.cpp (+21-21)
``````````diff
diff --git a/mlir/lib/Dialect/SPIRV/IR/ControlFlowOps.cpp b/mlir/lib/Dialect/SPIRV/IR/ControlFlowOps.cpp
index 371456552b5b5..890406df74e72 100644
--- a/mlir/lib/Dialect/SPIRV/IR/ControlFlowOps.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/ControlFlowOps.cpp
@@ -391,7 +391,7 @@ void LoopOp::addEntryAndMergeBlock(OpBuilder &builder) {
builder.createBlock(&getBody());
// Add a spirv.mlir.merge op into the merge block.
- builder.create<spirv::MergeOp>(getLoc());
+ spirv::MergeOp::create(builder, getLoc());
}
//===----------------------------------------------------------------------===//
@@ -543,7 +543,7 @@ void SelectionOp::addMergeBlock(OpBuilder &builder) {
builder.createBlock(&getBody());
// Add a spirv.mlir.merge op into the merge block.
- builder.create<spirv::MergeOp>(getLoc());
+ spirv::MergeOp::create(builder, getLoc());
}
SelectionOp
@@ -551,7 +551,7 @@ SelectionOp::createIfThen(Location loc, Value condition,
function_ref<void(OpBuilder &builder)> thenBody,
OpBuilder &builder) {
auto selectionOp =
- builder.create<spirv::SelectionOp>(loc, spirv::SelectionControl::None);
+ spirv::SelectionOp::create(builder, loc, spirv::SelectionControl::None);
selectionOp.addMergeBlock(builder);
Block *mergeBlock = selectionOp.getMergeBlock();
@@ -562,17 +562,17 @@ SelectionOp::createIfThen(Location loc, Value condition,
OpBuilder::InsertionGuard guard(builder);
thenBlock = builder.createBlock(mergeBlock);
thenBody(builder);
- builder.create<spirv::BranchOp>(loc, mergeBlock);
+ spirv::BranchOp::create(builder, loc, mergeBlock);
}
// Build the header block.
{
OpBuilder::InsertionGuard guard(builder);
builder.createBlock(thenBlock);
- builder.create<spirv::BranchConditionalOp>(
- loc, condition, thenBlock,
- /*trueArguments=*/ArrayRef<Value>(), mergeBlock,
- /*falseArguments=*/ArrayRef<Value>());
+ spirv::BranchConditionalOp::create(builder, loc, condition, thenBlock,
+ /*trueArguments=*/ArrayRef<Value>(),
+ mergeBlock,
+ /*falseArguments=*/ArrayRef<Value>());
}
return selectionOp;
diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVCanonicalization.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVCanonicalization.cpp
index 047f8da0cc003..2bde44baf961e 100644
--- a/mlir/lib/Dialect/SPIRV/IR/SPIRVCanonicalization.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVCanonicalization.cpp
@@ -178,16 +178,16 @@ struct IAddCarryFold final : OpRewritePattern<spirv::IAddCarryOp> {
return failure();
Value addsVal =
- rewriter.create<spirv::ConstantOp>(loc, constituentType, adds);
+ spirv::ConstantOp::create(rewriter, loc, constituentType, adds);
Value carrysVal =
- rewriter.create<spirv::ConstantOp>(loc, constituentType, carrys);
+ spirv::ConstantOp::create(rewriter, loc, constituentType, carrys);
// Create empty struct
- Value undef = rewriter.create<spirv::UndefOp>(loc, op.getType());
+ Value undef = spirv::UndefOp::create(rewriter, loc, op.getType());
// Fill in adds at id 0
Value intermediate =
- rewriter.create<spirv::CompositeInsertOp>(loc, addsVal, undef, 0);
+ spirv::CompositeInsertOp::create(rewriter, loc, addsVal, undef, 0);
// Fill in carrys at id 1
rewriter.replaceOpWithNewOp<spirv::CompositeInsertOp>(op, carrysVal,
intermediate, 1);
@@ -260,16 +260,16 @@ struct MulExtendedFold final : OpRewritePattern<MulOp> {
return failure();
Value lowBitsVal =
- rewriter.create<spirv::ConstantOp>(loc, constituentType, lowBits);
+ spirv::ConstantOp::create(rewriter, loc, constituentType, lowBits);
Value highBitsVal =
- rewriter.create<spirv::ConstantOp>(loc, constituentType, highBits);
+ spirv::ConstantOp::create(rewriter, loc, constituentType, highBits);
// Create empty struct
- Value undef = rewriter.create<spirv::UndefOp>(loc, op.getType());
+ Value undef = spirv::UndefOp::create(rewriter, loc, op.getType());
// Fill in lowBits at id 0
Value intermediate =
- rewriter.create<spirv::CompositeInsertOp>(loc, lowBitsVal, undef, 0);
+ spirv::CompositeInsertOp::create(rewriter, loc, lowBitsVal, undef, 0);
// Fill in highBits at id 1
rewriter.replaceOpWithNewOp<spirv::CompositeInsertOp>(op, highBitsVal,
intermediate, 1);
@@ -1309,11 +1309,11 @@ struct ConvertSelectionOpToSelect final : OpRewritePattern<spirv::SelectionOp> {
auto storeOpAttributes =
cast<spirv::StoreOp>(trueBlock->front())->getAttrs();
- auto selectOp = rewriter.create<spirv::SelectOp>(
- selectionOp.getLoc(), trueValue.getType(),
+ auto selectOp = spirv::SelectOp::create(
+ rewriter, selectionOp.getLoc(), trueValue.getType(),
brConditionalOp.getCondition(), trueValue, falseValue);
- rewriter.create<spirv::StoreOp>(selectOp.getLoc(), ptrValue,
- selectOp.getResult(), storeOpAttributes);
+ spirv::StoreOp::create(rewriter, selectOp.getLoc(), ptrValue,
+ selectOp.getResult(), storeOpAttributes);
// `spirv.mlir.selection` is not needed anymore.
rewriter.eraseOp(op);
diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp
index f32c53b8f0b9e..c9a8e97bd3296 100644
--- a/mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp
@@ -940,12 +940,12 @@ Operation *SPIRVDialect::materializeConstant(OpBuilder &builder,
Attribute value, Type type,
Location loc) {
if (auto poison = dyn_cast<ub::PoisonAttr>(value))
- return builder.create<ub::PoisonOp>(loc, type, poison);
+ return ub::PoisonOp::create(builder, loc, type, poison);
if (!spirv::ConstantOp::isBuildableWith(type))
return nullptr;
- return builder.create<spirv::ConstantOp>(loc, type, value);
+ return spirv::ConstantOp::create(builder, loc, type, value);
}
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
index 656236246b1ad..52c672a05fa43 100644
--- a/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
@@ -651,26 +651,26 @@ spirv::ConstantOp spirv::ConstantOp::getZero(Type type, Location loc,
if (auto intType = llvm::dyn_cast<IntegerType>(type)) {
unsigned width = intType.getWidth();
if (width == 1)
- return builder.create<spirv::ConstantOp>(loc, type,
- builder.getBoolAttr(false));
- return builder.create<spirv::ConstantOp>(
- loc, type, builder.getIntegerAttr(type, APInt(width, 0)));
+ return spirv::ConstantOp::create(builder, loc, type,
+ builder.getBoolAttr(false));
+ return spirv::ConstantOp::create(
+ builder, loc, type, builder.getIntegerAttr(type, APInt(width, 0)));
}
if (auto floatType = llvm::dyn_cast<FloatType>(type)) {
- return builder.create<spirv::ConstantOp>(
- loc, type, builder.getFloatAttr(floatType, 0.0));
+ return spirv::ConstantOp::create(builder, loc, type,
+ builder.getFloatAttr(floatType, 0.0));
}
if (auto vectorType = llvm::dyn_cast<VectorType>(type)) {
Type elemType = vectorType.getElementType();
if (llvm::isa<IntegerType>(elemType)) {
- return builder.create<spirv::ConstantOp>(
- loc, type,
+ return spirv::ConstantOp::create(
+ builder, loc, type,
DenseElementsAttr::get(vectorType,
IntegerAttr::get(elemType, 0).getValue()));
}
if (llvm::isa<FloatType>(elemType)) {
- return builder.create<spirv::ConstantOp>(
- loc, type,
+ return spirv::ConstantOp::create(
+ builder, loc, type,
DenseFPElementsAttr::get(vectorType,
FloatAttr::get(elemType, 0.0).getValue()));
}
@@ -684,26 +684,26 @@ spirv::ConstantOp spirv::ConstantOp::getOne(Type type, Location loc,
if (auto intType = llvm::dyn_cast<IntegerType>(type)) {
unsigned width = intType.getWidth();
if (width == 1)
- return builder.create<spirv::ConstantOp>(loc, type,
- builder.getBoolAttr(true));
- return builder.create<spirv::ConstantOp>(
- loc, type, builder.getIntegerAttr(type, APInt(width, 1)));
+ return spirv::ConstantOp::create(builder, loc, type,
+ builder.getBoolAttr(true));
+ return spirv::ConstantOp::create(
+ builder, loc, type, builder.getIntegerAttr(type, APInt(width, 1)));
}
if (auto floatType = llvm::dyn_cast<FloatType>(type)) {
- return builder.create<spirv::ConstantOp>(
- loc, type, builder.getFloatAttr(floatType, 1.0));
+ return spirv::ConstantOp::create(builder, loc, type,
+ builder.getFloatAttr(floatType, 1.0));
}
if (auto vectorType = llvm::dyn_cast<VectorType>(type)) {
Type elemType = vectorType.getElementType();
if (llvm::isa<IntegerType>(elemType)) {
- return builder.create<spirv::ConstantOp>(
- loc, type,
+ return spirv::ConstantOp::create(
+ builder, loc, type,
DenseElementsAttr::get(vectorType,
IntegerAttr::get(elemType, 1).getValue()));
}
if (llvm::isa<FloatType>(elemType)) {
- return builder.create<spirv::ConstantOp>(
- loc, type,
+ return spirv::ConstantOp::create(
+ builder, loc, type,
DenseFPElementsAttr::get(vectorType,
FloatAttr::get(elemType, 1.0).getValue()));
}
@@ -1985,7 +1985,7 @@ ParseResult spirv::SpecConstantOperationOp::parse(OpAsmParser &parser,
OpBuilder builder(parser.getContext());
builder.setInsertionPointToEnd(&block);
- builder.create<spirv::YieldOp>(wrappedOp->getLoc(), wrappedOp->getResult(0));
+ spirv::YieldOp::create(builder, wrappedOp->getLoc(), wrappedOp->getResult(0));
result.location = wrappedOp->getLoc();
result.addTypes(wrappedOp->getResult(0).getType());
diff --git a/mlir/lib/Dialect/SPIRV/Linking/ModuleCombiner/ModuleCombiner.cpp b/mlir/lib/Dialect/SPIRV/Linking/ModuleCombiner/ModuleCombiner.cpp
index 8da688806bade..2b9c7296830dc 100644
--- a/mlir/lib/Dialect/SPIRV/Linking/ModuleCombiner/ModuleCombiner.cpp
+++ b/mlir/lib/Dialect/SPIRV/Linking/ModuleCombiner/ModuleCombiner.cpp
@@ -105,8 +105,9 @@ OwningOpRef<spirv::ModuleOp> combine(ArrayRef<spirv::ModuleOp> inputModules,
}
}
- auto combinedModule = combinedModuleBuilder.create<spirv::ModuleOp>(
- firstModule.getLoc(), addressingModel, memoryModel, vceTriple);
+ auto combinedModule =
+ spirv::ModuleOp::create(combinedModuleBuilder, firstModule.getLoc(),
+ addressingModel, memoryModel, vceTriple);
combinedModuleBuilder.setInsertionPointToStart(combinedModule.getBody());
// In some cases, a symbol in the (current state of the) combined module is
diff --git a/mlir/lib/Dialect/SPIRV/Transforms/LowerABIAttributesPass.cpp b/mlir/lib/Dialect/SPIRV/Transforms/LowerABIAttributesPass.cpp
index 85525a5a02fa2..81365b44a3aad 100644
--- a/mlir/lib/Dialect/SPIRV/Transforms/LowerABIAttributesPass.cpp
+++ b/mlir/lib/Dialect/SPIRV/Transforms/LowerABIAttributesPass.cpp
@@ -70,9 +70,9 @@ createGlobalVarForEntryPointArgument(OpBuilder &builder, spirv::FuncOp funcOp,
varType =
spirv::PointerType::get(varPointeeType, varPtrType.getStorageClass());
- return builder.create<spirv::GlobalVariableOp>(
- funcOp.getLoc(), varType, varName, abiInfo.getDescriptorSet(),
- abiInfo.getBinding());
+ return spirv::GlobalVariableOp::create(builder, funcOp.getLoc(), varType,
+ varName, abiInfo.getDescriptorSet(),
+ abiInfo.getBinding());
}
/// Gets the global variables that need to be specified as interface variable
@@ -146,17 +146,17 @@ static LogicalResult lowerEntryPointABIAttr(spirv::FuncOp funcOp,
return funcOp.emitRemark("lower entry point failure: could not select "
"execution model based on 'spirv.target_env'");
- builder.create<spirv::EntryPointOp>(funcOp.getLoc(), *executionModel, funcOp,
- interfaceVars);
+ spirv::EntryPointOp::create(builder, funcOp.getLoc(), *executionModel, funcOp,
+ interfaceVars);
// Specifies the spirv.ExecutionModeOp.
if (DenseI32ArrayAttr workgroupSizeAttr = entryPointAttr.getWorkgroupSize()) {
std::optional<ArrayRef<spirv::Capability>> caps =
spirv::getCapabilities(spirv::ExecutionMode::LocalSize);
if (!caps || targetEnv.allows(*caps)) {
- builder.create<spirv::ExecutionModeOp>(funcOp.getLoc(), funcOp,
- spirv::ExecutionMode::LocalSize,
- workgroupSizeAttr.asArrayRef());
+ spirv::ExecutionModeOp::create(builder, funcOp.getLoc(), funcOp,
+ spirv::ExecutionMode::LocalSize,
+ workgroupSizeAttr.asArrayRef());
// Erase workgroup size.
entryPointAttr = spirv::EntryPointABIAttr::get(
entryPointAttr.getContext(), DenseI32ArrayAttr(),
@@ -167,9 +167,9 @@ static LogicalResult lowerEntryPointABIAttr(spirv::FuncOp funcOp,
std::optional<ArrayRef<spirv::Capability>> caps =
spirv::getCapabilities(spirv::ExecutionMode::SubgroupSize);
if (!caps || targetEnv.allows(*caps)) {
- builder.create<spirv::ExecutionModeOp>(funcOp.getLoc(), funcOp,
- spirv::ExecutionMode::SubgroupSize,
- *subgroupSize);
+ spirv::ExecutionModeOp::create(builder, funcOp.getLoc(), funcOp,
+ spirv::ExecutionMode::SubgroupSize,
+ *subgroupSize);
// Erase subgroup size.
entryPointAttr = spirv::EntryPointABIAttr::get(
entryPointAttr.getContext(), entryPointAttr.getWorkgroupSize(),
@@ -180,8 +180,8 @@ static LogicalResult lowerEntryPointABIAttr(spirv::FuncOp funcOp,
std::optional<ArrayRef<spirv::Capability>> caps =
spirv::getCapabilities(spirv::ExecutionMode::SignedZeroInfNanPreserve);
if (!caps || targetEnv.allows(*caps)) {
- builder.create<spirv::ExecutionModeOp>(
- funcOp.getLoc(), funcOp,
+ spirv::ExecutionModeOp::create(
+ builder, funcOp.getLoc(), funcOp,
spirv::ExecutionMode::SignedZeroInfNanPreserve, *targetWidth);
// Erase target width.
entryPointAttr = spirv::EntryPointABIAttr::get(
@@ -259,7 +259,7 @@ LogicalResult ProcessInterfaceVarABI::matchAndRewrite(
// Insert spirv::AddressOf and spirv::AccessChain operations.
Value replacement =
- rewriter.create<spirv::AddressOfOp>(funcOp.getLoc(), var);
+ spirv::AddressOfOp::create(rewriter, funcOp.getLoc(), var);
// Check if the arg is a scalar or vector type. In that case, the value
// needs to be loaded into registers.
// TODO: This is loading value of the scalar into registers
@@ -269,9 +269,9 @@ LogicalResult ProcessInterfaceVarABI::matchAndRewrite(
if (cast<spirv::SPIRVType>(argType.value()).isScalarOrVector()) {
auto zero =
spirv::ConstantOp::getZero(indexType, funcOp.getLoc(), rewriter);
- auto loadPtr = rewriter.create<spirv::AccessChainOp>(
- funcOp.getLoc(), replacement, zero.getConstant());
- replacement = rewriter.create<spirv::LoadOp>(funcOp.getLoc(), loadPtr);
+ auto loadPtr = spirv::AccessChainOp::create(
+ rewriter, funcOp.getLoc(), replacement, zero.getConstant());
+ replacement = spirv::LoadOp::create(rewriter, funcOp.getLoc(), loadPtr);
}
signatureConverter.remapInput(argType.index(), replacement);
}
@@ -308,7 +308,7 @@ void LowerABIAttributesPass::runOnOperation() {
ValueRange inputs, Location loc) {
if (inputs.size() != 1 || !isa<spirv::PointerType>(inputs[0].getType()))
return Value();
- return builder.create<spirv::BitcastOp>(loc, type, inputs[0]).getResult();
+ return spirv::BitcastOp::create(builder, loc, type, inputs[0]).getResult();
});
RewritePatternSet patterns(context);
diff --git a/mlir/lib/Dialect/SPIRV/Transforms/RewriteInsertsPass.cpp b/mlir/lib/Dialect/SPIRV/Transforms/RewriteInsertsPass.cpp
index ab5898d0e3925..38ef547f0769f 100644
--- a/mlir/lib/Dialect/SPIRV/Transforms/RewriteInsertsPass.cpp
+++ b/mlir/lib/Dialect/SPIRV/Transforms/RewriteInsertsPass.cpp
@@ -65,8 +65,8 @@ void RewriteInsertsPass::runOnOperation() {
operands.push_back(insertionOp.getObject());
OpBuilder builder(lastCompositeInsertOp);
- auto compositeConstructOp = builder.create<spirv::CompositeConstructOp>(
- location, compositeType, operands);
+ auto compositeConstructOp = spirv::CompositeConstructOp::create(
+ builder, location, compositeType, operands);
lastCompositeInsertOp.replaceAllUsesWith(
compositeConstructOp->getResult(0));
diff --git a/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp b/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp
index f70b3325f8725..35ec0190b5a61 100644
--- a/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp
+++ b/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp
@@ -669,21 +669,24 @@ static Value castToSourceType(const spirv::TargetEnv &targetEnv,
Location loc) {
// We can only cast one value in SPIR-V.
if (inputs.size() != 1) {
- auto castOp = builder.create<UnrealizedConversionCastOp>(loc, type, inputs);
+ auto castOp =
+ UnrealizedConversionCastOp::create(builder, loc, type, inputs);
return castOp.getResult(0);
}
Value input = inputs.front();
// Only support integer types for now. Floating point types to be implemented.
if (!isa<IntegerType>(type)) {
- auto castOp = builder.create<UnrealizedConversionCastOp>(loc, type, inputs);
+ auto castOp =
+ UnrealizedConversionCastOp::create(builder, loc, type, inputs);
return castOp.getResult(0);
}
auto inputType = cast<IntegerType>(input.getType());
auto scalarType = dyn_cast<spirv::ScalarType>(type);
if (!scalarType) {
- auto castOp = builder.create<UnrealizedConversionCastOp>(loc, type, inputs);
+ auto castOp =
+ UnrealizedConversionCastOp::create(builder, loc, type, inputs);
return castOp.getResult(0);
}
@@ -691,14 +694,15 @@ static Value castToSourceType(const spirv::TargetEnv &targetEnv,
// truncating to go back so we don't need to worry about the signedness.
// For extension, we cannot have enough signal here to decide which op to use.
if (inputType.getIntOrFloatBitWidth() < scalarType.getIntOrFloatBitWidth()) {
- auto castOp = builder.create<UnrealizedConversionCastOp>(loc, type, inputs);
+ auto castOp =
+ UnrealizedConversionCastOp::create(builder, loc, type, inputs);
return castOp.getResult(0);
}
// Boolean values would need to use different ops than normal integer values.
if (type.isInteger(1)) {
Value one = spirv::ConstantOp::getOne(inputType, loc, builder);
- return builder.create<spirv::IEqualOp>(loc, input, one);
+ return spirv::IEqualOp::create(builder, loc, input, one);
}
// Check that the source integer type is supported by the environment.
@@ -708,7 +712,8 @@ static Value castToSourceType(const spirv::TargetEnv &targetEnv,
scalarType.getCapabilities(caps);
if (failed(checkCapabilityRequirements(type, targetEnv, caps)) ||
failed(checkExtensionRequirements(type, targetEnv, exts))) {
- auto castOp = builder.create<UnrealizedConversionCastOp>(loc, type, inputs);
+ auto castO...
[truncated]
``````````
</details>
https://github.com/llvm/llvm-project/pull/149929
More information about the Mlir-commits
mailing list