[Mlir-commits] [mlir] c3823af - [mlir][NFC] update `mlir/Dialect` create APIs (22/n) (#149929)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Mon Jul 21 16:57:47 PDT 2025


Author: Maksim Levental
Date: 2025-07-21T19:57:44-04:00
New Revision: c3823af156b517d926a56e3d0d585e2a15720e96

URL: https://github.com/llvm/llvm-project/commit/c3823af156b517d926a56e3d0d585e2a15720e96
DIFF: https://github.com/llvm/llvm-project/commit/c3823af156b517d926a56e3d0d585e2a15720e96.diff

LOG: [mlir][NFC] update `mlir/Dialect` create APIs (22/n) (#149929)

See https://github.com/llvm/llvm-project/pull/147168 for more info.

Added: 
    

Modified: 
    mlir/lib/Dialect/SPIRV/IR/ControlFlowOps.cpp
    mlir/lib/Dialect/SPIRV/IR/SPIRVCanonicalization.cpp
    mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp
    mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
    mlir/lib/Dialect/SPIRV/Linking/ModuleCombiner/ModuleCombiner.cpp
    mlir/lib/Dialect/SPIRV/Transforms/LowerABIAttributesPass.cpp
    mlir/lib/Dialect/SPIRV/Transforms/RewriteInsertsPass.cpp
    mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp
    mlir/lib/Dialect/SPIRV/Transforms/SPIRVWebGPUTransforms.cpp
    mlir/lib/Dialect/SPIRV/Transforms/UnifyAliasedResourcePass.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/SPIRV/IR/ControlFlowOps.cpp b/mlir/lib/Dialect/SPIRV/IR/ControlFlowOps.cpp
index 371456552b5b5..890406df74e72 100644
--- a/mlir/lib/Dialect/SPIRV/IR/ControlFlowOps.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/ControlFlowOps.cpp
@@ -391,7 +391,7 @@ void LoopOp::addEntryAndMergeBlock(OpBuilder &builder) {
   builder.createBlock(&getBody());
 
   // Add a spirv.mlir.merge op into the merge block.
-  builder.create<spirv::MergeOp>(getLoc());
+  spirv::MergeOp::create(builder, getLoc());
 }
 
 //===----------------------------------------------------------------------===//
@@ -543,7 +543,7 @@ void SelectionOp::addMergeBlock(OpBuilder &builder) {
   builder.createBlock(&getBody());
 
   // Add a spirv.mlir.merge op into the merge block.
-  builder.create<spirv::MergeOp>(getLoc());
+  spirv::MergeOp::create(builder, getLoc());
 }
 
 SelectionOp
@@ -551,7 +551,7 @@ SelectionOp::createIfThen(Location loc, Value condition,
                           function_ref<void(OpBuilder &builder)> thenBody,
                           OpBuilder &builder) {
   auto selectionOp =
-      builder.create<spirv::SelectionOp>(loc, spirv::SelectionControl::None);
+      spirv::SelectionOp::create(builder, loc, spirv::SelectionControl::None);
 
   selectionOp.addMergeBlock(builder);
   Block *mergeBlock = selectionOp.getMergeBlock();
@@ -562,17 +562,17 @@ SelectionOp::createIfThen(Location loc, Value condition,
     OpBuilder::InsertionGuard guard(builder);
     thenBlock = builder.createBlock(mergeBlock);
     thenBody(builder);
-    builder.create<spirv::BranchOp>(loc, mergeBlock);
+    spirv::BranchOp::create(builder, loc, mergeBlock);
   }
 
   // Build the header block.
   {
     OpBuilder::InsertionGuard guard(builder);
     builder.createBlock(thenBlock);
-    builder.create<spirv::BranchConditionalOp>(
-        loc, condition, thenBlock,
-        /*trueArguments=*/ArrayRef<Value>(), mergeBlock,
-        /*falseArguments=*/ArrayRef<Value>());
+    spirv::BranchConditionalOp::create(builder, loc, condition, thenBlock,
+                                       /*trueArguments=*/ArrayRef<Value>(),
+                                       mergeBlock,
+                                       /*falseArguments=*/ArrayRef<Value>());
   }
 
   return selectionOp;

diff  --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVCanonicalization.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVCanonicalization.cpp
index 047f8da0cc003..2bde44baf961e 100644
--- a/mlir/lib/Dialect/SPIRV/IR/SPIRVCanonicalization.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVCanonicalization.cpp
@@ -178,16 +178,16 @@ struct IAddCarryFold final : OpRewritePattern<spirv::IAddCarryOp> {
       return failure();
 
     Value addsVal =
-        rewriter.create<spirv::ConstantOp>(loc, constituentType, adds);
+        spirv::ConstantOp::create(rewriter, loc, constituentType, adds);
 
     Value carrysVal =
-        rewriter.create<spirv::ConstantOp>(loc, constituentType, carrys);
+        spirv::ConstantOp::create(rewriter, loc, constituentType, carrys);
 
     // Create empty struct
-    Value undef = rewriter.create<spirv::UndefOp>(loc, op.getType());
+    Value undef = spirv::UndefOp::create(rewriter, loc, op.getType());
     // Fill in adds at id 0
     Value intermediate =
-        rewriter.create<spirv::CompositeInsertOp>(loc, addsVal, undef, 0);
+        spirv::CompositeInsertOp::create(rewriter, loc, addsVal, undef, 0);
     // Fill in carrys at id 1
     rewriter.replaceOpWithNewOp<spirv::CompositeInsertOp>(op, carrysVal,
                                                           intermediate, 1);
@@ -260,16 +260,16 @@ struct MulExtendedFold final : OpRewritePattern<MulOp> {
       return failure();
 
     Value lowBitsVal =
-        rewriter.create<spirv::ConstantOp>(loc, constituentType, lowBits);
+        spirv::ConstantOp::create(rewriter, loc, constituentType, lowBits);
 
     Value highBitsVal =
-        rewriter.create<spirv::ConstantOp>(loc, constituentType, highBits);
+        spirv::ConstantOp::create(rewriter, loc, constituentType, highBits);
 
     // Create empty struct
-    Value undef = rewriter.create<spirv::UndefOp>(loc, op.getType());
+    Value undef = spirv::UndefOp::create(rewriter, loc, op.getType());
     // Fill in lowBits at id 0
     Value intermediate =
-        rewriter.create<spirv::CompositeInsertOp>(loc, lowBitsVal, undef, 0);
+        spirv::CompositeInsertOp::create(rewriter, loc, lowBitsVal, undef, 0);
     // Fill in highBits at id 1
     rewriter.replaceOpWithNewOp<spirv::CompositeInsertOp>(op, highBitsVal,
                                                           intermediate, 1);
@@ -1309,11 +1309,11 @@ struct ConvertSelectionOpToSelect final : OpRewritePattern<spirv::SelectionOp> {
     auto storeOpAttributes =
         cast<spirv::StoreOp>(trueBlock->front())->getAttrs();
 
-    auto selectOp = rewriter.create<spirv::SelectOp>(
-        selectionOp.getLoc(), trueValue.getType(),
+    auto selectOp = spirv::SelectOp::create(
+        rewriter, selectionOp.getLoc(), trueValue.getType(),
         brConditionalOp.getCondition(), trueValue, falseValue);
-    rewriter.create<spirv::StoreOp>(selectOp.getLoc(), ptrValue,
-                                    selectOp.getResult(), storeOpAttributes);
+    spirv::StoreOp::create(rewriter, selectOp.getLoc(), ptrValue,
+                           selectOp.getResult(), storeOpAttributes);
 
     // `spirv.mlir.selection` is not needed anymore.
     rewriter.eraseOp(op);

diff  --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp
index f32c53b8f0b9e..c9a8e97bd3296 100644
--- a/mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp
@@ -940,12 +940,12 @@ Operation *SPIRVDialect::materializeConstant(OpBuilder &builder,
                                              Attribute value, Type type,
                                              Location loc) {
   if (auto poison = dyn_cast<ub::PoisonAttr>(value))
-    return builder.create<ub::PoisonOp>(loc, type, poison);
+    return ub::PoisonOp::create(builder, loc, type, poison);
 
   if (!spirv::ConstantOp::isBuildableWith(type))
     return nullptr;
 
-  return builder.create<spirv::ConstantOp>(loc, type, value);
+  return spirv::ConstantOp::create(builder, loc, type, value);
 }
 
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
index 656236246b1ad..52c672a05fa43 100644
--- a/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
@@ -651,26 +651,26 @@ spirv::ConstantOp spirv::ConstantOp::getZero(Type type, Location loc,
   if (auto intType = llvm::dyn_cast<IntegerType>(type)) {
     unsigned width = intType.getWidth();
     if (width == 1)
-      return builder.create<spirv::ConstantOp>(loc, type,
-                                               builder.getBoolAttr(false));
-    return builder.create<spirv::ConstantOp>(
-        loc, type, builder.getIntegerAttr(type, APInt(width, 0)));
+      return spirv::ConstantOp::create(builder, loc, type,
+                                       builder.getBoolAttr(false));
+    return spirv::ConstantOp::create(
+        builder, loc, type, builder.getIntegerAttr(type, APInt(width, 0)));
   }
   if (auto floatType = llvm::dyn_cast<FloatType>(type)) {
-    return builder.create<spirv::ConstantOp>(
-        loc, type, builder.getFloatAttr(floatType, 0.0));
+    return spirv::ConstantOp::create(builder, loc, type,
+                                     builder.getFloatAttr(floatType, 0.0));
   }
   if (auto vectorType = llvm::dyn_cast<VectorType>(type)) {
     Type elemType = vectorType.getElementType();
     if (llvm::isa<IntegerType>(elemType)) {
-      return builder.create<spirv::ConstantOp>(
-          loc, type,
+      return spirv::ConstantOp::create(
+          builder, loc, type,
           DenseElementsAttr::get(vectorType,
                                  IntegerAttr::get(elemType, 0).getValue()));
     }
     if (llvm::isa<FloatType>(elemType)) {
-      return builder.create<spirv::ConstantOp>(
-          loc, type,
+      return spirv::ConstantOp::create(
+          builder, loc, type,
           DenseFPElementsAttr::get(vectorType,
                                    FloatAttr::get(elemType, 0.0).getValue()));
     }
@@ -684,26 +684,26 @@ spirv::ConstantOp spirv::ConstantOp::getOne(Type type, Location loc,
   if (auto intType = llvm::dyn_cast<IntegerType>(type)) {
     unsigned width = intType.getWidth();
     if (width == 1)
-      return builder.create<spirv::ConstantOp>(loc, type,
-                                               builder.getBoolAttr(true));
-    return builder.create<spirv::ConstantOp>(
-        loc, type, builder.getIntegerAttr(type, APInt(width, 1)));
+      return spirv::ConstantOp::create(builder, loc, type,
+                                       builder.getBoolAttr(true));
+    return spirv::ConstantOp::create(
+        builder, loc, type, builder.getIntegerAttr(type, APInt(width, 1)));
   }
   if (auto floatType = llvm::dyn_cast<FloatType>(type)) {
-    return builder.create<spirv::ConstantOp>(
-        loc, type, builder.getFloatAttr(floatType, 1.0));
+    return spirv::ConstantOp::create(builder, loc, type,
+                                     builder.getFloatAttr(floatType, 1.0));
   }
   if (auto vectorType = llvm::dyn_cast<VectorType>(type)) {
     Type elemType = vectorType.getElementType();
     if (llvm::isa<IntegerType>(elemType)) {
-      return builder.create<spirv::ConstantOp>(
-          loc, type,
+      return spirv::ConstantOp::create(
+          builder, loc, type,
           DenseElementsAttr::get(vectorType,
                                  IntegerAttr::get(elemType, 1).getValue()));
     }
     if (llvm::isa<FloatType>(elemType)) {
-      return builder.create<spirv::ConstantOp>(
-          loc, type,
+      return spirv::ConstantOp::create(
+          builder, loc, type,
           DenseFPElementsAttr::get(vectorType,
                                    FloatAttr::get(elemType, 1.0).getValue()));
     }
@@ -1985,7 +1985,7 @@ ParseResult spirv::SpecConstantOperationOp::parse(OpAsmParser &parser,
 
   OpBuilder builder(parser.getContext());
   builder.setInsertionPointToEnd(&block);
-  builder.create<spirv::YieldOp>(wrappedOp->getLoc(), wrappedOp->getResult(0));
+  spirv::YieldOp::create(builder, wrappedOp->getLoc(), wrappedOp->getResult(0));
   result.location = wrappedOp->getLoc();
 
   result.addTypes(wrappedOp->getResult(0).getType());

diff  --git a/mlir/lib/Dialect/SPIRV/Linking/ModuleCombiner/ModuleCombiner.cpp b/mlir/lib/Dialect/SPIRV/Linking/ModuleCombiner/ModuleCombiner.cpp
index 8da688806bade..2b9c7296830dc 100644
--- a/mlir/lib/Dialect/SPIRV/Linking/ModuleCombiner/ModuleCombiner.cpp
+++ b/mlir/lib/Dialect/SPIRV/Linking/ModuleCombiner/ModuleCombiner.cpp
@@ -105,8 +105,9 @@ OwningOpRef<spirv::ModuleOp> combine(ArrayRef<spirv::ModuleOp> inputModules,
     }
   }
 
-  auto combinedModule = combinedModuleBuilder.create<spirv::ModuleOp>(
-      firstModule.getLoc(), addressingModel, memoryModel, vceTriple);
+  auto combinedModule =
+      spirv::ModuleOp::create(combinedModuleBuilder, firstModule.getLoc(),
+                              addressingModel, memoryModel, vceTriple);
   combinedModuleBuilder.setInsertionPointToStart(combinedModule.getBody());
 
   // In some cases, a symbol in the (current state of the) combined module is

diff  --git a/mlir/lib/Dialect/SPIRV/Transforms/LowerABIAttributesPass.cpp b/mlir/lib/Dialect/SPIRV/Transforms/LowerABIAttributesPass.cpp
index 85525a5a02fa2..81365b44a3aad 100644
--- a/mlir/lib/Dialect/SPIRV/Transforms/LowerABIAttributesPass.cpp
+++ b/mlir/lib/Dialect/SPIRV/Transforms/LowerABIAttributesPass.cpp
@@ -70,9 +70,9 @@ createGlobalVarForEntryPointArgument(OpBuilder &builder, spirv::FuncOp funcOp,
   varType =
       spirv::PointerType::get(varPointeeType, varPtrType.getStorageClass());
 
-  return builder.create<spirv::GlobalVariableOp>(
-      funcOp.getLoc(), varType, varName, abiInfo.getDescriptorSet(),
-      abiInfo.getBinding());
+  return spirv::GlobalVariableOp::create(builder, funcOp.getLoc(), varType,
+                                         varName, abiInfo.getDescriptorSet(),
+                                         abiInfo.getBinding());
 }
 
 /// Gets the global variables that need to be specified as interface variable
@@ -146,17 +146,17 @@ static LogicalResult lowerEntryPointABIAttr(spirv::FuncOp funcOp,
     return funcOp.emitRemark("lower entry point failure: could not select "
                              "execution model based on 'spirv.target_env'");
 
-  builder.create<spirv::EntryPointOp>(funcOp.getLoc(), *executionModel, funcOp,
-                                      interfaceVars);
+  spirv::EntryPointOp::create(builder, funcOp.getLoc(), *executionModel, funcOp,
+                              interfaceVars);
 
   // Specifies the spirv.ExecutionModeOp.
   if (DenseI32ArrayAttr workgroupSizeAttr = entryPointAttr.getWorkgroupSize()) {
     std::optional<ArrayRef<spirv::Capability>> caps =
         spirv::getCapabilities(spirv::ExecutionMode::LocalSize);
     if (!caps || targetEnv.allows(*caps)) {
-      builder.create<spirv::ExecutionModeOp>(funcOp.getLoc(), funcOp,
-                                             spirv::ExecutionMode::LocalSize,
-                                             workgroupSizeAttr.asArrayRef());
+      spirv::ExecutionModeOp::create(builder, funcOp.getLoc(), funcOp,
+                                     spirv::ExecutionMode::LocalSize,
+                                     workgroupSizeAttr.asArrayRef());
       // Erase workgroup size.
       entryPointAttr = spirv::EntryPointABIAttr::get(
           entryPointAttr.getContext(), DenseI32ArrayAttr(),
@@ -167,9 +167,9 @@ static LogicalResult lowerEntryPointABIAttr(spirv::FuncOp funcOp,
     std::optional<ArrayRef<spirv::Capability>> caps =
         spirv::getCapabilities(spirv::ExecutionMode::SubgroupSize);
     if (!caps || targetEnv.allows(*caps)) {
-      builder.create<spirv::ExecutionModeOp>(funcOp.getLoc(), funcOp,
-                                             spirv::ExecutionMode::SubgroupSize,
-                                             *subgroupSize);
+      spirv::ExecutionModeOp::create(builder, funcOp.getLoc(), funcOp,
+                                     spirv::ExecutionMode::SubgroupSize,
+                                     *subgroupSize);
       // Erase subgroup size.
       entryPointAttr = spirv::EntryPointABIAttr::get(
           entryPointAttr.getContext(), entryPointAttr.getWorkgroupSize(),
@@ -180,8 +180,8 @@ static LogicalResult lowerEntryPointABIAttr(spirv::FuncOp funcOp,
     std::optional<ArrayRef<spirv::Capability>> caps =
         spirv::getCapabilities(spirv::ExecutionMode::SignedZeroInfNanPreserve);
     if (!caps || targetEnv.allows(*caps)) {
-      builder.create<spirv::ExecutionModeOp>(
-          funcOp.getLoc(), funcOp,
+      spirv::ExecutionModeOp::create(
+          builder, funcOp.getLoc(), funcOp,
           spirv::ExecutionMode::SignedZeroInfNanPreserve, *targetWidth);
       // Erase target width.
       entryPointAttr = spirv::EntryPointABIAttr::get(
@@ -259,7 +259,7 @@ LogicalResult ProcessInterfaceVarABI::matchAndRewrite(
 
     // Insert spirv::AddressOf and spirv::AccessChain operations.
     Value replacement =
-        rewriter.create<spirv::AddressOfOp>(funcOp.getLoc(), var);
+        spirv::AddressOfOp::create(rewriter, funcOp.getLoc(), var);
     // Check if the arg is a scalar or vector type. In that case, the value
     // needs to be loaded into registers.
     // TODO: This is loading value of the scalar into registers
@@ -269,9 +269,9 @@ LogicalResult ProcessInterfaceVarABI::matchAndRewrite(
     if (cast<spirv::SPIRVType>(argType.value()).isScalarOrVector()) {
       auto zero =
           spirv::ConstantOp::getZero(indexType, funcOp.getLoc(), rewriter);
-      auto loadPtr = rewriter.create<spirv::AccessChainOp>(
-          funcOp.getLoc(), replacement, zero.getConstant());
-      replacement = rewriter.create<spirv::LoadOp>(funcOp.getLoc(), loadPtr);
+      auto loadPtr = spirv::AccessChainOp::create(
+          rewriter, funcOp.getLoc(), replacement, zero.getConstant());
+      replacement = spirv::LoadOp::create(rewriter, funcOp.getLoc(), loadPtr);
     }
     signatureConverter.remapInput(argType.index(), replacement);
   }
@@ -308,7 +308,7 @@ void LowerABIAttributesPass::runOnOperation() {
                                             ValueRange inputs, Location loc) {
     if (inputs.size() != 1 || !isa<spirv::PointerType>(inputs[0].getType()))
       return Value();
-    return builder.create<spirv::BitcastOp>(loc, type, inputs[0]).getResult();
+    return spirv::BitcastOp::create(builder, loc, type, inputs[0]).getResult();
   });
 
   RewritePatternSet patterns(context);

diff  --git a/mlir/lib/Dialect/SPIRV/Transforms/RewriteInsertsPass.cpp b/mlir/lib/Dialect/SPIRV/Transforms/RewriteInsertsPass.cpp
index ab5898d0e3925..38ef547f0769f 100644
--- a/mlir/lib/Dialect/SPIRV/Transforms/RewriteInsertsPass.cpp
+++ b/mlir/lib/Dialect/SPIRV/Transforms/RewriteInsertsPass.cpp
@@ -65,8 +65,8 @@ void RewriteInsertsPass::runOnOperation() {
       operands.push_back(insertionOp.getObject());
 
     OpBuilder builder(lastCompositeInsertOp);
-    auto compositeConstructOp = builder.create<spirv::CompositeConstructOp>(
-        location, compositeType, operands);
+    auto compositeConstructOp = spirv::CompositeConstructOp::create(
+        builder, location, compositeType, operands);
 
     lastCompositeInsertOp.replaceAllUsesWith(
         compositeConstructOp->getResult(0));

diff  --git a/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp b/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp
index f70b3325f8725..35ec0190b5a61 100644
--- a/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp
+++ b/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp
@@ -669,21 +669,24 @@ static Value castToSourceType(const spirv::TargetEnv &targetEnv,
                               Location loc) {
   // We can only cast one value in SPIR-V.
   if (inputs.size() != 1) {
-    auto castOp = builder.create<UnrealizedConversionCastOp>(loc, type, inputs);
+    auto castOp =
+        UnrealizedConversionCastOp::create(builder, loc, type, inputs);
     return castOp.getResult(0);
   }
   Value input = inputs.front();
 
   // Only support integer types for now. Floating point types to be implemented.
   if (!isa<IntegerType>(type)) {
-    auto castOp = builder.create<UnrealizedConversionCastOp>(loc, type, inputs);
+    auto castOp =
+        UnrealizedConversionCastOp::create(builder, loc, type, inputs);
     return castOp.getResult(0);
   }
   auto inputType = cast<IntegerType>(input.getType());
 
   auto scalarType = dyn_cast<spirv::ScalarType>(type);
   if (!scalarType) {
-    auto castOp = builder.create<UnrealizedConversionCastOp>(loc, type, inputs);
+    auto castOp =
+        UnrealizedConversionCastOp::create(builder, loc, type, inputs);
     return castOp.getResult(0);
   }
 
@@ -691,14 +694,15 @@ static Value castToSourceType(const spirv::TargetEnv &targetEnv,
   // truncating to go back so we don't need to worry about the signedness.
   // For extension, we cannot have enough signal here to decide which op to use.
   if (inputType.getIntOrFloatBitWidth() < scalarType.getIntOrFloatBitWidth()) {
-    auto castOp = builder.create<UnrealizedConversionCastOp>(loc, type, inputs);
+    auto castOp =
+        UnrealizedConversionCastOp::create(builder, loc, type, inputs);
     return castOp.getResult(0);
   }
 
   // Boolean values would need to use 
diff erent ops than normal integer values.
   if (type.isInteger(1)) {
     Value one = spirv::ConstantOp::getOne(inputType, loc, builder);
-    return builder.create<spirv::IEqualOp>(loc, input, one);
+    return spirv::IEqualOp::create(builder, loc, input, one);
   }
 
   // Check that the source integer type is supported by the environment.
@@ -708,7 +712,8 @@ static Value castToSourceType(const spirv::TargetEnv &targetEnv,
   scalarType.getCapabilities(caps);
   if (failed(checkCapabilityRequirements(type, targetEnv, caps)) ||
       failed(checkExtensionRequirements(type, targetEnv, exts))) {
-    auto castOp = builder.create<UnrealizedConversionCastOp>(loc, type, inputs);
+    auto castOp =
+        UnrealizedConversionCastOp::create(builder, loc, type, inputs);
     return castOp.getResult(0);
   }
 
@@ -716,9 +721,9 @@ static Value castToSourceType(const spirv::TargetEnv &targetEnv,
   // care about signedness here. Still try to use a corresponding op for better
   // consistency though.
   if (type.isSignedInteger()) {
-    return builder.create<spirv::SConvertOp>(loc, type, input);
+    return spirv::SConvertOp::create(builder, loc, type, input);
   }
-  return builder.create<spirv::UConvertOp>(loc, type, input);
+  return spirv::UConvertOp::create(builder, loc, type, input);
 }
 
 //===----------------------------------------------------------------------===//
@@ -770,7 +775,7 @@ getOrInsertBuiltinVariable(Block &body, Location loc, spirv::BuiltIn builtin,
                                            spirv::StorageClass::Input);
     std::string name = getBuiltinVarName(builtin, prefix, suffix);
     newVarOp =
-        builder.create<spirv::GlobalVariableOp>(loc, ptrType, name, builtin);
+        spirv::GlobalVariableOp::create(builder, loc, ptrType, name, builtin);
     break;
   }
   case spirv::BuiltIn::SubgroupId:
@@ -781,7 +786,7 @@ getOrInsertBuiltinVariable(Block &body, Location loc, spirv::BuiltIn builtin,
         spirv::PointerType::get(integerType, spirv::StorageClass::Input);
     std::string name = getBuiltinVarName(builtin, prefix, suffix);
     newVarOp =
-        builder.create<spirv::GlobalVariableOp>(loc, ptrType, name, builtin);
+        spirv::GlobalVariableOp::create(builder, loc, ptrType, name, builtin);
     break;
   }
   default:
@@ -842,8 +847,8 @@ getOrInsertPushConstantVariable(Location loc, Block &block,
   auto builder = OpBuilder::atBlockBegin(&block, b.getListener());
   auto type = getPushConstantStorageType(elementCount, builder, indexType);
   const char *name = "__push_constant_var__";
-  return builder.create<spirv::GlobalVariableOp>(loc, type, name,
-                                                 /*initializer=*/nullptr);
+  return spirv::GlobalVariableOp::create(builder, loc, type, name,
+                                         /*initializer=*/nullptr);
 }
 
 //===----------------------------------------------------------------------===//
@@ -879,8 +884,8 @@ struct FuncOpConversion final : OpConversionPattern<func::FuncOp> {
     }
 
     // Create the converted spirv.func op.
-    auto newFuncOp = rewriter.create<spirv::FuncOp>(
-        funcOp.getLoc(), funcOp.getName(),
+    auto newFuncOp = spirv::FuncOp::create(
+        rewriter, funcOp.getLoc(), funcOp.getName(),
         rewriter.getFunctionType(signatureConverter.getConvertedTypes(),
                                  resultType ? TypeRange(resultType)
                                             : TypeRange()));
@@ -919,8 +924,8 @@ struct FuncOpVectorUnroll final : OpRewritePattern<func::FuncOp> {
     }
 
     // Create a new func op with the original type and copy the function body.
-    auto newFuncOp = rewriter.create<func::FuncOp>(funcOp.getLoc(),
-                                                   funcOp.getName(), fnType);
+    auto newFuncOp = func::FuncOp::create(rewriter, funcOp.getLoc(),
+                                          funcOp.getName(), fnType);
     rewriter.inlineRegionBefore(funcOp.getBody(), newFuncOp.getBody(),
                                 newFuncOp.end());
 
@@ -954,8 +959,8 @@ struct FuncOpVectorUnroll final : OpRewritePattern<func::FuncOp> {
       auto origVecType = dyn_cast<VectorType>(origType);
       if (!origVecType) {
         // We need a placeholder for the old argument that will be erased later.
-        Value result = rewriter.create<arith::ConstantOp>(
-            loc, origType, rewriter.getZeroAttr(origType));
+        Value result = arith::ConstantOp::create(
+            rewriter, loc, origType, rewriter.getZeroAttr(origType));
         rewriter.replaceAllUsesWith(newFuncOp.getArgument(origInputNo), result);
         tmpOps.insert({result.getDefiningOp(), newInputNo});
         oneToNTypeMapping.addInputs(origInputNo, origType);
@@ -967,8 +972,8 @@ struct FuncOpVectorUnroll final : OpRewritePattern<func::FuncOp> {
       auto targetShape = getTargetShape(origVecType);
       if (!targetShape) {
         // We need a placeholder for the old argument that will be erased later.
-        Value result = rewriter.create<arith::ConstantOp>(
-            loc, origType, rewriter.getZeroAttr(origType));
+        Value result = arith::ConstantOp::create(
+            rewriter, loc, origType, rewriter.getZeroAttr(origType));
         rewriter.replaceAllUsesWith(newFuncOp.getArgument(origInputNo), result);
         tmpOps.insert({result.getDefiningOp(), newInputNo});
         oneToNTypeMapping.addInputs(origInputNo, origType);
@@ -982,12 +987,12 @@ struct FuncOpVectorUnroll final : OpRewritePattern<func::FuncOp> {
           llvm::to_vector_of<int64_t, 4>(origVecType.getShape());
 
       // Prepare the result vector.
-      Value result = rewriter.create<arith::ConstantOp>(
-          loc, origVecType, rewriter.getZeroAttr(origVecType));
+      Value result = arith::ConstantOp::create(
+          rewriter, loc, origVecType, rewriter.getZeroAttr(origVecType));
       ++newOpCount;
       // Prepare the placeholder for the new arguments that will be added later.
-      Value dummy = rewriter.create<arith::ConstantOp>(
-          loc, unrolledType, rewriter.getZeroAttr(unrolledType));
+      Value dummy = arith::ConstantOp::create(
+          rewriter, loc, unrolledType, rewriter.getZeroAttr(unrolledType));
       ++newOpCount;
 
       // Create the `vector.insert_strided_slice` ops.
@@ -995,8 +1000,8 @@ struct FuncOpVectorUnroll final : OpRewritePattern<func::FuncOp> {
       SmallVector<Type> newTypes;
       for (SmallVector<int64_t> offsets :
            StaticTileOffsetRange(originalShape, *targetShape)) {
-        result = rewriter.create<vector::InsertStridedSliceOp>(
-            loc, dummy, result, offsets, strides);
+        result = vector::InsertStridedSliceOp::create(rewriter, loc, dummy,
+                                                      result, offsets, strides);
         newTypes.push_back(unrolledType);
         unrolledInputNums.push_back(newInputNo);
         ++newInputNo;
@@ -1109,12 +1114,12 @@ struct ReturnOpVectorUnroll final : OpRewritePattern<func::ReturnOp> {
       Value returnValue = returnOp.getOperand(origResultNo);
       for (SmallVector<int64_t> offsets :
            StaticTileOffsetRange(originalShape, *targetShape)) {
-        Value result = rewriter.create<vector::ExtractStridedSliceOp>(
-            loc, returnValue, offsets, extractShape, strides);
+        Value result = vector::ExtractStridedSliceOp::create(
+            rewriter, loc, returnValue, offsets, extractShape, strides);
         if (originalShape.size() > 1) {
           SmallVector<int64_t> extractIndices(originalShape.size() - 1, 0);
           result =
-              rewriter.create<vector::ExtractOp>(loc, result, extractIndices);
+              vector::ExtractOp::create(rewriter, loc, result, extractIndices);
         }
         newOperands.push_back(result);
         newTypes.push_back(unrolledType);
@@ -1132,7 +1137,7 @@ struct ReturnOpVectorUnroll final : OpRewritePattern<func::ReturnOp> {
     // Replace the return op using the new operands. This will automatically
     // update the entry block as well.
     rewriter.replaceOp(returnOp,
-                       rewriter.create<func::ReturnOp>(loc, newOperands));
+                       func::ReturnOp::create(rewriter, loc, newOperands));
 
     return success();
   }
@@ -1157,8 +1162,8 @@ Value mlir::spirv::getBuiltinVariableValue(Operation *op,
   spirv::GlobalVariableOp varOp =
       getOrInsertBuiltinVariable(*parent->getRegion(0).begin(), op->getLoc(),
                                  builtin, integerType, builder, prefix, suffix);
-  Value ptr = builder.create<spirv::AddressOfOp>(op->getLoc(), varOp);
-  return builder.create<spirv::LoadOp>(op->getLoc(), ptr);
+  Value ptr = spirv::AddressOfOp::create(builder, op->getLoc(), varOp);
+  return spirv::LoadOp::create(builder, op->getLoc(), ptr);
 }
 
 //===----------------------------------------------------------------------===//
@@ -1179,12 +1184,12 @@ Value spirv::getPushConstantValue(Operation *op, unsigned elementCount,
       loc, parent->getRegion(0).front(), elementCount, builder, integerType);
 
   Value zeroOp = spirv::ConstantOp::getZero(integerType, loc, builder);
-  Value offsetOp = builder.create<spirv::ConstantOp>(
-      loc, integerType, builder.getI32IntegerAttr(offset));
-  auto addrOp = builder.create<spirv::AddressOfOp>(loc, varOp);
-  auto acOp = builder.create<spirv::AccessChainOp>(
-      loc, addrOp, llvm::ArrayRef({zeroOp, offsetOp}));
-  return builder.create<spirv::LoadOp>(loc, acOp);
+  Value offsetOp = spirv::ConstantOp::create(builder, loc, integerType,
+                                             builder.getI32IntegerAttr(offset));
+  auto addrOp = spirv::AddressOfOp::create(builder, loc, varOp);
+  auto acOp = spirv::AccessChainOp::create(builder, loc, addrOp,
+                                           llvm::ArrayRef({zeroOp, offsetOp}));
+  return spirv::LoadOp::create(builder, loc, acOp);
 }
 
 //===----------------------------------------------------------------------===//
@@ -1244,7 +1249,7 @@ Value mlir::spirv::getVulkanElementPtr(const SPIRVTypeConverter &typeConverter,
     linearizedIndices.push_back(
         linearizeIndex(indices, strides, offset, indexType, loc, builder));
   }
-  return builder.create<spirv::AccessChainOp>(loc, basePtr, linearizedIndices);
+  return spirv::AccessChainOp::create(builder, loc, basePtr, linearizedIndices);
 }
 
 Value mlir::spirv::getOpenCLElementPtr(const SPIRVTypeConverter &typeConverter,
@@ -1275,11 +1280,11 @@ Value mlir::spirv::getOpenCLElementPtr(const SPIRVTypeConverter &typeConverter,
       cast<spirv::PointerType>(basePtr.getType()).getPointeeType();
   if (isa<spirv::ArrayType>(pointeeType)) {
     linearizedIndices.push_back(linearIndex);
-    return builder.create<spirv::AccessChainOp>(loc, basePtr,
-                                                linearizedIndices);
+    return spirv::AccessChainOp::create(builder, loc, basePtr,
+                                        linearizedIndices);
   }
-  return builder.create<spirv::PtrAccessChainOp>(loc, basePtr, linearIndex,
-                                                 linearizedIndices);
+  return spirv::PtrAccessChainOp::create(builder, loc, basePtr, linearIndex,
+                                         linearizedIndices);
 }
 
 Value mlir::spirv::getElementPtr(const SPIRVTypeConverter &typeConverter,
@@ -1465,7 +1470,7 @@ SPIRVTypeConverter::SPIRVTypeConverter(spirv::TargetEnvAttr targetAttr,
       });
   addTargetMaterialization([](OpBuilder &builder, Type type, ValueRange inputs,
                               Location loc) {
-    auto cast = builder.create<UnrealizedConversionCastOp>(loc, type, inputs);
+    auto cast = UnrealizedConversionCastOp::create(builder, loc, type, inputs);
     return cast.getResult(0);
   });
 }

diff  --git a/mlir/lib/Dialect/SPIRV/Transforms/SPIRVWebGPUTransforms.cpp b/mlir/lib/Dialect/SPIRV/Transforms/SPIRVWebGPUTransforms.cpp
index af1cf2a1373e3..e0900005ea1bb 100644
--- a/mlir/lib/Dialect/SPIRV/Transforms/SPIRVWebGPUTransforms.cpp
+++ b/mlir/lib/Dialect/SPIRV/Transforms/SPIRVWebGPUTransforms.cpp
@@ -64,16 +64,16 @@ static Value lowerExtendedMultiplication(Operation *mulOp,
   //     and 4 additions after constant folding.
   //   - With sign-extended arguments, we end up emitting 8 multiplications and
   //     and 12 additions after CSE.
-  Value cstLowMask = rewriter.create<ConstantOp>(
-      loc, lhs.getType(), getScalarOrSplatAttr(argTy, (1 << 16) - 1));
+  Value cstLowMask = ConstantOp::create(
+      rewriter, loc, lhs.getType(), getScalarOrSplatAttr(argTy, (1 << 16) - 1));
   auto getLowDigit = [&rewriter, loc, cstLowMask](Value val) {
-    return rewriter.create<BitwiseAndOp>(loc, val, cstLowMask);
+    return BitwiseAndOp::create(rewriter, loc, val, cstLowMask);
   };
 
-  Value cst16 = rewriter.create<ConstantOp>(loc, lhs.getType(),
-                                            getScalarOrSplatAttr(argTy, 16));
+  Value cst16 = ConstantOp::create(rewriter, loc, lhs.getType(),
+                                   getScalarOrSplatAttr(argTy, 16));
   auto getHighDigit = [&rewriter, loc, cst16](Value val) {
-    return rewriter.create<ShiftRightLogicalOp>(loc, val, cst16);
+    return ShiftRightLogicalOp::create(rewriter, loc, val, cst16);
   };
 
   auto getSignDigit = [&rewriter, loc, cst16, &getHighDigit](Value val) {
@@ -82,11 +82,11 @@ static Value lowerExtendedMultiplication(Operation *mulOp,
     // fine. We do not have to introduce an extra constant since any
     // value in [15, 32) would do.
     return getHighDigit(
-        rewriter.create<ShiftRightArithmeticOp>(loc, val, cst16));
+        ShiftRightArithmeticOp::create(rewriter, loc, val, cst16));
   };
 
-  Value cst0 = rewriter.create<ConstantOp>(loc, lhs.getType(),
-                                           getScalarOrSplatAttr(argTy, 0));
+  Value cst0 = ConstantOp::create(rewriter, loc, lhs.getType(),
+                                  getScalarOrSplatAttr(argTy, 0));
 
   Value lhsLow = getLowDigit(lhs);
   Value lhsHigh = getHighDigit(lhs);
@@ -108,7 +108,7 @@ static Value lowerExtendedMultiplication(Operation *mulOp,
         continue;
 
       Value &thisResDigit = resultDigits[i + j];
-      Value mul = rewriter.create<IMulOp>(loc, lhsDigit, rhsDigit);
+      Value mul = IMulOp::create(rewriter, loc, lhsDigit, rhsDigit);
       Value current = rewriter.createOrFold<IAddOp>(loc, thisResDigit, mul);
       thisResDigit = getLowDigit(current);
 
@@ -122,14 +122,15 @@ static Value lowerExtendedMultiplication(Operation *mulOp,
   }
 
   auto combineDigits = [loc, cst16, &rewriter](Value low, Value high) {
-    Value highBits = rewriter.create<ShiftLeftLogicalOp>(loc, high, cst16);
-    return rewriter.create<BitwiseOrOp>(loc, low, highBits);
+    Value highBits = ShiftLeftLogicalOp::create(rewriter, loc, high, cst16);
+    return BitwiseOrOp::create(rewriter, loc, low, highBits);
   };
   Value low = combineDigits(resultDigits[0], resultDigits[1]);
   Value high = combineDigits(resultDigits[2], resultDigits[3]);
 
-  return rewriter.create<CompositeConstructOp>(
-      loc, mulOp->getResultTypes().front(), llvm::ArrayRef({low, high}));
+  return CompositeConstructOp::create(rewriter, loc,
+                                      mulOp->getResultTypes().front(),
+                                      llvm::ArrayRef({low, high}));
 }
 
 //===----------------------------------------------------------------------===//
@@ -184,18 +185,19 @@ struct ExpandAddCarryPattern final : OpRewritePattern<IAddCarryOp> {
           loc,
           llvm::formatv("Unexpected integer type for WebGPU: '{0}'", elemTy));
 
-    Value one =
-        rewriter.create<ConstantOp>(loc, argTy, getScalarOrSplatAttr(argTy, 1));
-    Value zero =
-        rewriter.create<ConstantOp>(loc, argTy, getScalarOrSplatAttr(argTy, 0));
+    Value one = ConstantOp::create(rewriter, loc, argTy,
+                                   getScalarOrSplatAttr(argTy, 1));
+    Value zero = ConstantOp::create(rewriter, loc, argTy,
+                                    getScalarOrSplatAttr(argTy, 0));
 
     // Calculate the carry by checking if the addition resulted in an overflow.
-    Value out = rewriter.create<IAddOp>(loc, lhs, rhs);
-    Value cmp = rewriter.create<ULessThanOp>(loc, out, lhs);
-    Value carry = rewriter.create<SelectOp>(loc, cmp, one, zero);
+    Value out = IAddOp::create(rewriter, loc, lhs, rhs);
+    Value cmp = ULessThanOp::create(rewriter, loc, out, lhs);
+    Value carry = SelectOp::create(rewriter, loc, cmp, one, zero);
 
-    Value add = rewriter.create<CompositeConstructOp>(
-        loc, op->getResultTypes().front(), llvm::ArrayRef({out, carry}));
+    Value add = CompositeConstructOp::create(rewriter, loc,
+                                             op->getResultTypes().front(),
+                                             llvm::ArrayRef({out, carry}));
 
     rewriter.replaceOp(op, add);
     return success();

diff  --git a/mlir/lib/Dialect/SPIRV/Transforms/UnifyAliasedResourcePass.cpp b/mlir/lib/Dialect/SPIRV/Transforms/UnifyAliasedResourcePass.cpp
index 527d92634c196..692f2e7616e5a 100644
--- a/mlir/lib/Dialect/SPIRV/Transforms/UnifyAliasedResourcePass.cpp
+++ b/mlir/lib/Dialect/SPIRV/Transforms/UnifyAliasedResourcePass.cpp
@@ -380,13 +380,13 @@ struct ConvertAccessChain : public ConvertAliasResource<spirv::AccessChainOp> {
       Type indexType = oldIndex.getType();
 
       int ratio = dstNumBytes / srcNumBytes;
-      auto ratioValue = rewriter.create<spirv::ConstantOp>(
-          loc, indexType, rewriter.getIntegerAttr(indexType, ratio));
+      auto ratioValue = spirv::ConstantOp::create(
+          rewriter, loc, indexType, rewriter.getIntegerAttr(indexType, ratio));
 
       indices.back() =
-          rewriter.create<spirv::SDivOp>(loc, indexType, oldIndex, ratioValue);
-      indices.push_back(
-          rewriter.create<spirv::SModOp>(loc, indexType, oldIndex, ratioValue));
+          spirv::SDivOp::create(rewriter, loc, indexType, oldIndex, ratioValue);
+      indices.push_back(spirv::SModOp::create(rewriter, loc, indexType,
+                                              oldIndex, ratioValue));
 
       rewriter.replaceOpWithNewOp<spirv::AccessChainOp>(
           acOp, adaptor.getBasePtr(), indices);
@@ -407,11 +407,11 @@ struct ConvertAccessChain : public ConvertAliasResource<spirv::AccessChainOp> {
       Type indexType = oldIndex.getType();
 
       int ratio = srcNumBytes / dstNumBytes;
-      auto ratioValue = rewriter.create<spirv::ConstantOp>(
-          loc, indexType, rewriter.getIntegerAttr(indexType, ratio));
+      auto ratioValue = spirv::ConstantOp::create(
+          rewriter, loc, indexType, rewriter.getIntegerAttr(indexType, ratio));
 
       indices.back() =
-          rewriter.create<spirv::IMulOp>(loc, indexType, oldIndex, ratioValue);
+          spirv::IMulOp::create(rewriter, loc, indexType, oldIndex, ratioValue);
 
       rewriter.replaceOpWithNewOp<spirv::AccessChainOp>(
           acOp, adaptor.getBasePtr(), indices);
@@ -435,15 +435,15 @@ struct ConvertLoad : public ConvertAliasResource<spirv::LoadOp> {
     auto dstElemType = cast<spirv::SPIRVType>(dstPtrType.getPointeeType());
 
     Location loc = loadOp.getLoc();
-    auto newLoadOp = rewriter.create<spirv::LoadOp>(loc, adaptor.getPtr());
+    auto newLoadOp = spirv::LoadOp::create(rewriter, loc, adaptor.getPtr());
     if (srcElemType == dstElemType) {
       rewriter.replaceOp(loadOp, newLoadOp->getResults());
       return success();
     }
 
     if (areSameBitwidthScalarType(srcElemType, dstElemType)) {
-      auto castOp = rewriter.create<spirv::BitcastOp>(loc, srcElemType,
-                                                      newLoadOp.getValue());
+      auto castOp = spirv::BitcastOp::create(rewriter, loc, srcElemType,
+                                             newLoadOp.getValue());
       rewriter.replaceOp(loadOp, castOp->getResults());
 
       return success();
@@ -475,14 +475,14 @@ struct ConvertLoad : public ConvertAliasResource<spirv::LoadOp> {
       auto indices = llvm::to_vector<4>(acOp.getIndices());
       for (int i = 1; i < ratio; ++i) {
         // Load all subsequent components belonging to this element.
-        indices.back() = rewriter.create<spirv::IAddOp>(
-            loc, i32Type, indices.back(), oneValue);
-        auto componentAcOp = rewriter.create<spirv::AccessChainOp>(
-            loc, acOp.getBasePtr(), indices);
+        indices.back() = spirv::IAddOp::create(rewriter, loc, i32Type,
+                                               indices.back(), oneValue);
+        auto componentAcOp = spirv::AccessChainOp::create(
+            rewriter, loc, acOp.getBasePtr(), indices);
         // Assuming little endian, this reads lower-ordered bits of the number
         // to lower-numbered components of the vector.
         components.push_back(
-            rewriter.create<spirv::LoadOp>(loc, componentAcOp));
+            spirv::LoadOp::create(rewriter, loc, componentAcOp));
       }
 
       // Create a vector of the components and then cast back to the larger
@@ -510,15 +510,15 @@ struct ConvertLoad : public ConvertAliasResource<spirv::LoadOp> {
               castType = VectorType::get({count}, castType);
 
             for (Value &c : components)
-              c = rewriter.create<spirv::BitcastOp>(loc, castType, c);
+              c = spirv::BitcastOp::create(rewriter, loc, castType, c);
           }
         }
-      Value vectorValue = rewriter.create<spirv::CompositeConstructOp>(
-          loc, vectorType, components);
+      Value vectorValue = spirv::CompositeConstructOp::create(
+          rewriter, loc, vectorType, components);
 
       if (!isa<VectorType>(srcElemType))
         vectorValue =
-            rewriter.create<spirv::BitcastOp>(loc, srcElemType, vectorValue);
+            spirv::BitcastOp::create(rewriter, loc, srcElemType, vectorValue);
       rewriter.replaceOp(loadOp, vectorValue);
       return success();
     }
@@ -546,7 +546,7 @@ struct ConvertStore : public ConvertAliasResource<spirv::StoreOp> {
     Location loc = storeOp.getLoc();
     Value value = adaptor.getValue();
     if (srcElemType != dstElemType)
-      value = rewriter.create<spirv::BitcastOp>(loc, dstElemType, value);
+      value = spirv::BitcastOp::create(rewriter, loc, dstElemType, value);
     rewriter.replaceOpWithNewOp<spirv::StoreOp>(storeOp, adaptor.getPtr(),
                                                 value, storeOp->getAttrs());
     return success();


        


More information about the Mlir-commits mailing list