[Mlir-commits] [mlir] b0312be - [mlir][NFC] update `mlir/Dialect` create APIs (19/n) (#149926)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Tue Jul 22 07:13:47 PDT 2025
Author: Maksim Levental
Date: 2025-07-22T10:13:44-04:00
New Revision: b0312be6aa664e4cb9abec6d080e971493093d05
URL: https://github.com/llvm/llvm-project/commit/b0312be6aa664e4cb9abec6d080e971493093d05
DIFF: https://github.com/llvm/llvm-project/commit/b0312be6aa664e4cb9abec6d080e971493093d05.diff
LOG: [mlir][NFC] update `mlir/Dialect` create APIs (19/n) (#149926)
See https://github.com/llvm/llvm-project/pull/147168 for more info.
Added:
Modified:
mlir/lib/Dialect/Index/IR/IndexOps.cpp
mlir/lib/Dialect/MLProgram/Transforms/BufferizableOpInterfaceImpl.cpp
mlir/lib/Dialect/MPI/IR/MPIOps.cpp
mlir/lib/Dialect/Math/IR/MathOps.cpp
mlir/lib/Dialect/Math/Transforms/AlgebraicSimplification.cpp
mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp
mlir/lib/Dialect/Math/Transforms/ExtendToSupportedTypes.cpp
mlir/lib/Dialect/Math/Transforms/PolynomialApproximation.cpp
mlir/lib/Dialect/Mesh/IR/MeshOps.cpp
mlir/lib/Dialect/Mesh/Transforms/Simplifications.cpp
mlir/lib/Dialect/Mesh/Transforms/Spmdization.cpp
mlir/lib/Dialect/Mesh/Transforms/Transforms.cpp
mlir/lib/Dialect/NVGPU/TransformOps/NVGPUTransformOps.cpp
mlir/lib/Dialect/NVGPU/Transforms/CreateAsyncGroups.cpp
mlir/lib/Dialect/NVGPU/Transforms/OptimizeSharedMemory.cpp
mlir/lib/Dialect/Quant/Transforms/LowerQuantOps.cpp
mlir/lib/Dialect/Quant/Transforms/StripFuncQuantTypes.cpp
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/Index/IR/IndexOps.cpp b/mlir/lib/Dialect/Index/IR/IndexOps.cpp
index bab9e2852a4607..a3e1542e6a9476 100644
--- a/mlir/lib/Dialect/Index/IR/IndexOps.cpp
+++ b/mlir/lib/Dialect/Index/IR/IndexOps.cpp
@@ -36,7 +36,7 @@ Operation *IndexDialect::materializeConstant(OpBuilder &b, Attribute value,
if (auto boolValue = dyn_cast<BoolAttr>(value)) {
if (!type.isSignlessInteger(1))
return nullptr;
- return b.create<BoolConstantOp>(loc, type, boolValue);
+ return BoolConstantOp::create(b, loc, type, boolValue);
}
// Materialize integer attributes as `index`.
@@ -46,7 +46,7 @@ Operation *IndexDialect::materializeConstant(OpBuilder &b, Attribute value,
return nullptr;
assert(indexValue.getValue().getBitWidth() ==
IndexType::kInternalStorageBitWidth);
- return b.create<ConstantOp>(loc, indexValue);
+ return ConstantOp::create(b, loc, indexValue);
}
return nullptr;
@@ -715,11 +715,11 @@ LogicalResult CmpOp::canonicalize(CmpOp op, PatternRewriter &rewriter) {
index::CmpOp newCmp;
if (rhsIsZero)
- newCmp = rewriter.create<index::CmpOp>(op.getLoc(), op.getPred(),
- subOp.getLhs(), subOp.getRhs());
+ newCmp = index::CmpOp::create(rewriter, op.getLoc(), op.getPred(),
+ subOp.getLhs(), subOp.getRhs());
else
- newCmp = rewriter.create<index::CmpOp>(op.getLoc(), op.getPred(),
- subOp.getRhs(), subOp.getLhs());
+ newCmp = index::CmpOp::create(rewriter, op.getLoc(), op.getPred(),
+ subOp.getRhs(), subOp.getLhs());
rewriter.replaceOp(op, newCmp);
return success();
}
diff --git a/mlir/lib/Dialect/MLProgram/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/MLProgram/Transforms/BufferizableOpInterfaceImpl.cpp
index ff6af63eee531b..364e4d385fd628 100644
--- a/mlir/lib/Dialect/MLProgram/Transforms/BufferizableOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/MLProgram/Transforms/BufferizableOpInterfaceImpl.cpp
@@ -135,8 +135,9 @@ struct GlobalStoreOpInterface
auto memrefType = getMemRefTypeWithStaticIdentityLayout(tensorType);
auto loc = globalStoreOp.getLoc();
- auto targetMemref = rewriter.create<memref::GetGlobalOp>(
- loc, memrefType, globalStoreOp.getGlobalAttr().getLeafReference());
+ auto targetMemref = memref::GetGlobalOp::create(
+ rewriter, loc, memrefType,
+ globalStoreOp.getGlobalAttr().getLeafReference());
auto sourceMemref =
getBuffer(rewriter, globalStoreOp.getValue(), options, state);
diff --git a/mlir/lib/Dialect/MPI/IR/MPIOps.cpp b/mlir/lib/Dialect/MPI/IR/MPIOps.cpp
index 7940ff60a48e7d..f52c3f99189d2a 100644
--- a/mlir/lib/Dialect/MPI/IR/MPIOps.cpp
+++ b/mlir/lib/Dialect/MPI/IR/MPIOps.cpp
@@ -60,8 +60,8 @@ struct FoldRank final : public mlir::OpRewritePattern<mlir::mpi::CommRankOp> {
if (!isa<IntegerAttr>(dltiAttr.value()))
return op->emitError()
<< "Expected an integer attribute for MPI:comm_world_rank";
- Value res = b.create<arith::ConstantIndexOp>(
- op.getLoc(), cast<IntegerAttr>(dltiAttr.value()).getInt());
+ Value res = arith::ConstantIndexOp::create(
+ b, op.getLoc(), cast<IntegerAttr>(dltiAttr.value()).getInt());
if (Value retVal = op.getRetval())
b.replaceOp(op, {retVal, res});
else
diff --git a/mlir/lib/Dialect/Math/IR/MathOps.cpp b/mlir/lib/Dialect/Math/IR/MathOps.cpp
index 26441a9d78658e..a21631cbf8510d 100644
--- a/mlir/lib/Dialect/Math/IR/MathOps.cpp
+++ b/mlir/lib/Dialect/Math/IR/MathOps.cpp
@@ -746,7 +746,7 @@ Operation *math::MathDialect::materializeConstant(OpBuilder &builder,
Attribute value, Type type,
Location loc) {
if (auto poison = dyn_cast<ub::PoisonAttr>(value))
- return builder.create<ub::PoisonOp>(loc, type, poison);
+ return ub::PoisonOp::create(builder, loc, type, poison);
return arith::ConstantOp::materialize(builder, value, type, loc);
}
diff --git a/mlir/lib/Dialect/Math/Transforms/AlgebraicSimplification.cpp b/mlir/lib/Dialect/Math/Transforms/AlgebraicSimplification.cpp
index 13e2a4b5541b21..31785eb20a6427 100644
--- a/mlir/lib/Dialect/Math/Transforms/AlgebraicSimplification.cpp
+++ b/mlir/lib/Dialect/Math/Transforms/AlgebraicSimplification.cpp
@@ -65,7 +65,7 @@ PowFStrengthReduction::matchAndRewrite(math::PowFOp op,
// Maybe broadcasts scalar value into vector type compatible with `op`.
auto bcast = [&](Value value) -> Value {
if (auto vec = dyn_cast<VectorType>(op.getType()))
- return rewriter.create<vector::BroadcastOp>(op.getLoc(), vec, value);
+ return vector::BroadcastOp::create(rewriter, op.getLoc(), vec, value);
return value;
};
@@ -84,15 +84,16 @@ PowFStrengthReduction::matchAndRewrite(math::PowFOp op,
// Replace `pow(x, 3.0)` with `x * x * x`.
if (isExponentValue(3.0)) {
Value square =
- rewriter.create<arith::MulFOp>(op.getLoc(), ValueRange({x, x}));
+ arith::MulFOp::create(rewriter, op.getLoc(), ValueRange({x, x}));
rewriter.replaceOpWithNewOp<arith::MulFOp>(op, ValueRange({x, square}));
return success();
}
// Replace `pow(x, -1.0)` with `1.0 / x`.
if (isExponentValue(-1.0)) {
- Value one = rewriter.create<arith::ConstantOp>(
- loc, rewriter.getFloatAttr(getElementTypeOrSelf(op.getType()), 1.0));
+ Value one = arith::ConstantOp::create(
+ rewriter, loc,
+ rewriter.getFloatAttr(getElementTypeOrSelf(op.getType()), 1.0));
rewriter.replaceOpWithNewOp<arith::DivFOp>(op, ValueRange({bcast(one), x}));
return success();
}
@@ -111,8 +112,8 @@ PowFStrengthReduction::matchAndRewrite(math::PowFOp op,
// Replace `pow(x, 0.75)` with `sqrt(sqrt(x)) * sqrt(x)`.
if (isExponentValue(0.75)) {
- Value powHalf = rewriter.create<math::SqrtOp>(op.getLoc(), x);
- Value powQuarter = rewriter.create<math::SqrtOp>(op.getLoc(), powHalf);
+ Value powHalf = math::SqrtOp::create(rewriter, op.getLoc(), x);
+ Value powQuarter = math::SqrtOp::create(rewriter, op.getLoc(), powHalf);
rewriter.replaceOpWithNewOp<arith::MulFOp>(op,
ValueRange{powHalf, powQuarter});
return success();
@@ -168,18 +169,18 @@ PowIStrengthReduction<PowIOpTy, DivOpTy, MulOpTy>::matchAndRewrite(
// Maybe broadcasts scalar value into vector type compatible with `op`.
auto bcast = [&loc, &op, &rewriter](Value value) -> Value {
if (auto vec = dyn_cast<VectorType>(op.getType()))
- return rewriter.create<vector::BroadcastOp>(loc, vec, value);
+ return vector::BroadcastOp::create(rewriter, loc, vec, value);
return value;
};
Value one;
Type opType = getElementTypeOrSelf(op.getType());
if constexpr (std::is_same_v<PowIOpTy, math::FPowIOp>)
- one = rewriter.create<arith::ConstantOp>(
- loc, rewriter.getFloatAttr(opType, 1.0));
+ one = arith::ConstantOp::create(rewriter, loc,
+ rewriter.getFloatAttr(opType, 1.0));
else
- one = rewriter.create<arith::ConstantOp>(
- loc, rewriter.getIntegerAttr(opType, 1));
+ one = arith::ConstantOp::create(rewriter, loc,
+ rewriter.getIntegerAttr(opType, 1));
// Replace `[fi]powi(x, 0)` with `1`.
if (exponentValue == 0) {
@@ -208,12 +209,12 @@ PowIStrengthReduction<PowIOpTy, DivOpTy, MulOpTy>::matchAndRewrite(
// with:
// (1 / x) * (1 / x) * (1 / x) * ...
for (unsigned i = 1; i < exponentValue; ++i)
- result = rewriter.create<MulOpTy>(loc, result, base);
+ result = MulOpTy::create(rewriter, loc, result, base);
// Inverse the base for negative exponent, i.e. for
// `[fi]powi(x, negative_exponent)` set `x` to `1 / x`.
if (exponentIsNegative)
- result = rewriter.create<DivOpTy>(loc, bcast(one), result);
+ result = DivOpTy::create(rewriter, loc, bcast(one), result);
rewriter.replaceOp(op, result);
return success();
diff --git a/mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp b/mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp
index bccd486def4bf0..5edb6e28fb0185 100644
--- a/mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp
+++ b/mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp
@@ -32,11 +32,11 @@ static Value createFloatConst(Location loc, Type type, APFloat value,
APFloat::rmNearestTiesToEven, &losesInfo);
auto attr = b.getFloatAttr(eltType, value);
if (auto shapedTy = dyn_cast<ShapedType>(type)) {
- return b.create<arith::ConstantOp>(loc,
- DenseElementsAttr::get(shapedTy, attr));
+ return arith::ConstantOp::create(b, loc,
+ DenseElementsAttr::get(shapedTy, attr));
}
- return b.create<arith::ConstantOp>(loc, attr);
+ return arith::ConstantOp::create(b, loc, attr);
}
static Value createFloatConst(Location loc, Type type, double value,
@@ -49,11 +49,11 @@ static Value createIntConst(Location loc, Type type, int64_t value,
OpBuilder &b) {
auto attr = b.getIntegerAttr(getElementTypeOrSelf(type), value);
if (auto shapedTy = dyn_cast<ShapedType>(type)) {
- return b.create<arith::ConstantOp>(loc,
- DenseElementsAttr::get(shapedTy, attr));
+ return arith::ConstantOp::create(b, loc,
+ DenseElementsAttr::get(shapedTy, attr));
}
- return b.create<arith::ConstantOp>(loc, attr);
+ return arith::ConstantOp::create(b, loc, attr);
}
static Value createTruncatedFPValue(Value operand, ImplicitLocOpBuilder &b) {
@@ -61,11 +61,11 @@ static Value createTruncatedFPValue(Value operand, ImplicitLocOpBuilder &b) {
Type i64Ty = b.getI64Type();
if (auto shapedTy = dyn_cast<ShapedType>(opType))
i64Ty = shapedTy.clone(i64Ty);
- Value fixedConvert = b.create<arith::FPToSIOp>(i64Ty, operand);
- Value fpFixedConvert = b.create<arith::SIToFPOp>(opType, fixedConvert);
+ Value fixedConvert = arith::FPToSIOp::create(b, i64Ty, operand);
+ Value fpFixedConvert = arith::SIToFPOp::create(b, opType, fixedConvert);
// The truncation does not preserve the sign when the truncated
// value is -0. So here the sign is copied again.
- return b.create<math::CopySignOp>(fpFixedConvert, operand);
+ return math::CopySignOp::create(b, fpFixedConvert, operand);
}
// sinhf(float x) -> (exp(x) - exp(-x)) / 2
@@ -74,12 +74,12 @@ static LogicalResult convertSinhOp(math::SinhOp op, PatternRewriter &rewriter) {
Value operand = op.getOperand();
Type opType = operand.getType();
- Value exp = b.create<math::ExpOp>(operand);
- Value neg = b.create<arith::NegFOp>(operand);
- Value nexp = b.create<math::ExpOp>(neg);
- Value sub = b.create<arith::SubFOp>(exp, nexp);
+ Value exp = math::ExpOp::create(b, operand);
+ Value neg = arith::NegFOp::create(b, operand);
+ Value nexp = math::ExpOp::create(b, neg);
+ Value sub = arith::SubFOp::create(b, exp, nexp);
Value half = createFloatConst(op->getLoc(), opType, 0.5, rewriter);
- Value res = b.create<arith::MulFOp>(sub, half);
+ Value res = arith::MulFOp::create(b, sub, half);
rewriter.replaceOp(op, res);
return success();
}
@@ -90,12 +90,12 @@ static LogicalResult convertCoshOp(math::CoshOp op, PatternRewriter &rewriter) {
Value operand = op.getOperand();
Type opType = operand.getType();
- Value exp = b.create<math::ExpOp>(operand);
- Value neg = b.create<arith::NegFOp>(operand);
- Value nexp = b.create<math::ExpOp>(neg);
- Value add = b.create<arith::AddFOp>(exp, nexp);
+ Value exp = math::ExpOp::create(b, operand);
+ Value neg = arith::NegFOp::create(b, operand);
+ Value nexp = math::ExpOp::create(b, neg);
+ Value add = arith::AddFOp::create(b, exp, nexp);
Value half = createFloatConst(op->getLoc(), opType, 0.5, rewriter);
- Value res = b.create<arith::MulFOp>(add, half);
+ Value res = arith::MulFOp::create(b, add, half);
rewriter.replaceOp(op, res);
return success();
}
@@ -116,23 +116,23 @@ static LogicalResult convertTanhOp(math::TanhOp op, PatternRewriter &rewriter) {
Value negTwo = createFloatConst(loc, floatType, -2.0, rewriter);
// Compute sign(x) = cast<float_type>(x < 0) * (-2) + 1
- Value isNegative = rewriter.create<arith::CmpFOp>(
- loc, arith::CmpFPredicate::OLT, op.getOperand(), zero);
+ Value isNegative = arith::CmpFOp::create(
+ rewriter, loc, arith::CmpFPredicate::OLT, op.getOperand(), zero);
Value isNegativeFloat =
- rewriter.create<arith::UIToFPOp>(loc, floatType, isNegative);
+ arith::UIToFPOp::create(rewriter, loc, floatType, isNegative);
Value isNegativeTimesNegTwo =
- rewriter.create<arith::MulFOp>(loc, isNegativeFloat, negTwo);
- Value sign = rewriter.create<arith::AddFOp>(loc, isNegativeTimesNegTwo, one);
+ arith::MulFOp::create(rewriter, loc, isNegativeFloat, negTwo);
+ Value sign = arith::AddFOp::create(rewriter, loc, isNegativeTimesNegTwo, one);
// Normalize input to positive value: y = sign(x) * x
- Value positiveX = rewriter.create<arith::MulFOp>(loc, sign, op.getOperand());
+ Value positiveX = arith::MulFOp::create(rewriter, loc, sign, op.getOperand());
// Decompose on normalized input
- Value negDoubledX = rewriter.create<arith::MulFOp>(loc, negTwo, positiveX);
- Value exp2x = rewriter.create<math::ExpOp>(loc, negDoubledX);
- Value dividend = rewriter.create<arith::SubFOp>(loc, one, exp2x);
- Value divisor = rewriter.create<arith::AddFOp>(loc, one, exp2x);
- Value positiveRes = rewriter.create<arith::DivFOp>(loc, dividend, divisor);
+ Value negDoubledX = arith::MulFOp::create(rewriter, loc, negTwo, positiveX);
+ Value exp2x = math::ExpOp::create(rewriter, loc, negDoubledX);
+ Value dividend = arith::SubFOp::create(rewriter, loc, one, exp2x);
+ Value divisor = arith::AddFOp::create(rewriter, loc, one, exp2x);
+ Value positiveRes = arith::DivFOp::create(rewriter, loc, dividend, divisor);
// Multiply result by sign(x) to retain signs from negative inputs
rewriter.replaceOpWithNewOp<arith::MulFOp>(op, sign, positiveRes);
@@ -145,9 +145,9 @@ static LogicalResult convertTanOp(math::TanOp op, PatternRewriter &rewriter) {
ImplicitLocOpBuilder b(op->getLoc(), rewriter);
Value operand = op.getOperand();
Type type = operand.getType();
- Value sin = b.create<math::SinOp>(type, operand);
- Value cos = b.create<math::CosOp>(type, operand);
- Value div = b.create<arith::DivFOp>(type, sin, cos);
+ Value sin = math::SinOp::create(b, type, operand);
+ Value cos = math::CosOp::create(b, type, operand);
+ Value div = arith::DivFOp::create(b, type, sin, cos);
rewriter.replaceOp(op, div);
return success();
}
@@ -160,10 +160,10 @@ static LogicalResult convertAsinhOp(math::AsinhOp op,
Type opType = operand.getType();
Value one = createFloatConst(op->getLoc(), opType, 1.0, rewriter);
- Value fma = b.create<math::FmaOp>(operand, operand, one);
- Value sqrt = b.create<math::SqrtOp>(fma);
- Value add = b.create<arith::AddFOp>(operand, sqrt);
- Value res = b.create<math::LogOp>(add);
+ Value fma = math::FmaOp::create(b, operand, operand, one);
+ Value sqrt = math::SqrtOp::create(b, fma);
+ Value add = arith::AddFOp::create(b, operand, sqrt);
+ Value res = math::LogOp::create(b, add);
rewriter.replaceOp(op, res);
return success();
}
@@ -176,10 +176,10 @@ static LogicalResult convertAcoshOp(math::AcoshOp op,
Type opType = operand.getType();
Value negOne = createFloatConst(op->getLoc(), opType, -1.0, rewriter);
- Value fma = b.create<math::FmaOp>(operand, operand, negOne);
- Value sqrt = b.create<math::SqrtOp>(fma);
- Value add = b.create<arith::AddFOp>(operand, sqrt);
- Value res = b.create<math::LogOp>(add);
+ Value fma = math::FmaOp::create(b, operand, operand, negOne);
+ Value sqrt = math::SqrtOp::create(b, fma);
+ Value add = arith::AddFOp::create(b, operand, sqrt);
+ Value res = math::LogOp::create(b, add);
rewriter.replaceOp(op, res);
return success();
}
@@ -192,13 +192,13 @@ static LogicalResult convertAtanhOp(math::AtanhOp op,
Type opType = operand.getType();
Value one = createFloatConst(op->getLoc(), opType, 1.0, rewriter);
- Value add = b.create<arith::AddFOp>(operand, one);
- Value neg = b.create<arith::NegFOp>(operand);
- Value sub = b.create<arith::AddFOp>(neg, one);
- Value div = b.create<arith::DivFOp>(add, sub);
- Value log = b.create<math::LogOp>(div);
+ Value add = arith::AddFOp::create(b, operand, one);
+ Value neg = arith::NegFOp::create(b, operand);
+ Value sub = arith::AddFOp::create(b, neg, one);
+ Value div = arith::DivFOp::create(b, add, sub);
+ Value log = math::LogOp::create(b, div);
Value half = createFloatConst(op->getLoc(), opType, 0.5, rewriter);
- Value res = b.create<arith::MulFOp>(log, half);
+ Value res = arith::MulFOp::create(b, log, half);
rewriter.replaceOp(op, res);
return success();
}
@@ -209,8 +209,8 @@ static LogicalResult convertFmaFOp(math::FmaOp op, PatternRewriter &rewriter) {
Value operandB = op.getOperand(1);
Value operandC = op.getOperand(2);
Type type = op.getType();
- Value mult = b.create<arith::MulFOp>(type, operandA, operandB);
- Value add = b.create<arith::AddFOp>(type, mult, operandC);
+ Value mult = arith::MulFOp::create(b, type, operandA, operandB);
+ Value add = arith::AddFOp::create(b, type, mult, operandC);
rewriter.replaceOp(op, add);
return success();
}
@@ -235,11 +235,12 @@ static LogicalResult convertCeilOp(math::CeilOp op, PatternRewriter &rewriter) {
Value zero = createFloatConst(op->getLoc(), opType, 0.00, rewriter);
Value one = createFloatConst(op->getLoc(), opType, 1.00, rewriter);
- Value gtCheck = b.create<arith::CmpFOp>(arith::CmpFPredicate::OGT, operand,
- fpFixedConvert);
- Value incrValue = b.create<arith::SelectOp>(op->getLoc(), gtCheck, one, zero);
+ Value gtCheck = arith::CmpFOp::create(b, arith::CmpFPredicate::OGT, operand,
+ fpFixedConvert);
+ Value incrValue =
+ arith::SelectOp::create(b, op->getLoc(), gtCheck, one, zero);
- Value ret = b.create<arith::AddFOp>(opType, fpFixedConvert, incrValue);
+ Value ret = arith::AddFOp::create(b, opType, fpFixedConvert, incrValue);
rewriter.replaceOp(op, ret);
return success();
}
@@ -257,9 +258,9 @@ static LogicalResult convertFPowIOp(math::FPowIOp op,
auto convertFPowItoPowf = [&]() -> LogicalResult {
Value castPowerToFp =
- rewriter.create<arith::SIToFPOp>(op.getLoc(), baseType, power);
- Value res = rewriter.create<math::PowFOp>(op.getLoc(), baseType, base,
- castPowerToFp);
+ arith::SIToFPOp::create(rewriter, op.getLoc(), baseType, power);
+ Value res = math::PowFOp::create(rewriter, op.getLoc(), baseType, base,
+ castPowerToFp);
rewriter.replaceOp(op, res);
return success();
};
@@ -280,9 +281,9 @@ static LogicalResult convertFPowIOp(math::FPowIOp op,
while (absPower > 0) {
if (absPower & 1)
- res = b.create<arith::MulFOp>(baseType, base, res);
+ res = arith::MulFOp::create(b, baseType, base, res);
absPower >>= 1;
- base = b.create<arith::MulFOp>(baseType, base, base);
+ base = arith::MulFOp::create(b, baseType, base, base);
}
// Make sure not to introduce UB in case of negative power.
@@ -302,14 +303,14 @@ static LogicalResult convertFPowIOp(math::FPowIOp op,
createFloatConst(op->getLoc(), baseType,
APFloat::getInf(sem, /*Negative=*/true), rewriter);
Value zeroEqCheck =
- b.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, res, zero);
+ arith::CmpFOp::create(b, arith::CmpFPredicate::OEQ, res, zero);
Value negZeroEqCheck =
- b.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, res, negZero);
- res = b.create<arith::DivFOp>(baseType, one, res);
+ arith::CmpFOp::create(b, arith::CmpFPredicate::OEQ, res, negZero);
+ res = arith::DivFOp::create(b, baseType, one, res);
res =
- b.create<arith::SelectOp>(op->getLoc(), zeroEqCheck, posInfinity, res);
- res = b.create<arith::SelectOp>(op->getLoc(), negZeroEqCheck, negInfinity,
- res);
+ arith::SelectOp::create(b, op->getLoc(), zeroEqCheck, posInfinity, res);
+ res = arith::SelectOp::create(b, op->getLoc(), negZeroEqCheck, negInfinity,
+ res);
}
rewriter.replaceOp(op, res);
@@ -330,7 +331,7 @@ static LogicalResult convertPowfOp(math::PowFOp op, PatternRewriter &rewriter) {
cast<mlir::FloatType>(getElementTypeOrSelf(typeB)).getFloatSemantics();
APFloat valueB(sem);
auto mulf = [&](Value x, Value y) -> Value {
- return b.create<arith::MulFOp>(x, y);
+ return arith::MulFOp::create(b, x, y);
};
if (matchPattern(operandB, m_ConstantFloat(&valueB))) {
if (valueB.isZero()) {
@@ -347,19 +348,19 @@ static LogicalResult convertPowfOp(math::PowFOp op, PatternRewriter &rewriter) {
if (valueB.isExactlyValue(-1.0)) {
// a^(-1) -> 1 / a
Value one = createFloatConst(op->getLoc(), typeA, 1.0, rewriter);
- Value div = b.create<arith::DivFOp>(one, operandA);
+ Value div = arith::DivFOp::create(b, one, operandA);
rewriter.replaceOp(op, div);
return success();
}
if (valueB.isExactlyValue(0.5)) {
// a^(1/2) -> sqrt(a)
- Value sqrt = b.create<math::SqrtOp>(operandA);
+ Value sqrt = math::SqrtOp::create(b, operandA);
rewriter.replaceOp(op, sqrt);
return success();
}
if (valueB.isExactlyValue(-0.5)) {
// a^(-1/2) -> 1 / sqrt(a)
- Value rsqrt = b.create<math::RsqrtOp>(operandA);
+ Value rsqrt = math::RsqrtOp::create(b, operandA);
rewriter.replaceOp(op, rsqrt);
return success();
}
@@ -372,7 +373,7 @@ static LogicalResult convertPowfOp(math::PowFOp op, PatternRewriter &rewriter) {
// a^(-2) -> 1 / (a * a)
Value one =
createFloatConst(op->getLoc(), operandA.getType(), 1.0, rewriter);
- Value div = b.create<arith::DivFOp>(one, mulf(operandA, operandA));
+ Value div = arith::DivFOp::create(b, one, mulf(operandA, operandA));
rewriter.replaceOp(op, div);
return success();
}
@@ -382,9 +383,9 @@ static LogicalResult convertPowfOp(math::PowFOp op, PatternRewriter &rewriter) {
}
}
- Value logA = b.create<math::LogOp>(operandA);
- Value mult = b.create<arith::MulFOp>(operandB, logA);
- Value expResult = b.create<math::ExpOp>(mult);
+ Value logA = math::LogOp::create(b, operandA);
+ Value mult = arith::MulFOp::create(b, operandB, logA);
+ Value expResult = math::ExpOp::create(b, mult);
rewriter.replaceOp(op, expResult);
return success();
}
@@ -399,8 +400,8 @@ static LogicalResult convertExp2fOp(math::Exp2Op op,
Value operand = op.getOperand();
Type opType = operand.getType();
Value ln2 = createFloatConst(op->getLoc(), opType, llvm::numbers::ln2, b);
- Value mult = b.create<arith::MulFOp>(opType, operand, ln2);
- Value exp = b.create<math::ExpOp>(op->getLoc(), mult);
+ Value mult = arith::MulFOp::create(b, opType, operand, ln2);
+ Value exp = math::ExpOp::create(b, op->getLoc(), mult);
rewriter.replaceOp(op, exp);
return success();
}
@@ -426,8 +427,8 @@ static LogicalResult convertRoundOp(math::RoundOp op,
Value c127 = createIntConst(loc, i32Ty, 127, b);
Value expMask = createIntConst(loc, i32Ty, (1 << 8) - 1, b);
- Value incrValue = b.create<math::CopySignOp>(half, operand);
- Value add = b.create<arith::AddFOp>(opType, operand, incrValue);
+ Value incrValue = math::CopySignOp::create(b, half, operand);
+ Value add = arith::AddFOp::create(b, opType, operand, incrValue);
Value fpFixedConvert = createTruncatedFPValue(add, b);
// There are three cases where adding 0.5 to the value and truncating by
@@ -450,15 +451,15 @@ static LogicalResult convertRoundOp(math::RoundOp op,
// i64 leading to wrong outputs.
//
// All three cases satisfy the property `biasedExp >= 23`.
- Value operandBitcast = b.create<arith::BitcastOp>(i32Ty, operand);
- Value operandExp = b.create<arith::AndIOp>(
- b.create<arith::ShRUIOp>(operandBitcast, c23), expMask);
- Value operandBiasedExp = b.create<arith::SubIOp>(operandExp, c127);
- Value isSpecialValOrLargeVal =
- b.create<arith::CmpIOp>(arith::CmpIPredicate::sge, operandBiasedExp, c23);
-
- Value result = b.create<arith::SelectOp>(isSpecialValOrLargeVal, operand,
- fpFixedConvert);
+ Value operandBitcast = arith::BitcastOp::create(b, i32Ty, operand);
+ Value operandExp = arith::AndIOp::create(
+ b, arith::ShRUIOp::create(b, operandBitcast, c23), expMask);
+ Value operandBiasedExp = arith::SubIOp::create(b, operandExp, c127);
+ Value isSpecialValOrLargeVal = arith::CmpIOp::create(
+ b, arith::CmpIPredicate::sge, operandBiasedExp, c23);
+
+ Value result = arith::SelectOp::create(b, isSpecialValOrLargeVal, operand,
+ fpFixedConvert);
rewriter.replaceOp(op, result);
return success();
}
@@ -488,21 +489,21 @@ static LogicalResult convertCtlzOp(math::CountLeadingZerosOp op,
auto bits = createIntConst(loc, operandTy, half, rewriter);
auto mask = createIntConst(loc, operandTy, allbits >> half, rewriter);
- Value pred =
- rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::ule, x, mask);
- Value add = rewriter.create<arith::AddIOp>(loc, count, bits);
- Value shift = rewriter.create<arith::ShLIOp>(loc, x, bits);
+ Value pred = arith::CmpIOp::create(rewriter, loc, arith::CmpIPredicate::ule,
+ x, mask);
+ Value add = arith::AddIOp::create(rewriter, loc, count, bits);
+ Value shift = arith::ShLIOp::create(rewriter, loc, x, bits);
- x = rewriter.create<arith::SelectOp>(loc, pred, shift, x);
- count = rewriter.create<arith::SelectOp>(loc, pred, add, count);
+ x = arith::SelectOp::create(rewriter, loc, pred, shift, x);
+ count = arith::SelectOp::create(rewriter, loc, pred, add, count);
}
Value zero = createIntConst(loc, operandTy, 0, rewriter);
- Value pred = rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::eq,
- operand, zero);
+ Value pred = arith::CmpIOp::create(rewriter, loc, arith::CmpIPredicate::eq,
+ operand, zero);
Value bwval = createIntConst(loc, operandTy, bitwidth, rewriter);
- Value sel = rewriter.create<arith::SelectOp>(loc, pred, bwval, count);
+ Value sel = arith::SelectOp::create(rewriter, loc, pred, bwval, count);
rewriter.replaceOp(op, sel);
return success();
}
@@ -549,29 +550,29 @@ static LogicalResult convertRoundEvenOp(math::RoundEvenOp op,
Value c23Mask = createIntConst(loc, iTy, (1ull << mantissaWidth) - 1, b);
Value expMask = createIntConst(loc, iTy, (1ull << exponentWidth) - 1, b);
- Value operandBitcast = b.create<arith::BitcastOp>(iTy, operand);
- Value round = b.create<math::RoundOp>(operand);
- Value roundBitcast = b.create<arith::BitcastOp>(iTy, round);
+ Value operandBitcast = arith::BitcastOp::create(b, iTy, operand);
+ Value round = math::RoundOp::create(b, operand);
+ Value roundBitcast = arith::BitcastOp::create(b, iTy, round);
// Get biased exponents for operand and round(operand)
- Value operandExp = b.create<arith::AndIOp>(
- b.create<arith::ShRUIOp>(operandBitcast, c23), expMask);
- Value operandBiasedExp = b.create<arith::SubIOp>(operandExp, c127);
- Value roundExp = b.create<arith::AndIOp>(
- b.create<arith::ShRUIOp>(roundBitcast, c23), expMask);
- Value roundBiasedExp = b.create<arith::SubIOp>(roundExp, c127);
+ Value operandExp = arith::AndIOp::create(
+ b, arith::ShRUIOp::create(b, operandBitcast, c23), expMask);
+ Value operandBiasedExp = arith::SubIOp::create(b, operandExp, c127);
+ Value roundExp = arith::AndIOp::create(
+ b, arith::ShRUIOp::create(b, roundBitcast, c23), expMask);
+ Value roundBiasedExp = arith::SubIOp::create(b, roundExp, c127);
auto safeShiftRight = [&](Value x, Value shift) -> Value {
// Clamp shift to valid range [0, bitwidth - 1] to avoid undefined behavior
- Value clampedShift = b.create<arith::MaxSIOp>(shift, c0);
- clampedShift = b.create<arith::MinSIOp>(clampedShift, c31);
- return b.create<arith::ShRUIOp>(x, clampedShift);
+ Value clampedShift = arith::MaxSIOp::create(b, shift, c0);
+ clampedShift = arith::MinSIOp::create(b, clampedShift, c31);
+ return arith::ShRUIOp::create(b, x, clampedShift);
};
auto maskMantissa = [&](Value mantissa,
Value mantissaMaskRightShift) -> Value {
Value shiftedMantissaMask = safeShiftRight(c23Mask, mantissaMaskRightShift);
- return b.create<arith::AndIOp>(mantissa, shiftedMantissaMask);
+ return arith::AndIOp::create(b, mantissa, shiftedMantissaMask);
};
// A whole number `x`, such that `|x| != 1`, is even if the mantissa, ignoring
@@ -589,13 +590,13 @@ static LogicalResult convertRoundEvenOp(math::RoundEvenOp op,
// `biasedExp > 23`, so they get treated as large numbers with no room for
// decimals, which are always even.
Value roundBiasedExpEq0 =
- b.create<arith::CmpIOp>(arith::CmpIPredicate::eq, roundBiasedExp, c0);
- Value roundBiasedExpMinus1 = b.create<arith::SubIOp>(roundBiasedExp, c1);
+ arith::CmpIOp::create(b, arith::CmpIPredicate::eq, roundBiasedExp, c0);
+ Value roundBiasedExpMinus1 = arith::SubIOp::create(b, roundBiasedExp, c1);
Value roundMaskedMantissa = maskMantissa(roundBitcast, roundBiasedExpMinus1);
- Value roundIsNotEvenOrSpecialVal = b.create<arith::CmpIOp>(
- arith::CmpIPredicate::ne, roundMaskedMantissa, c0);
+ Value roundIsNotEvenOrSpecialVal = arith::CmpIOp::create(
+ b, arith::CmpIPredicate::ne, roundMaskedMantissa, c0);
roundIsNotEvenOrSpecialVal =
- b.create<arith::OrIOp>(roundIsNotEvenOrSpecialVal, roundBiasedExpEq0);
+ arith::OrIOp::create(b, roundIsNotEvenOrSpecialVal, roundBiasedExpEq0);
// A value `x` with `0 <= biasedExp < 23`, is halfway between two consecutive
// integers if the bit at index `biasedExp` starting from the left in the
@@ -604,37 +605,37 @@ static LogicalResult convertRoundEvenOp(math::RoundEvenOp op,
// values +-0.5 are the only halfway values that have `biasedExp == -1 < 0`,
// so these are handled separately. In particular, if `biasedExp == -1`, the
// value is halfway if the entire mantissa is zero.
- Value operandBiasedExpEqNeg1 = b.create<arith::CmpIOp>(
- arith::CmpIPredicate::eq, operandBiasedExp, cNeg1);
- Value expectedOperandMaskedMantissa = b.create<arith::SelectOp>(
- operandBiasedExpEqNeg1, c0, safeShiftRight(c2To22, operandBiasedExp));
+ Value operandBiasedExpEqNeg1 = arith::CmpIOp::create(
+ b, arith::CmpIPredicate::eq, operandBiasedExp, cNeg1);
+ Value expectedOperandMaskedMantissa = arith::SelectOp::create(
+ b, operandBiasedExpEqNeg1, c0, safeShiftRight(c2To22, operandBiasedExp));
Value operandMaskedMantissa = maskMantissa(operandBitcast, operandBiasedExp);
Value operandIsHalfway =
- b.create<arith::CmpIOp>(arith::CmpIPredicate::eq, operandMaskedMantissa,
- expectedOperandMaskedMantissa);
+ arith::CmpIOp::create(b, arith::CmpIPredicate::eq, operandMaskedMantissa,
+ expectedOperandMaskedMantissa);
// Ensure `biasedExp` is in the valid range for half values.
- Value operandBiasedExpGeNeg1 = b.create<arith::CmpIOp>(
- arith::CmpIPredicate::sge, operandBiasedExp, cNeg1);
- Value operandBiasedExpLt23 =
- b.create<arith::CmpIOp>(arith::CmpIPredicate::slt, operandBiasedExp, c23);
+ Value operandBiasedExpGeNeg1 = arith::CmpIOp::create(
+ b, arith::CmpIPredicate::sge, operandBiasedExp, cNeg1);
+ Value operandBiasedExpLt23 = arith::CmpIOp::create(
+ b, arith::CmpIPredicate::slt, operandBiasedExp, c23);
operandIsHalfway =
- b.create<arith::AndIOp>(operandIsHalfway, operandBiasedExpLt23);
+ arith::AndIOp::create(b, operandIsHalfway, operandBiasedExpLt23);
operandIsHalfway =
- b.create<arith::AndIOp>(operandIsHalfway, operandBiasedExpGeNeg1);
+ arith::AndIOp::create(b, operandIsHalfway, operandBiasedExpGeNeg1);
// Adjust rounded operand with `round(operand) - sign(operand)` to correct the
// case where `round` rounded in the opposite direction of `roundeven`.
- Value sign = b.create<math::CopySignOp>(c1Float, operand);
- Value roundShifted = b.create<arith::SubFOp>(round, sign);
+ Value sign = math::CopySignOp::create(b, c1Float, operand);
+ Value roundShifted = arith::SubFOp::create(b, round, sign);
// If the rounded value is even or a special value, we default to the behavior
// of `math.round`.
Value needsShift =
- b.create<arith::AndIOp>(roundIsNotEvenOrSpecialVal, operandIsHalfway);
- Value result = b.create<arith::SelectOp>(needsShift, roundShifted, round);
+ arith::AndIOp::create(b, roundIsNotEvenOrSpecialVal, operandIsHalfway);
+ Value result = arith::SelectOp::create(b, needsShift, roundShifted, round);
// The `x - sign` adjustment does not preserve the sign when we are adjusting
// the value -1 to -0. So here the sign is copied again to ensure that -0.5 is
// rounded to -0.0.
- result = b.create<math::CopySignOp>(result, operand);
+ result = math::CopySignOp::create(b, result, operand);
rewriter.replaceOp(op, result);
return success();
}
@@ -656,7 +657,7 @@ static LogicalResult convertRsqrtOp(math::RsqrtOp op,
Location loc = op->getLoc();
auto constOneFloat = createFloatConst(loc, operandTy, 1.0, rewriter);
- auto sqrtOp = rewriter.create<math::SqrtOp>(loc, operand);
+ auto sqrtOp = math::SqrtOp::create(rewriter, loc, operand);
rewriter.replaceOpWithNewOp<arith::DivFOp>(op, constOneFloat, sqrtOp);
return success();
}
diff --git a/mlir/lib/Dialect/Math/Transforms/ExtendToSupportedTypes.cpp b/mlir/lib/Dialect/Math/Transforms/ExtendToSupportedTypes.cpp
index a570ed5118ef0b..9d6ad613fc945f 100644
--- a/mlir/lib/Dialect/Math/Transforms/ExtendToSupportedTypes.cpp
+++ b/mlir/lib/Dialect/Math/Transforms/ExtendToSupportedTypes.cpp
@@ -73,7 +73,7 @@ void mlir::math::populateExtendToSupportedTypesTypeConverter(
});
typeConverter.addTargetMaterialization(
[](OpBuilder &b, Type target, ValueRange input, Location loc) {
- auto extFOp = b.create<arith::ExtFOp>(loc, target, input);
+ auto extFOp = arith::ExtFOp::create(b, loc, target, input);
extFOp.setFastmath(arith::FastMathFlags::contract);
return extFOp;
});
@@ -104,7 +104,7 @@ LogicalResult ExtendToSupportedTypesRewritePattern::matchAndRewrite(
for (auto [result, newType, origType] : llvm::zip_equal(
results, (*legalized)->getResultTypes(), op->getResultTypes())) {
if (newType != origType) {
- auto truncFOp = rewriter.create<arith::TruncFOp>(loc, origType, result);
+ auto truncFOp = arith::TruncFOp::create(rewriter, loc, origType, result);
truncFOp.setFastmath(arith::FastMathFlags::contract);
result = truncFOp.getResult();
}
diff --git a/mlir/lib/Dialect/Math/Transforms/PolynomialApproximation.cpp b/mlir/lib/Dialect/Math/Transforms/PolynomialApproximation.cpp
index dd2dfe372b683a..76720cfd4a98cd 100644
--- a/mlir/lib/Dialect/Math/Transforms/PolynomialApproximation.cpp
+++ b/mlir/lib/Dialect/Math/Transforms/PolynomialApproximation.cpp
@@ -72,7 +72,7 @@ static Value broadcast(ImplicitLocOpBuilder &builder, Value value,
std::optional<VectorShape> shape) {
assert(!isa<VectorType>(value.getType()) && "must be scalar value");
auto type = broadcast(value.getType(), shape);
- return shape ? builder.create<BroadcastOp>(type, value) : value;
+ return shape ? BroadcastOp::create(builder, type, value) : value;
}
//----------------------------------------------------------------------------//
@@ -130,7 +130,7 @@ handleMultidimensionalVectors(ImplicitLocOpBuilder &builder,
auto eltType = cast<VectorType>(operand.getType()).getElementType();
auto expandedType = VectorType::get(expandedShape, eltType);
expandedOperands[i] =
- builder.create<vector::ShapeCastOp>(expandedType, operand);
+ vector::ShapeCastOp::create(builder, expandedType, operand);
}
}
@@ -148,7 +148,7 @@ handleMultidimensionalVectors(ImplicitLocOpBuilder &builder,
SmallVector<Value> extracted(expandedOperands.size());
for (const auto &tuple : llvm::enumerate(expandedOperands))
extracted[tuple.index()] =
- builder.create<vector::ExtractOp>(tuple.value(), offsets);
+ vector::ExtractOp::create(builder, tuple.value(), offsets);
results[i] = compute(extracted);
}
@@ -156,16 +156,16 @@ handleMultidimensionalVectors(ImplicitLocOpBuilder &builder,
// Stitch results together into one large vector.
Type resultEltType = cast<VectorType>(results[0].getType()).getElementType();
Type resultExpandedType = VectorType::get(expandedShape, resultEltType);
- Value result = builder.create<arith::ConstantOp>(
- resultExpandedType, builder.getZeroAttr(resultExpandedType));
+ Value result = arith::ConstantOp::create(
+ builder, resultExpandedType, builder.getZeroAttr(resultExpandedType));
for (int64_t i = 0; i < maxIndex; ++i)
- result = builder.create<vector::InsertOp>(results[i], result,
- delinearize(i, strides));
+ result = vector::InsertOp::create(builder, results[i], result,
+ delinearize(i, strides));
// Reshape back to the original vector shape.
- return builder.create<vector::ShapeCastOp>(
- VectorType::get(inputShape, resultEltType), result);
+ return vector::ShapeCastOp::create(
+ builder, VectorType::get(inputShape, resultEltType), result);
}
//----------------------------------------------------------------------------//
@@ -173,28 +173,28 @@ handleMultidimensionalVectors(ImplicitLocOpBuilder &builder,
//----------------------------------------------------------------------------//
static Value boolCst(ImplicitLocOpBuilder &builder, bool value) {
- return builder.create<arith::ConstantOp>(builder.getBoolAttr(value));
+ return arith::ConstantOp::create(builder, builder.getBoolAttr(value));
}
static Value floatCst(ImplicitLocOpBuilder &builder, float value,
Type elementType) {
assert((elementType.isF16() || elementType.isF32()) &&
"x must be f16 or f32 type.");
- return builder.create<arith::ConstantOp>(
- builder.getFloatAttr(elementType, value));
+ return arith::ConstantOp::create(builder,
+ builder.getFloatAttr(elementType, value));
}
static Value f32Cst(ImplicitLocOpBuilder &builder, double value) {
- return builder.create<arith::ConstantOp>(builder.getF32FloatAttr(value));
+ return arith::ConstantOp::create(builder, builder.getF32FloatAttr(value));
}
static Value i32Cst(ImplicitLocOpBuilder &builder, int32_t value) {
- return builder.create<arith::ConstantOp>(builder.getI32IntegerAttr(value));
+ return arith::ConstantOp::create(builder, builder.getI32IntegerAttr(value));
}
static Value f32FromBits(ImplicitLocOpBuilder &builder, uint32_t bits) {
Value i32Value = i32Cst(builder, static_cast<int32_t>(bits));
- return builder.create<arith::BitcastOp>(builder.getF32Type(), i32Value);
+ return arith::BitcastOp::create(builder, builder.getF32Type(), i32Value);
}
//----------------------------------------------------------------------------//
@@ -203,15 +203,17 @@ static Value f32FromBits(ImplicitLocOpBuilder &builder, uint32_t bits) {
// Return the minimum of the two values or NaN if value is NaN
static Value min(ImplicitLocOpBuilder &builder, Value value, Value bound) {
- return builder.create<arith::SelectOp>(
- builder.create<arith::CmpFOp>(arith::CmpFPredicate::ULT, value, bound),
+ return arith::SelectOp::create(
+ builder,
+ arith::CmpFOp::create(builder, arith::CmpFPredicate::ULT, value, bound),
value, bound);
}
// Return the maximum of the two values or NaN if value is NaN
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound) {
- return builder.create<arith::SelectOp>(
- builder.create<arith::CmpFOp>(arith::CmpFPredicate::UGT, value, bound),
+ return arith::SelectOp::create(
+ builder,
+ arith::CmpFOp::create(builder, arith::CmpFPredicate::UGT, value, bound),
value, bound);
}
@@ -241,24 +243,24 @@ static std::pair<Value, Value> frexp(ImplicitLocOpBuilder &builder, Value arg,
Value cstInvMantMask = f32FromBits(builder, ~0x7f800000u);
// Bitcast to i32 for bitwise operations.
- Value i32Half = builder.create<arith::BitcastOp>(i32, cstHalf);
- Value i32InvMantMask = builder.create<arith::BitcastOp>(i32, cstInvMantMask);
- Value i32Arg = builder.create<arith::BitcastOp>(i32Vec, arg);
+ Value i32Half = arith::BitcastOp::create(builder, i32, cstHalf);
+ Value i32InvMantMask = arith::BitcastOp::create(builder, i32, cstInvMantMask);
+ Value i32Arg = arith::BitcastOp::create(builder, i32Vec, arg);
// Compute normalized fraction.
- Value tmp0 = builder.create<arith::AndIOp>(i32Arg, bcast(i32InvMantMask));
- Value tmp1 = builder.create<arith::OrIOp>(tmp0, bcast(i32Half));
- Value normalizedFraction = builder.create<arith::BitcastOp>(f32Vec, tmp1);
+ Value tmp0 = arith::AndIOp::create(builder, i32Arg, bcast(i32InvMantMask));
+ Value tmp1 = arith::OrIOp::create(builder, tmp0, bcast(i32Half));
+ Value normalizedFraction = arith::BitcastOp::create(builder, f32Vec, tmp1);
// Compute exponent.
- Value arg0 = isPositive ? arg : builder.create<math::AbsFOp>(arg);
- Value biasedExponentBits = builder.create<arith::ShRUIOp>(
- builder.create<arith::BitcastOp>(i32Vec, arg0),
+ Value arg0 = isPositive ? arg : math::AbsFOp::create(builder, arg);
+ Value biasedExponentBits = arith::ShRUIOp::create(
+ builder, arith::BitcastOp::create(builder, i32Vec, arg0),
bcast(i32Cst(builder, 23)));
Value biasedExponent =
- builder.create<arith::SIToFPOp>(f32Vec, biasedExponentBits);
+ arith::SIToFPOp::create(builder, f32Vec, biasedExponentBits);
Value exponent =
- builder.create<arith::SubFOp>(biasedExponent, bcast(cst126f));
+ arith::SubFOp::create(builder, biasedExponent, bcast(cst126f));
return {normalizedFraction, exponent};
}
@@ -278,10 +280,10 @@ static Value exp2I32(ImplicitLocOpBuilder &builder, Value arg) {
// Set the exponent bias to zero.
auto bias = bcast(i32Cst(builder, 127));
- Value biasedArg = builder.create<arith::AddIOp>(arg, bias);
+ Value biasedArg = arith::AddIOp::create(builder, arg, bias);
Value exp2ValueInt =
- builder.create<arith::ShLIOp>(biasedArg, exponetBitLocation);
- Value exp2ValueF32 = builder.create<arith::BitcastOp>(f32Vec, exp2ValueInt);
+ arith::ShLIOp::create(builder, biasedArg, exponetBitLocation);
+ Value exp2ValueF32 = arith::BitcastOp::create(builder, f32Vec, exp2ValueInt);
return exp2ValueF32;
}
@@ -300,10 +302,10 @@ Value makePolynomialCalculation(ImplicitLocOpBuilder &builder,
if (coeffs.size() == 1)
return coeffs[0];
- Value res = builder.create<math::FmaOp>(x, coeffs[coeffs.size() - 1],
- coeffs[coeffs.size() - 2]);
+ Value res = math::FmaOp::create(builder, x, coeffs[coeffs.size() - 1],
+ coeffs[coeffs.size() - 2]);
for (auto i = ptr
diff _t(coeffs.size()) - 3; i >= 0; --i) {
- res = builder.create<math::FmaOp>(x, res, coeffs[i]);
+ res = math::FmaOp::create(builder, x, res, coeffs[i]);
}
return res;
}
@@ -343,9 +345,9 @@ LogicalResult insertCasts(Operation *op, PatternRewriter &rewriter) {
Location loc = op->getLoc();
SmallVector<Value> operands;
for (auto operand : op->getOperands())
- operands.push_back(rewriter.create<arith::ExtFOp>(loc, newType, operand));
+ operands.push_back(arith::ExtFOp::create(rewriter, loc, newType, operand));
auto result =
- rewriter.create<T>(loc, TypeRange{newType}, operands, op->getAttrs());
+ T::create(rewriter, loc, TypeRange{newType}, operands, op->getAttrs());
rewriter.replaceOpWithNewOp<arith::TruncFOp>(op, origType, result);
return success();
}
@@ -393,18 +395,18 @@ AtanApproximation::matchAndRewrite(math::AtanOp op,
std::optional<VectorShape> shape = vectorShape(op.getOperand());
ImplicitLocOpBuilder builder(op->getLoc(), rewriter);
- Value abs = builder.create<math::AbsFOp>(operand);
+ Value abs = math::AbsFOp::create(builder, operand);
auto one = broadcast(builder, f32Cst(builder, 1.0), shape);
// When 0.66 < x <= 2.41 we do (x-1) / (x+1):
auto twoThirds = broadcast(builder, f32Cst(builder, 0.66), shape);
Value cmp2 =
- builder.create<arith::CmpFOp>(arith::CmpFPredicate::OGT, abs, twoThirds);
- Value addone = builder.create<arith::AddFOp>(abs, one);
- Value subone = builder.create<arith::SubFOp>(abs, one);
- Value xnum = builder.create<arith::SelectOp>(cmp2, subone, abs);
- Value xden = builder.create<arith::SelectOp>(cmp2, addone, one);
+ arith::CmpFOp::create(builder, arith::CmpFPredicate::OGT, abs, twoThirds);
+ Value addone = arith::AddFOp::create(builder, abs, one);
+ Value subone = arith::SubFOp::create(builder, abs, one);
+ Value xnum = arith::SelectOp::create(builder, cmp2, subone, abs);
+ Value xden = arith::SelectOp::create(builder, cmp2, addone, one);
auto bcast = [&](Value value) -> Value {
return broadcast(builder, value, shape);
@@ -413,12 +415,12 @@ AtanApproximation::matchAndRewrite(math::AtanOp op,
// Break into the <= 0.66 or > 2.41 we do x or 1/x:
auto tan3pio8 = bcast(f32Cst(builder, 2.41421356237309504880));
Value cmp1 =
- builder.create<arith::CmpFOp>(arith::CmpFPredicate::OGT, abs, tan3pio8);
- xnum = builder.create<arith::SelectOp>(cmp1, one, xnum);
- xden = builder.create<arith::SelectOp>(cmp1, abs, xden);
+ arith::CmpFOp::create(builder, arith::CmpFPredicate::OGT, abs, tan3pio8);
+ xnum = arith::SelectOp::create(builder, cmp1, one, xnum);
+ xden = arith::SelectOp::create(builder, cmp1, abs, xden);
- Value x = builder.create<arith::DivFOp>(xnum, xden);
- Value xx = builder.create<arith::MulFOp>(x, x);
+ Value x = arith::DivFOp::create(builder, xnum, xden);
+ Value xx = arith::MulFOp::create(builder, x, x);
// Perform the Taylor series approximation for atan over the range
// [0.0, 0.66].
@@ -435,31 +437,31 @@ AtanApproximation::matchAndRewrite(math::AtanOp op,
// Apply the polynomial approximation for the numerator:
Value n = p0;
- n = builder.create<math::FmaOp>(xx, n, p1);
- n = builder.create<math::FmaOp>(xx, n, p2);
- n = builder.create<math::FmaOp>(xx, n, p3);
- n = builder.create<math::FmaOp>(xx, n, p4);
- n = builder.create<arith::MulFOp>(n, xx);
+ n = math::FmaOp::create(builder, xx, n, p1);
+ n = math::FmaOp::create(builder, xx, n, p2);
+ n = math::FmaOp::create(builder, xx, n, p3);
+ n = math::FmaOp::create(builder, xx, n, p4);
+ n = arith::MulFOp::create(builder, n, xx);
// Apply the polynomial approximation for the denominator:
Value d = q0;
- d = builder.create<math::FmaOp>(xx, d, q1);
- d = builder.create<math::FmaOp>(xx, d, q2);
- d = builder.create<math::FmaOp>(xx, d, q3);
- d = builder.create<math::FmaOp>(xx, d, q4);
+ d = math::FmaOp::create(builder, xx, d, q1);
+ d = math::FmaOp::create(builder, xx, d, q2);
+ d = math::FmaOp::create(builder, xx, d, q3);
+ d = math::FmaOp::create(builder, xx, d, q4);
// Compute approximation of theta:
- Value ans0 = builder.create<arith::DivFOp>(n, d);
- ans0 = builder.create<math::FmaOp>(ans0, x, x);
+ Value ans0 = arith::DivFOp::create(builder, n, d);
+ ans0 = math::FmaOp::create(builder, ans0, x, x);
// Correct for the input mapping's angles:
Value mpi4 = bcast(f32Cst(builder, llvm::numbers::pi / 4));
- Value ans2 = builder.create<arith::AddFOp>(mpi4, ans0);
- Value ans = builder.create<arith::SelectOp>(cmp2, ans2, ans0);
+ Value ans2 = arith::AddFOp::create(builder, mpi4, ans0);
+ Value ans = arith::SelectOp::create(builder, cmp2, ans2, ans0);
Value mpi2 = bcast(f32Cst(builder, llvm::numbers::pi / 2));
- Value ans1 = builder.create<arith::SubFOp>(mpi2, ans0);
- ans = builder.create<arith::SelectOp>(cmp1, ans1, ans);
+ Value ans1 = arith::SubFOp::create(builder, mpi2, ans0);
+ ans = arith::SelectOp::create(builder, cmp1, ans1, ans);
// Correct for signing of the input.
rewriter.replaceOpWithNewOp<math::CopySignOp>(op, ans, operand);
@@ -492,44 +494,46 @@ Atan2Approximation::matchAndRewrite(math::Atan2Op op,
std::optional<VectorShape> shape = vectorShape(op.getResult());
// Compute atan in the valid range.
- auto div = builder.create<arith::DivFOp>(y, x);
- auto atan = builder.create<math::AtanOp>(div);
+ auto div = arith::DivFOp::create(builder, y, x);
+ auto atan = math::AtanOp::create(builder, div);
// Determine what the atan would be for a 180 degree rotation.
auto zero = broadcast(builder, f32Cst(builder, 0.0f), shape);
auto pi = broadcast(builder, f32Cst(builder, 3.14159265359f), shape);
- auto addPi = builder.create<arith::AddFOp>(atan, pi);
- auto subPi = builder.create<arith::SubFOp>(atan, pi);
+ auto addPi = arith::AddFOp::create(builder, atan, pi);
+ auto subPi = arith::SubFOp::create(builder, atan, pi);
auto atanGt =
- builder.create<arith::CmpFOp>(arith::CmpFPredicate::OGT, atan, zero);
- auto flippedAtan = builder.create<arith::SelectOp>(atanGt, subPi, addPi);
+ arith::CmpFOp::create(builder, arith::CmpFPredicate::OGT, atan, zero);
+ auto flippedAtan = arith::SelectOp::create(builder, atanGt, subPi, addPi);
// Determine whether to directly use atan or use the 180 degree flip
- auto xGt = builder.create<arith::CmpFOp>(arith::CmpFPredicate::OGT, x, zero);
- Value result = builder.create<arith::SelectOp>(xGt, atan, flippedAtan);
+ auto xGt = arith::CmpFOp::create(builder, arith::CmpFPredicate::OGT, x, zero);
+ Value result = arith::SelectOp::create(builder, xGt, atan, flippedAtan);
// Handle x = 0, y > 0
Value xZero =
- builder.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, x, zero);
- Value yGt = builder.create<arith::CmpFOp>(arith::CmpFPredicate::OGT, y, zero);
- Value isHalfPi = builder.create<arith::AndIOp>(xZero, yGt);
+ arith::CmpFOp::create(builder, arith::CmpFPredicate::OEQ, x, zero);
+ Value yGt =
+ arith::CmpFOp::create(builder, arith::CmpFPredicate::OGT, y, zero);
+ Value isHalfPi = arith::AndIOp::create(builder, xZero, yGt);
auto halfPi = broadcast(builder, f32Cst(builder, 1.57079632679f), shape);
- result = builder.create<arith::SelectOp>(isHalfPi, halfPi, result);
+ result = arith::SelectOp::create(builder, isHalfPi, halfPi, result);
// Handle x = 0, y < 0
- Value yLt = builder.create<arith::CmpFOp>(arith::CmpFPredicate::OLT, y, zero);
- Value isNegativeHalfPiPi = builder.create<arith::AndIOp>(xZero, yLt);
+ Value yLt =
+ arith::CmpFOp::create(builder, arith::CmpFPredicate::OLT, y, zero);
+ Value isNegativeHalfPiPi = arith::AndIOp::create(builder, xZero, yLt);
auto negativeHalfPiPi =
broadcast(builder, f32Cst(builder, -1.57079632679f), shape);
- result = builder.create<arith::SelectOp>(isNegativeHalfPiPi, negativeHalfPiPi,
- result);
+ result = arith::SelectOp::create(builder, isNegativeHalfPiPi,
+ negativeHalfPiPi, result);
// Handle x = 0, y = 0;
Value yZero =
- builder.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, y, zero);
- Value isNan = builder.create<arith::AndIOp>(xZero, yZero);
+ arith::CmpFOp::create(builder, arith::CmpFPredicate::OEQ, y, zero);
+ Value isNan = arith::AndIOp::create(builder, xZero, yZero);
Value cstNan = broadcast(builder, f32FromBits(builder, 0x7fc00000), shape);
- result = builder.create<arith::SelectOp>(isNan, cstNan, result);
+ result = arith::SelectOp::create(builder, isNan, cstNan, result);
rewriter.replaceOp(op, result);
return success();
@@ -569,9 +573,9 @@ TanhApproximation::matchAndRewrite(math::TanhOp op,
// Mask for tiny values that are approximated with `operand`.
Value tiny = bcast(f32Cst(builder, 0.0004f));
- Value tinyMask = builder.create<arith::CmpFOp>(
- arith::CmpFPredicate::OLT, builder.create<math::AbsFOp>(op.getOperand()),
- tiny);
+ Value tinyMask = arith::CmpFOp::create(
+ builder, arith::CmpFPredicate::OLT,
+ math::AbsFOp::create(builder, op.getOperand()), tiny);
// The monomial coefficients of the numerator polynomial (odd).
Value alpha1 = bcast(f32Cst(builder, 4.89352455891786e-03f));
@@ -589,25 +593,25 @@ TanhApproximation::matchAndRewrite(math::TanhOp op,
Value beta6 = bcast(f32Cst(builder, 1.19825839466702e-06f));
// Since the polynomials are odd/even, we need x^2.
- Value x2 = builder.create<arith::MulFOp>(x, x);
+ Value x2 = arith::MulFOp::create(builder, x, x);
// Evaluate the numerator polynomial p.
- Value p = builder.create<math::FmaOp>(x2, alpha13, alpha11);
- p = builder.create<math::FmaOp>(x2, p, alpha9);
- p = builder.create<math::FmaOp>(x2, p, alpha7);
- p = builder.create<math::FmaOp>(x2, p, alpha5);
- p = builder.create<math::FmaOp>(x2, p, alpha3);
- p = builder.create<math::FmaOp>(x2, p, alpha1);
- p = builder.create<arith::MulFOp>(x, p);
+ Value p = math::FmaOp::create(builder, x2, alpha13, alpha11);
+ p = math::FmaOp::create(builder, x2, p, alpha9);
+ p = math::FmaOp::create(builder, x2, p, alpha7);
+ p = math::FmaOp::create(builder, x2, p, alpha5);
+ p = math::FmaOp::create(builder, x2, p, alpha3);
+ p = math::FmaOp::create(builder, x2, p, alpha1);
+ p = arith::MulFOp::create(builder, x, p);
// Evaluate the denominator polynomial q.
- Value q = builder.create<math::FmaOp>(x2, beta6, beta4);
- q = builder.create<math::FmaOp>(x2, q, beta2);
- q = builder.create<math::FmaOp>(x2, q, beta0);
+ Value q = math::FmaOp::create(builder, x2, beta6, beta4);
+ q = math::FmaOp::create(builder, x2, q, beta2);
+ q = math::FmaOp::create(builder, x2, q, beta0);
// Divide the numerator by the denominator.
- Value res = builder.create<arith::SelectOp>(
- tinyMask, x, builder.create<arith::DivFOp>(p, q));
+ Value res = arith::SelectOp::create(builder, tinyMask, x,
+ arith::DivFOp::create(builder, p, q));
rewriter.replaceOp(op, res);
@@ -690,57 +694,57 @@ LogApproximationBase<Op>::logMatchAndRewrite(Op op, PatternRewriter &rewriter,
// e -= 1;
// x = x + x - 1.0;
// } else { x = x - 1.0; }
- Value mask = builder.create<arith::CmpFOp>(arith::CmpFPredicate::OLT, x,
- cstCephesSQRTHF);
- Value tmp = builder.create<arith::SelectOp>(mask, x, cstZero);
+ Value mask = arith::CmpFOp::create(builder, arith::CmpFPredicate::OLT, x,
+ cstCephesSQRTHF);
+ Value tmp = arith::SelectOp::create(builder, mask, x, cstZero);
- x = builder.create<arith::SubFOp>(x, cstOne);
- e = builder.create<arith::SubFOp>(
- e, builder.create<arith::SelectOp>(mask, cstOne, cstZero));
- x = builder.create<arith::AddFOp>(x, tmp);
+ x = arith::SubFOp::create(builder, x, cstOne);
+ e = arith::SubFOp::create(
+ builder, e, arith::SelectOp::create(builder, mask, cstOne, cstZero));
+ x = arith::AddFOp::create(builder, x, tmp);
- Value x2 = builder.create<arith::MulFOp>(x, x);
- Value x3 = builder.create<arith::MulFOp>(x2, x);
+ Value x2 = arith::MulFOp::create(builder, x, x);
+ Value x3 = arith::MulFOp::create(builder, x2, x);
// Evaluate the polynomial approximant of degree 8 in three parts.
Value y0, y1, y2;
- y0 = builder.create<math::FmaOp>(cstCephesLogP0, x, cstCephesLogP1);
- y1 = builder.create<math::FmaOp>(cstCephesLogP3, x, cstCephesLogP4);
- y2 = builder.create<math::FmaOp>(cstCephesLogP6, x, cstCephesLogP7);
- y0 = builder.create<math::FmaOp>(y0, x, cstCephesLogP2);
- y1 = builder.create<math::FmaOp>(y1, x, cstCephesLogP5);
- y2 = builder.create<math::FmaOp>(y2, x, cstCephesLogP8);
- y0 = builder.create<math::FmaOp>(y0, x3, y1);
- y0 = builder.create<math::FmaOp>(y0, x3, y2);
- y0 = builder.create<arith::MulFOp>(y0, x3);
-
- y0 = builder.create<math::FmaOp>(cstNegHalf, x2, y0);
- x = builder.create<arith::AddFOp>(x, y0);
+ y0 = math::FmaOp::create(builder, cstCephesLogP0, x, cstCephesLogP1);
+ y1 = math::FmaOp::create(builder, cstCephesLogP3, x, cstCephesLogP4);
+ y2 = math::FmaOp::create(builder, cstCephesLogP6, x, cstCephesLogP7);
+ y0 = math::FmaOp::create(builder, y0, x, cstCephesLogP2);
+ y1 = math::FmaOp::create(builder, y1, x, cstCephesLogP5);
+ y2 = math::FmaOp::create(builder, y2, x, cstCephesLogP8);
+ y0 = math::FmaOp::create(builder, y0, x3, y1);
+ y0 = math::FmaOp::create(builder, y0, x3, y2);
+ y0 = arith::MulFOp::create(builder, y0, x3);
+
+ y0 = math::FmaOp::create(builder, cstNegHalf, x2, y0);
+ x = arith::AddFOp::create(builder, x, y0);
if (base2) {
Value cstLog2e = bcast(f32Cst(builder, static_cast<float>(LOG2E_VALUE)));
- x = builder.create<math::FmaOp>(x, cstLog2e, e);
+ x = math::FmaOp::create(builder, x, cstLog2e, e);
} else {
Value cstLn2 = bcast(f32Cst(builder, static_cast<float>(LN2_VALUE)));
- x = builder.create<math::FmaOp>(e, cstLn2, x);
+ x = math::FmaOp::create(builder, e, cstLn2, x);
}
- Value invalidMask = builder.create<arith::CmpFOp>(arith::CmpFPredicate::ULT,
- op.getOperand(), cstZero);
- Value zeroMask = builder.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ,
- op.getOperand(), cstZero);
- Value posInfMask = builder.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ,
- op.getOperand(), cstPosInf);
+ Value invalidMask = arith::CmpFOp::create(builder, arith::CmpFPredicate::ULT,
+ op.getOperand(), cstZero);
+ Value zeroMask = arith::CmpFOp::create(builder, arith::CmpFPredicate::OEQ,
+ op.getOperand(), cstZero);
+ Value posInfMask = arith::CmpFOp::create(builder, arith::CmpFPredicate::OEQ,
+ op.getOperand(), cstPosInf);
// Filter out invalid values:
// • x == 0 -> -INF
// • x < 0 -> NAN
// • x == +INF -> +INF
- Value aproximation = builder.create<arith::SelectOp>(
- zeroMask, cstMinusInf,
- builder.create<arith::SelectOp>(
- invalidMask, cstNan,
- builder.create<arith::SelectOp>(posInfMask, cstPosInf, x)));
+ Value aproximation = arith::SelectOp::create(
+ builder, zeroMask, cstMinusInf,
+ arith::SelectOp::create(
+ builder, invalidMask, cstNan,
+ arith::SelectOp::create(builder, posInfMask, cstPosInf, x)));
rewriter.replaceOp(op, aproximation);
@@ -805,17 +809,18 @@ Log1pApproximation::matchAndRewrite(math::Log1pOp op,
// "logLarge" below.
Value cstOne = bcast(f32Cst(builder, 1.0f));
Value x = op.getOperand();
- Value u = builder.create<arith::AddFOp>(x, cstOne);
+ Value u = arith::AddFOp::create(builder, x, cstOne);
Value uSmall =
- builder.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, u, cstOne);
- Value logU = builder.create<math::LogOp>(u);
+ arith::CmpFOp::create(builder, arith::CmpFPredicate::OEQ, u, cstOne);
+ Value logU = math::LogOp::create(builder, u);
Value uInf =
- builder.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, u, logU);
- Value logLarge = builder.create<arith::MulFOp>(
- x, builder.create<arith::DivFOp>(
- logU, builder.create<arith::SubFOp>(u, cstOne)));
- Value approximation = builder.create<arith::SelectOp>(
- builder.create<arith::OrIOp>(uSmall, uInf), x, logLarge);
+ arith::CmpFOp::create(builder, arith::CmpFPredicate::OEQ, u, logU);
+ Value logLarge = arith::MulFOp::create(
+ builder, x,
+ arith::DivFOp::create(builder, logU,
+ arith::SubFOp::create(builder, u, cstOne)));
+ Value approximation = arith::SelectOp::create(
+ builder, arith::OrIOp::create(builder, uSmall, uInf), x, logLarge);
rewriter.replaceOp(op, approximation);
return success();
}
@@ -853,36 +858,37 @@ AsinPolynomialApproximation::matchAndRewrite(math::AsinOp op,
};
auto fma = [&](Value a, Value b, Value c) -> Value {
- return builder.create<math::FmaOp>(a, b, c);
+ return math::FmaOp::create(builder, a, b, c);
};
auto mul = [&](Value a, Value b) -> Value {
- return builder.create<arith::MulFOp>(a, b);
+ return arith::MulFOp::create(builder, a, b);
};
auto sub = [&](Value a, Value b) -> Value {
- return builder.create<arith::SubFOp>(a, b);
+ return arith::SubFOp::create(builder, a, b);
};
- auto abs = [&](Value a) -> Value { return builder.create<math::AbsFOp>(a); };
+ auto abs = [&](Value a) -> Value { return math::AbsFOp::create(builder, a); };
- auto sqrt = [&](Value a) -> Value { return builder.create<math::SqrtOp>(a); };
+ auto sqrt = [&](Value a) -> Value {
+ return math::SqrtOp::create(builder, a);
+ };
auto scopy = [&](Value a, Value b) -> Value {
- return builder.create<math::CopySignOp>(a, b);
+ return math::CopySignOp::create(builder, a, b);
};
auto sel = [&](Value a, Value b, Value c) -> Value {
- return builder.create<arith::SelectOp>(a, b, c);
+ return arith::SelectOp::create(builder, a, b, c);
};
Value abso = abs(operand);
Value aa = mul(operand, operand);
Value opp = sqrt(sub(bcast(floatCst(builder, 1.0, elementType)), aa));
- Value gt =
- builder.create<arith::CmpFOp>(arith::CmpFPredicate::OGT, aa,
- bcast(floatCst(builder, 0.5, elementType)));
+ Value gt = arith::CmpFOp::create(builder, arith::CmpFPredicate::OGT, aa,
+ bcast(floatCst(builder, 0.5, elementType)));
Value x = sel(gt, opp, abso);
@@ -948,51 +954,51 @@ AcosPolynomialApproximation::matchAndRewrite(math::AcosOp op,
};
auto fma = [&](Value a, Value b, Value c) -> Value {
- return builder.create<math::FmaOp>(a, b, c);
+ return math::FmaOp::create(builder, a, b, c);
};
auto mul = [&](Value a, Value b) -> Value {
- return builder.create<arith::MulFOp>(a, b);
+ return arith::MulFOp::create(builder, a, b);
};
- Value negOperand = builder.create<arith::NegFOp>(operand);
+ Value negOperand = arith::NegFOp::create(builder, operand);
Value zero = bcast(floatCst(builder, 0.0, elementType));
Value half = bcast(floatCst(builder, 0.5, elementType));
Value negOne = bcast(floatCst(builder, -1.0, elementType));
Value selR =
- builder.create<arith::CmpFOp>(arith::CmpFPredicate::OGT, operand, zero);
- Value r = builder.create<arith::SelectOp>(selR, negOperand, operand);
+ arith::CmpFOp::create(builder, arith::CmpFPredicate::OGT, operand, zero);
+ Value r = arith::SelectOp::create(builder, selR, negOperand, operand);
Value chkConst = bcast(floatCst(builder, -0.5625, elementType));
Value firstPred =
- builder.create<arith::CmpFOp>(arith::CmpFPredicate::OGT, r, chkConst);
+ arith::CmpFOp::create(builder, arith::CmpFPredicate::OGT, r, chkConst);
Value trueVal =
fma(bcast(floatCst(builder, 9.3282184640716537e-1, elementType)),
bcast(floatCst(builder, 1.6839188885261840e+0, elementType)),
- builder.create<math::AsinOp>(r));
+ math::AsinOp::create(builder, r));
- Value falseVal = builder.create<math::SqrtOp>(fma(half, r, half));
- falseVal = builder.create<math::AsinOp>(falseVal);
+ Value falseVal = math::SqrtOp::create(builder, fma(half, r, half));
+ falseVal = math::AsinOp::create(builder, falseVal);
falseVal = mul(bcast(floatCst(builder, 2.0, elementType)), falseVal);
- r = builder.create<arith::SelectOp>(firstPred, trueVal, falseVal);
+ r = arith::SelectOp::create(builder, firstPred, trueVal, falseVal);
// Check whether the operand lies in between [-1.0, 0.0).
- Value greaterThanNegOne =
- builder.create<arith::CmpFOp>(arith::CmpFPredicate::OGE, operand, negOne);
+ Value greaterThanNegOne = arith::CmpFOp::create(
+ builder, arith::CmpFPredicate::OGE, operand, negOne);
Value lessThanZero =
- builder.create<arith::CmpFOp>(arith::CmpFPredicate::OLT, operand, zero);
+ arith::CmpFOp::create(builder, arith::CmpFPredicate::OLT, operand, zero);
Value betweenNegOneZero =
- builder.create<arith::AndIOp>(greaterThanNegOne, lessThanZero);
+ arith::AndIOp::create(builder, greaterThanNegOne, lessThanZero);
trueVal = fma(bcast(floatCst(builder, 1.8656436928143307e+0, elementType)),
bcast(floatCst(builder, 1.6839188885261840e+0, elementType)),
- builder.create<arith::NegFOp>(r));
+ arith::NegFOp::create(builder, r));
Value finalVal =
- builder.create<arith::SelectOp>(betweenNegOneZero, trueVal, r);
+ arith::SelectOp::create(builder, betweenNegOneZero, trueVal, r);
rewriter.replaceOp(op, finalVal);
return success();
@@ -1075,9 +1081,9 @@ ErfPolynomialApproximation::matchAndRewrite(math::ErfOp op,
bounds[2] = bcast(floatCst(builder, 3.75f, elementType));
Value isNegativeArg =
- builder.create<arith::CmpFOp>(arith::CmpFPredicate::OLT, operand, zero);
- Value negArg = builder.create<arith::NegFOp>(operand);
- Value x = builder.create<arith::SelectOp>(isNegativeArg, negArg, operand);
+ arith::CmpFOp::create(builder, arith::CmpFPredicate::OLT, operand, zero);
+ Value negArg = arith::NegFOp::create(builder, operand);
+ Value x = arith::SelectOp::create(builder, isNegativeArg, negArg, operand);
Value offset = offsets[0];
Value p[polyDegree + 1];
@@ -1091,30 +1097,30 @@ ErfPolynomialApproximation::matchAndRewrite(math::ErfOp op,
Value isLessThanBound[intervalsCount];
for (int j = 0; j < intervalsCount - 1; ++j) {
isLessThanBound[j] =
- builder.create<arith::CmpFOp>(arith::CmpFPredicate::OLT, x, bounds[j]);
+ arith::CmpFOp::create(builder, arith::CmpFPredicate::OLT, x, bounds[j]);
for (int i = 0; i <= polyDegree; ++i) {
- p[i] = builder.create<arith::SelectOp>(isLessThanBound[j], p[i],
- pp[j + 1][i]);
- q[i] = builder.create<arith::SelectOp>(isLessThanBound[j], q[i],
- qq[j + 1][i]);
+ p[i] = arith::SelectOp::create(builder, isLessThanBound[j], p[i],
+ pp[j + 1][i]);
+ q[i] = arith::SelectOp::create(builder, isLessThanBound[j], q[i],
+ qq[j + 1][i]);
}
- offset = builder.create<arith::SelectOp>(isLessThanBound[j], offset,
- offsets[j + 1]);
+ offset = arith::SelectOp::create(builder, isLessThanBound[j], offset,
+ offsets[j + 1]);
}
- isLessThanBound[intervalsCount - 1] = builder.create<arith::CmpFOp>(
- arith::CmpFPredicate::ULT, x, bounds[intervalsCount - 1]);
+ isLessThanBound[intervalsCount - 1] = arith::CmpFOp::create(
+ builder, arith::CmpFPredicate::ULT, x, bounds[intervalsCount - 1]);
Value pPoly = makePolynomialCalculation(builder, p, x);
Value qPoly = makePolynomialCalculation(builder, q, x);
- Value rationalPoly = builder.create<arith::DivFOp>(pPoly, qPoly);
- Value formula = builder.create<arith::AddFOp>(offset, rationalPoly);
- formula = builder.create<arith::SelectOp>(isLessThanBound[intervalsCount - 1],
- formula, one);
+ Value rationalPoly = arith::DivFOp::create(builder, pPoly, qPoly);
+ Value formula = arith::AddFOp::create(builder, offset, rationalPoly);
+ formula = arith::SelectOp::create(
+ builder, isLessThanBound[intervalsCount - 1], formula, one);
// erf is odd function: erf(x) = -erf(-x).
- Value negFormula = builder.create<arith::NegFOp>(formula);
+ Value negFormula = arith::NegFOp::create(builder, formula);
Value res =
- builder.create<arith::SelectOp>(isNegativeArg, negFormula, formula);
+ arith::SelectOp::create(builder, isNegativeArg, negFormula, formula);
rewriter.replaceOp(op, res);
@@ -1155,65 +1161,67 @@ ErfcPolynomialApproximation::matchAndRewrite(math::ErfcOp op,
Value posInf = bcast(floatCst(builder, INFINITY, et));
Value clampVal = bcast(floatCst(builder, 10.0546875f, et));
- Value a = builder.create<math::AbsFOp>(x);
- Value p = builder.create<arith::AddFOp>(a, pos2);
- Value r = builder.create<arith::DivFOp>(one, p);
- Value q = builder.create<math::FmaOp>(neg4, r, one);
- Value t = builder.create<math::FmaOp>(builder.create<arith::AddFOp>(q, one),
- neg2, a);
- Value e = builder.create<math::FmaOp>(builder.create<arith::NegFOp>(a), q, t);
- q = builder.create<math::FmaOp>(r, e, q);
+ Value a = math::AbsFOp::create(builder, x);
+ Value p = arith::AddFOp::create(builder, a, pos2);
+ Value r = arith::DivFOp::create(builder, one, p);
+ Value q = math::FmaOp::create(builder, neg4, r, one);
+ Value t = math::FmaOp::create(builder, arith::AddFOp::create(builder, q, one),
+ neg2, a);
+ Value e =
+ math::FmaOp::create(builder, arith::NegFOp::create(builder, a), q, t);
+ q = math::FmaOp::create(builder, r, e, q);
p = bcast(floatCst(builder, -0x1.a4a000p-12f, et)); // -4.01139259e-4
Value c1 = bcast(floatCst(builder, -0x1.42a260p-10f, et)); // -1.23075210e-3
- p = builder.create<math::FmaOp>(p, q, c1);
+ p = math::FmaOp::create(builder, p, q, c1);
Value c2 = bcast(floatCst(builder, 0x1.585714p-10f, et)); // 1.31355342e-3
- p = builder.create<math::FmaOp>(p, q, c2);
+ p = math::FmaOp::create(builder, p, q, c2);
Value c3 = bcast(floatCst(builder, 0x1.1adcc4p-07f, et)); // 8.63227434e-3
- p = builder.create<math::FmaOp>(p, q, c3);
+ p = math::FmaOp::create(builder, p, q, c3);
Value c4 = bcast(floatCst(builder, -0x1.081b82p-07f, et)); // -8.05991981e-3
- p = builder.create<math::FmaOp>(p, q, c4);
+ p = math::FmaOp::create(builder, p, q, c4);
Value c5 = bcast(floatCst(builder, -0x1.bc0b6ap-05f, et)); // -5.42046614e-2
- p = builder.create<math::FmaOp>(p, q, c5);
+ p = math::FmaOp::create(builder, p, q, c5);
Value c6 = bcast(floatCst(builder, 0x1.4ffc46p-03f, et)); // 1.64055392e-1
- p = builder.create<math::FmaOp>(p, q, c6);
+ p = math::FmaOp::create(builder, p, q, c6);
Value c7 = bcast(floatCst(builder, -0x1.540840p-03f, et)); // -1.66031361e-1
- p = builder.create<math::FmaOp>(p, q, c7);
+ p = math::FmaOp::create(builder, p, q, c7);
Value c8 = bcast(floatCst(builder, -0x1.7bf616p-04f, et)); // -9.27639827e-2
- p = builder.create<math::FmaOp>(p, q, c8);
+ p = math::FmaOp::create(builder, p, q, c8);
Value c9 = bcast(floatCst(builder, 0x1.1ba03ap-02f, et)); // 2.76978403e-1
- p = builder.create<math::FmaOp>(p, q, c9);
-
- Value d = builder.create<math::FmaOp>(pos2, a, one);
- r = builder.create<arith::DivFOp>(one, d);
- q = builder.create<math::FmaOp>(p, r, r);
- Value negfa = builder.create<arith::NegFOp>(a);
- Value fmaqah = builder.create<math::FmaOp>(q, negfa, onehalf);
- Value psubq = builder.create<arith::SubFOp>(p, q);
- e = builder.create<math::FmaOp>(fmaqah, pos2, psubq);
- r = builder.create<math::FmaOp>(e, r, q);
-
- Value s = builder.create<arith::MulFOp>(a, a);
- e = builder.create<math::ExpOp>(builder.create<arith::NegFOp>(s));
-
- t = builder.create<math::FmaOp>(builder.create<arith::NegFOp>(a), a, s);
- r = builder.create<math::FmaOp>(
- r, e,
- builder.create<arith::MulFOp>(builder.create<arith::MulFOp>(r, e), t));
-
- Value isNotLessThanInf = builder.create<arith::XOrIOp>(
- builder.create<arith::CmpFOp>(arith::CmpFPredicate::OLT, a, posInf),
+ p = math::FmaOp::create(builder, p, q, c9);
+
+ Value d = math::FmaOp::create(builder, pos2, a, one);
+ r = arith::DivFOp::create(builder, one, d);
+ q = math::FmaOp::create(builder, p, r, r);
+ Value negfa = arith::NegFOp::create(builder, a);
+ Value fmaqah = math::FmaOp::create(builder, q, negfa, onehalf);
+ Value psubq = arith::SubFOp::create(builder, p, q);
+ e = math::FmaOp::create(builder, fmaqah, pos2, psubq);
+ r = math::FmaOp::create(builder, e, r, q);
+
+ Value s = arith::MulFOp::create(builder, a, a);
+ e = math::ExpOp::create(builder, arith::NegFOp::create(builder, s));
+
+ t = math::FmaOp::create(builder, arith::NegFOp::create(builder, a), a, s);
+ r = math::FmaOp::create(
+ builder, r, e,
+ arith::MulFOp::create(builder, arith::MulFOp::create(builder, r, e), t));
+
+ Value isNotLessThanInf = arith::XOrIOp::create(
+ builder,
+ arith::CmpFOp::create(builder, arith::CmpFPredicate::OLT, a, posInf),
trueValue);
- r = builder.create<arith::SelectOp>(isNotLessThanInf,
- builder.create<arith::AddFOp>(x, x), r);
+ r = arith::SelectOp::create(builder, isNotLessThanInf,
+ arith::AddFOp::create(builder, x, x), r);
Value isGreaterThanClamp =
- builder.create<arith::CmpFOp>(arith::CmpFPredicate::OGT, a, clampVal);
- r = builder.create<arith::SelectOp>(isGreaterThanClamp, zero, r);
+ arith::CmpFOp::create(builder, arith::CmpFPredicate::OGT, a, clampVal);
+ r = arith::SelectOp::create(builder, isGreaterThanClamp, zero, r);
Value isNegative =
- builder.create<arith::CmpFOp>(arith::CmpFPredicate::OLT, x, zero);
- r = builder.create<arith::SelectOp>(
- isNegative, builder.create<arith::SubFOp>(pos2, r), r);
+ arith::CmpFOp::create(builder, arith::CmpFPredicate::OLT, x, zero);
+ r = arith::SelectOp::create(builder, isNegative,
+ arith::SubFOp::create(builder, pos2, r), r);
rewriter.replaceOp(op, r);
return success();
@@ -1235,8 +1243,9 @@ Value clampWithNormals(ImplicitLocOpBuilder &builder,
};
auto selectCmp = [&builder](auto pred, Value value, Value bound) {
- return builder.create<arith::SelectOp>(
- builder.create<arith::CmpFOp>(pred, value, bound), value, bound);
+ return arith::SelectOp::create(
+ builder, arith::CmpFOp::create(builder, pred, value, bound), value,
+ bound);
};
// Note: prefer UGE/ULE vs. UGT/ULT, since they generate vmaxps/vminps vs.
@@ -1268,17 +1277,17 @@ ExpApproximation::matchAndRewrite(math::ExpOp op,
ImplicitLocOpBuilder builder(op->getLoc(), rewriter);
auto add = [&](Value a, Value b) -> Value {
- return builder.create<arith::AddFOp>(a, b);
+ return arith::AddFOp::create(builder, a, b);
};
auto bcast = [&](Value value) -> Value {
return broadcast(builder, value, shape);
};
- auto floor = [&](Value a) { return builder.create<math::FloorOp>(a); };
+ auto floor = [&](Value a) { return math::FloorOp::create(builder, a); };
auto fmla = [&](Value a, Value b, Value c) {
- return builder.create<math::FmaOp>(a, b, c);
+ return math::FmaOp::create(builder, a, b, c);
};
auto mul = [&](Value a, Value b) -> Value {
- return builder.create<arith::MulFOp>(a, b);
+ return arith::MulFOp::create(builder, a, b);
};
// Polynomial approximation from Cephes.
@@ -1382,7 +1391,7 @@ ExpApproximation::matchAndRewrite(math::ExpOp op,
// Convert n' to an i32. This is safe because we clamped it above.
auto i32Vec = broadcast(builder.getI32Type(), shape);
- Value nI32 = builder.create<arith::FPToSIOp>(i32Vec, n);
+ Value nI32 = arith::FPToSIOp::create(builder, i32Vec, n);
// Creates the value 2^n' if -126 <= n' <= 127 and 0 if n' = -127.
Value pow2 = exp2I32(builder, nI32);
@@ -1430,26 +1439,26 @@ ExpM1Approximation::matchAndRewrite(math::ExpM1Op op,
Value cstOne = bcast(f32Cst(builder, 1.0f));
Value cstNegOne = bcast(f32Cst(builder, -1.0f));
Value x = op.getOperand();
- Value u = builder.create<math::ExpOp>(x);
+ Value u = math::ExpOp::create(builder, x);
Value uEqOneOrNaN =
- builder.create<arith::CmpFOp>(arith::CmpFPredicate::UEQ, u, cstOne);
- Value uMinusOne = builder.create<arith::SubFOp>(u, cstOne);
- Value uMinusOneEqNegOne = builder.create<arith::CmpFOp>(
- arith::CmpFPredicate::OEQ, uMinusOne, cstNegOne);
+ arith::CmpFOp::create(builder, arith::CmpFPredicate::UEQ, u, cstOne);
+ Value uMinusOne = arith::SubFOp::create(builder, u, cstOne);
+ Value uMinusOneEqNegOne = arith::CmpFOp::create(
+ builder, arith::CmpFPredicate::OEQ, uMinusOne, cstNegOne);
// logU = log(u) ~= x
- Value logU = builder.create<math::LogOp>(u);
+ Value logU = math::LogOp::create(builder, u);
// Detect exp(x) = +inf; written this way to avoid having to form +inf.
Value isInf =
- builder.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, logU, u);
+ arith::CmpFOp::create(builder, arith::CmpFPredicate::OEQ, logU, u);
// (u - 1) * (x / ~x)
- Value expm1 = builder.create<arith::MulFOp>(
- uMinusOne, builder.create<arith::DivFOp>(x, logU));
- expm1 = builder.create<arith::SelectOp>(isInf, u, expm1);
- Value approximation = builder.create<arith::SelectOp>(
- uEqOneOrNaN, x,
- builder.create<arith::SelectOp>(uMinusOneEqNegOne, cstNegOne, expm1));
+ Value expm1 = arith::MulFOp::create(builder, uMinusOne,
+ arith::DivFOp::create(builder, x, logU));
+ expm1 = arith::SelectOp::create(builder, isInf, u, expm1);
+ Value approximation = arith::SelectOp::create(
+ builder, uEqOneOrNaN, x,
+ arith::SelectOp::create(builder, uMinusOneEqNegOne, cstNegOne, expm1));
rewriter.replaceOp(op, approximation);
return success();
}
@@ -1494,40 +1503,40 @@ LogicalResult SinAndCosApproximation<isSine, OpTy>::matchAndRewrite(
return broadcast(builder, value, shape);
};
auto mul = [&](Value a, Value b) -> Value {
- return builder.create<arith::MulFOp>(a, b);
+ return arith::MulFOp::create(builder, a, b);
};
auto sub = [&](Value a, Value b) -> Value {
- return builder.create<arith::SubFOp>(a, b);
+ return arith::SubFOp::create(builder, a, b);
};
- auto floor = [&](Value a) { return builder.create<math::FloorOp>(a); };
+ auto floor = [&](Value a) { return math::FloorOp::create(builder, a); };
auto i32Vec = broadcast(builder.getI32Type(), shape);
auto fPToSingedInteger = [&](Value a) -> Value {
- return builder.create<arith::FPToSIOp>(i32Vec, a);
+ return arith::FPToSIOp::create(builder, i32Vec, a);
};
auto modulo4 = [&](Value a) -> Value {
- return builder.create<arith::AndIOp>(a, bcast(i32Cst(builder, 3)));
+ return arith::AndIOp::create(builder, a, bcast(i32Cst(builder, 3)));
};
auto isEqualTo = [&](Value a, Value b) -> Value {
- return builder.create<arith::CmpIOp>(arith::CmpIPredicate::eq, a, b);
+ return arith::CmpIOp::create(builder, arith::CmpIPredicate::eq, a, b);
};
auto isGreaterThan = [&](Value a, Value b) -> Value {
- return builder.create<arith::CmpIOp>(arith::CmpIPredicate::sgt, a, b);
+ return arith::CmpIOp::create(builder, arith::CmpIPredicate::sgt, a, b);
};
auto select = [&](Value cond, Value t, Value f) -> Value {
- return builder.create<arith::SelectOp>(cond, t, f);
+ return arith::SelectOp::create(builder, cond, t, f);
};
auto fmla = [&](Value a, Value b, Value c) {
- return builder.create<math::FmaOp>(a, b, c);
+ return math::FmaOp::create(builder, a, b, c);
};
auto bitwiseOr = [&](Value a, Value b) {
- return builder.create<arith::OrIOp>(a, b);
+ return arith::OrIOp::create(builder, a, b);
};
Value twoOverPi = bcast(f32Cst(builder, (float)TWO_OVER_PI));
@@ -1624,7 +1633,7 @@ CbrtApproximation::matchAndRewrite(math::CbrtOp op,
intTy = broadcast(intTy, shape);
auto bconst = [&](TypedAttr attr) -> Value {
- Value value = b.create<arith::ConstantOp>(attr);
+ Value value = arith::ConstantOp::create(b, attr);
return broadcast(b, value, shape);
};
@@ -1641,44 +1650,44 @@ CbrtApproximation::matchAndRewrite(math::CbrtOp op,
// union {int ix; float x;};
// x = x0;
// ix = ix/4 + ix/16;
- Value absValue = b.create<math::AbsFOp>(operand);
- Value intValue = b.create<arith::BitcastOp>(intTy, absValue);
- Value divideBy4 = b.create<arith::ShRSIOp>(intValue, intTwo);
- Value divideBy16 = b.create<arith::ShRSIOp>(intValue, intFour);
- intValue = b.create<arith::AddIOp>(divideBy4, divideBy16);
+ Value absValue = math::AbsFOp::create(b, operand);
+ Value intValue = arith::BitcastOp::create(b, intTy, absValue);
+ Value divideBy4 = arith::ShRSIOp::create(b, intValue, intTwo);
+ Value divideBy16 = arith::ShRSIOp::create(b, intValue, intFour);
+ intValue = arith::AddIOp::create(b, divideBy4, divideBy16);
// ix = ix + ix/16;
- divideBy16 = b.create<arith::ShRSIOp>(intValue, intFour);
- intValue = b.create<arith::AddIOp>(intValue, divideBy16);
+ divideBy16 = arith::ShRSIOp::create(b, intValue, intFour);
+ intValue = arith::AddIOp::create(b, intValue, divideBy16);
// ix = ix + ix/256;
- Value divideBy256 = b.create<arith::ShRSIOp>(intValue, intEight);
- intValue = b.create<arith::AddIOp>(intValue, divideBy256);
+ Value divideBy256 = arith::ShRSIOp::create(b, intValue, intEight);
+ intValue = arith::AddIOp::create(b, intValue, divideBy256);
// ix = 0x2a5137a0 + ix;
- intValue = b.create<arith::AddIOp>(intValue, intMagic);
+ intValue = arith::AddIOp::create(b, intValue, intMagic);
// Perform one newtons step:
// x = 0.33333333f*(2.0f*x + x0/(x*x));
- Value floatValue = b.create<arith::BitcastOp>(floatTy, intValue);
- Value squared = b.create<arith::MulFOp>(floatValue, floatValue);
- Value mulTwo = b.create<arith::MulFOp>(floatValue, fpTwo);
- Value divSquared = b.create<arith::DivFOp>(absValue, squared);
- floatValue = b.create<arith::AddFOp>(mulTwo, divSquared);
- floatValue = b.create<arith::MulFOp>(floatValue, fpThird);
+ Value floatValue = arith::BitcastOp::create(b, floatTy, intValue);
+ Value squared = arith::MulFOp::create(b, floatValue, floatValue);
+ Value mulTwo = arith::MulFOp::create(b, floatValue, fpTwo);
+ Value divSquared = arith::DivFOp::create(b, absValue, squared);
+ floatValue = arith::AddFOp::create(b, mulTwo, divSquared);
+ floatValue = arith::MulFOp::create(b, floatValue, fpThird);
// x = 0.33333333f*(2.0f*x + x0/(x*x));
- squared = b.create<arith::MulFOp>(floatValue, floatValue);
- mulTwo = b.create<arith::MulFOp>(floatValue, fpTwo);
- divSquared = b.create<arith::DivFOp>(absValue, squared);
- floatValue = b.create<arith::AddFOp>(mulTwo, divSquared);
- floatValue = b.create<arith::MulFOp>(floatValue, fpThird);
+ squared = arith::MulFOp::create(b, floatValue, floatValue);
+ mulTwo = arith::MulFOp::create(b, floatValue, fpTwo);
+ divSquared = arith::DivFOp::create(b, absValue, squared);
+ floatValue = arith::AddFOp::create(b, mulTwo, divSquared);
+ floatValue = arith::MulFOp::create(b, floatValue, fpThird);
// Check for zero and restore sign.
Value isZero =
- b.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, absValue, fpZero);
- floatValue = b.create<arith::SelectOp>(isZero, fpZero, floatValue);
- floatValue = b.create<math::CopySignOp>(floatValue, operand);
+ arith::CmpFOp::create(b, arith::CmpFPredicate::OEQ, absValue, fpZero);
+ floatValue = arith::SelectOp::create(b, isZero, fpZero, floatValue);
+ floatValue = math::CopySignOp::create(b, floatValue, operand);
rewriter.replaceOp(op, floatValue);
return success();
@@ -1719,29 +1728,29 @@ RsqrtApproximation::matchAndRewrite(math::RsqrtOp op,
Value cstNegHalf = bcast(f32Cst(builder, -0.5f));
Value cstMinNormPos = bcast(f32FromBits(builder, 0x00800000u));
- Value negHalf = builder.create<arith::MulFOp>(op.getOperand(), cstNegHalf);
+ Value negHalf = arith::MulFOp::create(builder, op.getOperand(), cstNegHalf);
// Select only the inverse sqrt of positive normals (denormals are
// flushed to zero).
- Value ltMinMask = builder.create<arith::CmpFOp>(
- arith::CmpFPredicate::OLT, op.getOperand(), cstMinNormPos);
- Value infMask = builder.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ,
- op.getOperand(), cstPosInf);
- Value notNormalFiniteMask = builder.create<arith::OrIOp>(ltMinMask, infMask);
+ Value ltMinMask = arith::CmpFOp::create(builder, arith::CmpFPredicate::OLT,
+ op.getOperand(), cstMinNormPos);
+ Value infMask = arith::CmpFOp::create(builder, arith::CmpFPredicate::OEQ,
+ op.getOperand(), cstPosInf);
+ Value notNormalFiniteMask = arith::OrIOp::create(builder, ltMinMask, infMask);
// Compute an approximate result.
Value yApprox = handleMultidimensionalVectors(
builder, op->getOperands(), 8, [&builder](ValueRange operands) -> Value {
- return builder.create<x86vector::RsqrtOp>(operands);
+ return x86vector::RsqrtOp::create(builder, operands);
});
// Do a single step of Newton-Raphson iteration to improve the approximation.
// This uses the formula y_{n+1} = y_n * (1.5 - y_n * (0.5 * x) * y_n).
// It is essential to evaluate the inner term like this because forming
// y_n^2 may over- or underflow.
- Value inner = builder.create<arith::MulFOp>(negHalf, yApprox);
- Value fma = builder.create<math::FmaOp>(yApprox, inner, cstOnePointFive);
- Value yNewton = builder.create<arith::MulFOp>(yApprox, fma);
+ Value inner = arith::MulFOp::create(builder, negHalf, yApprox);
+ Value fma = math::FmaOp::create(builder, yApprox, inner, cstOnePointFive);
+ Value yNewton = arith::MulFOp::create(builder, yApprox, fma);
// Select the result of the Newton-Raphson step for positive normal arguments.
// For other arguments, choose the output of the intrinsic. This will
@@ -1749,7 +1758,7 @@ RsqrtApproximation::matchAndRewrite(math::RsqrtOp op,
// x is zero or a positive denormalized float (equivalent to flushing positive
// denormalized inputs to zero).
Value res =
- builder.create<arith::SelectOp>(notNormalFiniteMask, yApprox, yNewton);
+ arith::SelectOp::create(builder, notNormalFiniteMask, yApprox, yNewton);
rewriter.replaceOp(op, res);
return success();
diff --git a/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp b/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp
index cf506d1e7812ba..ed0df4e8c58122 100644
--- a/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp
+++ b/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp
@@ -91,7 +91,7 @@ SmallVector<Value> mlir::mesh::getMixedAsValues(OpBuilder b,
values.emplace_back(*(dyn++));
} else {
TypedAttr val = type == i64 ? b.getI64IntegerAttr(s) : b.getIndexAttr(s);
- values.emplace_back(b.create<arith::ConstantOp>(loc, type, val));
+ values.emplace_back(arith::ConstantOp::create(b, loc, type, val));
}
}
return values;
@@ -316,10 +316,10 @@ static void maybeInsertTargetShardingAnnotationImpl(MeshSharding sharding,
if (!newShardOp) {
auto shardingOp =
- builder.create<ShardingOp>(operandValue.getLoc(), sharding);
- newShardOp =
- builder.create<ShardOp>(operandValue.getLoc(), operandValue, shardingOp,
- /*annotate_for_users*/ false);
+ ShardingOp::create(builder, operandValue.getLoc(), sharding);
+ newShardOp = ShardOp::create(builder, operandValue.getLoc(), operandValue,
+ shardingOp,
+ /*annotate_for_users*/ false);
}
operandValue.replaceUsesWithIf(
newShardOp, [operandOp, operandValue](OpOperand &use) {
@@ -330,9 +330,9 @@ static void maybeInsertTargetShardingAnnotationImpl(MeshSharding sharding,
return;
}
- auto newShardOp2 = builder.create<ShardOp>(operandValue.getLoc(), newShardOp,
- newShardOp.getSharding(),
- /*annotate_for_users*/ true);
+ auto newShardOp2 = ShardOp::create(builder, operandValue.getLoc(), newShardOp,
+ newShardOp.getSharding(),
+ /*annotate_for_users*/ true);
newShardOp.getResult().replaceAllUsesExcept(newShardOp2, newShardOp2);
}
@@ -378,10 +378,10 @@ void mlir::mesh::maybeInsertSourceShardingAnnotation(MeshSharding sharding,
builder.setInsertionPoint(operandOp);
auto shardingOp =
- builder.create<ShardingOp>(operand.get().getLoc(), sharding);
+ ShardingOp::create(builder, operand.get().getLoc(), sharding);
auto newShardOp =
- builder.create<ShardOp>(operandValue.getLoc(), operandValue, shardingOp,
- /*annotate_for_users*/ true);
+ ShardOp::create(builder, operandValue.getLoc(), operandValue, shardingOp,
+ /*annotate_for_users*/ true);
IRRewriter rewriter(builder);
rewriter.replaceUsesWithIf(
operandValue, newShardOp, [operandOp, operandValue](OpOperand &use) {
@@ -395,8 +395,8 @@ void mlir::mesh::maybeInsertSourceShardingAnnotation(MeshSharding sharding,
builder.setInsertionPoint(newShardOp);
auto newPreceedingShardOp =
- builder.create<ShardOp>(operandValue.getLoc(), operandValue, shardingOp,
- /*annotate_for_users*/ false);
+ ShardOp::create(builder, operandValue.getLoc(), operandValue, shardingOp,
+ /*annotate_for_users*/ false);
rewriter.replaceUsesWithIf(
newShardOp.getSrc(), newPreceedingShardOp, [&newShardOp](OpOperand &use) {
return use.getOwner() == newShardOp.getOperation();
diff --git a/mlir/lib/Dialect/Mesh/Transforms/Simplifications.cpp b/mlir/lib/Dialect/Mesh/Transforms/Simplifications.cpp
index 9da3c9a3dd160a..db5fd6e494da18 100644
--- a/mlir/lib/Dialect/Mesh/Transforms/Simplifications.cpp
+++ b/mlir/lib/Dialect/Mesh/Transforms/Simplifications.cpp
@@ -91,15 +91,15 @@ struct MeshShapeFolder
newShapeOpMeshAxes.push_back(opMeshAxes[i]);
} else {
// Fold static mesh axes.
- newResults[i] = builder.create<arith::ConstantOp>(
- builder.getIndexAttr(meshAxisSize));
+ newResults[i] = arith::ConstantOp::create(
+ builder, builder.getIndexAttr(meshAxisSize));
}
}
// Leave only the dynamic mesh axes to be queried.
if (!newShapeOpMeshAxes.empty()) {
MeshShapeOp newShapeOp =
- builder.create<MeshShapeOp>(mesh.getSymName(), newShapeOpMeshAxes);
+ MeshShapeOp::create(builder, mesh.getSymName(), newShapeOpMeshAxes);
for (size_t i = 0; i < newShapeOp->getResults().size(); ++i) {
newResults[newToOldResultsIndexMap[i]] = newShapeOp->getResults()[i];
}
diff --git a/mlir/lib/Dialect/Mesh/Transforms/Spmdization.cpp b/mlir/lib/Dialect/Mesh/Transforms/Spmdization.cpp
index d7b7234f693473..1e54affa8198f2 100644
--- a/mlir/lib/Dialect/Mesh/Transforms/Spmdization.cpp
+++ b/mlir/lib/Dialect/Mesh/Transforms/Spmdization.cpp
@@ -265,7 +265,8 @@ unsplitLastAxisInResharding(ImplicitLocOpBuilder &builder,
targetShardingInUnsplitLastAxis(ctx, sourceSharding, splitTensorAxis);
ShapedType allGatherResultShape = allGatherResultShapeInUnsplitLastAxis(
sourceShard.getType(), mesh.getShape()[splitMeshAxis], splitTensorAxis);
- Value allGatherResult = builder.create<AllGatherOp>(
+ Value allGatherResult = AllGatherOp::create(
+ builder,
RankedTensorType::get(allGatherResultShape.getShape(),
allGatherResultShape.getElementType()),
mesh.getSymName(), SmallVector<MeshAxis>({splitMeshAxis}), sourceShard,
@@ -273,7 +274,8 @@ unsplitLastAxisInResharding(ImplicitLocOpBuilder &builder,
ShapedType targetShape =
shardShapedType(sourceUnshardedShape, mesh, targetSharding);
TypedValue<ShapedType> targetShard = cast<TypedValue<ShapedType>>(
- builder.create<tensor::CastOp>(targetShape, allGatherResult).getResult());
+ tensor::CastOp::create(builder, targetShape, allGatherResult)
+ .getResult());
return {targetShard, targetSharding};
}
@@ -398,7 +400,8 @@ moveLastSplitAxisInResharding(ImplicitLocOpBuilder &builder, MeshOp mesh,
ShapedType allToAllResultShape = allToAllResultShapeInMoveLastAxis(
sourceShard.getType(), mesh.getShape()[meshAxis], sourceTensorAxis,
targetTensorAxis);
- Value allToAllResult = builder.create<AllToAllOp>(
+ Value allToAllResult = AllToAllOp::create(
+ builder,
RankedTensorType::get(allToAllResultShape.getShape(),
allToAllResultShape.getElementType()),
mesh.getSymName(), SmallVector<MeshAxis>({meshAxis}), sourceShard,
@@ -406,7 +409,7 @@ moveLastSplitAxisInResharding(ImplicitLocOpBuilder &builder, MeshOp mesh,
ShapedType targetShape =
shardShapedType(sourceUnshardedShape, mesh, targetSharding);
TypedValue<ShapedType> targetShard = cast<TypedValue<ShapedType>>(
- builder.create<tensor::CastOp>(targetShape, allToAllResult).getResult());
+ tensor::CastOp::create(builder, targetShape, allToAllResult).getResult());
return {targetShard, targetSharding};
}
@@ -477,15 +480,16 @@ tryUpdateHaloInResharding(ImplicitLocOpBuilder &builder, MeshOp mesh,
// Extract core from source and copy into destination core.
auto noVals = ValueRange{};
- auto initVal = builder.create<tensor::EmptyOp>(
- sourceShard.getLoc(), outShape, sourceShard.getType().getElementType());
- auto core = builder.create<tensor::ExtractSliceOp>(
- sourceShard.getLoc(),
+ auto initVal =
+ tensor::EmptyOp::create(builder, sourceShard.getLoc(), outShape,
+ sourceShard.getType().getElementType());
+ auto core = tensor::ExtractSliceOp::create(
+ builder, sourceShard.getLoc(),
RankedTensorType::get(coreShape, sourceShard.getType().getElementType()),
sourceShard, noVals, noVals, noVals, srcCoreOffs, coreShape, strides);
- auto initOprnd = builder.create<tensor::InsertSliceOp>(
- sourceShard.getLoc(), core, initVal, noVals, noVals, noVals, tgtCoreOffs,
- coreShape, strides);
+ auto initOprnd = tensor::InsertSliceOp::create(
+ builder, sourceShard.getLoc(), core, initVal, noVals, noVals, noVals,
+ tgtCoreOffs, coreShape, strides);
// Finally update the halo.
auto updateHaloResult =
diff --git a/mlir/lib/Dialect/Mesh/Transforms/Transforms.cpp b/mlir/lib/Dialect/Mesh/Transforms/Transforms.cpp
index f08ef75d8a0043..6ae95ae1f8a49c 100644
--- a/mlir/lib/Dialect/Mesh/Transforms/Transforms.cpp
+++ b/mlir/lib/Dialect/Mesh/Transforms/Transforms.cpp
@@ -49,10 +49,11 @@ struct ProcessMultiIndexOpLowering
ImplicitLocOpBuilder builder(op->getLoc(), rewriter);
builder.setInsertionPointAfter(op.getOperation());
- Value linearIndex = builder.create<ProcessLinearIndexOp>(mesh);
- ValueRange meshShape = builder.create<MeshShapeOp>(mesh).getResults();
+ Value linearIndex = ProcessLinearIndexOp::create(builder, mesh);
+ ValueRange meshShape = MeshShapeOp::create(builder, mesh).getResults();
SmallVector<Value> completeMultiIndex =
- builder.create<affine::AffineDelinearizeIndexOp>(linearIndex, meshShape)
+ affine::AffineDelinearizeIndexOp::create(builder, linearIndex,
+ meshShape)
.getMultiIndex();
SmallVector<Value> multiIndex;
ArrayRef<MeshAxis> opMeshAxes = op.getAxes();
@@ -101,32 +102,33 @@ struct AllSliceOpLowering
ImplicitLocOpBuilder builder(op->getLoc(), rewriter);
builder.setInsertionPointAfter(op.getOperation());
- Value zero = builder.create<arith::ConstantOp>(builder.getIndexAttr(0));
+ Value zero = arith::ConstantOp::create(builder, builder.getIndexAttr(0));
Operation::result_range processInGroupMultiIndex =
- builder.create<ProcessMultiIndexOp>(mesh.getSymName(), op.getMeshAxes())
+ ProcessMultiIndexOp::create(builder, mesh.getSymName(),
+ op.getMeshAxes())
.getResults();
Operation::result_range processGroupShape =
- builder.create<MeshShapeOp>(mesh.getSymName(), op.getMeshAxes())
+ MeshShapeOp::create(builder, mesh.getSymName(), op.getMeshAxes())
.getResult();
Value processGroupSize =
createCollectiveProcessGroupSize(mesh, op.getMeshAxes(), builder);
int64_t sliceAxis = op.getSliceAxis().getSExtValue();
Value operandSliceAxisSize =
- builder.create<tensor::DimOp>(op.getOperand(), sliceAxis);
+ tensor::DimOp::create(builder, op.getOperand(), sliceAxis);
Value operandSliceAxisSizeModProcessGroupSize =
- builder.create<arith::RemUIOp>(operandSliceAxisSize, processGroupSize);
- Value isTargetShapeExactlyDivisible = builder.create<arith::CmpIOp>(
- arith::CmpIPredicate::eq, operandSliceAxisSizeModProcessGroupSize,
- zero);
- builder.create<cf::AssertOp>(isTargetShapeExactlyDivisible,
- "Slicing a tensor with axis size that is "
- "not exactly divisible by the "
- "mesh process group size is not supported.");
+ arith::RemUIOp::create(builder, operandSliceAxisSize, processGroupSize);
+ Value isTargetShapeExactlyDivisible =
+ arith::CmpIOp::create(builder, arith::CmpIPredicate::eq,
+ operandSliceAxisSizeModProcessGroupSize, zero);
+ cf::AssertOp::create(builder, isTargetShapeExactlyDivisible,
+ "Slicing a tensor with axis size that is "
+ "not exactly divisible by the "
+ "mesh process group size is not supported.");
Value resultSliceAxisSize =
- builder.create<arith::DivUIOp>(operandSliceAxisSize, processGroupSize);
+ arith::DivUIOp::create(builder, operandSliceAxisSize, processGroupSize);
OpFoldResult processInGroupLinearIndex = affine::linearizeIndex(
llvm::to_vector_of<OpFoldResult>(processInGroupMultiIndex),
llvm::to_vector_of<OpFoldResult>(processGroupShape), builder);
@@ -139,7 +141,7 @@ struct AllSliceOpLowering
if (i == sliceAxis) {
sizes.emplace_back(resultSliceAxisSize);
} else {
- Value dimSize = builder.create<tensor::DimOp>(op.getOperand(), i);
+ Value dimSize = tensor::DimOp::create(builder, op.getOperand(), i);
sizes.emplace_back(dimSize);
}
}
@@ -152,10 +154,10 @@ struct AllSliceOpLowering
resultSliceAxisSize);
SmallVector<OpFoldResult> strides(
operandType.getRank(), getAsIndexOpFoldResult(builder.getContext(), 1));
- Value slice = builder.create<tensor::ExtractSliceOp>(
- op.getOperand(), offsets, sizes, strides);
+ Value slice = tensor::ExtractSliceOp::create(builder, op.getOperand(),
+ offsets, sizes, strides);
Value newResult =
- builder.create<tensor::CastOp>(op.getResult().getType(), slice);
+ tensor::CastOp::create(builder, op.getResult().getType(), slice);
rewriter.replaceAllUsesWith(op.getResult(), newResult);
return success();
@@ -201,7 +203,7 @@ TypedValue<IndexType>
createCollectiveProcessGroupSize(MeshOp mesh, ArrayRef<MeshAxis> axes,
ImplicitLocOpBuilder &builder) {
Operation::result_range meshShape =
- builder.create<mesh::MeshShapeOp>(mesh, axes).getResults();
+ mesh::MeshShapeOp::create(builder, mesh, axes).getResults();
return cast<TypedValue<IndexType>>(arith::createProduct(
builder, builder.getLoc(), llvm::to_vector_of<Value>(meshShape),
builder.getIndexType()));
@@ -212,13 +214,14 @@ createProcessLinearIndex(StringRef mesh, ValueRange processInGroupMultiIndex,
ArrayRef<MeshAxis> meshAxes,
ImplicitLocOpBuilder &builder) {
Operation::result_range processGroupShape =
- builder.create<MeshShapeOp>(mesh, meshAxes).getResult();
+ MeshShapeOp::create(builder, mesh, meshAxes).getResult();
OpFoldResult processInGroupLinearIndex = affine::linearizeIndex(
llvm::to_vector_of<OpFoldResult>(processInGroupMultiIndex),
llvm::to_vector_of<OpFoldResult>(processGroupShape), builder);
auto res = dyn_cast<Value>(processInGroupLinearIndex);
if (!res)
- res = builder.create<arith::ConstantIndexOp>(
+ res = arith::ConstantIndexOp::create(
+ builder,
cast<IntegerAttr>(cast<Attribute>(processInGroupLinearIndex)).getInt());
return cast<TypedValue<IndexType>>(res);
}
@@ -227,7 +230,7 @@ TypedValue<IndexType> createProcessLinearIndex(StringRef mesh,
ArrayRef<MeshAxis> meshAxes,
ImplicitLocOpBuilder &builder) {
return createProcessLinearIndex(
- mesh, builder.create<ProcessMultiIndexOp>(mesh, meshAxes).getResults(),
+ mesh, ProcessMultiIndexOp::create(builder, mesh, meshAxes).getResults(),
meshAxes, builder);
}
} // namespace mlir::mesh
diff --git a/mlir/lib/Dialect/NVGPU/TransformOps/NVGPUTransformOps.cpp b/mlir/lib/Dialect/NVGPU/TransformOps/NVGPUTransformOps.cpp
index d2c94b124cdfb1..5d253c1199dc0e 100644
--- a/mlir/lib/Dialect/NVGPU/TransformOps/NVGPUTransformOps.cpp
+++ b/mlir/lib/Dialect/NVGPU/TransformOps/NVGPUTransformOps.cpp
@@ -333,15 +333,15 @@ static Operation *replaceOpWithPredicatedOp(RewriterBase &rewriter,
// srcElement = (pred) ? prevSrcElements : 0;
//
Location loc = asyncCopyOp->getLoc();
- Value dstElements =
- rewriter.create<arith::ConstantOp>(loc, asyncCopyOp.getDstElementsAttr());
+ Value dstElements = arith::ConstantOp::create(
+ rewriter, loc, asyncCopyOp.getDstElementsAttr());
Value originalSrcElement =
asyncCopyOp.getSrcElements() ? asyncCopyOp.getSrcElements() : dstElements;
- Value c0Index = rewriter.create<arith::ConstantIndexOp>(loc, 0);
- auto srcElements = rewriter.create<arith::SelectOp>(
- loc, predicate, originalSrcElement, c0Index);
- auto asyncCopyZeroFillOp = rewriter.create<nvgpu::DeviceAsyncCopyOp>(
- loc, nvgpu::DeviceAsyncTokenType::get(asyncCopyOp.getContext()),
+ Value c0Index = arith::ConstantIndexOp::create(rewriter, loc, 0);
+ auto srcElements = arith::SelectOp::create(rewriter, loc, predicate,
+ originalSrcElement, c0Index);
+ auto asyncCopyZeroFillOp = nvgpu::DeviceAsyncCopyOp::create(
+ rewriter, loc, nvgpu::DeviceAsyncTokenType::get(asyncCopyOp.getContext()),
asyncCopyOp.getDst(), asyncCopyOp.getDstIndices(), asyncCopyOp.getSrc(),
asyncCopyOp.getSrcIndices(), asyncCopyOp.getDstElements(), srcElements,
UnitAttr());
@@ -675,7 +675,7 @@ MmaSyncBuilder::buildMemRefLoads(OpBuilder &b, Location loc,
for (auto indexing : indexings) {
Value row = getValueOrCreateConstantIndexOp(b, loc, aff(indexing.row()));
Value col = getValueOrCreateConstantIndexOp(b, loc, aff(indexing.col()));
- auto load = b.create<memref::LoadOp>(loc, memref, ValueRange{row, col});
+ auto load = memref::LoadOp::create(b, loc, memref, ValueRange{row, col});
res.push_back(load);
}
return res;
@@ -688,7 +688,7 @@ Value MmaSyncBuilder::buildMmaSyncMemRefLoadOperand(
Type elementType = getElementTypeOrSelf(memref.getType());
auto vt = VectorType::get(vectorShape, elementType);
- Value res = b.create<vector::SplatOp>(loc, vt, loads[0]);
+ Value res = vector::SplatOp::create(b, loc, vt, loads[0]);
foreachIndividualVectorElement(
res,
/*applyFn=*/
@@ -697,7 +697,7 @@ Value MmaSyncBuilder::buildMmaSyncMemRefLoadOperand(
},
/*reduceFn=*/
[&](Value v, int64_t linearIdx, ArrayRef<int64_t> indices) {
- res = b.create<vector::InsertOp>(loc, v, res, indices);
+ res = vector::InsertOp::create(b, loc, v, res, indices);
});
return res;
@@ -715,7 +715,7 @@ SmallVector<Operation *> MmaSyncBuilder::buildMemRefStores(
Value row = getValueOrCreateConstantIndexOp(b, loc, aff(indexing.row()));
Value col = getValueOrCreateConstantIndexOp(b, loc, aff(indexing.col()));
Operation *store =
- b.create<memref::StoreOp>(loc, val, memref, ValueRange{row, col});
+ memref::StoreOp::create(b, loc, val, memref, ValueRange{row, col});
res.push_back(store);
}
return res;
@@ -730,7 +730,7 @@ SmallVector<Operation *> MmaSyncBuilder::buildMmaSyncMemRefStoreOperand(
vectorToStore,
/*applyFn=*/
[&](Value v, int64_t linearIdx, ArrayRef<int64_t> indices) {
- return b.create<vector::ExtractOp>(loc, vectorToStore, indices);
+ return vector::ExtractOp::create(b, loc, vectorToStore, indices);
},
/*reduceFn=*/
[&](Value v, int64_t linearIdx, ArrayRef<int64_t> indices) {
@@ -810,8 +810,8 @@ FailureOr<Operation *> MmaSyncBuilder::buildMmaSync(LinalgOp linalgOp) {
rhsIndexFn, rhsShape);
Value res = buildMmaSyncMemRefLoadOperand(b, loc, laneId, resMemRef,
resIndexFn, resShape);
- res = b.create<nvgpu::MmaSyncOp>(loc, lhs, rhs, res, info.mmaShape,
- info.tf32Enabled);
+ res = nvgpu::MmaSyncOp::create(b, loc, lhs, rhs, res, info.mmaShape,
+ info.tf32Enabled);
buildMmaSyncMemRefStoreOperand(b, loc, res, laneId, resMemRef, resIndexFn,
resShape);
return res.getDefiningOp();
@@ -832,8 +832,8 @@ DiagnosedSilenceableFailure transform::RewriteMatmulAsMmaSyncOp::applyToOne(
}
Location loc = linalgOp.getLoc();
// TODO: more robust computation of laneId, for now assume a single warp.
- Value laneId = rewriter.create<gpu::ThreadIdOp>(
- loc, rewriter.getIndexType(), gpu::Dimension::x);
+ Value laneId = gpu::ThreadIdOp::create(
+ rewriter, loc, rewriter.getIndexType(), gpu::Dimension::x);
if (succeeded(MmaSyncBuilder(rewriter, loc, laneId).buildMmaSync(linalgOp)))
fail = false;
}
@@ -897,12 +897,12 @@ SmallVector<Operation *> HopperBuilder::buildPredicateLoadsOnThread0(
ArrayRef<TypedValue<MemRefType>> sharedMemBuffers,
TypedValue<nvgpu::MBarrierGroupType> barrier) {
SmallVector<Operation *> loadOps;
- Value zero = rewriter.create<arith::ConstantIndexOp>(loc, 0);
- Value tidx = rewriter.create<gpu::ThreadIdOp>(loc, gpu::Dimension::x);
- Value cond =
- rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::eq, tidx, zero);
+ Value zero = arith::ConstantIndexOp::create(rewriter, loc, 0);
+ Value tidx = gpu::ThreadIdOp::create(rewriter, loc, gpu::Dimension::x);
+ Value cond = arith::CmpIOp::create(rewriter, loc, arith::CmpIPredicate::eq,
+ tidx, zero);
// clang-format off
- rewriter.create<scf::IfOp>(
+ scf::IfOp::create(rewriter,
/*location=*/loc,
/*conditional=*/cond,
/*thenBuilder=*/
@@ -917,14 +917,14 @@ SmallVector<Operation *> HopperBuilder::buildPredicateLoadsOnThread0(
// TODO: Note that cutlass predeclares the barrier arrive tx before the tma.async.load.
// This may or may not have perf implications.
buildBarrierArriveTx(barrier, sizes);
- rewriter.create<scf::YieldOp>(loc);
+ scf::YieldOp::create(rewriter, loc);
},
/*elseBuilder=*/
[&](OpBuilder &lb, Location loc) {
// TODO: is this for no-thread divergence?
// Should we just yield the size and hoist?
buildBarrierArriveTx(barrier, getAsIndexOpFoldResult(rewriter.getContext(), 0));
- rewriter.create<scf::YieldOp>(loc);
+ scf::YieldOp::create(rewriter, loc);
});
// clang-format on
return loadOps;
@@ -939,14 +939,15 @@ static Attribute getSharedAddressSpaceAttribute(OpBuilder &b) {
TypedValue<nvgpu::MBarrierGroupType>
HopperBuilder::buildAndInitBarrierInSharedMemory(OpFoldResult numThreads) {
auto sharedMemorySpace = getSharedAddressSpaceAttribute(rewriter);
- Value barrier = rewriter.create<nvgpu::MBarrierCreateOp>(
- loc,
+ Value barrier = nvgpu::MBarrierCreateOp::create(
+ rewriter, loc,
nvgpu::MBarrierGroupType::get(rewriter.getContext(), sharedMemorySpace));
- Value zero = rewriter.create<arith::ConstantIndexOp>(loc, 0);
- rewriter.create<nvgpu::MBarrierInitOp>(
- loc, barrier, getValueOrCreateConstantIndexOp(rewriter, loc, numThreads),
- zero, Value());
- rewriter.create<gpu::BarrierOp>(loc);
+ Value zero = arith::ConstantIndexOp::create(rewriter, loc, 0);
+ nvgpu::MBarrierInitOp::create(
+ rewriter, loc, barrier,
+ getValueOrCreateConstantIndexOp(rewriter, loc, numThreads), zero,
+ Value());
+ gpu::BarrierOp::create(rewriter, loc);
return cast<TypedValue<nvgpu::MBarrierGroupType>>(barrier);
}
@@ -955,8 +956,8 @@ HopperBuilder::buildGlobalMemRefDescriptor(TypedValue<MemRefType> memref,
gpu::LaunchOp launchOp) {
OpBuilder::InsertionGuard guard(rewriter);
rewriter.setInsertionPoint(launchOp);
- Value unrankedMemRef = rewriter.create<memref::CastOp>(
- loc,
+ Value unrankedMemRef = memref::CastOp::create(
+ rewriter, loc,
UnrankedMemRefType::get(memref.getType().getElementType(),
memref.getType().getMemorySpace()),
memref);
@@ -966,8 +967,8 @@ HopperBuilder::buildGlobalMemRefDescriptor(TypedValue<MemRefType> memref,
getValueOrCreateConstantIndexOp(rewriter, loc, mixedSizes);
auto sharedMemorySpace = getSharedAddressSpaceAttribute(rewriter);
- Value desc = rewriter.create<nvgpu::TmaCreateDescriptorOp>(
- loc,
+ Value desc = nvgpu::TmaCreateDescriptorOp::create(
+ rewriter, loc,
nvgpu::TensorMapDescriptorType::get(
rewriter.getContext(),
MemRefType::Builder(memref.getType())
@@ -985,10 +986,10 @@ OpFoldResult HopperBuilder::buildTmaAsyncLoad(
TypedValue<nvgpu::MBarrierGroupType> barrier,
SmallVectorImpl<Operation *> &loadOps) {
MLIRContext *ctx = rewriter.getContext();
- Value zero = rewriter.create<arith::ConstantIndexOp>(loc, 0);
- Operation *loadOp = rewriter.create<nvgpu::TmaAsyncLoadOp>(
- loc, sharedMemref, barrier, globalDesc, ValueRange{zero, zero}, zero,
- Value(), Value());
+ Value zero = arith::ConstantIndexOp::create(rewriter, loc, 0);
+ Operation *loadOp = nvgpu::TmaAsyncLoadOp::create(
+ rewriter, loc, sharedMemref, barrier, globalDesc, ValueRange{zero, zero},
+ zero, Value(), Value());
loadOps.push_back(loadOp);
auto mixedSizes = memref::getMixedSizes(rewriter, loc, sharedMemref);
SmallVector<AffineExpr> symbols(mixedSizes.size());
@@ -1012,23 +1013,23 @@ void HopperBuilder::buildBarrierArriveTx(
OpFoldResult size =
affine::makeComposedFoldedAffineApply(rewriter, loc, sumExpr, mixedSizes);
Value sizeVal = getValueOrCreateConstantIndexOp(rewriter, loc, size);
- Value zero = rewriter.create<arith::ConstantIndexOp>(loc, 0);
- rewriter.create<nvgpu::MBarrierArriveExpectTxOp>(loc, barrier, sizeVal, zero,
- Value());
+ Value zero = arith::ConstantIndexOp::create(rewriter, loc, 0);
+ nvgpu::MBarrierArriveExpectTxOp::create(rewriter, loc, barrier, sizeVal, zero,
+ Value());
}
void HopperBuilder::buildTryWaitParity(
TypedValue<nvgpu::MBarrierGroupType> barrier) {
Type i1 = rewriter.getI1Type();
- Value parity = rewriter.create<LLVM::ConstantOp>(loc, i1, 0);
+ Value parity = LLVM::ConstantOp::create(rewriter, loc, i1, 0);
// 10M is an arbitrary, not too small or too big number to specify the number
// of ticks before retry.
// TODO: hoist this in a default dialect constant.
Value ticksBeforeRetry =
- rewriter.create<arith::ConstantIndexOp>(loc, 10000000);
- Value zero = rewriter.create<arith::ConstantIndexOp>(loc, 0);
- rewriter.create<nvgpu::MBarrierTryWaitParityOp>(loc, barrier, parity,
- ticksBeforeRetry, zero);
+ arith::ConstantIndexOp::create(rewriter, loc, 10000000);
+ Value zero = arith::ConstantIndexOp::create(rewriter, loc, 0);
+ nvgpu::MBarrierTryWaitParityOp::create(rewriter, loc, barrier, parity,
+ ticksBeforeRetry, zero);
}
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/NVGPU/Transforms/CreateAsyncGroups.cpp b/mlir/lib/Dialect/NVGPU/Transforms/CreateAsyncGroups.cpp
index 47a0c7096de95d..b392ffeb13de60 100644
--- a/mlir/lib/Dialect/NVGPU/Transforms/CreateAsyncGroups.cpp
+++ b/mlir/lib/Dialect/NVGPU/Transforms/CreateAsyncGroups.cpp
@@ -109,17 +109,17 @@ static Value buildNumReadElements(OpBuilder &b, Location loc,
for (auto [pos, sz] : llvm::zip(transferMask->extractPosition,
transferMask->createMaskOp->getOperands())) {
Value cmp =
- b.create<arith::CmpIOp>(loc, arith::CmpIPredicate::slt,
- b.create<arith::ConstantIndexOp>(loc, pos), sz);
+ arith::CmpIOp::create(b, loc, arith::CmpIPredicate::slt,
+ arith::ConstantIndexOp::create(b, loc, pos), sz);
if (!cond) {
cond = cmp;
continue;
}
- cond = b.create<arith::AndIOp>(loc, cmp, cond);
+ cond = arith::AndIOp::create(b, loc, cmp, cond);
}
- return b.create<arith::SelectOp>(
- loc, cond, transferMask->createMaskOp->getOperands().back(),
- b.create<arith::ConstantIndexOp>(loc, 0));
+ return arith::SelectOp::create(
+ b, loc, cond, transferMask->createMaskOp->getOperands().back(),
+ arith::ConstantIndexOp::create(b, loc, 0));
}
/// Return "true" if the conversion to async copy is supported by "async copy".
@@ -251,8 +251,9 @@ void nvgpu::createAsyncGroups(RewriterBase &rewriter, Operation *op,
int64_t sizeInBytes =
(dstMemref.getElementTypeBitWidth() * numElements) / 8;
// bypass_l1 only possible with 16 byte transfer.
- Value token = rewriter.create<nvgpu::DeviceAsyncCopyOp>(
- writeOp->getLoc(), nvgpu::DeviceAsyncTokenType::get(op->getContext()),
+ Value token = nvgpu::DeviceAsyncCopyOp::create(
+ rewriter, writeOp->getLoc(),
+ nvgpu::DeviceAsyncTokenType::get(op->getContext()),
/*dst=*/storeBase, /*dstIndices=*/nvgpu::getIndices(writeOp),
/*src=*/loadBase,
/*srcIndices=*/nvgpu::getIndices(readOp),
@@ -264,11 +265,11 @@ void nvgpu::createAsyncGroups(RewriterBase &rewriter, Operation *op,
}
// Create the group and wait for it right after.
- Value groupToken = rewriter.create<nvgpu::DeviceAsyncCreateGroupOp>(
- op->getLoc(), nvgpu::DeviceAsyncTokenType::get(op->getContext()),
- tokens);
- rewriter.create<nvgpu::DeviceAsyncWaitOp>(op->getLoc(), groupToken,
- nullptr);
+ Value groupToken = nvgpu::DeviceAsyncCreateGroupOp::create(
+ rewriter, op->getLoc(),
+ nvgpu::DeviceAsyncTokenType::get(op->getContext()), tokens);
+ nvgpu::DeviceAsyncWaitOp::create(rewriter, op->getLoc(), groupToken,
+ nullptr);
// Clean up old stores.
for (Operation *writeOp : group)
rewriter.eraseOp(writeOp);
diff --git a/mlir/lib/Dialect/NVGPU/Transforms/OptimizeSharedMemory.cpp b/mlir/lib/Dialect/NVGPU/Transforms/OptimizeSharedMemory.cpp
index 44e7fa961da123..957b9632422a6e 100644
--- a/mlir/lib/Dialect/NVGPU/Transforms/OptimizeSharedMemory.cpp
+++ b/mlir/lib/Dialect/NVGPU/Transforms/OptimizeSharedMemory.cpp
@@ -74,27 +74,28 @@ static Value permuteVectorOffset(OpBuilder &b, Location loc,
int64_t mask = (1LL << (m - n)) - 1;
if (permuteEveryN > 1)
mask = mask << llvm::Log2_64(permuteEveryN);
- Value srcBits = b.create<arith::ConstantIndexOp>(loc, mask);
- srcBits = b.create<arith::AndIOp>(loc, src, srcBits);
+ Value srcBits = arith::ConstantIndexOp::create(b, loc, mask);
+ srcBits = arith::AndIOp::create(b, loc, src, srcBits);
// Use the src bits to permute the target bits b[N:M] containing the
// vector offset.
if (permuteEveryN > 1) {
int64_t shlBits = n - llvm::Log2_64(permuteEveryN);
if (shlBits > 0) {
- Value finalShiftVal = b.create<arith::ConstantIndexOp>(loc, shlBits);
+ Value finalShiftVal = arith::ConstantIndexOp::create(b, loc, shlBits);
srcBits = b.createOrFold<arith::ShLIOp>(loc, srcBits, finalShiftVal);
} else if (shlBits < 0) {
- Value finalShiftVal = b.create<arith::ConstantIndexOp>(loc, -1 * shlBits);
+ Value finalShiftVal =
+ arith::ConstantIndexOp::create(b, loc, -1 * shlBits);
srcBits = b.createOrFold<arith::ShRUIOp>(loc, srcBits, finalShiftVal);
}
} else {
- Value finalShiftVal = b.create<arith::ConstantIndexOp>(loc, n);
+ Value finalShiftVal = arith::ConstantIndexOp::create(b, loc, n);
srcBits = b.createOrFold<arith::ShLIOp>(loc, srcBits, finalShiftVal);
}
Value permutedVectorIdx =
- b.create<arith::XOrIOp>(loc, indices[tgtDim], srcBits);
+ arith::XOrIOp::create(b, loc, indices[tgtDim], srcBits);
return permutedVectorIdx;
}
diff --git a/mlir/lib/Dialect/Quant/Transforms/LowerQuantOps.cpp b/mlir/lib/Dialect/Quant/Transforms/LowerQuantOps.cpp
index 793db73575b4fb..58cd160948f7f5 100644
--- a/mlir/lib/Dialect/Quant/Transforms/LowerQuantOps.cpp
+++ b/mlir/lib/Dialect/Quant/Transforms/LowerQuantOps.cpp
@@ -72,7 +72,7 @@ Value getScalarOrTensorConstant(OpBuilder &builder, Location loc, Value scalar,
// Create tensor splat
auto tensorConstant =
- builder.create<tensor::SplatOp>(loc, scalar, referenceShape);
+ tensor::SplatOp::create(builder, loc, scalar, referenceShape);
return tensorConstant;
}
@@ -94,22 +94,22 @@ std::pair<Value, Value> flattenUnrankedTensor(OpBuilder &builder, Location loc,
// Get unranked input shape and total size
auto *context = builder.getContext();
auto shapeType = shape::getExtentTensorType(context);
- auto inputShape = builder.create<shape::ShapeOfOp>(loc, shapeType, input);
- Value inputSize = builder.create<shape::NumElementsOp>(
- loc, builder.getIndexType(), inputShape);
+ auto inputShape = shape::ShapeOfOp::create(builder, loc, shapeType, input);
+ Value inputSize = shape::NumElementsOp::create(
+ builder, loc, builder.getIndexType(), inputShape);
// Turn input size into 1D tensor
auto flatShapeType = shape::getExtentTensorType(context, 1);
auto flatInputShape =
- builder.create<tensor::FromElementsOp>(loc, flatShapeType, inputSize);
+ tensor::FromElementsOp::create(builder, loc, flatShapeType, inputSize);
// Reshape input tensor into 1D
auto inputType = cast<UnrankedTensorType>(input.getType());
auto elementType = inputType.getElementType();
auto flatInputType =
RankedTensorType::get({ShapedType::kDynamic}, elementType);
- auto flatInput = builder.create<tensor::ReshapeOp>(loc, flatInputType, input,
- flatInputShape);
+ auto flatInput = tensor::ReshapeOp::create(builder, loc, flatInputType, input,
+ flatInputShape);
return std::make_pair(flatInput, inputShape);
}
@@ -142,39 +142,40 @@ flattenUnrankedTensorAroundAxis(OpBuilder &builder, Location loc, Value input,
auto *context = builder.getContext();
auto indexType = builder.getIndexType();
auto shapeType = shape::getExtentTensorType(context);
- auto inputShape = builder.create<shape::ShapeOfOp>(loc, shapeType, input);
+ auto inputShape = shape::ShapeOfOp::create(builder, loc, shapeType, input);
// Get shape and sizes on left and right of axis
- auto axisValue = builder.create<arith::ConstantIndexOp>(loc, axis);
- auto axisNextValue = builder.create<arith::ConstantIndexOp>(loc, axis + 1);
+ auto axisValue = arith::ConstantIndexOp::create(builder, loc, axis);
+ auto axisNextValue = arith::ConstantIndexOp::create(builder, loc, axis + 1);
auto shapeLeft =
builder
.create<shape::SplitAtOp>(loc, TypeRange{shapeType, shapeType},
inputShape, axisValue)
.getResult(0);
auto sizeLeft =
- builder.create<shape::NumElementsOp>(loc, indexType, shapeLeft);
+ shape::NumElementsOp::create(builder, loc, indexType, shapeLeft);
auto shapeRight =
builder
.create<shape::SplitAtOp>(loc, TypeRange{shapeType, shapeType},
inputShape, axisNextValue)
.getResult(1);
auto sizeRight =
- builder.create<shape::NumElementsOp>(loc, indexType, shapeRight);
+ shape::NumElementsOp::create(builder, loc, indexType, shapeRight);
// Compute flat input shape as a 3-element 1D tensor
- auto axisSizeValue = builder.create<arith::ConstantIndexOp>(loc, axisSize);
+ auto axisSizeValue = arith::ConstantIndexOp::create(builder, loc, axisSize);
auto flatShapeType = shape::getExtentTensorType(context, 3);
- auto flatInputShape = builder.create<tensor::FromElementsOp>(
- loc, flatShapeType, ValueRange{sizeLeft, axisSizeValue, sizeRight});
+ auto flatInputShape = tensor::FromElementsOp::create(
+ builder, loc, flatShapeType,
+ ValueRange{sizeLeft, axisSizeValue, sizeRight});
// Reshape input to 3D tensor
auto inputType = cast<UnrankedTensorType>(input.getType());
auto elementType = inputType.getElementType();
auto flatInputType = RankedTensorType::get(
{ShapedType::kDynamic, axisSize, ShapedType::kDynamic}, elementType);
- auto flatInput = builder.create<tensor::ReshapeOp>(loc, flatInputType, input,
- flatInputShape);
+ auto flatInput = tensor::ReshapeOp::create(builder, loc, flatInputType, input,
+ flatInputShape);
return std::make_pair(flatInput, inputShape);
}
@@ -192,8 +193,8 @@ Value restoreUnrankedTensorShape(OpBuilder &builder, Location loc, Value input,
auto inputType = cast<RankedTensorType>(input.getType());
auto elementType = inputType.getElementType();
auto unrankedType = UnrankedTensorType::get(elementType);
- return builder.create<tensor::ReshapeOp>(loc, unrankedType, input,
- inputShape);
+ return tensor::ReshapeOp::create(builder, loc, unrankedType, input,
+ inputShape);
}
// Create a tensor constant containing all scales in a per-channel quantized
@@ -215,7 +216,7 @@ Value materializePerChannelScales(OpBuilder &builder, Location loc,
auto tensorType =
RankedTensorType::get({(int64_t)scales.size()}, expressedType);
auto scalesAttr = DenseElementsAttr::get(tensorType, scaleAttrs);
- return builder.create<arith::ConstantOp>(loc, tensorType, scalesAttr);
+ return arith::ConstantOp::create(builder, loc, tensorType, scalesAttr);
}
// Create a tensor constant containing all zero points in a per-channel
@@ -239,7 +240,7 @@ Value materializePerChannelZeroPoints(
auto tensorType =
RankedTensorType::get({(int64_t)zeroPoints.size()}, storageType);
auto zeroPointsAttr = DenseElementsAttr::get(tensorType, zeroPointAttrs);
- return builder.create<arith::ConstantOp>(loc, tensorType, zeroPointsAttr);
+ return arith::ConstantOp::create(builder, loc, tensorType, zeroPointsAttr);
}
// Create a tensor constant containing all scales in a sub-channel quantized
@@ -263,7 +264,7 @@ Value materializeSubChannelScales(
auto tensorType =
RankedTensorType::get(scales.getType().getShape(), expressedType);
auto scalesAttr = DenseElementsAttr::get(tensorType, scaleAttrs);
- return builder.create<arith::ConstantOp>(loc, tensorType, scalesAttr);
+ return arith::ConstantOp::create(builder, loc, tensorType, scalesAttr);
}
// Create a tensor constant containing all zero points in a sub-channel
@@ -287,7 +288,7 @@ Value materializeSubChannelZeroPoints(
auto tensorType =
RankedTensorType::get(zeroPoints.getType().getShape(), storageType);
auto zeroPointsAttr = DenseElementsAttr::get(tensorType, zeroPointAttrs);
- return builder.create<arith::ConstantOp>(loc, tensorType, zeroPointsAttr);
+ return arith::ConstantOp::create(builder, loc, tensorType, zeroPointsAttr);
}
// Clamp the given scalar or tensor input using the storage bounds encoded in
@@ -314,10 +315,10 @@ Value clampScalarOrTensor(OpBuilder &builder, Location loc, Value input,
// Materialize bounds
auto inputType = input.getType();
auto storageType = quantizedType.getStorageType();
- auto storageMinScalar = builder.create<arith::ConstantIntOp>(
- loc, storageType, quantizedType.getStorageTypeMin());
- auto storageMaxScalar = builder.create<arith::ConstantIntOp>(
- loc, storageType, quantizedType.getStorageTypeMax());
+ auto storageMinScalar = arith::ConstantIntOp::create(
+ builder, loc, storageType, quantizedType.getStorageTypeMin());
+ auto storageMaxScalar = arith::ConstantIntOp::create(
+ builder, loc, storageType, quantizedType.getStorageTypeMax());
auto storageMin = getScalarOrTensorConstant(builder, loc, storageMinScalar,
inputType, inputShape);
auto storageMax = getScalarOrTensorConstant(builder, loc, storageMaxScalar,
@@ -325,11 +326,11 @@ Value clampScalarOrTensor(OpBuilder &builder, Location loc, Value input,
// Clamp
if (quantizedType.isSigned()) {
- input = builder.create<arith::MaxSIOp>(loc, input, storageMin);
- input = builder.create<arith::MinSIOp>(loc, input, storageMax);
+ input = arith::MaxSIOp::create(builder, loc, input, storageMin);
+ input = arith::MinSIOp::create(builder, loc, input, storageMax);
} else {
- input = builder.create<arith::MaxUIOp>(loc, input, storageMin);
- input = builder.create<arith::MinUIOp>(loc, input, storageMax);
+ input = arith::MaxUIOp::create(builder, loc, input, storageMin);
+ input = arith::MinUIOp::create(builder, loc, input, storageMax);
}
return input;
}
@@ -338,16 +339,16 @@ Value clampScalarOrTensor(OpBuilder &builder, Location loc, Value input,
Value convertFloatToInteger(OpBuilder &builder, Location loc, Value input,
Type resultType, bool isSigned) {
if (isSigned)
- return builder.create<arith::FPToSIOp>(loc, resultType, input);
- return builder.create<arith::FPToUIOp>(loc, resultType, input);
+ return arith::FPToSIOp::create(builder, loc, resultType, input);
+ return arith::FPToUIOp::create(builder, loc, resultType, input);
}
// Emit op 'arith.sitofp' or 'arith.uitofp'.
Value convertIntegerToFloat(OpBuilder &builder, Location loc, Value input,
Type resultType, bool isSigned) {
if (isSigned)
- return builder.create<arith::SIToFPOp>(loc, resultType, input);
- return builder.create<arith::UIToFPOp>(loc, resultType, input);
+ return arith::SIToFPOp::create(builder, loc, resultType, input);
+ return arith::UIToFPOp::create(builder, loc, resultType, input);
}
// Quantize a scalar or ranked tensor value. The stored value is clamped using
@@ -362,7 +363,7 @@ Value quantizeValue(OpBuilder &builder, Location loc, Value input,
scale = getScalarOrTensorConstant(builder, loc, scale, inputType, inputShape);
// Scale input
- auto scaledValue = builder.create<arith::DivFOp>(loc, input, scale);
+ auto scaledValue = arith::DivFOp::create(builder, loc, input, scale);
// Skip unnecessary computations if no zero point is given
Value storedValueFloat = scaledValue;
@@ -377,7 +378,7 @@ Value quantizeValue(OpBuilder &builder, Location loc, Value input,
// Add zero point to stored value
storedValueFloat =
- builder.create<arith::AddFOp>(loc, scaledValue, zeroPoint);
+ arith::AddFOp::create(builder, loc, scaledValue, zeroPoint);
}
// Convert stored value to storage type
@@ -418,11 +419,11 @@ Value dequantizeValue(OpBuilder &builder, Location loc, Value input,
quantizedType.isSigned());
// Subtract zero point to stored value
- result = builder.create<arith::SubFOp>(loc, result, zeroPoint);
+ result = arith::SubFOp::create(builder, loc, result, zeroPoint);
}
// Multiply by scale
- result = builder.create<arith::MulFOp>(loc, result, scale);
+ result = arith::MulFOp::create(builder, loc, result, scale);
return result;
}
@@ -477,11 +478,12 @@ Value convertPerLayerRanked(OpBuilder &builder, Location loc, Operation *op,
auto storageType = quantizedType.getStorageType();
auto scaleAttr =
builder.getFloatAttr(expressedType, quantizedType.getScale());
- auto scale = builder.create<arith::ConstantOp>(loc, expressedType, scaleAttr);
+ auto scale =
+ arith::ConstantOp::create(builder, loc, expressedType, scaleAttr);
auto zeroPointAttr =
builder.getIntegerAttr(storageType, quantizedType.getZeroPoint());
auto zeroPoint =
- builder.create<arith::ConstantOp>(loc, storageType, zeroPointAttr);
+ arith::ConstantOp::create(builder, loc, storageType, zeroPointAttr);
auto inputShape = getScalarOrTensorShape(builder, loc, input);
return convertRanked(builder, loc, op, input, inputShape, scale, zeroPoint,
@@ -546,7 +548,7 @@ Value convertPerChannelRanked(OpBuilder &builder, Location loc, Operation *op,
? quantizedType.getStorageType()
: quantizedType.getExpressedType();
auto initShape = tensor::getMixedSizes(builder, loc, input);
- Value init = builder.create<tensor::EmptyOp>(loc, initShape, elementType);
+ Value init = tensor::EmptyOp::create(builder, loc, initShape, elementType);
SmallVector<utils::IteratorType> iteratorTypes(inputRank,
utils::IteratorType::parallel);
@@ -572,7 +574,7 @@ Value convertPerChannelRanked(OpBuilder &builder, Location loc, Operation *op,
convertRanked(builder, loc, op, input, {}, scale,
zeroPoint, quantizedType);
- builder.create<linalg::YieldOp>(loc, result);
+ linalg::YieldOp::create(builder, loc, result);
})
.getResult(0);
@@ -642,7 +644,7 @@ Value convertSubChannel(OpBuilder &builder, Location loc, Operation *op,
? quantizedType.getStorageType()
: quantizedType.getExpressedType();
auto initShape = tensor::getMixedSizes(builder, loc, input);
- Value init = builder.create<tensor::EmptyOp>(loc, initShape, elementType);
+ Value init = tensor::EmptyOp::create(builder, loc, initShape, elementType);
SmallVector<utils::IteratorType> iteratorTypes(inputRank,
utils::IteratorType::parallel);
@@ -675,7 +677,7 @@ Value convertSubChannel(OpBuilder &builder, Location loc, Operation *op,
convertRanked(builder, loc, op, input, {}, scale,
zeroPoint, quantizedType);
- builder.create<linalg::YieldOp>(loc, result);
+ linalg::YieldOp::create(builder, loc, result);
})
.getResult(0);
@@ -729,8 +731,8 @@ struct DequantizeCastOpConversion
// Convert quantized input to storage type
auto storageScalarOrTensorType =
getScalarOrTensorType(quantizedType.getStorageType(), input.getType());
- input = rewriter.create<quant::StorageCastOp>(
- loc, storageScalarOrTensorType, input);
+ input = quant::StorageCastOp::create(rewriter, loc,
+ storageScalarOrTensorType, input);
auto result = convertQuantized(rewriter, loc, op, input, quantizedType);
diff --git a/mlir/lib/Dialect/Quant/Transforms/StripFuncQuantTypes.cpp b/mlir/lib/Dialect/Quant/Transforms/StripFuncQuantTypes.cpp
index 920b6ecb01d47d..1ffb18fb7ab96c 100644
--- a/mlir/lib/Dialect/Quant/Transforms/StripFuncQuantTypes.cpp
+++ b/mlir/lib/Dialect/Quant/Transforms/StripFuncQuantTypes.cpp
@@ -41,8 +41,8 @@ class QuantizedTypeConverter : public TypeConverter {
static Value materializeConversion(OpBuilder &builder, Type type,
ValueRange inputs, Location loc) {
- return builder.create<quant::StorageCastOp>(loc, type,
- llvm::getSingleElement(inputs));
+ return quant::StorageCastOp::create(builder, loc, type,
+ llvm::getSingleElement(inputs));
}
public:
More information about the Mlir-commits
mailing list