[Mlir-commits] [mlir] switch type and value ordering for arith `Constant[XX]Op` (PR #144636)

Skrai Pardus llvmlistbot at llvm.org
Tue Jun 17 22:01:01 PDT 2025


https://github.com/ashjeong created https://github.com/llvm/llvm-project/pull/144636

Change made to standardize with all other `Op` `build()` constructors.

>From 45206b448b3765686f23d88c0c6c0ea4d76feaf6 Mon Sep 17 00:00:00 2001
From: ashjeong <ashjeong at umich.edu>
Date: Wed, 18 Jun 2025 13:48:34 +0900
Subject: [PATCH] switch type and value ordering for arith `Constant[XX]Op`

---
 mlir/include/mlir/Dialect/Arith/IR/Arith.h           |  8 ++++----
 mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp    |  8 ++++----
 .../Conversion/TosaToLinalg/TosaToLinalgNamed.cpp    |  8 ++++----
 mlir/lib/Dialect/Arith/IR/ArithOps.cpp               |  8 +++-----
 mlir/lib/Dialect/Arith/Utils/Utils.cpp               |  4 ++--
 .../Dialect/Async/Transforms/AsyncParallelFor.cpp    |  4 ++--
 .../lib/Dialect/GPU/Transforms/AllReduceLowering.cpp | 12 ++++++------
 .../Dialect/Linalg/Transforms/DecomposeLinalgOps.cpp |  4 ++--
 mlir/lib/Dialect/Quant/Transforms/LowerQuantOps.cpp  |  4 ++--
 .../Dialect/SCF/Transforms/ParallelLoopTiling.cpp    |  2 +-
 mlir/lib/Dialect/SCF/Transforms/UpliftWhileToFor.cpp |  2 +-
 11 files changed, 31 insertions(+), 33 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Arith/IR/Arith.h b/mlir/include/mlir/Dialect/Arith/IR/Arith.h
index 77241319851e6..0bee876ac9bfa 100644
--- a/mlir/include/mlir/Dialect/Arith/IR/Arith.h
+++ b/mlir/include/mlir/Dialect/Arith/IR/Arith.h
@@ -62,8 +62,8 @@ class ConstantIntOp : public arith::ConstantOp {
 
   /// Build a constant int op that produces an integer of the specified type,
   /// which must be an integer type.
-  static void build(OpBuilder &builder, OperationState &result, int64_t value,
-                    Type type);
+  static void build(OpBuilder &builder, OperationState &result, Type type,
+                    int64_t value);
 
   inline int64_t value() {
     return cast<IntegerAttr>(arith::ConstantOp::getValue()).getInt();
@@ -79,8 +79,8 @@ class ConstantFloatOp : public arith::ConstantOp {
   static ::mlir::TypeID resolveTypeID() { return TypeID::get<ConstantOp>(); }
 
   /// Build a constant float op that produces a float of the specified type.
-  static void build(OpBuilder &builder, OperationState &result,
-                    const APFloat &value, FloatType type);
+  static void build(OpBuilder &builder, OperationState &result, FloatType type,
+                    const APFloat &value);
 
   inline APFloat value() {
     return cast<FloatAttr>(arith::ConstantOp::getValue()).getValue();
diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
index 6d73f23e2aae1..923f5f67b865a 100644
--- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
+++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
@@ -244,11 +244,11 @@ static Value createLinalgBodyCalculationForElementwiseOp(
 
       // Clamp to the negation range.
       Value min = rewriter.create<arith::ConstantIntOp>(
-          loc, APInt::getSignedMinValue(inputBitWidth).getSExtValue(),
-          intermediateType);
+          loc, intermediateType,
+          APInt::getSignedMinValue(inputBitWidth).getSExtValue());
       Value max = rewriter.create<arith::ConstantIntOp>(
-          loc, APInt::getSignedMaxValue(inputBitWidth).getSExtValue(),
-          intermediateType);
+          loc, intermediateType,
+          APInt::getSignedMaxValue(inputBitWidth).getSExtValue());
       auto clamp = clampIntHelper(loc, sub, min, max, rewriter, false);
 
       // Truncate to the final value.
diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp
index 86f5e9baf4a94..c460a8bb2f4b2 100644
--- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp
+++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp
@@ -1073,11 +1073,11 @@ class AvgPool2dConverter : public OpRewritePattern<tosa::AvgPool2dOp> {
             int64_t outBitwidth = resultETy.getIntOrFloatBitWidth();
 
             auto min = rewriter.create<arith::ConstantIntOp>(
-                loc, APInt::getSignedMinValue(outBitwidth).getSExtValue(),
-                accETy);
+                loc, accETy,
+                APInt::getSignedMinValue(outBitwidth).getSExtValue());
             auto max = rewriter.create<arith::ConstantIntOp>(
-                loc, APInt::getSignedMaxValue(outBitwidth).getSExtValue(),
-                accETy);
+                loc, accETy,
+                APInt::getSignedMaxValue(outBitwidth).getSExtValue());
             auto clamp = clampIntHelper(loc, scaled, min, max, rewriter,
                                         /*isUnsigned=*/false);
 
diff --git a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
index 9e53e195274aa..b9f91a0509103 100644
--- a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
+++ b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
@@ -257,9 +257,7 @@ void arith::ConstantIntOp::build(OpBuilder &builder, OperationState &result,
 }
 
 void arith::ConstantIntOp::build(OpBuilder &builder, OperationState &result,
-                                 int64_t value, Type type) {
-  assert(type.isSignlessInteger() &&
-         "ConstantIntOp can only have signless integer type values");
+                                 Type type, int64_t value) {
   arith::ConstantOp::build(builder, result, type,
                            builder.getIntegerAttr(type, value));
 }
@@ -271,7 +269,7 @@ bool arith::ConstantIntOp::classof(Operation *op) {
 }
 
 void arith::ConstantFloatOp::build(OpBuilder &builder, OperationState &result,
-                                   const APFloat &value, FloatType type) {
+                                   FloatType type, const APFloat &value) {
   arith::ConstantOp::build(builder, result, type,
                            builder.getFloatAttr(type, value));
 }
@@ -2363,7 +2361,7 @@ struct SelectToExtUI : public OpRewritePattern<arith::SelectOp> {
           rewriter.create<arith::XOrIOp>(
               op.getLoc(), op.getCondition(),
               rewriter.create<arith::ConstantIntOp>(
-                  op.getLoc(), 1, op.getCondition().getType())));
+                  op.getLoc(), op.getCondition().getType(), 1)));
       return success();
     }
 
diff --git a/mlir/lib/Dialect/Arith/Utils/Utils.cpp b/mlir/lib/Dialect/Arith/Utils/Utils.cpp
index bb4807ab39cd6..3cd8684878a11 100644
--- a/mlir/lib/Dialect/Arith/Utils/Utils.cpp
+++ b/mlir/lib/Dialect/Arith/Utils/Utils.cpp
@@ -216,7 +216,7 @@ static Value convertScalarToComplexDtype(ImplicitLocOpBuilder &b, Value operand,
       from = b.create<arith::TruncFOp>(toFpTy, from);
     }
     Value zero = b.create<mlir::arith::ConstantFloatOp>(
-        mlir::APFloat(toFpTy.getFloatSemantics(), 0), toFpTy);
+        toFpTy, mlir::APFloat(toFpTy.getFloatSemantics(), 0));
     return b.create<complex::CreateOp>(targetType, from, zero);
   }
 
@@ -229,7 +229,7 @@ static Value convertScalarToComplexDtype(ImplicitLocOpBuilder &b, Value operand,
       from = b.create<arith::SIToFPOp>(toFpTy, from);
     }
     Value zero = b.create<mlir::arith::ConstantFloatOp>(
-        mlir::APFloat(toFpTy.getFloatSemantics(), 0), toFpTy);
+        toFpTy, mlir::APFloat(toFpTy.getFloatSemantics(), 0));
     return b.create<complex::CreateOp>(targetType, from, zero);
   }
 
diff --git a/mlir/lib/Dialect/Async/Transforms/AsyncParallelFor.cpp b/mlir/lib/Dialect/Async/Transforms/AsyncParallelFor.cpp
index 9c776dfa176a4..27fa92cee79c2 100644
--- a/mlir/lib/Dialect/Async/Transforms/AsyncParallelFor.cpp
+++ b/mlir/lib/Dialect/Async/Transforms/AsyncParallelFor.cpp
@@ -820,13 +820,13 @@ AsyncParallelForRewrite::matchAndRewrite(scf::ParallelOp op,
     const float initialOvershardingFactor = 8.0f;
 
     Value scalingFactor = b.create<arith::ConstantFloatOp>(
-        llvm::APFloat(initialOvershardingFactor), b.getF32Type());
+        b.getF32Type(), llvm::APFloat(initialOvershardingFactor));
     for (const std::pair<int, float> &p : overshardingBrackets) {
       Value bracketBegin = b.create<arith::ConstantIndexOp>(p.first);
       Value inBracket = b.create<arith::CmpIOp>(
           arith::CmpIPredicate::sgt, numWorkerThreadsVal, bracketBegin);
       Value bracketScalingFactor = b.create<arith::ConstantFloatOp>(
-          llvm::APFloat(p.second), b.getF32Type());
+          b.getF32Type(), llvm::APFloat(p.second));
       scalingFactor = b.create<arith::SelectOp>(inBracket, bracketScalingFactor,
                                                 scalingFactor);
     }
diff --git a/mlir/lib/Dialect/GPU/Transforms/AllReduceLowering.cpp b/mlir/lib/Dialect/GPU/Transforms/AllReduceLowering.cpp
index a75598afe8c72..d35f72e5a9e26 100644
--- a/mlir/lib/Dialect/GPU/Transforms/AllReduceLowering.cpp
+++ b/mlir/lib/Dialect/GPU/Transforms/AllReduceLowering.cpp
@@ -83,11 +83,11 @@ struct GpuAllReduceRewriter {
 
     // Compute lane id (invocation id withing the subgroup).
     Value subgroupMask =
-        create<arith::ConstantIntOp>(kSubgroupSize - 1, int32Type);
+        create<arith::ConstantIntOp>(int32Type, kSubgroupSize - 1);
     Value laneId = create<arith::AndIOp>(invocationIdx, subgroupMask);
     Value isFirstLane =
         create<arith::CmpIOp>(arith::CmpIPredicate::eq, laneId,
-                              create<arith::ConstantIntOp>(0, int32Type));
+                              create<arith::ConstantIntOp>(int32Type, 0));
 
     Value numThreadsWithSmallerSubgroupId =
         create<arith::SubIOp>(invocationIdx, laneId);
@@ -282,7 +282,7 @@ struct GpuAllReduceRewriter {
   /// The first lane returns the result, all others return values are undefined.
   Value createSubgroupReduce(Value activeWidth, Value laneId, Value operand,
                              AccumulatorFactory &accumFactory) {
-    Value subgroupSize = create<arith::ConstantIntOp>(kSubgroupSize, int32Type);
+    Value subgroupSize = create<arith::ConstantIntOp>(int32Type, kSubgroupSize);
     Value isPartialSubgroup = create<arith::CmpIOp>(arith::CmpIPredicate::slt,
                                                     activeWidth, subgroupSize);
     std::array<Type, 2> shuffleType = {valueType, rewriter.getI1Type()};
@@ -296,7 +296,7 @@ struct GpuAllReduceRewriter {
           // lane is within the active range. The accumulated value is available
           // in the first lane.
           for (int i = 1; i < kSubgroupSize; i <<= 1) {
-            Value offset = create<arith::ConstantIntOp>(i, int32Type);
+            Value offset = create<arith::ConstantIntOp>(int32Type, i);
             auto shuffleOp = create<gpu::ShuffleOp>(
                 shuffleType, value, offset, activeWidth, gpu::ShuffleMode::XOR);
             // Skip the accumulation if the shuffle op read from a lane outside
@@ -318,7 +318,7 @@ struct GpuAllReduceRewriter {
         [&] {
           Value value = operand;
           for (int i = 1; i < kSubgroupSize; i <<= 1) {
-            Value offset = create<arith::ConstantIntOp>(i, int32Type);
+            Value offset = create<arith::ConstantIntOp>(int32Type, i);
             auto shuffleOp =
                 create<gpu::ShuffleOp>(shuffleType, value, offset, subgroupSize,
                                        gpu::ShuffleMode::XOR);
@@ -331,7 +331,7 @@ struct GpuAllReduceRewriter {
 
   /// Returns value divided by the subgroup size (i.e. 32).
   Value getDivideBySubgroupSize(Value value) {
-    Value subgroupSize = create<arith::ConstantIntOp>(kSubgroupSize, int32Type);
+    Value subgroupSize = create<arith::ConstantIntOp>(int32Type, kSubgroupSize);
     return create<arith::DivSIOp>(int32Type, value, subgroupSize);
   }
 
diff --git a/mlir/lib/Dialect/Linalg/Transforms/DecomposeLinalgOps.cpp b/mlir/lib/Dialect/Linalg/Transforms/DecomposeLinalgOps.cpp
index 999359c7fa872..1419175304899 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/DecomposeLinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/DecomposeLinalgOps.cpp
@@ -133,13 +133,13 @@ static Value getZero(OpBuilder &b, Location loc, Type elementType) {
   assert(elementType.isIntOrIndexOrFloat() &&
          "expected scalar type while computing zero value");
   if (isa<IntegerType>(elementType))
-    return b.create<arith::ConstantIntOp>(loc, 0, elementType);
+    return b.create<arith::ConstantIntOp>(loc, elementType, 0);
   if (elementType.isIndex())
     return b.create<arith::ConstantIndexOp>(loc, 0);
   // Assume float.
   auto floatType = cast<FloatType>(elementType);
   return b.create<arith::ConstantFloatOp>(
-      loc, APFloat::getZero(floatType.getFloatSemantics()), floatType);
+      loc, floatType, APFloat::getZero(floatType.getFloatSemantics()));
 }
 
 GenericOp
diff --git a/mlir/lib/Dialect/Quant/Transforms/LowerQuantOps.cpp b/mlir/lib/Dialect/Quant/Transforms/LowerQuantOps.cpp
index c2dbcde1aeba6..793db73575b4f 100644
--- a/mlir/lib/Dialect/Quant/Transforms/LowerQuantOps.cpp
+++ b/mlir/lib/Dialect/Quant/Transforms/LowerQuantOps.cpp
@@ -315,9 +315,9 @@ Value clampScalarOrTensor(OpBuilder &builder, Location loc, Value input,
   auto inputType = input.getType();
   auto storageType = quantizedType.getStorageType();
   auto storageMinScalar = builder.create<arith::ConstantIntOp>(
-      loc, quantizedType.getStorageTypeMin(), storageType);
+      loc, storageType, quantizedType.getStorageTypeMin());
   auto storageMaxScalar = builder.create<arith::ConstantIntOp>(
-      loc, quantizedType.getStorageTypeMax(), storageType);
+      loc, storageType, quantizedType.getStorageTypeMax());
   auto storageMin = getScalarOrTensorConstant(builder, loc, storageMinScalar,
                                               inputType, inputShape);
   auto storageMax = getScalarOrTensorConstant(builder, loc, storageMaxScalar,
diff --git a/mlir/lib/Dialect/SCF/Transforms/ParallelLoopTiling.cpp b/mlir/lib/Dialect/SCF/Transforms/ParallelLoopTiling.cpp
index ed73d81198f29..66f7bc27f82ff 100644
--- a/mlir/lib/Dialect/SCF/Transforms/ParallelLoopTiling.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/ParallelLoopTiling.cpp
@@ -141,7 +141,7 @@ mlir::scf::tileParallelLoop(ParallelOp op, ArrayRef<int64_t> tileSizes,
     b.setInsertionPointToStart(innerLoop.getBody());
     // Insert in-bound check
     Value inbound =
-        b.create<arith::ConstantIntOp>(op.getLoc(), 1, b.getIntegerType(1));
+        b.create<arith::ConstantIntOp>(op.getLoc(), b.getIntegerType(1), 1);
     for (auto [outerUpperBound, outerIV, innerIV, innerStep] :
          llvm::zip(outerLoop.getUpperBound(), outerLoop.getInductionVars(),
                    innerLoop.getInductionVars(), innerLoop.getStep())) {
diff --git a/mlir/lib/Dialect/SCF/Transforms/UpliftWhileToFor.cpp b/mlir/lib/Dialect/SCF/Transforms/UpliftWhileToFor.cpp
index ebe718ae4fb61..29d6d2574a2be 100644
--- a/mlir/lib/Dialect/SCF/Transforms/UpliftWhileToFor.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/UpliftWhileToFor.cpp
@@ -240,7 +240,7 @@ FailureOr<scf::ForOp> mlir::scf::upliftWhileToForLoop(RewriterBase &rewriter,
   if (isa<IndexType>(step.getType())) {
     one = rewriter.create<arith::ConstantIndexOp>(loc, 1);
   } else {
-    one = rewriter.create<arith::ConstantIntOp>(loc, 1, step.getType());
+    one = rewriter.create<arith::ConstantIntOp>(loc, step.getType(), 1);
   }
 
   Value stepDec = rewriter.create<arith::SubIOp>(loc, step, one);



More information about the Mlir-commits mailing list