[Mlir-commits] [mlir] [mlir] update affine+arith create APIs (1/n) (PR #149656)
Maksim Levental
llvmlistbot at llvm.org
Sat Jul 19 09:52:27 PDT 2025
https://github.com/makslevental updated https://github.com/llvm/llvm-project/pull/149656
>From 56146a51f7d94f5fe79d9e35fbed9d68c7445d73 Mon Sep 17 00:00:00 2001
From: max <maksim.levental at gmail.com>
Date: Mon, 7 Jul 2025 12:56:56 -0400
Subject: [PATCH 1/2] [mlir] update create APIs (1/n)
---
.../mlir/Dialect/Affine/IR/AffineOps.h | 21 +++
mlir/include/mlir/Dialect/Arith/IR/Arith.h | 19 +++
mlir/lib/Dialect/Affine/IR/AffineOps.cpp | 114 +++++++++-----
mlir/lib/Dialect/Arith/IR/ArithOps.cpp | 141 ++++++++++++++----
4 files changed, 228 insertions(+), 67 deletions(-)
diff --git a/mlir/include/mlir/Dialect/Affine/IR/AffineOps.h b/mlir/include/mlir/Dialect/Affine/IR/AffineOps.h
index 2091faa6b0b02..333de6bbd8a05 100644
--- a/mlir/include/mlir/Dialect/Affine/IR/AffineOps.h
+++ b/mlir/include/mlir/Dialect/Affine/IR/AffineOps.h
@@ -114,6 +114,21 @@ class AffineDmaStartOp
AffineMap tagMap, ValueRange tagIndices, Value numElements,
Value stride = nullptr, Value elementsPerStride = nullptr);
+ static AffineDmaStartOp
+ create(OpBuilder &builder, Location location, Value srcMemRef,
+ AffineMap srcMap, ValueRange srcIndices, Value destMemRef,
+ AffineMap dstMap, ValueRange destIndices, Value tagMemRef,
+ AffineMap tagMap, ValueRange tagIndices, Value numElements,
+ Value stride = nullptr, Value elementsPerStride = nullptr);
+
+ static AffineDmaStartOp create(ImplicitLocOpBuilder &builder, Value srcMemRef,
+ AffineMap srcMap, ValueRange srcIndices,
+ Value destMemRef, AffineMap dstMap,
+ ValueRange destIndices, Value tagMemRef,
+ AffineMap tagMap, ValueRange tagIndices,
+ Value numElements, Value stride = nullptr,
+ Value elementsPerStride = nullptr);
+
/// Returns the operand index of the source memref.
unsigned getSrcMemRefOperandIndex() { return 0; }
@@ -319,6 +334,12 @@ class AffineDmaWaitOp
static void build(OpBuilder &builder, OperationState &result, Value tagMemRef,
AffineMap tagMap, ValueRange tagIndices, Value numElements);
+ static AffineDmaWaitOp create(OpBuilder &builder, Location location,
+ Value tagMemRef, AffineMap tagMap,
+ ValueRange tagIndices, Value numElements);
+ static AffineDmaWaitOp create(ImplicitLocOpBuilder &builder, Value tagMemRef,
+ AffineMap tagMap, ValueRange tagIndices,
+ Value numElements);
static StringRef getOperationName() { return "affine.dma_wait"; }
diff --git a/mlir/include/mlir/Dialect/Arith/IR/Arith.h b/mlir/include/mlir/Dialect/Arith/IR/Arith.h
index 7c50c2036ffdc..0fc3db8e993d8 100644
--- a/mlir/include/mlir/Dialect/Arith/IR/Arith.h
+++ b/mlir/include/mlir/Dialect/Arith/IR/Arith.h
@@ -59,15 +59,27 @@ class ConstantIntOp : public arith::ConstantOp {
/// Build a constant int op that produces an integer of the specified width.
static void build(OpBuilder &builder, OperationState &result, int64_t value,
unsigned width);
+ static ConstantIntOp create(OpBuilder &builder, Location location,
+ int64_t value, unsigned width);
+ static ConstantIntOp create(ImplicitLocOpBuilder &builder, int64_t value,
+ unsigned width);
/// Build a constant int op that produces an integer of the specified type,
/// which must be an integer type.
static void build(OpBuilder &builder, OperationState &result, Type type,
int64_t value);
+ static ConstantIntOp create(OpBuilder &builder, Location location, Type type,
+ int64_t value);
+ static ConstantIntOp create(ImplicitLocOpBuilder &builder, Type type,
+ int64_t value);
/// Build a constant int op that produces an integer from an APInt
static void build(OpBuilder &builder, OperationState &result, Type type,
const APInt &value);
+ static ConstantIntOp create(OpBuilder &builder, Location location, Type type,
+ const APInt &value);
+ static ConstantIntOp create(ImplicitLocOpBuilder &builder, Type type,
+ const APInt &value);
inline int64_t value() {
return cast<IntegerAttr>(arith::ConstantOp::getValue()).getInt();
@@ -85,6 +97,10 @@ class ConstantFloatOp : public arith::ConstantOp {
/// Build a constant float op that produces a float of the specified type.
static void build(OpBuilder &builder, OperationState &result, FloatType type,
const APFloat &value);
+ static ConstantFloatOp create(OpBuilder &builder, Location location,
+ FloatType type, const APFloat &value);
+ static ConstantFloatOp create(ImplicitLocOpBuilder &builder, FloatType type,
+ const APFloat &value);
inline APFloat value() {
return cast<FloatAttr>(arith::ConstantOp::getValue()).getValue();
@@ -100,6 +116,9 @@ class ConstantIndexOp : public arith::ConstantOp {
static ::mlir::TypeID resolveTypeID() { return TypeID::get<ConstantOp>(); }
/// Build a constant int op that produces an index.
static void build(OpBuilder &builder, OperationState &result, int64_t value);
+ static ConstantIndexOp create(OpBuilder &builder, Location location,
+ int64_t value);
+ static ConstantIndexOp create(ImplicitLocOpBuilder &builder, int64_t value);
inline int64_t value() {
return cast<IntegerAttr>(arith::ConstantOp::getValue()).getInt();
diff --git a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
index e8e8f624d806e..bd1be7c087d05 100644
--- a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
+++ b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
@@ -240,7 +240,7 @@ Operation *AffineDialect::materializeConstant(OpBuilder &builder,
Attribute value, Type type,
Location loc) {
if (auto poison = dyn_cast<ub::PoisonAttr>(value))
- return builder.create<ub::PoisonOp>(loc, type, poison);
+ return ub::PoisonOp::create(builder, loc, type, poison);
return arith::ConstantOp::materialize(builder, value, type, loc);
}
@@ -1282,7 +1282,7 @@ mlir::affine::makeComposedAffineApply(OpBuilder &b, Location loc, AffineMap map,
map = foldAttributesIntoMap(b, map, operands, valueOperands);
composeAffineMapAndOperands(&map, &valueOperands, composeAffineMin);
assert(map);
- return b.create<AffineApplyOp>(loc, map, valueOperands);
+ return AffineApplyOp::create(b, loc, map, valueOperands);
}
AffineApplyOp
@@ -1389,7 +1389,7 @@ static OpTy makeComposedMinMax(OpBuilder &b, Location loc, AffineMap map,
SmallVector<Value> valueOperands;
map = foldAttributesIntoMap(b, map, operands, valueOperands);
composeMultiResultAffineMap(map, valueOperands);
- return b.create<OpTy>(loc, b.getIndexType(), map, valueOperands);
+ return OpTy::create(b, loc, b.getIndexType(), map, valueOperands);
}
AffineMinOp
@@ -1747,6 +1747,32 @@ void AffineDmaStartOp::build(OpBuilder &builder, OperationState &result,
}
}
+AffineDmaStartOp AffineDmaStartOp::create(
+ OpBuilder &builder, Location location, Value srcMemRef, AffineMap srcMap,
+ ValueRange srcIndices, Value destMemRef, AffineMap dstMap,
+ ValueRange destIndices, Value tagMemRef, AffineMap tagMap,
+ ValueRange tagIndices, Value numElements, Value stride,
+ Value elementsPerStride) {
+ mlir::OperationState state(location, getOperationName());
+ build(builder, state, srcMemRef, srcMap, srcIndices, destMemRef, dstMap,
+ destIndices, tagMemRef, tagMap, tagIndices, numElements, stride,
+ elementsPerStride);
+ auto result = llvm::dyn_cast<AffineDmaStartOp>(builder.create(state));
+ assert(result && "builder didn't return the right type");
+ return result;
+}
+
+AffineDmaStartOp AffineDmaStartOp::create(
+ ImplicitLocOpBuilder &builder, Value srcMemRef, AffineMap srcMap,
+ ValueRange srcIndices, Value destMemRef, AffineMap dstMap,
+ ValueRange destIndices, Value tagMemRef, AffineMap tagMap,
+ ValueRange tagIndices, Value numElements, Value stride,
+ Value elementsPerStride) {
+ return create(builder, builder.getLoc(), srcMemRef, srcMap, srcIndices,
+ destMemRef, dstMap, destIndices, tagMemRef, tagMap, tagIndices,
+ numElements, stride, elementsPerStride);
+}
+
void AffineDmaStartOp::print(OpAsmPrinter &p) {
p << " " << getSrcMemRef() << '[';
p.printAffineMapOfSSAIds(getSrcMapAttr(), getSrcIndices());
@@ -1917,6 +1943,25 @@ void AffineDmaWaitOp::build(OpBuilder &builder, OperationState &result,
result.addOperands(numElements);
}
+AffineDmaWaitOp AffineDmaWaitOp::create(OpBuilder &builder, Location location,
+ Value tagMemRef, AffineMap tagMap,
+ ValueRange tagIndices,
+ Value numElements) {
+ mlir::OperationState state(location, getOperationName());
+ build(builder, state, tagMemRef, tagMap, tagIndices, numElements);
+ auto result = llvm::dyn_cast<AffineDmaWaitOp>(builder.create(state));
+ assert(result && "builder didn't return the right type");
+ return result;
+}
+
+AffineDmaWaitOp AffineDmaWaitOp::create(ImplicitLocOpBuilder &builder,
+ Value tagMemRef, AffineMap tagMap,
+ ValueRange tagIndices,
+ Value numElements) {
+ return create(builder, builder.getLoc(), tagMemRef, tagMap, tagIndices,
+ numElements);
+}
+
void AffineDmaWaitOp::print(OpAsmPrinter &p) {
p << " " << getTagMemRef() << '[';
SmallVector<Value, 2> operands(getTagIndices());
@@ -2688,8 +2733,8 @@ FailureOr<LoopLikeOpInterface> AffineForOp::replaceWithAdditionalYields(
rewriter.setInsertionPoint(getOperation());
auto inits = llvm::to_vector(getInits());
inits.append(newInitOperands.begin(), newInitOperands.end());
- AffineForOp newLoop = rewriter.create<AffineForOp>(
- getLoc(), getLowerBoundOperands(), getLowerBoundMap(),
+ AffineForOp newLoop = AffineForOp::create(
+ rewriter, getLoc(), getLowerBoundOperands(), getLowerBoundMap(),
getUpperBoundOperands(), getUpperBoundMap(), getStepAsInt(), inits);
// Generate the new yield values and append them to the scf.yield operation.
@@ -2831,7 +2876,7 @@ static void buildAffineLoopNestImpl(
OpBuilder::InsertionGuard nestedGuard(nestedBuilder);
bodyBuilderFn(nestedBuilder, nestedLoc, ivs);
}
- nestedBuilder.create<AffineYieldOp>(nestedLoc);
+ AffineYieldOp::create(nestedBuilder, nestedLoc);
};
// Delegate actual loop creation to the callback in order to dispatch
@@ -2846,8 +2891,8 @@ static AffineForOp
buildAffineLoopFromConstants(OpBuilder &builder, Location loc, int64_t lb,
int64_t ub, int64_t step,
AffineForOp::BodyBuilderFn bodyBuilderFn) {
- return builder.create<AffineForOp>(loc, lb, ub, step,
- /*iterArgs=*/ValueRange(), bodyBuilderFn);
+ return AffineForOp::create(builder, loc, lb, ub, step,
+ /*iterArgs=*/ValueRange(), bodyBuilderFn);
}
/// Creates an affine loop from the bounds that may or may not be constants.
@@ -2860,9 +2905,9 @@ buildAffineLoopFromValues(OpBuilder &builder, Location loc, Value lb, Value ub,
if (lbConst && ubConst)
return buildAffineLoopFromConstants(builder, loc, lbConst.value(),
ubConst.value(), step, bodyBuilderFn);
- return builder.create<AffineForOp>(loc, lb, builder.getDimIdentityMap(), ub,
- builder.getDimIdentityMap(), step,
- /*iterArgs=*/ValueRange(), bodyBuilderFn);
+ return AffineForOp::create(builder, loc, lb, builder.getDimIdentityMap(), ub,
+ builder.getDimIdentityMap(), step,
+ /*iterArgs=*/ValueRange(), bodyBuilderFn);
}
void mlir::affine::buildAffineLoopNest(
@@ -4883,7 +4928,7 @@ struct DropUnitExtentBasis
Location loc = delinearizeOp->getLoc();
auto getZero = [&]() -> Value {
if (!zero)
- zero = rewriter.create<arith::ConstantIndexOp>(loc, 0);
+ zero = arith::ConstantIndexOp::create(rewriter, loc, 0);
return zero.value();
};
@@ -4906,8 +4951,8 @@ struct DropUnitExtentBasis
if (!newBasis.empty()) {
// Will drop the leading nullptr from `basis` if there was no outer bound.
- auto newDelinearizeOp = rewriter.create<affine::AffineDelinearizeIndexOp>(
- loc, delinearizeOp.getLinearIndex(), newBasis);
+ auto newDelinearizeOp = affine::AffineDelinearizeIndexOp::create(
+ rewriter, loc, delinearizeOp.getLinearIndex(), newBasis);
int newIndex = 0;
// Map back the new delinearized indices to the values they replace.
for (auto &replacement : replacements) {
@@ -4971,12 +5016,12 @@ struct CancelDelinearizeOfLinearizeDisjointExactTail
return success();
}
- Value newLinearize = rewriter.create<affine::AffineLinearizeIndexOp>(
- linearizeOp.getLoc(), linearizeIns.drop_back(numMatches),
+ Value newLinearize = affine::AffineLinearizeIndexOp::create(
+ rewriter, linearizeOp.getLoc(), linearizeIns.drop_back(numMatches),
ArrayRef<OpFoldResult>{linearizeBasis}.drop_back(numMatches),
linearizeOp.getDisjoint());
- auto newDelinearize = rewriter.create<affine::AffineDelinearizeIndexOp>(
- delinearizeOp.getLoc(), newLinearize,
+ auto newDelinearize = affine::AffineDelinearizeIndexOp::create(
+ rewriter, delinearizeOp.getLoc(), newLinearize,
ArrayRef<OpFoldResult>{delinearizeBasis}.drop_back(numMatches),
delinearizeOp.hasOuterBound());
SmallVector<Value> mergedResults(newDelinearize.getResults());
@@ -5048,19 +5093,16 @@ struct SplitDelinearizeSpanningLastLinearizeArg final
delinearizeOp,
"need at least two elements to form the basis product");
- Value linearizeWithoutBack =
- rewriter.create<affine::AffineLinearizeIndexOp>(
- linearizeOp.getLoc(), linearizeOp.getMultiIndex().drop_back(),
- linearizeOp.getDynamicBasis(),
- linearizeOp.getStaticBasis().drop_back(),
- linearizeOp.getDisjoint());
- auto delinearizeWithoutSplitPart =
- rewriter.create<affine::AffineDelinearizeIndexOp>(
- delinearizeOp.getLoc(), linearizeWithoutBack,
- delinearizeOp.getDynamicBasis(), basis.drop_back(elemsToSplit),
- delinearizeOp.hasOuterBound());
- auto delinearizeBack = rewriter.create<affine::AffineDelinearizeIndexOp>(
- delinearizeOp.getLoc(), linearizeOp.getMultiIndex().back(),
+ Value linearizeWithoutBack = affine::AffineLinearizeIndexOp::create(
+ rewriter, linearizeOp.getLoc(), linearizeOp.getMultiIndex().drop_back(),
+ linearizeOp.getDynamicBasis(), linearizeOp.getStaticBasis().drop_back(),
+ linearizeOp.getDisjoint());
+ auto delinearizeWithoutSplitPart = affine::AffineDelinearizeIndexOp::create(
+ rewriter, delinearizeOp.getLoc(), linearizeWithoutBack,
+ delinearizeOp.getDynamicBasis(), basis.drop_back(elemsToSplit),
+ delinearizeOp.hasOuterBound());
+ auto delinearizeBack = affine::AffineDelinearizeIndexOp::create(
+ rewriter, delinearizeOp.getLoc(), linearizeOp.getMultiIndex().back(),
basis.take_back(elemsToSplit), /*hasOuterBound=*/true);
SmallVector<Value> results = llvm::to_vector(
llvm::concat<Value>(delinearizeWithoutSplitPart.getResults(),
@@ -5272,7 +5314,7 @@ OpFoldResult computeProduct(Location loc, OpBuilder &builder,
}
if (auto constant = dyn_cast<AffineConstantExpr>(result))
return getAsIndexOpFoldResult(builder.getContext(), constant.getValue());
- return builder.create<AffineApplyOp>(loc, result, dynamicPart).getResult();
+ return AffineApplyOp::create(builder, loc, result, dynamicPart).getResult();
}
/// If conseceutive outputs of a delinearize_index are linearized with the same
@@ -5437,16 +5479,16 @@ struct CancelLinearizeOfDelinearizePortion final
newDelinBasis.erase(newDelinBasis.begin() + m.delinStart,
newDelinBasis.begin() + m.delinStart + m.length);
newDelinBasis.insert(newDelinBasis.begin() + m.delinStart, newSize);
- auto newDelinearize = rewriter.create<AffineDelinearizeIndexOp>(
- m.delinearize.getLoc(), m.delinearize.getLinearIndex(),
+ auto newDelinearize = AffineDelinearizeIndexOp::create(
+ rewriter, m.delinearize.getLoc(), m.delinearize.getLinearIndex(),
newDelinBasis);
// Since there may be other uses of the indices we just merged together,
// create a residual affine.delinearize_index that delinearizes the
// merged output into its component parts.
Value combinedElem = newDelinearize.getResult(m.delinStart);
- auto residualDelinearize = rewriter.create<AffineDelinearizeIndexOp>(
- m.delinearize.getLoc(), combinedElem, basisToMerge);
+ auto residualDelinearize = AffineDelinearizeIndexOp::create(
+ rewriter, m.delinearize.getLoc(), combinedElem, basisToMerge);
// Swap all the uses of the unaffected delinearize outputs to the new
// delinearization so that the old code can be removed if this
diff --git a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
index 4e40d4ebda004..ac00917d33a85 100644
--- a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
+++ b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
@@ -242,7 +242,7 @@ bool arith::ConstantOp::isBuildableWith(Attribute value, Type type) {
ConstantOp arith::ConstantOp::materialize(OpBuilder &builder, Attribute value,
Type type, Location loc) {
if (isBuildableWith(value, type))
- return builder.create<arith::ConstantOp>(loc, cast<TypedAttr>(value));
+ return arith::ConstantOp::create(builder, loc, cast<TypedAttr>(value));
return nullptr;
}
@@ -255,18 +255,66 @@ void arith::ConstantIntOp::build(OpBuilder &builder, OperationState &result,
builder.getIntegerAttr(type, value));
}
+arith::ConstantIntOp arith::ConstantIntOp::create(OpBuilder &builder,
+ Location location,
+ int64_t value,
+ unsigned width) {
+ mlir::OperationState state(location, getOperationName());
+ build(builder, state, value, width);
+ auto result = llvm::dyn_cast<ConstantIntOp>(builder.create(state));
+ assert(result && "builder didn't return the right type");
+ return result;
+}
+
+arith::ConstantIntOp arith::ConstantIntOp::create(ImplicitLocOpBuilder &builder,
+ int64_t value,
+ unsigned width) {
+ return create(builder, builder.getLoc(), value, width);
+}
+
void arith::ConstantIntOp::build(OpBuilder &builder, OperationState &result,
Type type, int64_t value) {
arith::ConstantOp::build(builder, result, type,
builder.getIntegerAttr(type, value));
}
+arith::ConstantIntOp arith::ConstantIntOp::create(OpBuilder &builder,
+ Location location, Type type,
+ int64_t value) {
+ mlir::OperationState state(location, getOperationName());
+ build(builder, state, type, value);
+ auto result = llvm::dyn_cast<ConstantIntOp>(builder.create(state));
+ assert(result && "builder didn't return the right type");
+ return result;
+}
+
+arith::ConstantIntOp arith::ConstantIntOp::create(ImplicitLocOpBuilder &builder,
+ Type type, int64_t value) {
+ return create(builder, builder.getLoc(), type, value);
+}
+
void arith::ConstantIntOp::build(OpBuilder &builder, OperationState &result,
Type type, const APInt &value) {
arith::ConstantOp::build(builder, result, type,
builder.getIntegerAttr(type, value));
}
+arith::ConstantIntOp arith::ConstantIntOp::create(OpBuilder &builder,
+ Location location, Type type,
+ const APInt &value) {
+ mlir::OperationState state(location, getOperationName());
+ build(builder, state, type, value);
+ auto result = llvm::dyn_cast<ConstantIntOp>(builder.create(state));
+ assert(result && "builder didn't return the right type");
+ return result;
+}
+
+arith::ConstantIntOp arith::ConstantIntOp::create(ImplicitLocOpBuilder &builder,
+ Type type,
+ const APInt &value) {
+ return create(builder, builder.getLoc(), type, value);
+}
+
bool arith::ConstantIntOp::classof(Operation *op) {
if (auto constOp = dyn_cast_or_null<arith::ConstantOp>(op))
return constOp.getType().isSignlessInteger();
@@ -279,6 +327,23 @@ void arith::ConstantFloatOp::build(OpBuilder &builder, OperationState &result,
builder.getFloatAttr(type, value));
}
+arith::ConstantFloatOp arith::ConstantFloatOp::create(OpBuilder &builder,
+ Location location,
+ FloatType type,
+ const APFloat &value) {
+ mlir::OperationState state(location, getOperationName());
+ build(builder, state, type, value);
+ auto result = llvm::dyn_cast<ConstantFloatOp>(builder.create(state));
+ assert(result && "builder didn't return the right type");
+ return result;
+}
+
+arith::ConstantFloatOp
+arith::ConstantFloatOp::create(ImplicitLocOpBuilder &builder, FloatType type,
+ const APFloat &value) {
+ return create(builder, builder.getLoc(), type, value);
+}
+
bool arith::ConstantFloatOp::classof(Operation *op) {
if (auto constOp = dyn_cast_or_null<arith::ConstantOp>(op))
return llvm::isa<FloatType>(constOp.getType());
@@ -291,6 +356,21 @@ void arith::ConstantIndexOp::build(OpBuilder &builder, OperationState &result,
builder.getIndexAttr(value));
}
+arith::ConstantIndexOp arith::ConstantIndexOp::create(OpBuilder &builder,
+ Location location,
+ int64_t value) {
+ mlir::OperationState state(location, getOperationName());
+ build(builder, state, value);
+ auto result = llvm::dyn_cast<ConstantIndexOp>(builder.create(state));
+ assert(result && "builder didn't return the right type");
+ return result;
+}
+
+arith::ConstantIndexOp
+arith::ConstantIndexOp::create(ImplicitLocOpBuilder &builder, int64_t value) {
+ return create(builder, builder.getLoc(), value);
+}
+
bool arith::ConstantIndexOp::classof(Operation *op) {
if (auto constOp = dyn_cast_or_null<arith::ConstantOp>(op))
return constOp.getType().isIndex();
@@ -304,7 +384,7 @@ Value mlir::arith::getZeroConstant(OpBuilder &builder, Location loc,
"type doesn't have a zero representation");
TypedAttr zeroAttr = builder.getZeroAttr(type);
assert(zeroAttr && "unsupported type for zero attribute");
- return builder.create<arith::ConstantOp>(loc, zeroAttr);
+ return arith::ConstantOp::create(builder, loc, zeroAttr);
}
//===----------------------------------------------------------------------===//
@@ -2334,9 +2414,8 @@ class CmpFIntToFPConst final : public OpRewritePattern<CmpFOp> {
// comparison.
rewriter.replaceOpWithNewOp<CmpIOp>(
op, pred, intVal,
- rewriter.create<ConstantOp>(
- op.getLoc(), intVal.getType(),
- rewriter.getIntegerAttr(intVal.getType(), rhsInt)));
+ ConstantOp::create(rewriter, op.getLoc(), intVal.getType(),
+ rewriter.getIntegerAttr(intVal.getType(), rhsInt)));
return success();
}
};
@@ -2373,10 +2452,10 @@ struct SelectToExtUI : public OpRewritePattern<arith::SelectOp> {
matchPattern(op.getFalseValue(), m_One())) {
rewriter.replaceOpWithNewOp<arith::ExtUIOp>(
op, op.getType(),
- rewriter.create<arith::XOrIOp>(
- op.getLoc(), op.getCondition(),
- rewriter.create<arith::ConstantIntOp>(
- op.getLoc(), op.getCondition().getType(), 1)));
+ arith::XOrIOp::create(
+ rewriter, op.getLoc(), op.getCondition(),
+ arith::ConstantIntOp::create(rewriter, op.getLoc(),
+ op.getCondition().getType(), 1)));
return success();
}
@@ -2439,12 +2518,12 @@ OpFoldResult arith::SelectOp::fold(FoldAdaptor adaptor) {
// Constant-fold constant operands over non-splat constant condition.
// select %cst_vec, %cst0, %cst1 => %cst2
- if (auto cond =
- llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getCondition())) {
- if (auto lhs =
- llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getTrueValue())) {
- if (auto rhs =
- llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getFalseValue())) {
+ if (auto cond = llvm::dyn_cast_if_present<DenseElementsAttr>(
+ adaptor.getCondition())) {
+ if (auto lhs = llvm::dyn_cast_if_present<DenseElementsAttr>(
+ adaptor.getTrueValue())) {
+ if (auto rhs = llvm::dyn_cast_if_present<DenseElementsAttr>(
+ adaptor.getFalseValue())) {
SmallVector<Attribute> results;
results.reserve(static_cast<size_t>(cond.getNumElements()));
auto condVals = llvm::make_range(cond.value_begin<BoolAttr>(),
@@ -2692,7 +2771,7 @@ Value mlir::arith::getIdentityValue(AtomicRMWKind op, Type resultType,
bool useOnlyFiniteValue) {
auto attr =
getIdentityValueAttr(op, resultType, builder, loc, useOnlyFiniteValue);
- return builder.create<arith::ConstantOp>(loc, attr);
+ return arith::ConstantOp::create(builder, loc, attr);
}
/// Return the value obtained by applying the reduction operation kind
@@ -2701,33 +2780,33 @@ Value mlir::arith::getReductionOp(AtomicRMWKind op, OpBuilder &builder,
Location loc, Value lhs, Value rhs) {
switch (op) {
case AtomicRMWKind::addf:
- return builder.create<arith::AddFOp>(loc, lhs, rhs);
+ return arith::AddFOp::create(builder, loc, lhs, rhs);
case AtomicRMWKind::addi:
- return builder.create<arith::AddIOp>(loc, lhs, rhs);
+ return arith::AddIOp::create(builder, loc, lhs, rhs);
case AtomicRMWKind::mulf:
- return builder.create<arith::MulFOp>(loc, lhs, rhs);
+ return arith::MulFOp::create(builder, loc, lhs, rhs);
case AtomicRMWKind::muli:
- return builder.create<arith::MulIOp>(loc, lhs, rhs);
+ return arith::MulIOp::create(builder, loc, lhs, rhs);
case AtomicRMWKind::maximumf:
- return builder.create<arith::MaximumFOp>(loc, lhs, rhs);
+ return arith::MaximumFOp::create(builder, loc, lhs, rhs);
case AtomicRMWKind::minimumf:
- return builder.create<arith::MinimumFOp>(loc, lhs, rhs);
- case AtomicRMWKind::maxnumf:
- return builder.create<arith::MaxNumFOp>(loc, lhs, rhs);
+ return arith::MinimumFOp::create(builder, loc, lhs, rhs);
+ case AtomicRMWKind::maxnumf:
+ return arith::MaxNumFOp::create(builder, loc, lhs, rhs);
case AtomicRMWKind::minnumf:
- return builder.create<arith::MinNumFOp>(loc, lhs, rhs);
+ return arith::MinNumFOp::create(builder, loc, lhs, rhs);
case AtomicRMWKind::maxs:
- return builder.create<arith::MaxSIOp>(loc, lhs, rhs);
+ return arith::MaxSIOp::create(builder, loc, lhs, rhs);
case AtomicRMWKind::mins:
- return builder.create<arith::MinSIOp>(loc, lhs, rhs);
+ return arith::MinSIOp::create(builder, loc, lhs, rhs);
case AtomicRMWKind::maxu:
- return builder.create<arith::MaxUIOp>(loc, lhs, rhs);
+ return arith::MaxUIOp::create(builder, loc, lhs, rhs);
case AtomicRMWKind::minu:
- return builder.create<arith::MinUIOp>(loc, lhs, rhs);
+ return arith::MinUIOp::create(builder, loc, lhs, rhs);
case AtomicRMWKind::ori:
- return builder.create<arith::OrIOp>(loc, lhs, rhs);
+ return arith::OrIOp::create(builder, loc, lhs, rhs);
case AtomicRMWKind::andi:
- return builder.create<arith::AndIOp>(loc, lhs, rhs);
+ return arith::AndIOp::create(builder, loc, lhs, rhs);
// TODO: Add remaining reduction operations.
default:
(void)emitOptionalError(loc, "Reduction operation type not supported");
>From 45fb25d9d32dd84a495489e68dad6590508a41b6 Mon Sep 17 00:00:00 2001
From: max <maksim.levental at gmail.com>
Date: Sat, 19 Jul 2025 12:52:16 -0400
Subject: [PATCH 2/2] remove llvm scope
---
mlir/lib/Dialect/Affine/IR/AffineOps.cpp | 30 ++++++-------
mlir/lib/Dialect/Arith/IR/ArithOps.cpp | 55 ++++++++++++------------
2 files changed, 42 insertions(+), 43 deletions(-)
diff --git a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
index bd1be7c087d05..ee5db073ffc4e 100644
--- a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
+++ b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
@@ -49,7 +49,7 @@ using llvm::mod;
/// top level of a `AffineScope` region is always a valid symbol for all
/// uses in that region.
bool mlir::affine::isTopLevelValue(Value value, Region *region) {
- if (auto arg = llvm::dyn_cast<BlockArgument>(value))
+ if (auto arg = dyn_cast<BlockArgument>(value))
return arg.getParentRegion() == region;
return value.getDefiningOp()->getParentRegion() == region;
}
@@ -249,7 +249,7 @@ Operation *AffineDialect::materializeConstant(OpBuilder &builder,
/// conservatively assume it is not top-level. A value of index type defined at
/// the top level is always a valid symbol.
bool mlir::affine::isTopLevelValue(Value value) {
- if (auto arg = llvm::dyn_cast<BlockArgument>(value)) {
+ if (auto arg = dyn_cast<BlockArgument>(value)) {
// The block owning the argument may be unlinked, e.g. when the surrounding
// region has not yet been attached to an Op, at which point the parent Op
// is null.
@@ -1757,7 +1757,7 @@ AffineDmaStartOp AffineDmaStartOp::create(
build(builder, state, srcMemRef, srcMap, srcIndices, destMemRef, dstMap,
destIndices, tagMemRef, tagMap, tagIndices, numElements, stride,
elementsPerStride);
- auto result = llvm::dyn_cast<AffineDmaStartOp>(builder.create(state));
+ auto result = dyn_cast<AffineDmaStartOp>(builder.create(state));
assert(result && "builder didn't return the right type");
return result;
}
@@ -1949,7 +1949,7 @@ AffineDmaWaitOp AffineDmaWaitOp::create(OpBuilder &builder, Location location,
Value numElements) {
mlir::OperationState state(location, getOperationName());
build(builder, state, tagMemRef, tagMap, tagIndices, numElements);
- auto result = llvm::dyn_cast<AffineDmaWaitOp>(builder.create(state));
+ auto result = dyn_cast<AffineDmaWaitOp>(builder.create(state));
assert(result && "builder didn't return the right type");
return result;
}
@@ -2198,7 +2198,7 @@ static ParseResult parseBound(bool isLower, OperationState &result,
return failure();
// Parse full form - affine map followed by dim and symbol list.
- if (auto affineMapAttr = llvm::dyn_cast<AffineMapAttr>(boundAttr)) {
+ if (auto affineMapAttr = dyn_cast<AffineMapAttr>(boundAttr)) {
unsigned currentNumOperands = result.operands.size();
unsigned numDims;
if (parseDimAndSymbolList(p, result.operands, numDims))
@@ -2231,7 +2231,7 @@ static ParseResult parseBound(bool isLower, OperationState &result,
}
// Parse custom assembly form.
- if (auto integerAttr = llvm::dyn_cast<IntegerAttr>(boundAttr)) {
+ if (auto integerAttr = dyn_cast<IntegerAttr>(boundAttr)) {
result.attributes.pop_back();
result.addAttribute(
boundAttrStrName,
@@ -2801,7 +2801,7 @@ bool mlir::affine::isAffineInductionVar(Value val) {
}
AffineForOp mlir::affine::getForInductionVarOwner(Value val) {
- auto ivArg = llvm::dyn_cast<BlockArgument>(val);
+ auto ivArg = dyn_cast<BlockArgument>(val);
if (!ivArg || !ivArg.getOwner() || !ivArg.getOwner()->getParent())
return AffineForOp();
if (auto forOp =
@@ -2812,7 +2812,7 @@ AffineForOp mlir::affine::getForInductionVarOwner(Value val) {
}
AffineParallelOp mlir::affine::getAffineParallelInductionVarOwner(Value val) {
- auto ivArg = llvm::dyn_cast<BlockArgument>(val);
+ auto ivArg = dyn_cast<BlockArgument>(val);
if (!ivArg || !ivArg.getOwner())
return nullptr;
Operation *containingOp = ivArg.getOwner()->getParentOp();
@@ -3339,11 +3339,11 @@ OpFoldResult AffineLoadOp::fold(FoldAdaptor adaptor) {
// Check if the global memref is a constant.
auto cstAttr =
- llvm::dyn_cast_or_null<DenseElementsAttr>(global.getConstantInitValue());
+ dyn_cast_or_null<DenseElementsAttr>(global.getConstantInitValue());
if (!cstAttr)
return {};
// If it's a splat constant, we can fold irrespective of indices.
- if (auto splatAttr = llvm::dyn_cast<SplatElementsAttr>(cstAttr))
+ if (auto splatAttr = dyn_cast<SplatElementsAttr>(cstAttr))
return splatAttr.getSplatValue<Attribute>();
// Otherwise, we can fold only if we know the indices.
if (!getAffineMap().isConstant())
@@ -4110,19 +4110,19 @@ static bool isResultTypeMatchAtomicRMWKind(Type resultType,
case arith::AtomicRMWKind::minimumf:
return isa<FloatType>(resultType);
case arith::AtomicRMWKind::maxs: {
- auto intType = llvm::dyn_cast<IntegerType>(resultType);
+ auto intType = dyn_cast<IntegerType>(resultType);
return intType && intType.isSigned();
}
case arith::AtomicRMWKind::mins: {
- auto intType = llvm::dyn_cast<IntegerType>(resultType);
+ auto intType = dyn_cast<IntegerType>(resultType);
return intType && intType.isSigned();
}
case arith::AtomicRMWKind::maxu: {
- auto intType = llvm::dyn_cast<IntegerType>(resultType);
+ auto intType = dyn_cast<IntegerType>(resultType);
return intType && intType.isUnsigned();
}
case arith::AtomicRMWKind::minu: {
- auto intType = llvm::dyn_cast<IntegerType>(resultType);
+ auto intType = dyn_cast<IntegerType>(resultType);
return intType && intType.isUnsigned();
}
case arith::AtomicRMWKind::ori:
@@ -4179,7 +4179,7 @@ LogicalResult AffineParallelOp::verify() {
// ops
for (auto it : llvm::enumerate((getReductions()))) {
Attribute attr = it.value();
- auto intAttr = llvm::dyn_cast<IntegerAttr>(attr);
+ auto intAttr = dyn_cast<IntegerAttr>(attr);
if (!intAttr || !arith::symbolizeAtomicRMWKind(intAttr.getInt()))
return emitOpError("invalid reduction attribute");
auto kind = arith::symbolizeAtomicRMWKind(intAttr.getInt()).value();
diff --git a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
index ac00917d33a85..910334b17748b 100644
--- a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
+++ b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
@@ -148,7 +148,7 @@ static FailureOr<APInt> getIntOrSplatIntValue(Attribute attr) {
static Attribute getBoolAttribute(Type type, bool value) {
auto boolAttr = BoolAttr::get(type.getContext(), value);
- ShapedType shapedType = llvm::dyn_cast_or_null<ShapedType>(type);
+ ShapedType shapedType = dyn_cast_or_null<ShapedType>(type);
if (!shapedType)
return boolAttr;
return DenseElementsAttr::get(shapedType, boolAttr);
@@ -169,7 +169,7 @@ namespace {
/// Return the type of the same shape (scalar, vector or tensor) containing i1.
static Type getI1SameShape(Type type) {
auto i1Type = IntegerType::get(type.getContext(), 1);
- if (auto shapedType = llvm::dyn_cast<ShapedType>(type))
+ if (auto shapedType = dyn_cast<ShapedType>(type))
return shapedType.cloneWith(std::nullopt, i1Type);
if (llvm::isa<UnrankedTensorType>(type))
return UnrankedTensorType::get(i1Type);
@@ -183,8 +183,8 @@ static Type getI1SameShape(Type type) {
void arith::ConstantOp::getAsmResultNames(
function_ref<void(Value, StringRef)> setNameFn) {
auto type = getType();
- if (auto intCst = llvm::dyn_cast<IntegerAttr>(getValue())) {
- auto intType = llvm::dyn_cast<IntegerType>(type);
+ if (auto intCst = dyn_cast<IntegerAttr>(getValue())) {
+ auto intType = dyn_cast<IntegerType>(type);
// Sugar i1 constants with 'true' and 'false'.
if (intType && intType.getWidth() == 1)
@@ -228,7 +228,7 @@ LogicalResult arith::ConstantOp::verify() {
bool arith::ConstantOp::isBuildableWith(Attribute value, Type type) {
// The value's type must be the same as the provided type.
- auto typedAttr = llvm::dyn_cast<TypedAttr>(value);
+ auto typedAttr = dyn_cast<TypedAttr>(value);
if (!typedAttr || typedAttr.getType() != type)
return false;
// Integer values must be signless.
@@ -261,7 +261,7 @@ arith::ConstantIntOp arith::ConstantIntOp::create(OpBuilder &builder,
unsigned width) {
mlir::OperationState state(location, getOperationName());
build(builder, state, value, width);
- auto result = llvm::dyn_cast<ConstantIntOp>(builder.create(state));
+ auto result = dyn_cast<ConstantIntOp>(builder.create(state));
assert(result && "builder didn't return the right type");
return result;
}
@@ -283,7 +283,7 @@ arith::ConstantIntOp arith::ConstantIntOp::create(OpBuilder &builder,
int64_t value) {
mlir::OperationState state(location, getOperationName());
build(builder, state, type, value);
- auto result = llvm::dyn_cast<ConstantIntOp>(builder.create(state));
+ auto result = dyn_cast<ConstantIntOp>(builder.create(state));
assert(result && "builder didn't return the right type");
return result;
}
@@ -304,7 +304,7 @@ arith::ConstantIntOp arith::ConstantIntOp::create(OpBuilder &builder,
const APInt &value) {
mlir::OperationState state(location, getOperationName());
build(builder, state, type, value);
- auto result = llvm::dyn_cast<ConstantIntOp>(builder.create(state));
+ auto result = dyn_cast<ConstantIntOp>(builder.create(state));
assert(result && "builder didn't return the right type");
return result;
}
@@ -333,7 +333,7 @@ arith::ConstantFloatOp arith::ConstantFloatOp::create(OpBuilder &builder,
const APFloat &value) {
mlir::OperationState state(location, getOperationName());
build(builder, state, type, value);
- auto result = llvm::dyn_cast<ConstantFloatOp>(builder.create(state));
+ auto result = dyn_cast<ConstantFloatOp>(builder.create(state));
assert(result && "builder didn't return the right type");
return result;
}
@@ -361,7 +361,7 @@ arith::ConstantIndexOp arith::ConstantIndexOp::create(OpBuilder &builder,
int64_t value) {
mlir::OperationState state(location, getOperationName());
build(builder, state, value);
- auto result = llvm::dyn_cast<ConstantIndexOp>(builder.create(state));
+ auto result = dyn_cast<ConstantIndexOp>(builder.create(state));
assert(result && "builder didn't return the right type");
return result;
}
@@ -423,7 +423,7 @@ void arith::AddIOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
std::optional<SmallVector<int64_t, 4>>
arith::AddUIExtendedOp::getShapeForUnroll() {
- if (auto vt = llvm::dyn_cast<VectorType>(getType(0)))
+ if (auto vt = dyn_cast<VectorType>(getType(0)))
return llvm::to_vector<4>(vt.getShape());
return std::nullopt;
}
@@ -569,7 +569,7 @@ void arith::MulIOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
std::optional<SmallVector<int64_t, 4>>
arith::MulSIExtendedOp::getShapeForUnroll() {
- if (auto vt = llvm::dyn_cast<VectorType>(getType(0)))
+ if (auto vt = dyn_cast<VectorType>(getType(0)))
return llvm::to_vector<4>(vt.getShape());
return std::nullopt;
}
@@ -615,7 +615,7 @@ void arith::MulSIExtendedOp::getCanonicalizationPatterns(
std::optional<SmallVector<int64_t, 4>>
arith::MulUIExtendedOp::getShapeForUnroll() {
- if (auto vt = llvm::dyn_cast<VectorType>(getType(0)))
+ if (auto vt = dyn_cast<VectorType>(getType(0)))
return llvm::to_vector<4>(vt.getShape());
return std::nullopt;
}
@@ -1895,7 +1895,7 @@ OpFoldResult arith::BitcastOp::fold(FoldAdaptor adaptor) {
return {};
/// Bitcast dense elements.
- if (auto denseAttr = llvm::dyn_cast_or_null<DenseElementsAttr>(operand))
+ if (auto denseAttr = dyn_cast_or_null<DenseElementsAttr>(operand))
return denseAttr.bitcast(llvm::cast<ShapedType>(resType).getElementType());
/// Other shaped types unhandled.
if (llvm::isa<ShapedType>(resType))
@@ -1912,7 +1912,7 @@ OpFoldResult arith::BitcastOp::fold(FoldAdaptor adaptor) {
assert(resType.getIntOrFloatBitWidth() == bits.getBitWidth() &&
"trying to fold on broken IR: operands have incompatible types");
- if (auto resFloatType = llvm::dyn_cast<FloatType>(resType))
+ if (auto resFloatType = dyn_cast<FloatType>(resType))
return FloatAttr::get(resType,
APFloat(resFloatType.getFloatSemantics(), bits));
return IntegerAttr::get(resType, bits);
@@ -1976,10 +1976,10 @@ static bool applyCmpPredicateToEqualOperands(arith::CmpIPredicate predicate) {
}
static std::optional<int64_t> getIntegerWidth(Type t) {
- if (auto intType = llvm::dyn_cast<IntegerType>(t)) {
+ if (auto intType = dyn_cast<IntegerType>(t)) {
return intType.getWidth();
}
- if (auto vectorIntType = llvm::dyn_cast<VectorType>(t)) {
+ if (auto vectorIntType = dyn_cast<VectorType>(t)) {
return llvm::cast<IntegerType>(vectorIntType.getElementType()).getWidth();
}
return std::nullopt;
@@ -2049,7 +2049,7 @@ OpFoldResult arith::CmpIOp::fold(FoldAdaptor adaptor) {
// We are moving constants to the right side; So if lhs is constant rhs is
// guaranteed to be a constant.
- if (auto lhs = llvm::dyn_cast_if_present<TypedAttr>(adaptor.getLhs())) {
+ if (auto lhs = dyn_cast_if_present<TypedAttr>(adaptor.getLhs())) {
return constFoldBinaryOp<IntegerAttr>(
adaptor.getOperands(), getI1SameShape(lhs.getType()),
[pred = getPredicate()](const APInt &lhs, const APInt &rhs) {
@@ -2119,8 +2119,8 @@ bool mlir::arith::applyCmpPredicate(arith::CmpFPredicate predicate,
}
OpFoldResult arith::CmpFOp::fold(FoldAdaptor adaptor) {
- auto lhs = llvm::dyn_cast_if_present<FloatAttr>(adaptor.getLhs());
- auto rhs = llvm::dyn_cast_if_present<FloatAttr>(adaptor.getRhs());
+ auto lhs = dyn_cast_if_present<FloatAttr>(adaptor.getLhs());
+ auto rhs = dyn_cast_if_present<FloatAttr>(adaptor.getRhs());
// If one operand is NaN, making them both NaN does not change the result.
if (lhs && lhs.getValue().isNaN())
@@ -2518,12 +2518,12 @@ OpFoldResult arith::SelectOp::fold(FoldAdaptor adaptor) {
// Constant-fold constant operands over non-splat constant condition.
// select %cst_vec, %cst0, %cst1 => %cst2
- if (auto cond = llvm::dyn_cast_if_present<DenseElementsAttr>(
- adaptor.getCondition())) {
- if (auto lhs = llvm::dyn_cast_if_present<DenseElementsAttr>(
- adaptor.getTrueValue())) {
- if (auto rhs = llvm::dyn_cast_if_present<DenseElementsAttr>(
- adaptor.getFalseValue())) {
+ if (auto cond =
+ dyn_cast_if_present<DenseElementsAttr>(adaptor.getCondition())) {
+ if (auto lhs =
+ dyn_cast_if_present<DenseElementsAttr>(adaptor.getTrueValue())) {
+ if (auto rhs =
+ dyn_cast_if_present<DenseElementsAttr>(adaptor.getFalseValue())) {
SmallVector<Attribute> results;
results.reserve(static_cast<size_t>(cond.getNumElements()));
auto condVals = llvm::make_range(cond.value_begin<BoolAttr>(),
@@ -2572,8 +2572,7 @@ void arith::SelectOp::print(OpAsmPrinter &p) {
p << " " << getOperands();
p.printOptionalAttrDict((*this)->getAttrs());
p << " : ";
- if (ShapedType condType =
- llvm::dyn_cast<ShapedType>(getCondition().getType()))
+ if (ShapedType condType = dyn_cast<ShapedType>(getCondition().getType()))
p << condType << ", ";
p << getType();
}
More information about the Mlir-commits
mailing list