[Mlir-commits] [mlir] [mlir] update create APIs (1/n) (PR #149656)

Maksim Levental llvmlistbot at llvm.org
Sat Jul 19 08:55:17 PDT 2025


https://github.com/makslevental created https://github.com/llvm/llvm-project/pull/149656

This PR updates create APIs for arith and affine - specifically these are the only in-tree dialects/ops with "custom" builders:

```
AffineDmaStartOp
AffineDmaWaitOp
ConstantIntOp
ConstantFloatOp
ConstantIndexOp
```

>From d3556c9d1585c0a16427e57ad768400cc06db2cb Mon Sep 17 00:00:00 2001
From: max <maksim.levental at gmail.com>
Date: Mon, 7 Jul 2025 12:56:56 -0400
Subject: [PATCH] [mlir] update create APIs (1/n)

---
 .../mlir/Dialect/Affine/IR/AffineOps.h        |  21 +++
 mlir/include/mlir/Dialect/Arith/IR/Arith.h    |  19 +++
 mlir/lib/Dialect/Affine/IR/AffineOps.cpp      | 115 +++++++++-----
 mlir/lib/Dialect/Arith/IR/ArithOps.cpp        | 141 ++++++++++++++----
 4 files changed, 229 insertions(+), 67 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Affine/IR/AffineOps.h b/mlir/include/mlir/Dialect/Affine/IR/AffineOps.h
index 2091faa6b0b02..333de6bbd8a05 100644
--- a/mlir/include/mlir/Dialect/Affine/IR/AffineOps.h
+++ b/mlir/include/mlir/Dialect/Affine/IR/AffineOps.h
@@ -114,6 +114,21 @@ class AffineDmaStartOp
                     AffineMap tagMap, ValueRange tagIndices, Value numElements,
                     Value stride = nullptr, Value elementsPerStride = nullptr);
 
+  static AffineDmaStartOp
+  create(OpBuilder &builder, Location location, Value srcMemRef,
+         AffineMap srcMap, ValueRange srcIndices, Value destMemRef,
+         AffineMap dstMap, ValueRange destIndices, Value tagMemRef,
+         AffineMap tagMap, ValueRange tagIndices, Value numElements,
+         Value stride = nullptr, Value elementsPerStride = nullptr);
+
+  static AffineDmaStartOp create(ImplicitLocOpBuilder &builder, Value srcMemRef,
+                                 AffineMap srcMap, ValueRange srcIndices,
+                                 Value destMemRef, AffineMap dstMap,
+                                 ValueRange destIndices, Value tagMemRef,
+                                 AffineMap tagMap, ValueRange tagIndices,
+                                 Value numElements, Value stride = nullptr,
+                                 Value elementsPerStride = nullptr);
+
   /// Returns the operand index of the source memref.
   unsigned getSrcMemRefOperandIndex() { return 0; }
 
@@ -319,6 +334,12 @@ class AffineDmaWaitOp
 
   static void build(OpBuilder &builder, OperationState &result, Value tagMemRef,
                     AffineMap tagMap, ValueRange tagIndices, Value numElements);
+  static AffineDmaWaitOp create(OpBuilder &builder, Location location,
+                                Value tagMemRef, AffineMap tagMap,
+                                ValueRange tagIndices, Value numElements);
+  static AffineDmaWaitOp create(ImplicitLocOpBuilder &builder, Value tagMemRef,
+                                AffineMap tagMap, ValueRange tagIndices,
+                                Value numElements);
 
   static StringRef getOperationName() { return "affine.dma_wait"; }
 
diff --git a/mlir/include/mlir/Dialect/Arith/IR/Arith.h b/mlir/include/mlir/Dialect/Arith/IR/Arith.h
index 7c50c2036ffdc..0fc3db8e993d8 100644
--- a/mlir/include/mlir/Dialect/Arith/IR/Arith.h
+++ b/mlir/include/mlir/Dialect/Arith/IR/Arith.h
@@ -59,15 +59,27 @@ class ConstantIntOp : public arith::ConstantOp {
   /// Build a constant int op that produces an integer of the specified width.
   static void build(OpBuilder &builder, OperationState &result, int64_t value,
                     unsigned width);
+  static ConstantIntOp create(OpBuilder &builder, Location location,
+                              int64_t value, unsigned width);
+  static ConstantIntOp create(ImplicitLocOpBuilder &builder, int64_t value,
+                              unsigned width);
 
   /// Build a constant int op that produces an integer of the specified type,
   /// which must be an integer type.
   static void build(OpBuilder &builder, OperationState &result, Type type,
                     int64_t value);
+  static ConstantIntOp create(OpBuilder &builder, Location location, Type type,
+                              int64_t value);
+  static ConstantIntOp create(ImplicitLocOpBuilder &builder, Type type,
+                              int64_t value);
 
   /// Build a constant int op that produces an integer from an APInt
   static void build(OpBuilder &builder, OperationState &result, Type type,
                     const APInt &value);
+  static ConstantIntOp create(OpBuilder &builder, Location location, Type type,
+                              const APInt &value);
+  static ConstantIntOp create(ImplicitLocOpBuilder &builder, Type type,
+                              const APInt &value);
 
   inline int64_t value() {
     return cast<IntegerAttr>(arith::ConstantOp::getValue()).getInt();
@@ -85,6 +97,10 @@ class ConstantFloatOp : public arith::ConstantOp {
   /// Build a constant float op that produces a float of the specified type.
   static void build(OpBuilder &builder, OperationState &result, FloatType type,
                     const APFloat &value);
+  static ConstantFloatOp create(OpBuilder &builder, Location location,
+                                FloatType type, const APFloat &value);
+  static ConstantFloatOp create(ImplicitLocOpBuilder &builder, FloatType type,
+                                const APFloat &value);
 
   inline APFloat value() {
     return cast<FloatAttr>(arith::ConstantOp::getValue()).getValue();
@@ -100,6 +116,9 @@ class ConstantIndexOp : public arith::ConstantOp {
   static ::mlir::TypeID resolveTypeID() { return TypeID::get<ConstantOp>(); }
   /// Build a constant int op that produces an index.
   static void build(OpBuilder &builder, OperationState &result, int64_t value);
+  static ConstantIndexOp create(OpBuilder &builder, Location location,
+                                int64_t value);
+  static ConstantIndexOp create(ImplicitLocOpBuilder &builder, int64_t value);
 
   inline int64_t value() {
     return cast<IntegerAttr>(arith::ConstantOp::getValue()).getInt();
diff --git a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
index e8e8f624d806e..110e9f5572973 100644
--- a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
+++ b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
@@ -228,6 +228,7 @@ void AffineDialect::initialize() {
   addOperations<AffineDmaStartOp, AffineDmaWaitOp,
 #define GET_OP_LIST
 #include "mlir/Dialect/Affine/IR/AffineOps.cpp.inc"
+
                 >();
   addInterfaces<AffineInlinerInterface>();
   declarePromisedInterfaces<ValueBoundsOpInterface, AffineApplyOp, AffineMaxOp,
@@ -240,7 +241,7 @@ Operation *AffineDialect::materializeConstant(OpBuilder &builder,
                                               Attribute value, Type type,
                                               Location loc) {
   if (auto poison = dyn_cast<ub::PoisonAttr>(value))
-    return builder.create<ub::PoisonOp>(loc, type, poison);
+    return ub::PoisonOp::create(builder, loc, type, poison);
   return arith::ConstantOp::materialize(builder, value, type, loc);
 }
 
@@ -1282,7 +1283,7 @@ mlir::affine::makeComposedAffineApply(OpBuilder &b, Location loc, AffineMap map,
   map = foldAttributesIntoMap(b, map, operands, valueOperands);
   composeAffineMapAndOperands(&map, &valueOperands, composeAffineMin);
   assert(map);
-  return b.create<AffineApplyOp>(loc, map, valueOperands);
+  return AffineApplyOp::create(b, loc, map, valueOperands);
 }
 
 AffineApplyOp
@@ -1389,7 +1390,7 @@ static OpTy makeComposedMinMax(OpBuilder &b, Location loc, AffineMap map,
   SmallVector<Value> valueOperands;
   map = foldAttributesIntoMap(b, map, operands, valueOperands);
   composeMultiResultAffineMap(map, valueOperands);
-  return b.create<OpTy>(loc, b.getIndexType(), map, valueOperands);
+  return OpTy::create(b, loc, b.getIndexType(), map, valueOperands);
 }
 
 AffineMinOp
@@ -1747,6 +1748,32 @@ void AffineDmaStartOp::build(OpBuilder &builder, OperationState &result,
   }
 }
 
+AffineDmaStartOp AffineDmaStartOp::create(
+    OpBuilder &builder, Location location, Value srcMemRef, AffineMap srcMap,
+    ValueRange srcIndices, Value destMemRef, AffineMap dstMap,
+    ValueRange destIndices, Value tagMemRef, AffineMap tagMap,
+    ValueRange tagIndices, Value numElements, Value stride,
+    Value elementsPerStride) {
+  mlir::OperationState state(location, getOperationName());
+  build(builder, state, srcMemRef, srcMap, srcIndices, destMemRef, dstMap,
+        destIndices, tagMemRef, tagMap, tagIndices, numElements, stride,
+        elementsPerStride);
+  auto result = llvm::dyn_cast<AffineDmaStartOp>(builder.create(state));
+  assert(result && "builder didn't return the right type");
+  return result;
+}
+
+AffineDmaStartOp AffineDmaStartOp::create(
+    ImplicitLocOpBuilder &builder, Value srcMemRef, AffineMap srcMap,
+    ValueRange srcIndices, Value destMemRef, AffineMap dstMap,
+    ValueRange destIndices, Value tagMemRef, AffineMap tagMap,
+    ValueRange tagIndices, Value numElements, Value stride,
+    Value elementsPerStride) {
+  return create(builder, builder.getLoc(), srcMemRef, srcMap, srcIndices,
+                destMemRef, dstMap, destIndices, tagMemRef, tagMap, tagIndices,
+                numElements, stride, elementsPerStride);
+}
+
 void AffineDmaStartOp::print(OpAsmPrinter &p) {
   p << " " << getSrcMemRef() << '[';
   p.printAffineMapOfSSAIds(getSrcMapAttr(), getSrcIndices());
@@ -1917,6 +1944,25 @@ void AffineDmaWaitOp::build(OpBuilder &builder, OperationState &result,
   result.addOperands(numElements);
 }
 
+AffineDmaWaitOp AffineDmaWaitOp::create(OpBuilder &builder, Location location,
+                                        Value tagMemRef, AffineMap tagMap,
+                                        ValueRange tagIndices,
+                                        Value numElements) {
+  mlir::OperationState state(location, getOperationName());
+  build(builder, state, tagMemRef, tagMap, tagIndices, numElements);
+  auto result = llvm::dyn_cast<AffineDmaWaitOp>(builder.create(state));
+  assert(result && "builder didn't return the right type");
+  return result;
+}
+
+AffineDmaWaitOp AffineDmaWaitOp::create(ImplicitLocOpBuilder &builder,
+                                        Value tagMemRef, AffineMap tagMap,
+                                        ValueRange tagIndices,
+                                        Value numElements) {
+  return create(builder, builder.getLoc(), tagMemRef, tagMap, tagIndices,
+                numElements);
+}
+
 void AffineDmaWaitOp::print(OpAsmPrinter &p) {
   p << " " << getTagMemRef() << '[';
   SmallVector<Value, 2> operands(getTagIndices());
@@ -2688,8 +2734,8 @@ FailureOr<LoopLikeOpInterface> AffineForOp::replaceWithAdditionalYields(
   rewriter.setInsertionPoint(getOperation());
   auto inits = llvm::to_vector(getInits());
   inits.append(newInitOperands.begin(), newInitOperands.end());
-  AffineForOp newLoop = rewriter.create<AffineForOp>(
-      getLoc(), getLowerBoundOperands(), getLowerBoundMap(),
+  AffineForOp newLoop = AffineForOp::create(
+      rewriter, getLoc(), getLowerBoundOperands(), getLowerBoundMap(),
       getUpperBoundOperands(), getUpperBoundMap(), getStepAsInt(), inits);
 
   // Generate the new yield values and append them to the scf.yield operation.
@@ -2831,7 +2877,7 @@ static void buildAffineLoopNestImpl(
         OpBuilder::InsertionGuard nestedGuard(nestedBuilder);
         bodyBuilderFn(nestedBuilder, nestedLoc, ivs);
       }
-      nestedBuilder.create<AffineYieldOp>(nestedLoc);
+      AffineYieldOp::create(nestedBuilder, nestedLoc);
     };
 
     // Delegate actual loop creation to the callback in order to dispatch
@@ -2846,8 +2892,8 @@ static AffineForOp
 buildAffineLoopFromConstants(OpBuilder &builder, Location loc, int64_t lb,
                              int64_t ub, int64_t step,
                              AffineForOp::BodyBuilderFn bodyBuilderFn) {
-  return builder.create<AffineForOp>(loc, lb, ub, step,
-                                     /*iterArgs=*/ValueRange(), bodyBuilderFn);
+  return AffineForOp::create(builder, loc, lb, ub, step,
+                             /*iterArgs=*/ValueRange(), bodyBuilderFn);
 }
 
 /// Creates an affine loop from the bounds that may or may not be constants.
@@ -2860,9 +2906,9 @@ buildAffineLoopFromValues(OpBuilder &builder, Location loc, Value lb, Value ub,
   if (lbConst && ubConst)
     return buildAffineLoopFromConstants(builder, loc, lbConst.value(),
                                         ubConst.value(), step, bodyBuilderFn);
-  return builder.create<AffineForOp>(loc, lb, builder.getDimIdentityMap(), ub,
-                                     builder.getDimIdentityMap(), step,
-                                     /*iterArgs=*/ValueRange(), bodyBuilderFn);
+  return AffineForOp::create(builder, loc, lb, builder.getDimIdentityMap(), ub,
+                             builder.getDimIdentityMap(), step,
+                             /*iterArgs=*/ValueRange(), bodyBuilderFn);
 }
 
 void mlir::affine::buildAffineLoopNest(
@@ -4883,7 +4929,7 @@ struct DropUnitExtentBasis
     Location loc = delinearizeOp->getLoc();
     auto getZero = [&]() -> Value {
       if (!zero)
-        zero = rewriter.create<arith::ConstantIndexOp>(loc, 0);
+        zero = arith::ConstantIndexOp::create(rewriter, loc, 0);
       return zero.value();
     };
 
@@ -4906,8 +4952,8 @@ struct DropUnitExtentBasis
 
     if (!newBasis.empty()) {
       // Will drop the leading nullptr from `basis` if there was no outer bound.
-      auto newDelinearizeOp = rewriter.create<affine::AffineDelinearizeIndexOp>(
-          loc, delinearizeOp.getLinearIndex(), newBasis);
+      auto newDelinearizeOp = affine::AffineDelinearizeIndexOp::create(
+          rewriter, loc, delinearizeOp.getLinearIndex(), newBasis);
       int newIndex = 0;
       // Map back the new delinearized indices to the values they replace.
       for (auto &replacement : replacements) {
@@ -4971,12 +5017,12 @@ struct CancelDelinearizeOfLinearizeDisjointExactTail
       return success();
     }
 
-    Value newLinearize = rewriter.create<affine::AffineLinearizeIndexOp>(
-        linearizeOp.getLoc(), linearizeIns.drop_back(numMatches),
+    Value newLinearize = affine::AffineLinearizeIndexOp::create(
+        rewriter, linearizeOp.getLoc(), linearizeIns.drop_back(numMatches),
         ArrayRef<OpFoldResult>{linearizeBasis}.drop_back(numMatches),
         linearizeOp.getDisjoint());
-    auto newDelinearize = rewriter.create<affine::AffineDelinearizeIndexOp>(
-        delinearizeOp.getLoc(), newLinearize,
+    auto newDelinearize = affine::AffineDelinearizeIndexOp::create(
+        rewriter, delinearizeOp.getLoc(), newLinearize,
         ArrayRef<OpFoldResult>{delinearizeBasis}.drop_back(numMatches),
         delinearizeOp.hasOuterBound());
     SmallVector<Value> mergedResults(newDelinearize.getResults());
@@ -5048,19 +5094,16 @@ struct SplitDelinearizeSpanningLastLinearizeArg final
           delinearizeOp,
           "need at least two elements to form the basis product");
 
-    Value linearizeWithoutBack =
-        rewriter.create<affine::AffineLinearizeIndexOp>(
-            linearizeOp.getLoc(), linearizeOp.getMultiIndex().drop_back(),
-            linearizeOp.getDynamicBasis(),
-            linearizeOp.getStaticBasis().drop_back(),
-            linearizeOp.getDisjoint());
-    auto delinearizeWithoutSplitPart =
-        rewriter.create<affine::AffineDelinearizeIndexOp>(
-            delinearizeOp.getLoc(), linearizeWithoutBack,
-            delinearizeOp.getDynamicBasis(), basis.drop_back(elemsToSplit),
-            delinearizeOp.hasOuterBound());
-    auto delinearizeBack = rewriter.create<affine::AffineDelinearizeIndexOp>(
-        delinearizeOp.getLoc(), linearizeOp.getMultiIndex().back(),
+    Value linearizeWithoutBack = affine::AffineLinearizeIndexOp::create(
+        rewriter, linearizeOp.getLoc(), linearizeOp.getMultiIndex().drop_back(),
+        linearizeOp.getDynamicBasis(), linearizeOp.getStaticBasis().drop_back(),
+        linearizeOp.getDisjoint());
+    auto delinearizeWithoutSplitPart = affine::AffineDelinearizeIndexOp::create(
+        rewriter, delinearizeOp.getLoc(), linearizeWithoutBack,
+        delinearizeOp.getDynamicBasis(), basis.drop_back(elemsToSplit),
+        delinearizeOp.hasOuterBound());
+    auto delinearizeBack = affine::AffineDelinearizeIndexOp::create(
+        rewriter, delinearizeOp.getLoc(), linearizeOp.getMultiIndex().back(),
         basis.take_back(elemsToSplit), /*hasOuterBound=*/true);
     SmallVector<Value> results = llvm::to_vector(
         llvm::concat<Value>(delinearizeWithoutSplitPart.getResults(),
@@ -5272,7 +5315,7 @@ OpFoldResult computeProduct(Location loc, OpBuilder &builder,
   }
   if (auto constant = dyn_cast<AffineConstantExpr>(result))
     return getAsIndexOpFoldResult(builder.getContext(), constant.getValue());
-  return builder.create<AffineApplyOp>(loc, result, dynamicPart).getResult();
+  return AffineApplyOp::create(builder, loc, result, dynamicPart).getResult();
 }
 
 /// If conseceutive outputs of a delinearize_index are linearized with the same
@@ -5437,16 +5480,16 @@ struct CancelLinearizeOfDelinearizePortion final
       newDelinBasis.erase(newDelinBasis.begin() + m.delinStart,
                           newDelinBasis.begin() + m.delinStart + m.length);
       newDelinBasis.insert(newDelinBasis.begin() + m.delinStart, newSize);
-      auto newDelinearize = rewriter.create<AffineDelinearizeIndexOp>(
-          m.delinearize.getLoc(), m.delinearize.getLinearIndex(),
+      auto newDelinearize = AffineDelinearizeIndexOp::create(
+          rewriter, m.delinearize.getLoc(), m.delinearize.getLinearIndex(),
           newDelinBasis);
 
       // Since there may be other uses of the indices we just merged together,
       // create a residual affine.delinearize_index that delinearizes the
       // merged output into its component parts.
       Value combinedElem = newDelinearize.getResult(m.delinStart);
-      auto residualDelinearize = rewriter.create<AffineDelinearizeIndexOp>(
-          m.delinearize.getLoc(), combinedElem, basisToMerge);
+      auto residualDelinearize = AffineDelinearizeIndexOp::create(
+          rewriter, m.delinearize.getLoc(), combinedElem, basisToMerge);
 
       // Swap all the uses of the unaffected delinearize outputs to the new
       // delinearization so that the old code can be removed if this
diff --git a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
index 4e40d4ebda004..ac00917d33a85 100644
--- a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
+++ b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
@@ -242,7 +242,7 @@ bool arith::ConstantOp::isBuildableWith(Attribute value, Type type) {
 ConstantOp arith::ConstantOp::materialize(OpBuilder &builder, Attribute value,
                                           Type type, Location loc) {
   if (isBuildableWith(value, type))
-    return builder.create<arith::ConstantOp>(loc, cast<TypedAttr>(value));
+    return arith::ConstantOp::create(builder, loc, cast<TypedAttr>(value));
   return nullptr;
 }
 
@@ -255,18 +255,66 @@ void arith::ConstantIntOp::build(OpBuilder &builder, OperationState &result,
                            builder.getIntegerAttr(type, value));
 }
 
+arith::ConstantIntOp arith::ConstantIntOp::create(OpBuilder &builder,
+                                                  Location location,
+                                                  int64_t value,
+                                                  unsigned width) {
+  mlir::OperationState state(location, getOperationName());
+  build(builder, state, value, width);
+  auto result = llvm::dyn_cast<ConstantIntOp>(builder.create(state));
+  assert(result && "builder didn't return the right type");
+  return result;
+}
+
+arith::ConstantIntOp arith::ConstantIntOp::create(ImplicitLocOpBuilder &builder,
+                                                  int64_t value,
+                                                  unsigned width) {
+  return create(builder, builder.getLoc(), value, width);
+}
+
 void arith::ConstantIntOp::build(OpBuilder &builder, OperationState &result,
                                  Type type, int64_t value) {
   arith::ConstantOp::build(builder, result, type,
                            builder.getIntegerAttr(type, value));
 }
 
+arith::ConstantIntOp arith::ConstantIntOp::create(OpBuilder &builder,
+                                                  Location location, Type type,
+                                                  int64_t value) {
+  mlir::OperationState state(location, getOperationName());
+  build(builder, state, type, value);
+  auto result = llvm::dyn_cast<ConstantIntOp>(builder.create(state));
+  assert(result && "builder didn't return the right type");
+  return result;
+}
+
+arith::ConstantIntOp arith::ConstantIntOp::create(ImplicitLocOpBuilder &builder,
+                                                  Type type, int64_t value) {
+  return create(builder, builder.getLoc(), type, value);
+}
+
 void arith::ConstantIntOp::build(OpBuilder &builder, OperationState &result,
                                  Type type, const APInt &value) {
   arith::ConstantOp::build(builder, result, type,
                            builder.getIntegerAttr(type, value));
 }
 
+arith::ConstantIntOp arith::ConstantIntOp::create(OpBuilder &builder,
+                                                  Location location, Type type,
+                                                  const APInt &value) {
+  mlir::OperationState state(location, getOperationName());
+  build(builder, state, type, value);
+  auto result = llvm::dyn_cast<ConstantIntOp>(builder.create(state));
+  assert(result && "builder didn't return the right type");
+  return result;
+}
+
+arith::ConstantIntOp arith::ConstantIntOp::create(ImplicitLocOpBuilder &builder,
+                                                  Type type,
+                                                  const APInt &value) {
+  return create(builder, builder.getLoc(), type, value);
+}
+
 bool arith::ConstantIntOp::classof(Operation *op) {
   if (auto constOp = dyn_cast_or_null<arith::ConstantOp>(op))
     return constOp.getType().isSignlessInteger();
@@ -279,6 +327,23 @@ void arith::ConstantFloatOp::build(OpBuilder &builder, OperationState &result,
                            builder.getFloatAttr(type, value));
 }
 
+arith::ConstantFloatOp arith::ConstantFloatOp::create(OpBuilder &builder,
+                                                      Location location,
+                                                      FloatType type,
+                                                      const APFloat &value) {
+  mlir::OperationState state(location, getOperationName());
+  build(builder, state, type, value);
+  auto result = llvm::dyn_cast<ConstantFloatOp>(builder.create(state));
+  assert(result && "builder didn't return the right type");
+  return result;
+}
+
+arith::ConstantFloatOp
+arith::ConstantFloatOp::create(ImplicitLocOpBuilder &builder, FloatType type,
+                               const APFloat &value) {
+  return create(builder, builder.getLoc(), type, value);
+}
+
 bool arith::ConstantFloatOp::classof(Operation *op) {
   if (auto constOp = dyn_cast_or_null<arith::ConstantOp>(op))
     return llvm::isa<FloatType>(constOp.getType());
@@ -291,6 +356,21 @@ void arith::ConstantIndexOp::build(OpBuilder &builder, OperationState &result,
                            builder.getIndexAttr(value));
 }
 
+arith::ConstantIndexOp arith::ConstantIndexOp::create(OpBuilder &builder,
+                                                      Location location,
+                                                      int64_t value) {
+  mlir::OperationState state(location, getOperationName());
+  build(builder, state, value);
+  auto result = llvm::dyn_cast<ConstantIndexOp>(builder.create(state));
+  assert(result && "builder didn't return the right type");
+  return result;
+}
+
+arith::ConstantIndexOp
+arith::ConstantIndexOp::create(ImplicitLocOpBuilder &builder, int64_t value) {
+  return create(builder, builder.getLoc(), value);
+}
+
 bool arith::ConstantIndexOp::classof(Operation *op) {
   if (auto constOp = dyn_cast_or_null<arith::ConstantOp>(op))
     return constOp.getType().isIndex();
@@ -304,7 +384,7 @@ Value mlir::arith::getZeroConstant(OpBuilder &builder, Location loc,
          "type doesn't have a zero representation");
   TypedAttr zeroAttr = builder.getZeroAttr(type);
   assert(zeroAttr && "unsupported type for zero attribute");
-  return builder.create<arith::ConstantOp>(loc, zeroAttr);
+  return arith::ConstantOp::create(builder, loc, zeroAttr);
 }
 
 //===----------------------------------------------------------------------===//
@@ -2334,9 +2414,8 @@ class CmpFIntToFPConst final : public OpRewritePattern<CmpFOp> {
     // comparison.
     rewriter.replaceOpWithNewOp<CmpIOp>(
         op, pred, intVal,
-        rewriter.create<ConstantOp>(
-            op.getLoc(), intVal.getType(),
-            rewriter.getIntegerAttr(intVal.getType(), rhsInt)));
+        ConstantOp::create(rewriter, op.getLoc(), intVal.getType(),
+                           rewriter.getIntegerAttr(intVal.getType(), rhsInt)));
     return success();
   }
 };
@@ -2373,10 +2452,10 @@ struct SelectToExtUI : public OpRewritePattern<arith::SelectOp> {
         matchPattern(op.getFalseValue(), m_One())) {
       rewriter.replaceOpWithNewOp<arith::ExtUIOp>(
           op, op.getType(),
-          rewriter.create<arith::XOrIOp>(
-              op.getLoc(), op.getCondition(),
-              rewriter.create<arith::ConstantIntOp>(
-                  op.getLoc(), op.getCondition().getType(), 1)));
+          arith::XOrIOp::create(
+              rewriter, op.getLoc(), op.getCondition(),
+              arith::ConstantIntOp::create(rewriter, op.getLoc(),
+                                           op.getCondition().getType(), 1)));
       return success();
     }
 
@@ -2439,12 +2518,12 @@ OpFoldResult arith::SelectOp::fold(FoldAdaptor adaptor) {
 
   // Constant-fold constant operands over non-splat constant condition.
   // select %cst_vec, %cst0, %cst1 => %cst2
-  if (auto cond =
-          llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getCondition())) {
-    if (auto lhs =
-            llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getTrueValue())) {
-      if (auto rhs =
-              llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getFalseValue())) {
+  if (auto cond = llvm::dyn_cast_if_present<DenseElementsAttr>(
+          adaptor.getCondition())) {
+    if (auto lhs = llvm::dyn_cast_if_present<DenseElementsAttr>(
+            adaptor.getTrueValue())) {
+      if (auto rhs = llvm::dyn_cast_if_present<DenseElementsAttr>(
+              adaptor.getFalseValue())) {
         SmallVector<Attribute> results;
         results.reserve(static_cast<size_t>(cond.getNumElements()));
         auto condVals = llvm::make_range(cond.value_begin<BoolAttr>(),
@@ -2692,7 +2771,7 @@ Value mlir::arith::getIdentityValue(AtomicRMWKind op, Type resultType,
                                     bool useOnlyFiniteValue) {
   auto attr =
       getIdentityValueAttr(op, resultType, builder, loc, useOnlyFiniteValue);
-  return builder.create<arith::ConstantOp>(loc, attr);
+  return arith::ConstantOp::create(builder, loc, attr);
 }
 
 /// Return the value obtained by applying the reduction operation kind
@@ -2701,33 +2780,33 @@ Value mlir::arith::getReductionOp(AtomicRMWKind op, OpBuilder &builder,
                                   Location loc, Value lhs, Value rhs) {
   switch (op) {
   case AtomicRMWKind::addf:
-    return builder.create<arith::AddFOp>(loc, lhs, rhs);
+    return arith::AddFOp::create(builder, loc, lhs, rhs);
   case AtomicRMWKind::addi:
-    return builder.create<arith::AddIOp>(loc, lhs, rhs);
+    return arith::AddIOp::create(builder, loc, lhs, rhs);
   case AtomicRMWKind::mulf:
-    return builder.create<arith::MulFOp>(loc, lhs, rhs);
+    return arith::MulFOp::create(builder, loc, lhs, rhs);
   case AtomicRMWKind::muli:
-    return builder.create<arith::MulIOp>(loc, lhs, rhs);
+    return arith::MulIOp::create(builder, loc, lhs, rhs);
   case AtomicRMWKind::maximumf:
-    return builder.create<arith::MaximumFOp>(loc, lhs, rhs);
+    return arith::MaximumFOp::create(builder, loc, lhs, rhs);
   case AtomicRMWKind::minimumf:
-    return builder.create<arith::MinimumFOp>(loc, lhs, rhs);
-   case AtomicRMWKind::maxnumf:
-    return builder.create<arith::MaxNumFOp>(loc, lhs, rhs);
+    return arith::MinimumFOp::create(builder, loc, lhs, rhs);
+  case AtomicRMWKind::maxnumf:
+    return arith::MaxNumFOp::create(builder, loc, lhs, rhs);
   case AtomicRMWKind::minnumf:
-    return builder.create<arith::MinNumFOp>(loc, lhs, rhs);
+    return arith::MinNumFOp::create(builder, loc, lhs, rhs);
   case AtomicRMWKind::maxs:
-    return builder.create<arith::MaxSIOp>(loc, lhs, rhs);
+    return arith::MaxSIOp::create(builder, loc, lhs, rhs);
   case AtomicRMWKind::mins:
-    return builder.create<arith::MinSIOp>(loc, lhs, rhs);
+    return arith::MinSIOp::create(builder, loc, lhs, rhs);
   case AtomicRMWKind::maxu:
-    return builder.create<arith::MaxUIOp>(loc, lhs, rhs);
+    return arith::MaxUIOp::create(builder, loc, lhs, rhs);
   case AtomicRMWKind::minu:
-    return builder.create<arith::MinUIOp>(loc, lhs, rhs);
+    return arith::MinUIOp::create(builder, loc, lhs, rhs);
   case AtomicRMWKind::ori:
-    return builder.create<arith::OrIOp>(loc, lhs, rhs);
+    return arith::OrIOp::create(builder, loc, lhs, rhs);
   case AtomicRMWKind::andi:
-    return builder.create<arith::AndIOp>(loc, lhs, rhs);
+    return arith::AndIOp::create(builder, loc, lhs, rhs);
   // TODO: Add remaining reduction operations.
   default:
     (void)emitOptionalError(loc, "Reduction operation type not supported");



More information about the Mlir-commits mailing list