[Mlir-commits] [mlir] switch type and value ordering for arith `Constant[XX]Op` (PR #144636)
Skrai Pardus
llvmlistbot at llvm.org
Tue Jun 17 22:01:01 PDT 2025
https://github.com/ashjeong created https://github.com/llvm/llvm-project/pull/144636
Change made to standardize with all other `Op` `build()` constructors.
>From 45206b448b3765686f23d88c0c6c0ea4d76feaf6 Mon Sep 17 00:00:00 2001
From: ashjeong <ashjeong at umich.edu>
Date: Wed, 18 Jun 2025 13:48:34 +0900
Subject: [PATCH] switch type and value ordering for arith `Constant[XX]Op`
---
mlir/include/mlir/Dialect/Arith/IR/Arith.h | 8 ++++----
mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp | 8 ++++----
.../Conversion/TosaToLinalg/TosaToLinalgNamed.cpp | 8 ++++----
mlir/lib/Dialect/Arith/IR/ArithOps.cpp | 8 +++-----
mlir/lib/Dialect/Arith/Utils/Utils.cpp | 4 ++--
.../Dialect/Async/Transforms/AsyncParallelFor.cpp | 4 ++--
.../lib/Dialect/GPU/Transforms/AllReduceLowering.cpp | 12 ++++++------
.../Dialect/Linalg/Transforms/DecomposeLinalgOps.cpp | 4 ++--
mlir/lib/Dialect/Quant/Transforms/LowerQuantOps.cpp | 4 ++--
.../Dialect/SCF/Transforms/ParallelLoopTiling.cpp | 2 +-
mlir/lib/Dialect/SCF/Transforms/UpliftWhileToFor.cpp | 2 +-
11 files changed, 31 insertions(+), 33 deletions(-)
diff --git a/mlir/include/mlir/Dialect/Arith/IR/Arith.h b/mlir/include/mlir/Dialect/Arith/IR/Arith.h
index 77241319851e6..0bee876ac9bfa 100644
--- a/mlir/include/mlir/Dialect/Arith/IR/Arith.h
+++ b/mlir/include/mlir/Dialect/Arith/IR/Arith.h
@@ -62,8 +62,8 @@ class ConstantIntOp : public arith::ConstantOp {
/// Build a constant int op that produces an integer of the specified type,
/// which must be an integer type.
- static void build(OpBuilder &builder, OperationState &result, int64_t value,
- Type type);
+ static void build(OpBuilder &builder, OperationState &result, Type type,
+ int64_t value);
inline int64_t value() {
return cast<IntegerAttr>(arith::ConstantOp::getValue()).getInt();
@@ -79,8 +79,8 @@ class ConstantFloatOp : public arith::ConstantOp {
static ::mlir::TypeID resolveTypeID() { return TypeID::get<ConstantOp>(); }
/// Build a constant float op that produces a float of the specified type.
- static void build(OpBuilder &builder, OperationState &result,
- const APFloat &value, FloatType type);
+ static void build(OpBuilder &builder, OperationState &result, FloatType type,
+ const APFloat &value);
inline APFloat value() {
return cast<FloatAttr>(arith::ConstantOp::getValue()).getValue();
diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
index 6d73f23e2aae1..923f5f67b865a 100644
--- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
+++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
@@ -244,11 +244,11 @@ static Value createLinalgBodyCalculationForElementwiseOp(
// Clamp to the negation range.
Value min = rewriter.create<arith::ConstantIntOp>(
- loc, APInt::getSignedMinValue(inputBitWidth).getSExtValue(),
- intermediateType);
+ loc, intermediateType,
+ APInt::getSignedMinValue(inputBitWidth).getSExtValue());
Value max = rewriter.create<arith::ConstantIntOp>(
- loc, APInt::getSignedMaxValue(inputBitWidth).getSExtValue(),
- intermediateType);
+ loc, intermediateType,
+ APInt::getSignedMaxValue(inputBitWidth).getSExtValue());
auto clamp = clampIntHelper(loc, sub, min, max, rewriter, false);
// Truncate to the final value.
diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp
index 86f5e9baf4a94..c460a8bb2f4b2 100644
--- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp
+++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp
@@ -1073,11 +1073,11 @@ class AvgPool2dConverter : public OpRewritePattern<tosa::AvgPool2dOp> {
int64_t outBitwidth = resultETy.getIntOrFloatBitWidth();
auto min = rewriter.create<arith::ConstantIntOp>(
- loc, APInt::getSignedMinValue(outBitwidth).getSExtValue(),
- accETy);
+ loc, accETy,
+ APInt::getSignedMinValue(outBitwidth).getSExtValue());
auto max = rewriter.create<arith::ConstantIntOp>(
- loc, APInt::getSignedMaxValue(outBitwidth).getSExtValue(),
- accETy);
+ loc, accETy,
+ APInt::getSignedMaxValue(outBitwidth).getSExtValue());
auto clamp = clampIntHelper(loc, scaled, min, max, rewriter,
/*isUnsigned=*/false);
diff --git a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
index 9e53e195274aa..b9f91a0509103 100644
--- a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
+++ b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
@@ -257,9 +257,7 @@ void arith::ConstantIntOp::build(OpBuilder &builder, OperationState &result,
}
void arith::ConstantIntOp::build(OpBuilder &builder, OperationState &result,
- int64_t value, Type type) {
- assert(type.isSignlessInteger() &&
- "ConstantIntOp can only have signless integer type values");
+ Type type, int64_t value) {
arith::ConstantOp::build(builder, result, type,
builder.getIntegerAttr(type, value));
}
@@ -271,7 +269,7 @@ bool arith::ConstantIntOp::classof(Operation *op) {
}
void arith::ConstantFloatOp::build(OpBuilder &builder, OperationState &result,
- const APFloat &value, FloatType type) {
+ FloatType type, const APFloat &value) {
arith::ConstantOp::build(builder, result, type,
builder.getFloatAttr(type, value));
}
@@ -2363,7 +2361,7 @@ struct SelectToExtUI : public OpRewritePattern<arith::SelectOp> {
rewriter.create<arith::XOrIOp>(
op.getLoc(), op.getCondition(),
rewriter.create<arith::ConstantIntOp>(
- op.getLoc(), 1, op.getCondition().getType())));
+ op.getLoc(), op.getCondition().getType(), 1)));
return success();
}
diff --git a/mlir/lib/Dialect/Arith/Utils/Utils.cpp b/mlir/lib/Dialect/Arith/Utils/Utils.cpp
index bb4807ab39cd6..3cd8684878a11 100644
--- a/mlir/lib/Dialect/Arith/Utils/Utils.cpp
+++ b/mlir/lib/Dialect/Arith/Utils/Utils.cpp
@@ -216,7 +216,7 @@ static Value convertScalarToComplexDtype(ImplicitLocOpBuilder &b, Value operand,
from = b.create<arith::TruncFOp>(toFpTy, from);
}
Value zero = b.create<mlir::arith::ConstantFloatOp>(
- mlir::APFloat(toFpTy.getFloatSemantics(), 0), toFpTy);
+ toFpTy, mlir::APFloat(toFpTy.getFloatSemantics(), 0));
return b.create<complex::CreateOp>(targetType, from, zero);
}
@@ -229,7 +229,7 @@ static Value convertScalarToComplexDtype(ImplicitLocOpBuilder &b, Value operand,
from = b.create<arith::SIToFPOp>(toFpTy, from);
}
Value zero = b.create<mlir::arith::ConstantFloatOp>(
- mlir::APFloat(toFpTy.getFloatSemantics(), 0), toFpTy);
+ toFpTy, mlir::APFloat(toFpTy.getFloatSemantics(), 0));
return b.create<complex::CreateOp>(targetType, from, zero);
}
diff --git a/mlir/lib/Dialect/Async/Transforms/AsyncParallelFor.cpp b/mlir/lib/Dialect/Async/Transforms/AsyncParallelFor.cpp
index 9c776dfa176a4..27fa92cee79c2 100644
--- a/mlir/lib/Dialect/Async/Transforms/AsyncParallelFor.cpp
+++ b/mlir/lib/Dialect/Async/Transforms/AsyncParallelFor.cpp
@@ -820,13 +820,13 @@ AsyncParallelForRewrite::matchAndRewrite(scf::ParallelOp op,
const float initialOvershardingFactor = 8.0f;
Value scalingFactor = b.create<arith::ConstantFloatOp>(
- llvm::APFloat(initialOvershardingFactor), b.getF32Type());
+ b.getF32Type(), llvm::APFloat(initialOvershardingFactor));
for (const std::pair<int, float> &p : overshardingBrackets) {
Value bracketBegin = b.create<arith::ConstantIndexOp>(p.first);
Value inBracket = b.create<arith::CmpIOp>(
arith::CmpIPredicate::sgt, numWorkerThreadsVal, bracketBegin);
Value bracketScalingFactor = b.create<arith::ConstantFloatOp>(
- llvm::APFloat(p.second), b.getF32Type());
+ b.getF32Type(), llvm::APFloat(p.second));
scalingFactor = b.create<arith::SelectOp>(inBracket, bracketScalingFactor,
scalingFactor);
}
diff --git a/mlir/lib/Dialect/GPU/Transforms/AllReduceLowering.cpp b/mlir/lib/Dialect/GPU/Transforms/AllReduceLowering.cpp
index a75598afe8c72..d35f72e5a9e26 100644
--- a/mlir/lib/Dialect/GPU/Transforms/AllReduceLowering.cpp
+++ b/mlir/lib/Dialect/GPU/Transforms/AllReduceLowering.cpp
@@ -83,11 +83,11 @@ struct GpuAllReduceRewriter {
// Compute lane id (invocation id withing the subgroup).
Value subgroupMask =
- create<arith::ConstantIntOp>(kSubgroupSize - 1, int32Type);
+ create<arith::ConstantIntOp>(int32Type, kSubgroupSize - 1);
Value laneId = create<arith::AndIOp>(invocationIdx, subgroupMask);
Value isFirstLane =
create<arith::CmpIOp>(arith::CmpIPredicate::eq, laneId,
- create<arith::ConstantIntOp>(0, int32Type));
+ create<arith::ConstantIntOp>(int32Type, 0));
Value numThreadsWithSmallerSubgroupId =
create<arith::SubIOp>(invocationIdx, laneId);
@@ -282,7 +282,7 @@ struct GpuAllReduceRewriter {
/// The first lane returns the result, all others return values are undefined.
Value createSubgroupReduce(Value activeWidth, Value laneId, Value operand,
AccumulatorFactory &accumFactory) {
- Value subgroupSize = create<arith::ConstantIntOp>(kSubgroupSize, int32Type);
+ Value subgroupSize = create<arith::ConstantIntOp>(int32Type, kSubgroupSize);
Value isPartialSubgroup = create<arith::CmpIOp>(arith::CmpIPredicate::slt,
activeWidth, subgroupSize);
std::array<Type, 2> shuffleType = {valueType, rewriter.getI1Type()};
@@ -296,7 +296,7 @@ struct GpuAllReduceRewriter {
// lane is within the active range. The accumulated value is available
// in the first lane.
for (int i = 1; i < kSubgroupSize; i <<= 1) {
- Value offset = create<arith::ConstantIntOp>(i, int32Type);
+ Value offset = create<arith::ConstantIntOp>(int32Type, i);
auto shuffleOp = create<gpu::ShuffleOp>(
shuffleType, value, offset, activeWidth, gpu::ShuffleMode::XOR);
// Skip the accumulation if the shuffle op read from a lane outside
@@ -318,7 +318,7 @@ struct GpuAllReduceRewriter {
[&] {
Value value = operand;
for (int i = 1; i < kSubgroupSize; i <<= 1) {
- Value offset = create<arith::ConstantIntOp>(i, int32Type);
+ Value offset = create<arith::ConstantIntOp>(int32Type, i);
auto shuffleOp =
create<gpu::ShuffleOp>(shuffleType, value, offset, subgroupSize,
gpu::ShuffleMode::XOR);
@@ -331,7 +331,7 @@ struct GpuAllReduceRewriter {
/// Returns value divided by the subgroup size (i.e. 32).
Value getDivideBySubgroupSize(Value value) {
- Value subgroupSize = create<arith::ConstantIntOp>(kSubgroupSize, int32Type);
+ Value subgroupSize = create<arith::ConstantIntOp>(int32Type, kSubgroupSize);
return create<arith::DivSIOp>(int32Type, value, subgroupSize);
}
diff --git a/mlir/lib/Dialect/Linalg/Transforms/DecomposeLinalgOps.cpp b/mlir/lib/Dialect/Linalg/Transforms/DecomposeLinalgOps.cpp
index 999359c7fa872..1419175304899 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/DecomposeLinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/DecomposeLinalgOps.cpp
@@ -133,13 +133,13 @@ static Value getZero(OpBuilder &b, Location loc, Type elementType) {
assert(elementType.isIntOrIndexOrFloat() &&
"expected scalar type while computing zero value");
if (isa<IntegerType>(elementType))
- return b.create<arith::ConstantIntOp>(loc, 0, elementType);
+ return b.create<arith::ConstantIntOp>(loc, elementType, 0);
if (elementType.isIndex())
return b.create<arith::ConstantIndexOp>(loc, 0);
// Assume float.
auto floatType = cast<FloatType>(elementType);
return b.create<arith::ConstantFloatOp>(
- loc, APFloat::getZero(floatType.getFloatSemantics()), floatType);
+ loc, floatType, APFloat::getZero(floatType.getFloatSemantics()));
}
GenericOp
diff --git a/mlir/lib/Dialect/Quant/Transforms/LowerQuantOps.cpp b/mlir/lib/Dialect/Quant/Transforms/LowerQuantOps.cpp
index c2dbcde1aeba6..793db73575b4f 100644
--- a/mlir/lib/Dialect/Quant/Transforms/LowerQuantOps.cpp
+++ b/mlir/lib/Dialect/Quant/Transforms/LowerQuantOps.cpp
@@ -315,9 +315,9 @@ Value clampScalarOrTensor(OpBuilder &builder, Location loc, Value input,
auto inputType = input.getType();
auto storageType = quantizedType.getStorageType();
auto storageMinScalar = builder.create<arith::ConstantIntOp>(
- loc, quantizedType.getStorageTypeMin(), storageType);
+ loc, storageType, quantizedType.getStorageTypeMin());
auto storageMaxScalar = builder.create<arith::ConstantIntOp>(
- loc, quantizedType.getStorageTypeMax(), storageType);
+ loc, storageType, quantizedType.getStorageTypeMax());
auto storageMin = getScalarOrTensorConstant(builder, loc, storageMinScalar,
inputType, inputShape);
auto storageMax = getScalarOrTensorConstant(builder, loc, storageMaxScalar,
diff --git a/mlir/lib/Dialect/SCF/Transforms/ParallelLoopTiling.cpp b/mlir/lib/Dialect/SCF/Transforms/ParallelLoopTiling.cpp
index ed73d81198f29..66f7bc27f82ff 100644
--- a/mlir/lib/Dialect/SCF/Transforms/ParallelLoopTiling.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/ParallelLoopTiling.cpp
@@ -141,7 +141,7 @@ mlir::scf::tileParallelLoop(ParallelOp op, ArrayRef<int64_t> tileSizes,
b.setInsertionPointToStart(innerLoop.getBody());
// Insert in-bound check
Value inbound =
- b.create<arith::ConstantIntOp>(op.getLoc(), 1, b.getIntegerType(1));
+ b.create<arith::ConstantIntOp>(op.getLoc(), b.getIntegerType(1), 1);
for (auto [outerUpperBound, outerIV, innerIV, innerStep] :
llvm::zip(outerLoop.getUpperBound(), outerLoop.getInductionVars(),
innerLoop.getInductionVars(), innerLoop.getStep())) {
diff --git a/mlir/lib/Dialect/SCF/Transforms/UpliftWhileToFor.cpp b/mlir/lib/Dialect/SCF/Transforms/UpliftWhileToFor.cpp
index ebe718ae4fb61..29d6d2574a2be 100644
--- a/mlir/lib/Dialect/SCF/Transforms/UpliftWhileToFor.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/UpliftWhileToFor.cpp
@@ -240,7 +240,7 @@ FailureOr<scf::ForOp> mlir::scf::upliftWhileToForLoop(RewriterBase &rewriter,
if (isa<IndexType>(step.getType())) {
one = rewriter.create<arith::ConstantIndexOp>(loc, 1);
} else {
- one = rewriter.create<arith::ConstantIntOp>(loc, 1, step.getType());
+ one = rewriter.create<arith::ConstantIntOp>(loc, step.getType(), 1);
}
Value stepDec = rewriter.create<arith::SubIOp>(loc, step, one);
More information about the Mlir-commits
mailing list