[Mlir-commits] [mlir] eaa67a3 - [mlir][NFC] update `Conversion` create APIs (5/n) (#149887)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Tue Jul 22 07:40:49 PDT 2025
Author: Maksim Levental
Date: 2025-07-22T10:40:45-04:00
New Revision: eaa67a3cf041009ae33a45159d0465262c3af5dc
URL: https://github.com/llvm/llvm-project/commit/eaa67a3cf041009ae33a45159d0465262c3af5dc
DIFF: https://github.com/llvm/llvm-project/commit/eaa67a3cf041009ae33a45159d0465262c3af5dc.diff
LOG: [mlir][NFC] update `Conversion` create APIs (5/n) (#149887)
See https://github.com/llvm/llvm-project/pull/147168 for more info.
Added:
Modified:
mlir/lib/Conversion/ComplexCommon/DivisionConverter.cpp
mlir/lib/Conversion/ComplexToLLVM/ComplexToLLVM.cpp
mlir/lib/Conversion/ComplexToLibm/ComplexToLibm.cpp
mlir/lib/Conversion/ComplexToROCDLLibraryCalls/ComplexToROCDLLibraryCalls.cpp
mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp
mlir/lib/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.cpp
mlir/lib/Conversion/ControlFlowToSCF/ControlFlowToSCF.cpp
mlir/lib/Conversion/FuncToEmitC/FuncToEmitC.cpp
mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp
mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp
mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp
mlir/lib/Conversion/GPUCommon/IndexIntrinsicsOpLowering.h
mlir/lib/Conversion/GPUCommon/OpToFuncCallLowering.h
mlir/lib/Conversion/GPUToLLVMSPV/GPUToLLVMSPV.cpp
mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp
mlir/lib/Conversion/GPUToNVVM/WmmaOpsToNvvm.cpp
mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp
mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp
mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRVPass.cpp
mlir/lib/Conversion/GPUToSPIRV/WmmaOpsToSPIRV.cpp
Removed:
################################################################################
diff --git a/mlir/lib/Conversion/ComplexCommon/DivisionConverter.cpp b/mlir/lib/Conversion/ComplexCommon/DivisionConverter.cpp
index 70b22386f1eea..14fbb9bf09545 100644
--- a/mlir/lib/Conversion/ComplexCommon/DivisionConverter.cpp
+++ b/mlir/lib/Conversion/ComplexCommon/DivisionConverter.cpp
@@ -23,41 +23,43 @@ void mlir::complex::convertDivToLLVMUsingAlgebraic(
ConversionPatternRewriter &rewriter, Location loc, Value lhsRe, Value lhsIm,
Value rhsRe, Value rhsIm, LLVM::FastmathFlagsAttr fmf, Value *resultRe,
Value *resultIm) {
- Value rhsSqNorm = rewriter.create<LLVM::FAddOp>(
- loc, rewriter.create<LLVM::FMulOp>(loc, rhsRe, rhsRe, fmf),
- rewriter.create<LLVM::FMulOp>(loc, rhsIm, rhsIm, fmf), fmf);
+ Value rhsSqNorm = LLVM::FAddOp::create(
+ rewriter, loc, LLVM::FMulOp::create(rewriter, loc, rhsRe, rhsRe, fmf),
+ LLVM::FMulOp::create(rewriter, loc, rhsIm, rhsIm, fmf), fmf);
- Value realNumerator = rewriter.create<LLVM::FAddOp>(
- loc, rewriter.create<LLVM::FMulOp>(loc, lhsRe, rhsRe, fmf),
- rewriter.create<LLVM::FMulOp>(loc, lhsIm, rhsIm, fmf), fmf);
+ Value realNumerator = LLVM::FAddOp::create(
+ rewriter, loc, LLVM::FMulOp::create(rewriter, loc, lhsRe, rhsRe, fmf),
+ LLVM::FMulOp::create(rewriter, loc, lhsIm, rhsIm, fmf), fmf);
- Value imagNumerator = rewriter.create<LLVM::FSubOp>(
- loc, rewriter.create<LLVM::FMulOp>(loc, lhsIm, rhsRe, fmf),
- rewriter.create<LLVM::FMulOp>(loc, lhsRe, rhsIm, fmf), fmf);
+ Value imagNumerator = LLVM::FSubOp::create(
+ rewriter, loc, LLVM::FMulOp::create(rewriter, loc, lhsIm, rhsRe, fmf),
+ LLVM::FMulOp::create(rewriter, loc, lhsRe, rhsIm, fmf), fmf);
- *resultRe = rewriter.create<LLVM::FDivOp>(loc, realNumerator, rhsSqNorm, fmf);
- *resultIm = rewriter.create<LLVM::FDivOp>(loc, imagNumerator, rhsSqNorm, fmf);
+ *resultRe =
+ LLVM::FDivOp::create(rewriter, loc, realNumerator, rhsSqNorm, fmf);
+ *resultIm =
+ LLVM::FDivOp::create(rewriter, loc, imagNumerator, rhsSqNorm, fmf);
}
void mlir::complex::convertDivToStandardUsingAlgebraic(
ConversionPatternRewriter &rewriter, Location loc, Value lhsRe, Value lhsIm,
Value rhsRe, Value rhsIm, arith::FastMathFlagsAttr fmf, Value *resultRe,
Value *resultIm) {
- Value rhsSqNorm = rewriter.create<arith::AddFOp>(
- loc, rewriter.create<arith::MulFOp>(loc, rhsRe, rhsRe, fmf),
- rewriter.create<arith::MulFOp>(loc, rhsIm, rhsIm, fmf), fmf);
+ Value rhsSqNorm = arith::AddFOp::create(
+ rewriter, loc, arith::MulFOp::create(rewriter, loc, rhsRe, rhsRe, fmf),
+ arith::MulFOp::create(rewriter, loc, rhsIm, rhsIm, fmf), fmf);
- Value realNumerator = rewriter.create<arith::AddFOp>(
- loc, rewriter.create<arith::MulFOp>(loc, lhsRe, rhsRe, fmf),
- rewriter.create<arith::MulFOp>(loc, lhsIm, rhsIm, fmf), fmf);
- Value imagNumerator = rewriter.create<arith::SubFOp>(
- loc, rewriter.create<arith::MulFOp>(loc, lhsIm, rhsRe, fmf),
- rewriter.create<arith::MulFOp>(loc, lhsRe, rhsIm, fmf), fmf);
+ Value realNumerator = arith::AddFOp::create(
+ rewriter, loc, arith::MulFOp::create(rewriter, loc, lhsRe, rhsRe, fmf),
+ arith::MulFOp::create(rewriter, loc, lhsIm, rhsIm, fmf), fmf);
+ Value imagNumerator = arith::SubFOp::create(
+ rewriter, loc, arith::MulFOp::create(rewriter, loc, lhsIm, rhsRe, fmf),
+ arith::MulFOp::create(rewriter, loc, lhsRe, rhsIm, fmf), fmf);
*resultRe =
- rewriter.create<arith::DivFOp>(loc, realNumerator, rhsSqNorm, fmf);
+ arith::DivFOp::create(rewriter, loc, realNumerator, rhsSqNorm, fmf);
*resultIm =
- rewriter.create<arith::DivFOp>(loc, imagNumerator, rhsSqNorm, fmf);
+ arith::DivFOp::create(rewriter, loc, imagNumerator, rhsSqNorm, fmf);
}
// Smith's algorithm to divide complex numbers. It is just a bit smarter
@@ -94,181 +96,185 @@ void mlir::complex::convertDivToLLVMUsingRangeReduction(
auto elementType = cast<FloatType>(rhsRe.getType());
Value rhsRealImagRatio =
- rewriter.create<LLVM::FDivOp>(loc, rhsRe, rhsIm, fmf);
- Value rhsRealImagDenom = rewriter.create<LLVM::FAddOp>(
- loc, rhsIm,
- rewriter.create<LLVM::FMulOp>(loc, rhsRealImagRatio, rhsRe, fmf), fmf);
- Value realNumerator1 = rewriter.create<LLVM::FAddOp>(
- loc, rewriter.create<LLVM::FMulOp>(loc, lhsRe, rhsRealImagRatio, fmf),
- lhsIm, fmf);
- Value resultReal1 =
- rewriter.create<LLVM::FDivOp>(loc, realNumerator1, rhsRealImagDenom, fmf);
- Value imagNumerator1 = rewriter.create<LLVM::FSubOp>(
- loc, rewriter.create<LLVM::FMulOp>(loc, lhsIm, rhsRealImagRatio, fmf),
- lhsRe, fmf);
- Value resultImag1 =
- rewriter.create<LLVM::FDivOp>(loc, imagNumerator1, rhsRealImagDenom, fmf);
+ LLVM::FDivOp::create(rewriter, loc, rhsRe, rhsIm, fmf);
+ Value rhsRealImagDenom = LLVM::FAddOp::create(
+ rewriter, loc, rhsIm,
+ LLVM::FMulOp::create(rewriter, loc, rhsRealImagRatio, rhsRe, fmf), fmf);
+ Value realNumerator1 = LLVM::FAddOp::create(
+ rewriter, loc,
+ LLVM::FMulOp::create(rewriter, loc, lhsRe, rhsRealImagRatio, fmf), lhsIm,
+ fmf);
+ Value resultReal1 = LLVM::FDivOp::create(rewriter, loc, realNumerator1,
+ rhsRealImagDenom, fmf);
+ Value imagNumerator1 = LLVM::FSubOp::create(
+ rewriter, loc,
+ LLVM::FMulOp::create(rewriter, loc, lhsIm, rhsRealImagRatio, fmf), lhsRe,
+ fmf);
+ Value resultImag1 = LLVM::FDivOp::create(rewriter, loc, imagNumerator1,
+ rhsRealImagDenom, fmf);
Value rhsImagRealRatio =
- rewriter.create<LLVM::FDivOp>(loc, rhsIm, rhsRe, fmf);
- Value rhsImagRealDenom = rewriter.create<LLVM::FAddOp>(
- loc, rhsRe,
- rewriter.create<LLVM::FMulOp>(loc, rhsImagRealRatio, rhsIm, fmf), fmf);
- Value realNumerator2 = rewriter.create<LLVM::FAddOp>(
- loc, lhsRe,
- rewriter.create<LLVM::FMulOp>(loc, lhsIm, rhsImagRealRatio, fmf), fmf);
- Value resultReal2 =
- rewriter.create<LLVM::FDivOp>(loc, realNumerator2, rhsImagRealDenom, fmf);
- Value imagNumerator2 = rewriter.create<LLVM::FSubOp>(
- loc, lhsIm,
- rewriter.create<LLVM::FMulOp>(loc, lhsRe, rhsImagRealRatio, fmf), fmf);
- Value resultImag2 =
- rewriter.create<LLVM::FDivOp>(loc, imagNumerator2, rhsImagRealDenom, fmf);
+ LLVM::FDivOp::create(rewriter, loc, rhsIm, rhsRe, fmf);
+ Value rhsImagRealDenom = LLVM::FAddOp::create(
+ rewriter, loc, rhsRe,
+ LLVM::FMulOp::create(rewriter, loc, rhsImagRealRatio, rhsIm, fmf), fmf);
+ Value realNumerator2 = LLVM::FAddOp::create(
+ rewriter, loc, lhsRe,
+ LLVM::FMulOp::create(rewriter, loc, lhsIm, rhsImagRealRatio, fmf), fmf);
+ Value resultReal2 = LLVM::FDivOp::create(rewriter, loc, realNumerator2,
+ rhsImagRealDenom, fmf);
+ Value imagNumerator2 = LLVM::FSubOp::create(
+ rewriter, loc, lhsIm,
+ LLVM::FMulOp::create(rewriter, loc, lhsRe, rhsImagRealRatio, fmf), fmf);
+ Value resultImag2 = LLVM::FDivOp::create(rewriter, loc, imagNumerator2,
+ rhsImagRealDenom, fmf);
// Consider corner cases.
// Case 1. Zero denominator, numerator contains at most one NaN value.
- Value zero = rewriter.create<LLVM::ConstantOp>(
- loc, elementType, rewriter.getZeroAttr(elementType));
- Value rhsRealAbs = rewriter.create<LLVM::FAbsOp>(loc, rhsRe, fmf);
- Value rhsRealIsZero = rewriter.create<LLVM::FCmpOp>(
- loc, LLVM::FCmpPredicate::oeq, rhsRealAbs, zero);
- Value rhsImagAbs = rewriter.create<LLVM::FAbsOp>(loc, rhsIm, fmf);
- Value rhsImagIsZero = rewriter.create<LLVM::FCmpOp>(
- loc, LLVM::FCmpPredicate::oeq, rhsImagAbs, zero);
- Value lhsRealIsNotNaN =
- rewriter.create<LLVM::FCmpOp>(loc, LLVM::FCmpPredicate::ord, lhsRe, zero);
- Value lhsImagIsNotNaN =
- rewriter.create<LLVM::FCmpOp>(loc, LLVM::FCmpPredicate::ord, lhsIm, zero);
+ Value zero = LLVM::ConstantOp::create(rewriter, loc, elementType,
+ rewriter.getZeroAttr(elementType));
+ Value rhsRealAbs = LLVM::FAbsOp::create(rewriter, loc, rhsRe, fmf);
+ Value rhsRealIsZero = LLVM::FCmpOp::create(
+ rewriter, loc, LLVM::FCmpPredicate::oeq, rhsRealAbs, zero);
+ Value rhsImagAbs = LLVM::FAbsOp::create(rewriter, loc, rhsIm, fmf);
+ Value rhsImagIsZero = LLVM::FCmpOp::create(
+ rewriter, loc, LLVM::FCmpPredicate::oeq, rhsImagAbs, zero);
+ Value lhsRealIsNotNaN = LLVM::FCmpOp::create(
+ rewriter, loc, LLVM::FCmpPredicate::ord, lhsRe, zero);
+ Value lhsImagIsNotNaN = LLVM::FCmpOp::create(
+ rewriter, loc, LLVM::FCmpPredicate::ord, lhsIm, zero);
Value lhsContainsNotNaNValue =
- rewriter.create<LLVM::OrOp>(loc, lhsRealIsNotNaN, lhsImagIsNotNaN);
- Value resultIsInfinity = rewriter.create<LLVM::AndOp>(
- loc, lhsContainsNotNaNValue,
- rewriter.create<LLVM::AndOp>(loc, rhsRealIsZero, rhsImagIsZero));
- Value inf = rewriter.create<LLVM::ConstantOp>(
- loc, elementType,
+ LLVM::OrOp::create(rewriter, loc, lhsRealIsNotNaN, lhsImagIsNotNaN);
+ Value resultIsInfinity = LLVM::AndOp::create(
+ rewriter, loc, lhsContainsNotNaNValue,
+ LLVM::AndOp::create(rewriter, loc, rhsRealIsZero, rhsImagIsZero));
+ Value inf = LLVM::ConstantOp::create(
+ rewriter, loc, elementType,
rewriter.getFloatAttr(elementType,
APFloat::getInf(elementType.getFloatSemantics())));
Value infWithSignOfrhsReal =
- rewriter.create<LLVM::CopySignOp>(loc, inf, rhsRe);
+ LLVM::CopySignOp::create(rewriter, loc, inf, rhsRe);
Value infinityResultReal =
- rewriter.create<LLVM::FMulOp>(loc, infWithSignOfrhsReal, lhsRe, fmf);
+ LLVM::FMulOp::create(rewriter, loc, infWithSignOfrhsReal, lhsRe, fmf);
Value infinityResultImag =
- rewriter.create<LLVM::FMulOp>(loc, infWithSignOfrhsReal, lhsIm, fmf);
+ LLVM::FMulOp::create(rewriter, loc, infWithSignOfrhsReal, lhsIm, fmf);
// Case 2. Infinite numerator, finite denominator.
- Value rhsRealFinite = rewriter.create<LLVM::FCmpOp>(
- loc, LLVM::FCmpPredicate::one, rhsRealAbs, inf);
- Value rhsImagFinite = rewriter.create<LLVM::FCmpOp>(
- loc, LLVM::FCmpPredicate::one, rhsImagAbs, inf);
+ Value rhsRealFinite = LLVM::FCmpOp::create(
+ rewriter, loc, LLVM::FCmpPredicate::one, rhsRealAbs, inf);
+ Value rhsImagFinite = LLVM::FCmpOp::create(
+ rewriter, loc, LLVM::FCmpPredicate::one, rhsImagAbs, inf);
Value rhsFinite =
- rewriter.create<LLVM::AndOp>(loc, rhsRealFinite, rhsImagFinite);
- Value lhsRealAbs = rewriter.create<LLVM::FAbsOp>(loc, lhsRe, fmf);
- Value lhsRealInfinite = rewriter.create<LLVM::FCmpOp>(
- loc, LLVM::FCmpPredicate::oeq, lhsRealAbs, inf);
- Value lhsImagAbs = rewriter.create<LLVM::FAbsOp>(loc, lhsIm, fmf);
- Value lhsImagInfinite = rewriter.create<LLVM::FCmpOp>(
- loc, LLVM::FCmpPredicate::oeq, lhsImagAbs, inf);
+ LLVM::AndOp::create(rewriter, loc, rhsRealFinite, rhsImagFinite);
+ Value lhsRealAbs = LLVM::FAbsOp::create(rewriter, loc, lhsRe, fmf);
+ Value lhsRealInfinite = LLVM::FCmpOp::create(
+ rewriter, loc, LLVM::FCmpPredicate::oeq, lhsRealAbs, inf);
+ Value lhsImagAbs = LLVM::FAbsOp::create(rewriter, loc, lhsIm, fmf);
+ Value lhsImagInfinite = LLVM::FCmpOp::create(
+ rewriter, loc, LLVM::FCmpPredicate::oeq, lhsImagAbs, inf);
Value lhsInfinite =
- rewriter.create<LLVM::OrOp>(loc, lhsRealInfinite, lhsImagInfinite);
+ LLVM::OrOp::create(rewriter, loc, lhsRealInfinite, lhsImagInfinite);
Value infNumFiniteDenom =
- rewriter.create<LLVM::AndOp>(loc, lhsInfinite, rhsFinite);
- Value one = rewriter.create<LLVM::ConstantOp>(
- loc, elementType, rewriter.getFloatAttr(elementType, 1));
- Value lhsRealIsInfWithSign = rewriter.create<LLVM::CopySignOp>(
- loc, rewriter.create<LLVM::SelectOp>(loc, lhsRealInfinite, one, zero),
- lhsRe);
- Value lhsImagIsInfWithSign = rewriter.create<LLVM::CopySignOp>(
- loc, rewriter.create<LLVM::SelectOp>(loc, lhsImagInfinite, one, zero),
- lhsIm);
+ LLVM::AndOp::create(rewriter, loc, lhsInfinite, rhsFinite);
+ Value one = LLVM::ConstantOp::create(rewriter, loc, elementType,
+ rewriter.getFloatAttr(elementType, 1));
+ Value lhsRealIsInfWithSign = LLVM::CopySignOp::create(
+ rewriter, loc,
+ LLVM::SelectOp::create(rewriter, loc, lhsRealInfinite, one, zero), lhsRe);
+ Value lhsImagIsInfWithSign = LLVM::CopySignOp::create(
+ rewriter, loc,
+ LLVM::SelectOp::create(rewriter, loc, lhsImagInfinite, one, zero), lhsIm);
Value lhsRealIsInfWithSignTimesrhsReal =
- rewriter.create<LLVM::FMulOp>(loc, lhsRealIsInfWithSign, rhsRe, fmf);
+ LLVM::FMulOp::create(rewriter, loc, lhsRealIsInfWithSign, rhsRe, fmf);
Value lhsImagIsInfWithSignTimesrhsImag =
- rewriter.create<LLVM::FMulOp>(loc, lhsImagIsInfWithSign, rhsIm, fmf);
- Value resultReal3 = rewriter.create<LLVM::FMulOp>(
- loc, inf,
- rewriter.create<LLVM::FAddOp>(loc, lhsRealIsInfWithSignTimesrhsReal,
- lhsImagIsInfWithSignTimesrhsImag, fmf),
+ LLVM::FMulOp::create(rewriter, loc, lhsImagIsInfWithSign, rhsIm, fmf);
+ Value resultReal3 = LLVM::FMulOp::create(
+ rewriter, loc, inf,
+ LLVM::FAddOp::create(rewriter, loc, lhsRealIsInfWithSignTimesrhsReal,
+ lhsImagIsInfWithSignTimesrhsImag, fmf),
fmf);
Value lhsRealIsInfWithSignTimesrhsImag =
- rewriter.create<LLVM::FMulOp>(loc, lhsRealIsInfWithSign, rhsIm, fmf);
+ LLVM::FMulOp::create(rewriter, loc, lhsRealIsInfWithSign, rhsIm, fmf);
Value lhsImagIsInfWithSignTimesrhsReal =
- rewriter.create<LLVM::FMulOp>(loc, lhsImagIsInfWithSign, rhsRe, fmf);
- Value resultImag3 = rewriter.create<LLVM::FMulOp>(
- loc, inf,
- rewriter.create<LLVM::FSubOp>(loc, lhsImagIsInfWithSignTimesrhsReal,
- lhsRealIsInfWithSignTimesrhsImag, fmf),
+ LLVM::FMulOp::create(rewriter, loc, lhsImagIsInfWithSign, rhsRe, fmf);
+ Value resultImag3 = LLVM::FMulOp::create(
+ rewriter, loc, inf,
+ LLVM::FSubOp::create(rewriter, loc, lhsImagIsInfWithSignTimesrhsReal,
+ lhsRealIsInfWithSignTimesrhsImag, fmf),
fmf);
// Case 3: Finite numerator, infinite denominator.
- Value lhsRealFinite = rewriter.create<LLVM::FCmpOp>(
- loc, LLVM::FCmpPredicate::one, lhsRealAbs, inf);
- Value lhsImagFinite = rewriter.create<LLVM::FCmpOp>(
- loc, LLVM::FCmpPredicate::one, lhsImagAbs, inf);
+ Value lhsRealFinite = LLVM::FCmpOp::create(
+ rewriter, loc, LLVM::FCmpPredicate::one, lhsRealAbs, inf);
+ Value lhsImagFinite = LLVM::FCmpOp::create(
+ rewriter, loc, LLVM::FCmpPredicate::one, lhsImagAbs, inf);
Value lhsFinite =
- rewriter.create<LLVM::AndOp>(loc, lhsRealFinite, lhsImagFinite);
- Value rhsRealInfinite = rewriter.create<LLVM::FCmpOp>(
- loc, LLVM::FCmpPredicate::oeq, rhsRealAbs, inf);
- Value rhsImagInfinite = rewriter.create<LLVM::FCmpOp>(
- loc, LLVM::FCmpPredicate::oeq, rhsImagAbs, inf);
+ LLVM::AndOp::create(rewriter, loc, lhsRealFinite, lhsImagFinite);
+ Value rhsRealInfinite = LLVM::FCmpOp::create(
+ rewriter, loc, LLVM::FCmpPredicate::oeq, rhsRealAbs, inf);
+ Value rhsImagInfinite = LLVM::FCmpOp::create(
+ rewriter, loc, LLVM::FCmpPredicate::oeq, rhsImagAbs, inf);
Value rhsInfinite =
- rewriter.create<LLVM::OrOp>(loc, rhsRealInfinite, rhsImagInfinite);
+ LLVM::OrOp::create(rewriter, loc, rhsRealInfinite, rhsImagInfinite);
Value finiteNumInfiniteDenom =
- rewriter.create<LLVM::AndOp>(loc, lhsFinite, rhsInfinite);
- Value rhsRealIsInfWithSign = rewriter.create<LLVM::CopySignOp>(
- loc, rewriter.create<LLVM::SelectOp>(loc, rhsRealInfinite, one, zero),
- rhsRe);
- Value rhsImagIsInfWithSign = rewriter.create<LLVM::CopySignOp>(
- loc, rewriter.create<LLVM::SelectOp>(loc, rhsImagInfinite, one, zero),
- rhsIm);
+ LLVM::AndOp::create(rewriter, loc, lhsFinite, rhsInfinite);
+ Value rhsRealIsInfWithSign = LLVM::CopySignOp::create(
+ rewriter, loc,
+ LLVM::SelectOp::create(rewriter, loc, rhsRealInfinite, one, zero), rhsRe);
+ Value rhsImagIsInfWithSign = LLVM::CopySignOp::create(
+ rewriter, loc,
+ LLVM::SelectOp::create(rewriter, loc, rhsImagInfinite, one, zero), rhsIm);
Value rhsRealIsInfWithSignTimeslhsReal =
- rewriter.create<LLVM::FMulOp>(loc, lhsRe, rhsRealIsInfWithSign, fmf);
+ LLVM::FMulOp::create(rewriter, loc, lhsRe, rhsRealIsInfWithSign, fmf);
Value rhsImagIsInfWithSignTimeslhsImag =
- rewriter.create<LLVM::FMulOp>(loc, lhsIm, rhsImagIsInfWithSign, fmf);
- Value resultReal4 = rewriter.create<LLVM::FMulOp>(
- loc, zero,
- rewriter.create<LLVM::FAddOp>(loc, rhsRealIsInfWithSignTimeslhsReal,
- rhsImagIsInfWithSignTimeslhsImag, fmf),
+ LLVM::FMulOp::create(rewriter, loc, lhsIm, rhsImagIsInfWithSign, fmf);
+ Value resultReal4 = LLVM::FMulOp::create(
+ rewriter, loc, zero,
+ LLVM::FAddOp::create(rewriter, loc, rhsRealIsInfWithSignTimeslhsReal,
+ rhsImagIsInfWithSignTimeslhsImag, fmf),
fmf);
Value rhsRealIsInfWithSignTimeslhsImag =
- rewriter.create<LLVM::FMulOp>(loc, lhsIm, rhsRealIsInfWithSign, fmf);
+ LLVM::FMulOp::create(rewriter, loc, lhsIm, rhsRealIsInfWithSign, fmf);
Value rhsImagIsInfWithSignTimeslhsReal =
- rewriter.create<LLVM::FMulOp>(loc, lhsRe, rhsImagIsInfWithSign, fmf);
- Value resultImag4 = rewriter.create<LLVM::FMulOp>(
- loc, zero,
- rewriter.create<LLVM::FSubOp>(loc, rhsRealIsInfWithSignTimeslhsImag,
- rhsImagIsInfWithSignTimeslhsReal, fmf),
+ LLVM::FMulOp::create(rewriter, loc, lhsRe, rhsImagIsInfWithSign, fmf);
+ Value resultImag4 = LLVM::FMulOp::create(
+ rewriter, loc, zero,
+ LLVM::FSubOp::create(rewriter, loc, rhsRealIsInfWithSignTimeslhsImag,
+ rhsImagIsInfWithSignTimeslhsReal, fmf),
fmf);
- Value realAbsSmallerThanImagAbs = rewriter.create<LLVM::FCmpOp>(
- loc, LLVM::FCmpPredicate::olt, rhsRealAbs, rhsImagAbs);
- Value resultReal5 = rewriter.create<LLVM::SelectOp>(
- loc, realAbsSmallerThanImagAbs, resultReal1, resultReal2);
- Value resultImag5 = rewriter.create<LLVM::SelectOp>(
- loc, realAbsSmallerThanImagAbs, resultImag1, resultImag2);
- Value resultRealSpecialCase3 = rewriter.create<LLVM::SelectOp>(
- loc, finiteNumInfiniteDenom, resultReal4, resultReal5);
- Value resultImagSpecialCase3 = rewriter.create<LLVM::SelectOp>(
- loc, finiteNumInfiniteDenom, resultImag4, resultImag5);
- Value resultRealSpecialCase2 = rewriter.create<LLVM::SelectOp>(
- loc, infNumFiniteDenom, resultReal3, resultRealSpecialCase3);
- Value resultImagSpecialCase2 = rewriter.create<LLVM::SelectOp>(
- loc, infNumFiniteDenom, resultImag3, resultImagSpecialCase3);
- Value resultRealSpecialCase1 = rewriter.create<LLVM::SelectOp>(
- loc, resultIsInfinity, infinityResultReal, resultRealSpecialCase2);
- Value resultImagSpecialCase1 = rewriter.create<LLVM::SelectOp>(
- loc, resultIsInfinity, infinityResultImag, resultImagSpecialCase2);
+ Value realAbsSmallerThanImagAbs = LLVM::FCmpOp::create(
+ rewriter, loc, LLVM::FCmpPredicate::olt, rhsRealAbs, rhsImagAbs);
+ Value resultReal5 = LLVM::SelectOp::create(
+ rewriter, loc, realAbsSmallerThanImagAbs, resultReal1, resultReal2);
+ Value resultImag5 = LLVM::SelectOp::create(
+ rewriter, loc, realAbsSmallerThanImagAbs, resultImag1, resultImag2);
+ Value resultRealSpecialCase3 = LLVM::SelectOp::create(
+ rewriter, loc, finiteNumInfiniteDenom, resultReal4, resultReal5);
+ Value resultImagSpecialCase3 = LLVM::SelectOp::create(
+ rewriter, loc, finiteNumInfiniteDenom, resultImag4, resultImag5);
+ Value resultRealSpecialCase2 = LLVM::SelectOp::create(
+ rewriter, loc, infNumFiniteDenom, resultReal3, resultRealSpecialCase3);
+ Value resultImagSpecialCase2 = LLVM::SelectOp::create(
+ rewriter, loc, infNumFiniteDenom, resultImag3, resultImagSpecialCase3);
+ Value resultRealSpecialCase1 =
+ LLVM::SelectOp::create(rewriter, loc, resultIsInfinity,
+ infinityResultReal, resultRealSpecialCase2);
+ Value resultImagSpecialCase1 =
+ LLVM::SelectOp::create(rewriter, loc, resultIsInfinity,
+ infinityResultImag, resultImagSpecialCase2);
- Value resultRealIsNaN = rewriter.create<LLVM::FCmpOp>(
- loc, LLVM::FCmpPredicate::uno, resultReal5, zero);
- Value resultImagIsNaN = rewriter.create<LLVM::FCmpOp>(
- loc, LLVM::FCmpPredicate::uno, resultImag5, zero);
+ Value resultRealIsNaN = LLVM::FCmpOp::create(
+ rewriter, loc, LLVM::FCmpPredicate::uno, resultReal5, zero);
+ Value resultImagIsNaN = LLVM::FCmpOp::create(
+ rewriter, loc, LLVM::FCmpPredicate::uno, resultImag5, zero);
Value resultIsNaN =
- rewriter.create<LLVM::AndOp>(loc, resultRealIsNaN, resultImagIsNaN);
+ LLVM::AndOp::create(rewriter, loc, resultRealIsNaN, resultImagIsNaN);
- *resultRe = rewriter.create<LLVM::SelectOp>(
- loc, resultIsNaN, resultRealSpecialCase1, resultReal5);
- *resultIm = rewriter.create<LLVM::SelectOp>(
- loc, resultIsNaN, resultImagSpecialCase1, resultImag5);
+ *resultRe = LLVM::SelectOp::create(rewriter, loc, resultIsNaN,
+ resultRealSpecialCase1, resultReal5);
+ *resultIm = LLVM::SelectOp::create(rewriter, loc, resultIsNaN,
+ resultImagSpecialCase1, resultImag5);
}
void mlir::complex::convertDivToStandardUsingRangeReduction(
@@ -278,179 +284,187 @@ void mlir::complex::convertDivToStandardUsingRangeReduction(
auto elementType = cast<FloatType>(rhsRe.getType());
Value rhsRealImagRatio =
- rewriter.create<arith::DivFOp>(loc, rhsRe, rhsIm, fmf);
- Value rhsRealImagDenom = rewriter.create<arith::AddFOp>(
- loc, rhsIm,
- rewriter.create<arith::MulFOp>(loc, rhsRealImagRatio, rhsRe, fmf), fmf);
- Value realNumerator1 = rewriter.create<arith::AddFOp>(
- loc, rewriter.create<arith::MulFOp>(loc, lhsRe, rhsRealImagRatio, fmf),
- lhsIm, fmf);
- Value resultReal1 = rewriter.create<arith::DivFOp>(loc, realNumerator1,
- rhsRealImagDenom, fmf);
- Value imagNumerator1 = rewriter.create<arith::SubFOp>(
- loc, rewriter.create<arith::MulFOp>(loc, lhsIm, rhsRealImagRatio, fmf),
- lhsRe, fmf);
- Value resultImag1 = rewriter.create<arith::DivFOp>(loc, imagNumerator1,
- rhsRealImagDenom, fmf);
+ arith::DivFOp::create(rewriter, loc, rhsRe, rhsIm, fmf);
+ Value rhsRealImagDenom = arith::AddFOp::create(
+ rewriter, loc, rhsIm,
+ arith::MulFOp::create(rewriter, loc, rhsRealImagRatio, rhsRe, fmf), fmf);
+ Value realNumerator1 = arith::AddFOp::create(
+ rewriter, loc,
+ arith::MulFOp::create(rewriter, loc, lhsRe, rhsRealImagRatio, fmf), lhsIm,
+ fmf);
+ Value resultReal1 = arith::DivFOp::create(rewriter, loc, realNumerator1,
+ rhsRealImagDenom, fmf);
+ Value imagNumerator1 = arith::SubFOp::create(
+ rewriter, loc,
+ arith::MulFOp::create(rewriter, loc, lhsIm, rhsRealImagRatio, fmf), lhsRe,
+ fmf);
+ Value resultImag1 = arith::DivFOp::create(rewriter, loc, imagNumerator1,
+ rhsRealImagDenom, fmf);
Value rhsImagRealRatio =
- rewriter.create<arith::DivFOp>(loc, rhsIm, rhsRe, fmf);
- Value rhsImagRealDenom = rewriter.create<arith::AddFOp>(
- loc, rhsRe,
- rewriter.create<arith::MulFOp>(loc, rhsImagRealRatio, rhsIm, fmf), fmf);
- Value realNumerator2 = rewriter.create<arith::AddFOp>(
- loc, lhsRe,
- rewriter.create<arith::MulFOp>(loc, lhsIm, rhsImagRealRatio, fmf), fmf);
- Value resultReal2 = rewriter.create<arith::DivFOp>(loc, realNumerator2,
- rhsImagRealDenom, fmf);
- Value imagNumerator2 = rewriter.create<arith::SubFOp>(
- loc, lhsIm,
- rewriter.create<arith::MulFOp>(loc, lhsRe, rhsImagRealRatio, fmf), fmf);
- Value resultImag2 = rewriter.create<arith::DivFOp>(loc, imagNumerator2,
- rhsImagRealDenom, fmf);
+ arith::DivFOp::create(rewriter, loc, rhsIm, rhsRe, fmf);
+ Value rhsImagRealDenom = arith::AddFOp::create(
+ rewriter, loc, rhsRe,
+ arith::MulFOp::create(rewriter, loc, rhsImagRealRatio, rhsIm, fmf), fmf);
+ Value realNumerator2 = arith::AddFOp::create(
+ rewriter, loc, lhsRe,
+ arith::MulFOp::create(rewriter, loc, lhsIm, rhsImagRealRatio, fmf), fmf);
+ Value resultReal2 = arith::DivFOp::create(rewriter, loc, realNumerator2,
+ rhsImagRealDenom, fmf);
+ Value imagNumerator2 = arith::SubFOp::create(
+ rewriter, loc, lhsIm,
+ arith::MulFOp::create(rewriter, loc, lhsRe, rhsImagRealRatio, fmf), fmf);
+ Value resultImag2 = arith::DivFOp::create(rewriter, loc, imagNumerator2,
+ rhsImagRealDenom, fmf);
// Consider corner cases.
// Case 1. Zero denominator, numerator contains at most one NaN value.
- Value zero = rewriter.create<arith::ConstantOp>(
- loc, elementType, rewriter.getZeroAttr(elementType));
- Value rhsRealAbs = rewriter.create<math::AbsFOp>(loc, rhsRe, fmf);
- Value rhsRealIsZero = rewriter.create<arith::CmpFOp>(
- loc, arith::CmpFPredicate::OEQ, rhsRealAbs, zero);
- Value rhsImagAbs = rewriter.create<math::AbsFOp>(loc, rhsIm, fmf);
- Value rhsImagIsZero = rewriter.create<arith::CmpFOp>(
- loc, arith::CmpFPredicate::OEQ, rhsImagAbs, zero);
- Value lhsRealIsNotNaN = rewriter.create<arith::CmpFOp>(
- loc, arith::CmpFPredicate::ORD, lhsRe, zero);
- Value lhsImagIsNotNaN = rewriter.create<arith::CmpFOp>(
- loc, arith::CmpFPredicate::ORD, lhsIm, zero);
+ Value zero = arith::ConstantOp::create(rewriter, loc, elementType,
+ rewriter.getZeroAttr(elementType));
+ Value rhsRealAbs = math::AbsFOp::create(rewriter, loc, rhsRe, fmf);
+ Value rhsRealIsZero = arith::CmpFOp::create(
+ rewriter, loc, arith::CmpFPredicate::OEQ, rhsRealAbs, zero);
+ Value rhsImagAbs = math::AbsFOp::create(rewriter, loc, rhsIm, fmf);
+ Value rhsImagIsZero = arith::CmpFOp::create(
+ rewriter, loc, arith::CmpFPredicate::OEQ, rhsImagAbs, zero);
+ Value lhsRealIsNotNaN = arith::CmpFOp::create(
+ rewriter, loc, arith::CmpFPredicate::ORD, lhsRe, zero);
+ Value lhsImagIsNotNaN = arith::CmpFOp::create(
+ rewriter, loc, arith::CmpFPredicate::ORD, lhsIm, zero);
Value lhsContainsNotNaNValue =
- rewriter.create<arith::OrIOp>(loc, lhsRealIsNotNaN, lhsImagIsNotNaN);
- Value resultIsInfinity = rewriter.create<arith::AndIOp>(
- loc, lhsContainsNotNaNValue,
- rewriter.create<arith::AndIOp>(loc, rhsRealIsZero, rhsImagIsZero));
- Value inf = rewriter.create<arith::ConstantOp>(
- loc, elementType,
+ arith::OrIOp::create(rewriter, loc, lhsRealIsNotNaN, lhsImagIsNotNaN);
+ Value resultIsInfinity = arith::AndIOp::create(
+ rewriter, loc, lhsContainsNotNaNValue,
+ arith::AndIOp::create(rewriter, loc, rhsRealIsZero, rhsImagIsZero));
+ Value inf = arith::ConstantOp::create(
+ rewriter, loc, elementType,
rewriter.getFloatAttr(elementType,
APFloat::getInf(elementType.getFloatSemantics())));
Value infWithSignOfRhsReal =
- rewriter.create<math::CopySignOp>(loc, inf, rhsRe);
+ math::CopySignOp::create(rewriter, loc, inf, rhsRe);
Value infinityResultReal =
- rewriter.create<arith::MulFOp>(loc, infWithSignOfRhsReal, lhsRe, fmf);
+ arith::MulFOp::create(rewriter, loc, infWithSignOfRhsReal, lhsRe, fmf);
Value infinityResultImag =
- rewriter.create<arith::MulFOp>(loc, infWithSignOfRhsReal, lhsIm, fmf);
+ arith::MulFOp::create(rewriter, loc, infWithSignOfRhsReal, lhsIm, fmf);
// Case 2. Infinite numerator, finite denominator.
- Value rhsRealFinite = rewriter.create<arith::CmpFOp>(
- loc, arith::CmpFPredicate::ONE, rhsRealAbs, inf);
- Value rhsImagFinite = rewriter.create<arith::CmpFOp>(
- loc, arith::CmpFPredicate::ONE, rhsImagAbs, inf);
+ Value rhsRealFinite = arith::CmpFOp::create(
+ rewriter, loc, arith::CmpFPredicate::ONE, rhsRealAbs, inf);
+ Value rhsImagFinite = arith::CmpFOp::create(
+ rewriter, loc, arith::CmpFPredicate::ONE, rhsImagAbs, inf);
Value rhsFinite =
- rewriter.create<arith::AndIOp>(loc, rhsRealFinite, rhsImagFinite);
- Value lhsRealAbs = rewriter.create<math::AbsFOp>(loc, lhsRe, fmf);
- Value lhsRealInfinite = rewriter.create<arith::CmpFOp>(
- loc, arith::CmpFPredicate::OEQ, lhsRealAbs, inf);
- Value lhsImagAbs = rewriter.create<math::AbsFOp>(loc, lhsIm, fmf);
- Value lhsImagInfinite = rewriter.create<arith::CmpFOp>(
- loc, arith::CmpFPredicate::OEQ, lhsImagAbs, inf);
+ arith::AndIOp::create(rewriter, loc, rhsRealFinite, rhsImagFinite);
+ Value lhsRealAbs = math::AbsFOp::create(rewriter, loc, lhsRe, fmf);
+ Value lhsRealInfinite = arith::CmpFOp::create(
+ rewriter, loc, arith::CmpFPredicate::OEQ, lhsRealAbs, inf);
+ Value lhsImagAbs = math::AbsFOp::create(rewriter, loc, lhsIm, fmf);
+ Value lhsImagInfinite = arith::CmpFOp::create(
+ rewriter, loc, arith::CmpFPredicate::OEQ, lhsImagAbs, inf);
Value lhsInfinite =
- rewriter.create<arith::OrIOp>(loc, lhsRealInfinite, lhsImagInfinite);
+ arith::OrIOp::create(rewriter, loc, lhsRealInfinite, lhsImagInfinite);
Value infNumFiniteDenom =
- rewriter.create<arith::AndIOp>(loc, lhsInfinite, rhsFinite);
- Value one = rewriter.create<arith::ConstantOp>(
- loc, elementType, rewriter.getFloatAttr(elementType, 1));
- Value lhsRealIsInfWithSign = rewriter.create<math::CopySignOp>(
- loc, rewriter.create<arith::SelectOp>(loc, lhsRealInfinite, one, zero),
+ arith::AndIOp::create(rewriter, loc, lhsInfinite, rhsFinite);
+ Value one = arith::ConstantOp::create(rewriter, loc, elementType,
+ rewriter.getFloatAttr(elementType, 1));
+ Value lhsRealIsInfWithSign = math::CopySignOp::create(
+ rewriter, loc,
+ arith::SelectOp::create(rewriter, loc, lhsRealInfinite, one, zero),
lhsRe);
- Value lhsImagIsInfWithSign = rewriter.create<math::CopySignOp>(
- loc, rewriter.create<arith::SelectOp>(loc, lhsImagInfinite, one, zero),
+ Value lhsImagIsInfWithSign = math::CopySignOp::create(
+ rewriter, loc,
+ arith::SelectOp::create(rewriter, loc, lhsImagInfinite, one, zero),
lhsIm);
Value lhsRealIsInfWithSignTimesRhsReal =
- rewriter.create<arith::MulFOp>(loc, lhsRealIsInfWithSign, rhsRe, fmf);
+ arith::MulFOp::create(rewriter, loc, lhsRealIsInfWithSign, rhsRe, fmf);
Value lhsImagIsInfWithSignTimesRhsImag =
- rewriter.create<arith::MulFOp>(loc, lhsImagIsInfWithSign, rhsIm, fmf);
- Value resultReal3 = rewriter.create<arith::MulFOp>(
- loc, inf,
- rewriter.create<arith::AddFOp>(loc, lhsRealIsInfWithSignTimesRhsReal,
- lhsImagIsInfWithSignTimesRhsImag, fmf),
+ arith::MulFOp::create(rewriter, loc, lhsImagIsInfWithSign, rhsIm, fmf);
+ Value resultReal3 = arith::MulFOp::create(
+ rewriter, loc, inf,
+ arith::AddFOp::create(rewriter, loc, lhsRealIsInfWithSignTimesRhsReal,
+ lhsImagIsInfWithSignTimesRhsImag, fmf),
fmf);
Value lhsRealIsInfWithSignTimesRhsImag =
- rewriter.create<arith::MulFOp>(loc, lhsRealIsInfWithSign, rhsIm, fmf);
+ arith::MulFOp::create(rewriter, loc, lhsRealIsInfWithSign, rhsIm, fmf);
Value lhsImagIsInfWithSignTimesRhsReal =
- rewriter.create<arith::MulFOp>(loc, lhsImagIsInfWithSign, rhsRe, fmf);
- Value resultImag3 = rewriter.create<arith::MulFOp>(
- loc, inf,
- rewriter.create<arith::SubFOp>(loc, lhsImagIsInfWithSignTimesRhsReal,
- lhsRealIsInfWithSignTimesRhsImag, fmf),
+ arith::MulFOp::create(rewriter, loc, lhsImagIsInfWithSign, rhsRe, fmf);
+ Value resultImag3 = arith::MulFOp::create(
+ rewriter, loc, inf,
+ arith::SubFOp::create(rewriter, loc, lhsImagIsInfWithSignTimesRhsReal,
+ lhsRealIsInfWithSignTimesRhsImag, fmf),
fmf);
// Case 3: Finite numerator, infinite denominator.
- Value lhsRealFinite = rewriter.create<arith::CmpFOp>(
- loc, arith::CmpFPredicate::ONE, lhsRealAbs, inf);
- Value lhsImagFinite = rewriter.create<arith::CmpFOp>(
- loc, arith::CmpFPredicate::ONE, lhsImagAbs, inf);
+ Value lhsRealFinite = arith::CmpFOp::create(
+ rewriter, loc, arith::CmpFPredicate::ONE, lhsRealAbs, inf);
+ Value lhsImagFinite = arith::CmpFOp::create(
+ rewriter, loc, arith::CmpFPredicate::ONE, lhsImagAbs, inf);
Value lhsFinite =
- rewriter.create<arith::AndIOp>(loc, lhsRealFinite, lhsImagFinite);
- Value rhsRealInfinite = rewriter.create<arith::CmpFOp>(
- loc, arith::CmpFPredicate::OEQ, rhsRealAbs, inf);
- Value rhsImagInfinite = rewriter.create<arith::CmpFOp>(
- loc, arith::CmpFPredicate::OEQ, rhsImagAbs, inf);
+ arith::AndIOp::create(rewriter, loc, lhsRealFinite, lhsImagFinite);
+ Value rhsRealInfinite = arith::CmpFOp::create(
+ rewriter, loc, arith::CmpFPredicate::OEQ, rhsRealAbs, inf);
+ Value rhsImagInfinite = arith::CmpFOp::create(
+ rewriter, loc, arith::CmpFPredicate::OEQ, rhsImagAbs, inf);
Value rhsInfinite =
- rewriter.create<arith::OrIOp>(loc, rhsRealInfinite, rhsImagInfinite);
+ arith::OrIOp::create(rewriter, loc, rhsRealInfinite, rhsImagInfinite);
Value finiteNumInfiniteDenom =
- rewriter.create<arith::AndIOp>(loc, lhsFinite, rhsInfinite);
- Value rhsRealIsInfWithSign = rewriter.create<math::CopySignOp>(
- loc, rewriter.create<arith::SelectOp>(loc, rhsRealInfinite, one, zero),
+ arith::AndIOp::create(rewriter, loc, lhsFinite, rhsInfinite);
+ Value rhsRealIsInfWithSign = math::CopySignOp::create(
+ rewriter, loc,
+ arith::SelectOp::create(rewriter, loc, rhsRealInfinite, one, zero),
rhsRe);
- Value rhsImagIsInfWithSign = rewriter.create<math::CopySignOp>(
- loc, rewriter.create<arith::SelectOp>(loc, rhsImagInfinite, one, zero),
+ Value rhsImagIsInfWithSign = math::CopySignOp::create(
+ rewriter, loc,
+ arith::SelectOp::create(rewriter, loc, rhsImagInfinite, one, zero),
rhsIm);
Value rhsRealIsInfWithSignTimesLhsReal =
- rewriter.create<arith::MulFOp>(loc, lhsRe, rhsRealIsInfWithSign, fmf);
+ arith::MulFOp::create(rewriter, loc, lhsRe, rhsRealIsInfWithSign, fmf);
Value rhsImagIsInfWithSignTimesLhsImag =
- rewriter.create<arith::MulFOp>(loc, lhsIm, rhsImagIsInfWithSign, fmf);
- Value resultReal4 = rewriter.create<arith::MulFOp>(
- loc, zero,
- rewriter.create<arith::AddFOp>(loc, rhsRealIsInfWithSignTimesLhsReal,
- rhsImagIsInfWithSignTimesLhsImag, fmf),
+ arith::MulFOp::create(rewriter, loc, lhsIm, rhsImagIsInfWithSign, fmf);
+ Value resultReal4 = arith::MulFOp::create(
+ rewriter, loc, zero,
+ arith::AddFOp::create(rewriter, loc, rhsRealIsInfWithSignTimesLhsReal,
+ rhsImagIsInfWithSignTimesLhsImag, fmf),
fmf);
Value rhsRealIsInfWithSignTimesLhsImag =
- rewriter.create<arith::MulFOp>(loc, lhsIm, rhsRealIsInfWithSign, fmf);
+ arith::MulFOp::create(rewriter, loc, lhsIm, rhsRealIsInfWithSign, fmf);
Value rhsImagIsInfWithSignTimesLhsReal =
- rewriter.create<arith::MulFOp>(loc, lhsRe, rhsImagIsInfWithSign, fmf);
- Value resultImag4 = rewriter.create<arith::MulFOp>(
- loc, zero,
- rewriter.create<arith::SubFOp>(loc, rhsRealIsInfWithSignTimesLhsImag,
- rhsImagIsInfWithSignTimesLhsReal, fmf),
+ arith::MulFOp::create(rewriter, loc, lhsRe, rhsImagIsInfWithSign, fmf);
+ Value resultImag4 = arith::MulFOp::create(
+ rewriter, loc, zero,
+ arith::SubFOp::create(rewriter, loc, rhsRealIsInfWithSignTimesLhsImag,
+ rhsImagIsInfWithSignTimesLhsReal, fmf),
fmf);
- Value realAbsSmallerThanImagAbs = rewriter.create<arith::CmpFOp>(
- loc, arith::CmpFPredicate::OLT, rhsRealAbs, rhsImagAbs);
- Value resultReal5 = rewriter.create<arith::SelectOp>(
- loc, realAbsSmallerThanImagAbs, resultReal1, resultReal2);
- Value resultImag5 = rewriter.create<arith::SelectOp>(
- loc, realAbsSmallerThanImagAbs, resultImag1, resultImag2);
- Value resultRealSpecialCase3 = rewriter.create<arith::SelectOp>(
- loc, finiteNumInfiniteDenom, resultReal4, resultReal5);
- Value resultImagSpecialCase3 = rewriter.create<arith::SelectOp>(
- loc, finiteNumInfiniteDenom, resultImag4, resultImag5);
- Value resultRealSpecialCase2 = rewriter.create<arith::SelectOp>(
- loc, infNumFiniteDenom, resultReal3, resultRealSpecialCase3);
- Value resultImagSpecialCase2 = rewriter.create<arith::SelectOp>(
- loc, infNumFiniteDenom, resultImag3, resultImagSpecialCase3);
- Value resultRealSpecialCase1 = rewriter.create<arith::SelectOp>(
- loc, resultIsInfinity, infinityResultReal, resultRealSpecialCase2);
- Value resultImagSpecialCase1 = rewriter.create<arith::SelectOp>(
- loc, resultIsInfinity, infinityResultImag, resultImagSpecialCase2);
+ Value realAbsSmallerThanImagAbs = arith::CmpFOp::create(
+ rewriter, loc, arith::CmpFPredicate::OLT, rhsRealAbs, rhsImagAbs);
+ Value resultReal5 = arith::SelectOp::create(
+ rewriter, loc, realAbsSmallerThanImagAbs, resultReal1, resultReal2);
+ Value resultImag5 = arith::SelectOp::create(
+ rewriter, loc, realAbsSmallerThanImagAbs, resultImag1, resultImag2);
+ Value resultRealSpecialCase3 = arith::SelectOp::create(
+ rewriter, loc, finiteNumInfiniteDenom, resultReal4, resultReal5);
+ Value resultImagSpecialCase3 = arith::SelectOp::create(
+ rewriter, loc, finiteNumInfiniteDenom, resultImag4, resultImag5);
+ Value resultRealSpecialCase2 = arith::SelectOp::create(
+ rewriter, loc, infNumFiniteDenom, resultReal3, resultRealSpecialCase3);
+ Value resultImagSpecialCase2 = arith::SelectOp::create(
+ rewriter, loc, infNumFiniteDenom, resultImag3, resultImagSpecialCase3);
+ Value resultRealSpecialCase1 =
+ arith::SelectOp::create(rewriter, loc, resultIsInfinity,
+ infinityResultReal, resultRealSpecialCase2);
+ Value resultImagSpecialCase1 =
+ arith::SelectOp::create(rewriter, loc, resultIsInfinity,
+ infinityResultImag, resultImagSpecialCase2);
- Value resultRealIsNaN = rewriter.create<arith::CmpFOp>(
- loc, arith::CmpFPredicate::UNO, resultReal5, zero);
- Value resultImagIsNaN = rewriter.create<arith::CmpFOp>(
- loc, arith::CmpFPredicate::UNO, resultImag5, zero);
+ Value resultRealIsNaN = arith::CmpFOp::create(
+ rewriter, loc, arith::CmpFPredicate::UNO, resultReal5, zero);
+ Value resultImagIsNaN = arith::CmpFOp::create(
+ rewriter, loc, arith::CmpFPredicate::UNO, resultImag5, zero);
Value resultIsNaN =
- rewriter.create<arith::AndIOp>(loc, resultRealIsNaN, resultImagIsNaN);
+ arith::AndIOp::create(rewriter, loc, resultRealIsNaN, resultImagIsNaN);
- *resultRe = rewriter.create<arith::SelectOp>(
- loc, resultIsNaN, resultRealSpecialCase1, resultReal5);
- *resultIm = rewriter.create<arith::SelectOp>(
- loc, resultIsNaN, resultImagSpecialCase1, resultImag5);
+ *resultRe = arith::SelectOp::create(rewriter, loc, resultIsNaN,
+ resultRealSpecialCase1, resultReal5);
+ *resultIm = arith::SelectOp::create(rewriter, loc, resultIsNaN,
+ resultImagSpecialCase1, resultImag5);
}
diff --git a/mlir/lib/Conversion/ComplexToLLVM/ComplexToLLVM.cpp b/mlir/lib/Conversion/ComplexToLLVM/ComplexToLLVM.cpp
index e5e862315941d..86d02e6c6209f 100644
--- a/mlir/lib/Conversion/ComplexToLLVM/ComplexToLLVM.cpp
+++ b/mlir/lib/Conversion/ComplexToLLVM/ComplexToLLVM.cpp
@@ -35,7 +35,7 @@ static constexpr unsigned kImaginaryPosInComplexNumberStruct = 1;
ComplexStructBuilder ComplexStructBuilder::poison(OpBuilder &builder,
Location loc, Type type) {
- Value val = builder.create<LLVM::PoisonOp>(loc, type);
+ Value val = LLVM::PoisonOp::create(builder, loc, type);
return ComplexStructBuilder(val);
}
@@ -79,9 +79,9 @@ struct AbsOpConversion : public ConvertOpToLLVMPattern<complex::AbsOp> {
LLVM::FastmathFlagsAttr fmf = LLVM::FastmathFlagsAttr::get(
op.getContext(),
convertArithFastMathFlagsToLLVM(complexFMFAttr.getValue()));
- Value sqNorm = rewriter.create<LLVM::FAddOp>(
- loc, rewriter.create<LLVM::FMulOp>(loc, real, real, fmf),
- rewriter.create<LLVM::FMulOp>(loc, imag, imag, fmf), fmf);
+ Value sqNorm = LLVM::FAddOp::create(
+ rewriter, loc, LLVM::FMulOp::create(rewriter, loc, real, real, fmf),
+ LLVM::FMulOp::create(rewriter, loc, imag, imag, fmf), fmf);
rewriter.replaceOpWithNewOp<LLVM::SqrtOp>(op, sqNorm);
return success();
@@ -191,10 +191,10 @@ struct AddOpConversion : public ConvertOpToLLVMPattern<complex::AddOp> {
LLVM::FastmathFlagsAttr fmf = LLVM::FastmathFlagsAttr::get(
op.getContext(),
convertArithFastMathFlagsToLLVM(complexFMFAttr.getValue()));
- Value real =
- rewriter.create<LLVM::FAddOp>(loc, arg.lhs.real(), arg.rhs.real(), fmf);
- Value imag =
- rewriter.create<LLVM::FAddOp>(loc, arg.lhs.imag(), arg.rhs.imag(), fmf);
+ Value real = LLVM::FAddOp::create(rewriter, loc, arg.lhs.real(),
+ arg.rhs.real(), fmf);
+ Value imag = LLVM::FAddOp::create(rewriter, loc, arg.lhs.imag(),
+ arg.rhs.imag(), fmf);
result.setReal(rewriter, loc, real);
result.setImaginary(rewriter, loc, imag);
@@ -278,13 +278,13 @@ struct MulOpConversion : public ConvertOpToLLVMPattern<complex::MulOp> {
Value lhsRe = arg.lhs.real();
Value lhsIm = arg.lhs.imag();
- Value real = rewriter.create<LLVM::FSubOp>(
- loc, rewriter.create<LLVM::FMulOp>(loc, rhsRe, lhsRe, fmf),
- rewriter.create<LLVM::FMulOp>(loc, rhsIm, lhsIm, fmf), fmf);
+ Value real = LLVM::FSubOp::create(
+ rewriter, loc, LLVM::FMulOp::create(rewriter, loc, rhsRe, lhsRe, fmf),
+ LLVM::FMulOp::create(rewriter, loc, rhsIm, lhsIm, fmf), fmf);
- Value imag = rewriter.create<LLVM::FAddOp>(
- loc, rewriter.create<LLVM::FMulOp>(loc, lhsIm, rhsRe, fmf),
- rewriter.create<LLVM::FMulOp>(loc, lhsRe, rhsIm, fmf), fmf);
+ Value imag = LLVM::FAddOp::create(
+ rewriter, loc, LLVM::FMulOp::create(rewriter, loc, lhsIm, rhsRe, fmf),
+ LLVM::FMulOp::create(rewriter, loc, lhsRe, rhsIm, fmf), fmf);
result.setReal(rewriter, loc, real);
result.setImaginary(rewriter, loc, imag);
@@ -313,10 +313,10 @@ struct SubOpConversion : public ConvertOpToLLVMPattern<complex::SubOp> {
LLVM::FastmathFlagsAttr fmf = LLVM::FastmathFlagsAttr::get(
op.getContext(),
convertArithFastMathFlagsToLLVM(complexFMFAttr.getValue()));
- Value real =
- rewriter.create<LLVM::FSubOp>(loc, arg.lhs.real(), arg.rhs.real(), fmf);
- Value imag =
- rewriter.create<LLVM::FSubOp>(loc, arg.lhs.imag(), arg.rhs.imag(), fmf);
+ Value real = LLVM::FSubOp::create(rewriter, loc, arg.lhs.real(),
+ arg.rhs.real(), fmf);
+ Value imag = LLVM::FSubOp::create(rewriter, loc, arg.lhs.imag(),
+ arg.rhs.imag(), fmf);
result.setReal(rewriter, loc, real);
result.setImaginary(rewriter, loc, imag);
diff --git a/mlir/lib/Conversion/ComplexToLibm/ComplexToLibm.cpp b/mlir/lib/Conversion/ComplexToLibm/ComplexToLibm.cpp
index 56269d189873a..f83cac751ff05 100644
--- a/mlir/lib/Conversion/ComplexToLibm/ComplexToLibm.cpp
+++ b/mlir/lib/Conversion/ComplexToLibm/ComplexToLibm.cpp
@@ -84,8 +84,8 @@ LogicalResult ScalarOpToLibmCall<Op, TypeResolver>::matchAndRewrite(
rewriter.setInsertionPointToStart(&module->getRegion(0).front());
auto opFunctionTy = FunctionType::get(
rewriter.getContext(), op->getOperandTypes(), op->getResultTypes());
- opFunc = rewriter.create<func::FuncOp>(rewriter.getUnknownLoc(), name,
- opFunctionTy);
+ opFunc = func::FuncOp::create(rewriter, rewriter.getUnknownLoc(), name,
+ opFunctionTy);
opFunc.setPrivate();
}
assert(isa<FunctionOpInterface>(SymbolTable::lookupSymbolIn(module, name)));
diff --git a/mlir/lib/Conversion/ComplexToROCDLLibraryCalls/ComplexToROCDLLibraryCalls.cpp b/mlir/lib/Conversion/ComplexToROCDLLibraryCalls/ComplexToROCDLLibraryCalls.cpp
index 99d5424aef79a..6f0fc2965e6fd 100644
--- a/mlir/lib/Conversion/ComplexToROCDLLibraryCalls/ComplexToROCDLLibraryCalls.cpp
+++ b/mlir/lib/Conversion/ComplexToROCDLLibraryCalls/ComplexToROCDLLibraryCalls.cpp
@@ -44,8 +44,8 @@ struct ComplexOpToROCDLLibraryCalls : public OpRewritePattern<Op> {
rewriter.setInsertionPointToStart(&symTable->getRegion(0).front());
auto funcTy = FunctionType::get(
rewriter.getContext(), op->getOperandTypes(), op->getResultTypes());
- opFunc = rewriter.create<func::FuncOp>(rewriter.getUnknownLoc(), funcName,
- funcTy);
+ opFunc = func::FuncOp::create(rewriter, rewriter.getUnknownLoc(),
+ funcName, funcTy);
opFunc.setPrivate();
}
rewriter.replaceOpWithNewOp<func::CallOp>(op, funcName, op.getType(),
diff --git a/mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp b/mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp
index 0c832c452718b..eeff8a93e7a72 100644
--- a/mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp
+++ b/mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp
@@ -31,44 +31,45 @@ enum class AbsFn { abs, sqrt, rsqrt };
// Returns the absolute value, its square root or its reciprocal square root.
Value computeAbs(Value real, Value imag, arith::FastMathFlags fmf,
ImplicitLocOpBuilder &b, AbsFn fn = AbsFn::abs) {
- Value one = b.create<arith::ConstantOp>(real.getType(),
- b.getFloatAttr(real.getType(), 1.0));
+ Value one = arith::ConstantOp::create(b, real.getType(),
+ b.getFloatAttr(real.getType(), 1.0));
- Value absReal = b.create<math::AbsFOp>(real, fmf);
- Value absImag = b.create<math::AbsFOp>(imag, fmf);
+ Value absReal = math::AbsFOp::create(b, real, fmf);
+ Value absImag = math::AbsFOp::create(b, imag, fmf);
- Value max = b.create<arith::MaximumFOp>(absReal, absImag, fmf);
- Value min = b.create<arith::MinimumFOp>(absReal, absImag, fmf);
+ Value max = arith::MaximumFOp::create(b, absReal, absImag, fmf);
+ Value min = arith::MinimumFOp::create(b, absReal, absImag, fmf);
// The lowering below requires NaNs and infinities to work correctly.
arith::FastMathFlags fmfWithNaNInf = arith::bitEnumClear(
fmf, arith::FastMathFlags::nnan | arith::FastMathFlags::ninf);
- Value ratio = b.create<arith::DivFOp>(min, max, fmfWithNaNInf);
- Value ratioSq = b.create<arith::MulFOp>(ratio, ratio, fmfWithNaNInf);
- Value ratioSqPlusOne = b.create<arith::AddFOp>(ratioSq, one, fmfWithNaNInf);
+ Value ratio = arith::DivFOp::create(b, min, max, fmfWithNaNInf);
+ Value ratioSq = arith::MulFOp::create(b, ratio, ratio, fmfWithNaNInf);
+ Value ratioSqPlusOne = arith::AddFOp::create(b, ratioSq, one, fmfWithNaNInf);
Value result;
if (fn == AbsFn::rsqrt) {
- ratioSqPlusOne = b.create<math::RsqrtOp>(ratioSqPlusOne, fmfWithNaNInf);
- min = b.create<math::RsqrtOp>(min, fmfWithNaNInf);
- max = b.create<math::RsqrtOp>(max, fmfWithNaNInf);
+ ratioSqPlusOne = math::RsqrtOp::create(b, ratioSqPlusOne, fmfWithNaNInf);
+ min = math::RsqrtOp::create(b, min, fmfWithNaNInf);
+ max = math::RsqrtOp::create(b, max, fmfWithNaNInf);
}
if (fn == AbsFn::sqrt) {
- Value quarter = b.create<arith::ConstantOp>(
- real.getType(), b.getFloatAttr(real.getType(), 0.25));
+ Value quarter = arith::ConstantOp::create(
+ b, real.getType(), b.getFloatAttr(real.getType(), 0.25));
// sqrt(sqrt(a*b)) would avoid the pow, but will overflow more easily.
- Value sqrt = b.create<math::SqrtOp>(max, fmfWithNaNInf);
- Value p025 = b.create<math::PowFOp>(ratioSqPlusOne, quarter, fmfWithNaNInf);
- result = b.create<arith::MulFOp>(sqrt, p025, fmfWithNaNInf);
+ Value sqrt = math::SqrtOp::create(b, max, fmfWithNaNInf);
+ Value p025 =
+ math::PowFOp::create(b, ratioSqPlusOne, quarter, fmfWithNaNInf);
+ result = arith::MulFOp::create(b, sqrt, p025, fmfWithNaNInf);
} else {
- Value sqrt = b.create<math::SqrtOp>(ratioSqPlusOne, fmfWithNaNInf);
- result = b.create<arith::MulFOp>(max, sqrt, fmfWithNaNInf);
+ Value sqrt = math::SqrtOp::create(b, ratioSqPlusOne, fmfWithNaNInf);
+ result = arith::MulFOp::create(b, max, sqrt, fmfWithNaNInf);
}
- Value isNaN = b.create<arith::CmpFOp>(arith::CmpFPredicate::UNO, result,
- result, fmfWithNaNInf);
- return b.create<arith::SelectOp>(isNaN, min, result);
+ Value isNaN = arith::CmpFOp::create(b, arith::CmpFPredicate::UNO, result,
+ result, fmfWithNaNInf);
+ return arith::SelectOp::create(b, isNaN, min, result);
}
struct AbsOpConversion : public OpConversionPattern<complex::AbsOp> {
@@ -81,8 +82,8 @@ struct AbsOpConversion : public OpConversionPattern<complex::AbsOp> {
arith::FastMathFlags fmf = op.getFastMathFlagsAttr().getValue();
- Value real = b.create<complex::ReOp>(adaptor.getComplex());
- Value imag = b.create<complex::ImOp>(adaptor.getComplex());
+ Value real = complex::ReOp::create(b, adaptor.getComplex());
+ Value imag = complex::ImOp::create(b, adaptor.getComplex());
rewriter.replaceOp(op, computeAbs(real, imag, fmf, b));
return success();
@@ -105,28 +106,28 @@ struct Atan2OpConversion : public OpConversionPattern<complex::Atan2Op> {
Value lhs = adaptor.getLhs();
Value rhs = adaptor.getRhs();
- Value rhsSquared = b.create<complex::MulOp>(type, rhs, rhs, fmf);
- Value lhsSquared = b.create<complex::MulOp>(type, lhs, lhs, fmf);
+ Value rhsSquared = complex::MulOp::create(b, type, rhs, rhs, fmf);
+ Value lhsSquared = complex::MulOp::create(b, type, lhs, lhs, fmf);
Value rhsSquaredPlusLhsSquared =
- b.create<complex::AddOp>(type, rhsSquared, lhsSquared, fmf);
+ complex::AddOp::create(b, type, rhsSquared, lhsSquared, fmf);
Value sqrtOfRhsSquaredPlusLhsSquared =
- b.create<complex::SqrtOp>(type, rhsSquaredPlusLhsSquared, fmf);
+ complex::SqrtOp::create(b, type, rhsSquaredPlusLhsSquared, fmf);
Value zero =
- b.create<arith::ConstantOp>(elementType, b.getZeroAttr(elementType));
- Value one = b.create<arith::ConstantOp>(elementType,
- b.getFloatAttr(elementType, 1));
- Value i = b.create<complex::CreateOp>(type, zero, one);
- Value iTimesLhs = b.create<complex::MulOp>(i, lhs, fmf);
- Value rhsPlusILhs = b.create<complex::AddOp>(rhs, iTimesLhs, fmf);
+ arith::ConstantOp::create(b, elementType, b.getZeroAttr(elementType));
+ Value one = arith::ConstantOp::create(b, elementType,
+ b.getFloatAttr(elementType, 1));
+ Value i = complex::CreateOp::create(b, type, zero, one);
+ Value iTimesLhs = complex::MulOp::create(b, i, lhs, fmf);
+ Value rhsPlusILhs = complex::AddOp::create(b, rhs, iTimesLhs, fmf);
- Value divResult = b.create<complex::DivOp>(
- rhsPlusILhs, sqrtOfRhsSquaredPlusLhsSquared, fmf);
- Value logResult = b.create<complex::LogOp>(divResult, fmf);
+ Value divResult = complex::DivOp::create(
+ b, rhsPlusILhs, sqrtOfRhsSquaredPlusLhsSquared, fmf);
+ Value logResult = complex::LogOp::create(b, divResult, fmf);
- Value negativeOne = b.create<arith::ConstantOp>(
- elementType, b.getFloatAttr(elementType, -1));
- Value negativeI = b.create<complex::CreateOp>(type, zero, negativeOne);
+ Value negativeOne = arith::ConstantOp::create(
+ b, elementType, b.getFloatAttr(elementType, -1));
+ Value negativeI = complex::CreateOp::create(b, type, zero, negativeOne);
rewriter.replaceOpWithNewOp<complex::MulOp>(op, negativeI, logResult, fmf);
return success();
@@ -146,14 +147,18 @@ struct ComparisonOpConversion : public OpConversionPattern<ComparisonOp> {
auto loc = op.getLoc();
auto type = cast<ComplexType>(adaptor.getLhs().getType()).getElementType();
- Value realLhs = rewriter.create<complex::ReOp>(loc, type, adaptor.getLhs());
- Value imagLhs = rewriter.create<complex::ImOp>(loc, type, adaptor.getLhs());
- Value realRhs = rewriter.create<complex::ReOp>(loc, type, adaptor.getRhs());
- Value imagRhs = rewriter.create<complex::ImOp>(loc, type, adaptor.getRhs());
+ Value realLhs =
+ complex::ReOp::create(rewriter, loc, type, adaptor.getLhs());
+ Value imagLhs =
+ complex::ImOp::create(rewriter, loc, type, adaptor.getLhs());
+ Value realRhs =
+ complex::ReOp::create(rewriter, loc, type, adaptor.getRhs());
+ Value imagRhs =
+ complex::ImOp::create(rewriter, loc, type, adaptor.getRhs());
Value realComparison =
- rewriter.create<arith::CmpFOp>(loc, p, realLhs, realRhs);
+ arith::CmpFOp::create(rewriter, loc, p, realLhs, realRhs);
Value imagComparison =
- rewriter.create<arith::CmpFOp>(loc, p, imagLhs, imagRhs);
+ arith::CmpFOp::create(rewriter, loc, p, imagLhs, imagRhs);
rewriter.replaceOpWithNewOp<ResultCombiner>(op, realComparison,
imagComparison);
@@ -176,14 +181,14 @@ struct BinaryComplexOpConversion : public OpConversionPattern<BinaryComplexOp> {
mlir::ImplicitLocOpBuilder b(op.getLoc(), rewriter);
arith::FastMathFlagsAttr fmf = op.getFastMathFlagsAttr();
- Value realLhs = b.create<complex::ReOp>(elementType, adaptor.getLhs());
- Value realRhs = b.create<complex::ReOp>(elementType, adaptor.getRhs());
- Value resultReal = b.create<BinaryStandardOp>(elementType, realLhs, realRhs,
- fmf.getValue());
- Value imagLhs = b.create<complex::ImOp>(elementType, adaptor.getLhs());
- Value imagRhs = b.create<complex::ImOp>(elementType, adaptor.getRhs());
- Value resultImag = b.create<BinaryStandardOp>(elementType, imagLhs, imagRhs,
- fmf.getValue());
+ Value realLhs = complex::ReOp::create(b, elementType, adaptor.getLhs());
+ Value realRhs = complex::ReOp::create(b, elementType, adaptor.getRhs());
+ Value resultReal = BinaryStandardOp::create(b, elementType, realLhs,
+ realRhs, fmf.getValue());
+ Value imagLhs = complex::ImOp::create(b, elementType, adaptor.getLhs());
+ Value imagRhs = complex::ImOp::create(b, elementType, adaptor.getRhs());
+ Value resultImag = BinaryStandardOp::create(b, elementType, imagLhs,
+ imagRhs, fmf.getValue());
rewriter.replaceOpWithNewOp<complex::CreateOp>(op, type, resultReal,
resultImag);
return success();
@@ -205,20 +210,20 @@ struct TrigonometricOpConversion : public OpConversionPattern<TrigonometricOp> {
arith::FastMathFlagsAttr fmf = op.getFastMathFlagsAttr();
Value real =
- rewriter.create<complex::ReOp>(loc, elementType, adaptor.getComplex());
+ complex::ReOp::create(rewriter, loc, elementType, adaptor.getComplex());
Value imag =
- rewriter.create<complex::ImOp>(loc, elementType, adaptor.getComplex());
+ complex::ImOp::create(rewriter, loc, elementType, adaptor.getComplex());
// Trigonometric ops use a set of common building blocks to convert to real
// ops. Here we create these building blocks and call into an op-specific
// implementation in the subclass to combine them.
- Value half = rewriter.create<arith::ConstantOp>(
- loc, elementType, rewriter.getFloatAttr(elementType, 0.5));
- Value exp = rewriter.create<math::ExpOp>(loc, imag, fmf);
- Value scaledExp = rewriter.create<arith::MulFOp>(loc, half, exp, fmf);
- Value reciprocalExp = rewriter.create<arith::DivFOp>(loc, half, exp, fmf);
- Value sin = rewriter.create<math::SinOp>(loc, real, fmf);
- Value cos = rewriter.create<math::CosOp>(loc, real, fmf);
+ Value half = arith::ConstantOp::create(
+ rewriter, loc, elementType, rewriter.getFloatAttr(elementType, 0.5));
+ Value exp = math::ExpOp::create(rewriter, loc, imag, fmf);
+ Value scaledExp = arith::MulFOp::create(rewriter, loc, half, exp, fmf);
+ Value reciprocalExp = arith::DivFOp::create(rewriter, loc, half, exp, fmf);
+ Value sin = math::SinOp::create(rewriter, loc, real, fmf);
+ Value cos = math::CosOp::create(rewriter, loc, real, fmf);
auto resultPair =
combine(loc, scaledExp, reciprocalExp, sin, cos, rewriter, fmf);
@@ -251,11 +256,11 @@ struct CosOpConversion : public TrigonometricOpConversion<complex::CosOp> {
// Re(cos(x + iy)) = (0.5/t + 0.5*t) * cos x
// Im(cos(x + iy)) = (0.5/t - 0.5*t) * sin x
Value sum =
- rewriter.create<arith::AddFOp>(loc, reciprocalExp, scaledExp, fmf);
- Value resultReal = rewriter.create<arith::MulFOp>(loc, sum, cos, fmf);
+ arith::AddFOp::create(rewriter, loc, reciprocalExp, scaledExp, fmf);
+ Value resultReal = arith::MulFOp::create(rewriter, loc, sum, cos, fmf);
Value
diff =
- rewriter.create<arith::SubFOp>(loc, reciprocalExp, scaledExp, fmf);
- Value resultImag = rewriter.create<arith::MulFOp>(loc,
diff , sin, fmf);
+ arith::SubFOp::create(rewriter, loc, reciprocalExp, scaledExp, fmf);
+ Value resultImag = arith::MulFOp::create(rewriter, loc,
diff , sin, fmf);
return {resultReal, resultImag};
}
};
@@ -275,13 +280,13 @@ struct DivOpConversion : public OpConversionPattern<complex::DivOp> {
arith::FastMathFlagsAttr fmf = op.getFastMathFlagsAttr();
Value lhsReal =
- rewriter.create<complex::ReOp>(loc, elementType, adaptor.getLhs());
+ complex::ReOp::create(rewriter, loc, elementType, adaptor.getLhs());
Value lhsImag =
- rewriter.create<complex::ImOp>(loc, elementType, adaptor.getLhs());
+ complex::ImOp::create(rewriter, loc, elementType, adaptor.getLhs());
Value rhsReal =
- rewriter.create<complex::ReOp>(loc, elementType, adaptor.getRhs());
+ complex::ReOp::create(rewriter, loc, elementType, adaptor.getRhs());
Value rhsImag =
- rewriter.create<complex::ImOp>(loc, elementType, adaptor.getRhs());
+ complex::ImOp::create(rewriter, loc, elementType, adaptor.getRhs());
Value resultReal, resultImag;
@@ -318,16 +323,16 @@ struct ExpOpConversion : public OpConversionPattern<complex::ExpOp> {
arith::FastMathFlagsAttr fmf = op.getFastMathFlagsAttr();
Value real =
- rewriter.create<complex::ReOp>(loc, elementType, adaptor.getComplex());
+ complex::ReOp::create(rewriter, loc, elementType, adaptor.getComplex());
Value imag =
- rewriter.create<complex::ImOp>(loc, elementType, adaptor.getComplex());
- Value expReal = rewriter.create<math::ExpOp>(loc, real, fmf.getValue());
- Value cosImag = rewriter.create<math::CosOp>(loc, imag, fmf.getValue());
+ complex::ImOp::create(rewriter, loc, elementType, adaptor.getComplex());
+ Value expReal = math::ExpOp::create(rewriter, loc, real, fmf.getValue());
+ Value cosImag = math::CosOp::create(rewriter, loc, imag, fmf.getValue());
Value resultReal =
- rewriter.create<arith::MulFOp>(loc, expReal, cosImag, fmf.getValue());
- Value sinImag = rewriter.create<math::SinOp>(loc, imag, fmf.getValue());
+ arith::MulFOp::create(rewriter, loc, expReal, cosImag, fmf.getValue());
+ Value sinImag = math::SinOp::create(rewriter, loc, imag, fmf.getValue());
Value resultImag =
- rewriter.create<arith::MulFOp>(loc, expReal, sinImag, fmf.getValue());
+ arith::MulFOp::create(rewriter, loc, expReal, sinImag, fmf.getValue());
rewriter.replaceOpWithNewOp<complex::CreateOp>(op, type, resultReal,
resultImag);
@@ -340,11 +345,11 @@ Value evaluatePolynomial(ImplicitLocOpBuilder &b, Value arg,
arith::FastMathFlagsAttr fmf) {
auto argType = mlir::cast<FloatType>(arg.getType());
Value poly =
- b.create<arith::ConstantOp>(b.getFloatAttr(argType, coefficients[0]));
+ arith::ConstantOp::create(b, b.getFloatAttr(argType, coefficients[0]));
for (unsigned i = 1; i < coefficients.size(); ++i) {
- poly = b.create<math::FmaOp>(
- poly, arg,
- b.create<arith::ConstantOp>(b.getFloatAttr(argType, coefficients[i])),
+ poly = math::FmaOp::create(
+ b, poly, arg,
+ arith::ConstantOp::create(b, b.getFloatAttr(argType, coefficients[i])),
fmf);
}
return poly;
@@ -365,26 +370,26 @@ struct Expm1OpConversion : public OpConversionPattern<complex::Expm1Op> {
arith::FastMathFlagsAttr fmf = op.getFastMathFlagsAttr();
ImplicitLocOpBuilder b(op.getLoc(), rewriter);
- Value real = b.create<complex::ReOp>(adaptor.getComplex());
- Value imag = b.create<complex::ImOp>(adaptor.getComplex());
+ Value real = complex::ReOp::create(b, adaptor.getComplex());
+ Value imag = complex::ImOp::create(b, adaptor.getComplex());
- Value zero = b.create<arith::ConstantOp>(b.getFloatAttr(elemType, 0.0));
- Value one = b.create<arith::ConstantOp>(b.getFloatAttr(elemType, 1.0));
+ Value zero = arith::ConstantOp::create(b, b.getFloatAttr(elemType, 0.0));
+ Value one = arith::ConstantOp::create(b, b.getFloatAttr(elemType, 1.0));
- Value expm1Real = b.create<math::ExpM1Op>(real, fmf);
- Value expReal = b.create<arith::AddFOp>(expm1Real, one, fmf);
+ Value expm1Real = math::ExpM1Op::create(b, real, fmf);
+ Value expReal = arith::AddFOp::create(b, expm1Real, one, fmf);
- Value sinImag = b.create<math::SinOp>(imag, fmf);
+ Value sinImag = math::SinOp::create(b, imag, fmf);
Value cosm1Imag = emitCosm1(imag, fmf, b);
- Value cosImag = b.create<arith::AddFOp>(cosm1Imag, one, fmf);
+ Value cosImag = arith::AddFOp::create(b, cosm1Imag, one, fmf);
- Value realResult = b.create<arith::AddFOp>(
- b.create<arith::MulFOp>(expm1Real, cosImag, fmf), cosm1Imag, fmf);
+ Value realResult = arith::AddFOp::create(
+ b, arith::MulFOp::create(b, expm1Real, cosImag, fmf), cosm1Imag, fmf);
- Value imagIsZero = b.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, imag,
- zero, fmf.getValue());
- Value imagResult = b.create<arith::SelectOp>(
- imagIsZero, zero, b.create<arith::MulFOp>(expReal, sinImag, fmf));
+ Value imagIsZero = arith::CmpFOp::create(b, arith::CmpFPredicate::OEQ, imag,
+ zero, fmf.getValue());
+ Value imagResult = arith::SelectOp::create(
+ b, imagIsZero, zero, arith::MulFOp::create(b, expReal, sinImag, fmf));
rewriter.replaceOpWithNewOp<complex::CreateOp>(op, type, realResult,
imagResult);
@@ -395,8 +400,8 @@ struct Expm1OpConversion : public OpConversionPattern<complex::Expm1Op> {
Value emitCosm1(Value arg, arith::FastMathFlagsAttr fmf,
ImplicitLocOpBuilder &b) const {
auto argType = mlir::cast<FloatType>(arg.getType());
- auto negHalf = b.create<arith::ConstantOp>(b.getFloatAttr(argType, -0.5));
- auto negOne = b.create<arith::ConstantOp>(b.getFloatAttr(argType, -1.0));
+ auto negHalf = arith::ConstantOp::create(b, b.getFloatAttr(argType, -0.5));
+ auto negOne = arith::ConstantOp::create(b, b.getFloatAttr(argType, -1.0));
// Algorithm copied from cephes cosm1.
SmallVector<double, 7> kCoeffs{
@@ -405,23 +410,23 @@ struct Expm1OpConversion : public OpConversionPattern<complex::Expm1Op> {
2.4801587301570552304991E-5, -1.3888888888888872993737E-3,
4.1666666666666666609054E-2,
};
- Value cos = b.create<math::CosOp>(arg, fmf);
- Value forLargeArg = b.create<arith::AddFOp>(cos, negOne, fmf);
+ Value cos = math::CosOp::create(b, arg, fmf);
+ Value forLargeArg = arith::AddFOp::create(b, cos, negOne, fmf);
- Value argPow2 = b.create<arith::MulFOp>(arg, arg, fmf);
- Value argPow4 = b.create<arith::MulFOp>(argPow2, argPow2, fmf);
+ Value argPow2 = arith::MulFOp::create(b, arg, arg, fmf);
+ Value argPow4 = arith::MulFOp::create(b, argPow2, argPow2, fmf);
Value poly = evaluatePolynomial(b, argPow2, kCoeffs, fmf);
auto forSmallArg =
- b.create<arith::AddFOp>(b.create<arith::MulFOp>(argPow4, poly, fmf),
- b.create<arith::MulFOp>(negHalf, argPow2, fmf));
+ arith::AddFOp::create(b, arith::MulFOp::create(b, argPow4, poly, fmf),
+ arith::MulFOp::create(b, negHalf, argPow2, fmf));
// (pi/4)^2 is approximately 0.61685
Value piOver4Pow2 =
- b.create<arith::ConstantOp>(b.getFloatAttr(argType, 0.61685));
- Value cond = b.create<arith::CmpFOp>(arith::CmpFPredicate::OGE, argPow2,
- piOver4Pow2, fmf.getValue());
- return b.create<arith::SelectOp>(cond, forLargeArg, forSmallArg);
+ arith::ConstantOp::create(b, b.getFloatAttr(argType, 0.61685));
+ Value cond = arith::CmpFOp::create(b, arith::CmpFPredicate::OGE, argPow2,
+ piOver4Pow2, fmf.getValue());
+ return arith::SelectOp::create(b, cond, forLargeArg, forSmallArg);
}
};
@@ -436,13 +441,13 @@ struct LogOpConversion : public OpConversionPattern<complex::LogOp> {
arith::FastMathFlagsAttr fmf = op.getFastMathFlagsAttr();
mlir::ImplicitLocOpBuilder b(op.getLoc(), rewriter);
- Value abs = b.create<complex::AbsOp>(elementType, adaptor.getComplex(),
- fmf.getValue());
- Value resultReal = b.create<math::LogOp>(elementType, abs, fmf.getValue());
- Value real = b.create<complex::ReOp>(elementType, adaptor.getComplex());
- Value imag = b.create<complex::ImOp>(elementType, adaptor.getComplex());
+ Value abs = complex::AbsOp::create(b, elementType, adaptor.getComplex(),
+ fmf.getValue());
+ Value resultReal = math::LogOp::create(b, elementType, abs, fmf.getValue());
+ Value real = complex::ReOp::create(b, elementType, adaptor.getComplex());
+ Value imag = complex::ImOp::create(b, elementType, adaptor.getComplex());
Value resultImag =
- b.create<math::Atan2Op>(elementType, imag, real, fmf.getValue());
+ math::Atan2Op::create(b, elementType, imag, real, fmf.getValue());
rewriter.replaceOpWithNewOp<complex::CreateOp>(op, type, resultReal,
resultImag);
return success();
@@ -460,40 +465,42 @@ struct Log1pOpConversion : public OpConversionPattern<complex::Log1pOp> {
arith::FastMathFlags fmf = op.getFastMathFlagsAttr().getValue();
mlir::ImplicitLocOpBuilder b(op.getLoc(), rewriter);
- Value real = b.create<complex::ReOp>(adaptor.getComplex());
- Value imag = b.create<complex::ImOp>(adaptor.getComplex());
+ Value real = complex::ReOp::create(b, adaptor.getComplex());
+ Value imag = complex::ImOp::create(b, adaptor.getComplex());
- Value half = b.create<arith::ConstantOp>(elementType,
- b.getFloatAttr(elementType, 0.5));
- Value one = b.create<arith::ConstantOp>(elementType,
- b.getFloatAttr(elementType, 1));
- Value realPlusOne = b.create<arith::AddFOp>(real, one, fmf);
- Value absRealPlusOne = b.create<math::AbsFOp>(realPlusOne, fmf);
- Value absImag = b.create<math::AbsFOp>(imag, fmf);
+ Value half = arith::ConstantOp::create(b, elementType,
+ b.getFloatAttr(elementType, 0.5));
+ Value one = arith::ConstantOp::create(b, elementType,
+ b.getFloatAttr(elementType, 1));
+ Value realPlusOne = arith::AddFOp::create(b, real, one, fmf);
+ Value absRealPlusOne = math::AbsFOp::create(b, realPlusOne, fmf);
+ Value absImag = math::AbsFOp::create(b, imag, fmf);
- Value maxAbs = b.create<arith::MaximumFOp>(absRealPlusOne, absImag, fmf);
- Value minAbs = b.create<arith::MinimumFOp>(absRealPlusOne, absImag, fmf);
+ Value maxAbs = arith::MaximumFOp::create(b, absRealPlusOne, absImag, fmf);
+ Value minAbs = arith::MinimumFOp::create(b, absRealPlusOne, absImag, fmf);
- Value useReal = b.create<arith::CmpFOp>(arith::CmpFPredicate::OGT,
- realPlusOne, absImag, fmf);
- Value maxMinusOne = b.create<arith::SubFOp>(maxAbs, one, fmf);
+ Value useReal = arith::CmpFOp::create(b, arith::CmpFPredicate::OGT,
+ realPlusOne, absImag, fmf);
+ Value maxMinusOne = arith::SubFOp::create(b, maxAbs, one, fmf);
Value maxAbsOfRealPlusOneAndImagMinusOne =
- b.create<arith::SelectOp>(useReal, real, maxMinusOne);
+ arith::SelectOp::create(b, useReal, real, maxMinusOne);
arith::FastMathFlags fmfWithNaNInf = arith::bitEnumClear(
fmf, arith::FastMathFlags::nnan | arith::FastMathFlags::ninf);
- Value minMaxRatio = b.create<arith::DivFOp>(minAbs, maxAbs, fmfWithNaNInf);
+ Value minMaxRatio = arith::DivFOp::create(b, minAbs, maxAbs, fmfWithNaNInf);
Value logOfMaxAbsOfRealPlusOneAndImag =
- b.create<math::Log1pOp>(maxAbsOfRealPlusOneAndImagMinusOne, fmf);
- Value logOfSqrtPart = b.create<math::Log1pOp>(
- b.create<arith::MulFOp>(minMaxRatio, minMaxRatio, fmfWithNaNInf),
+ math::Log1pOp::create(b, maxAbsOfRealPlusOneAndImagMinusOne, fmf);
+ Value logOfSqrtPart = math::Log1pOp::create(
+ b, arith::MulFOp::create(b, minMaxRatio, minMaxRatio, fmfWithNaNInf),
fmfWithNaNInf);
- Value r = b.create<arith::AddFOp>(
- b.create<arith::MulFOp>(half, logOfSqrtPart, fmfWithNaNInf),
+ Value r = arith::AddFOp::create(
+ b, arith::MulFOp::create(b, half, logOfSqrtPart, fmfWithNaNInf),
logOfMaxAbsOfRealPlusOneAndImag, fmfWithNaNInf);
- Value resultReal = b.create<arith::SelectOp>(
- b.create<arith::CmpFOp>(arith::CmpFPredicate::UNO, r, r, fmfWithNaNInf),
+ Value resultReal = arith::SelectOp::create(
+ b,
+ arith::CmpFOp::create(b, arith::CmpFPredicate::UNO, r, r,
+ fmfWithNaNInf),
minAbs, r);
- Value resultImag = b.create<math::Atan2Op>(imag, realPlusOne, fmf);
+ Value resultImag = math::Atan2Op::create(b, imag, realPlusOne, fmf);
rewriter.replaceOpWithNewOp<complex::CreateOp>(op, type, resultReal,
resultImag);
return success();
@@ -511,22 +518,22 @@ struct MulOpConversion : public OpConversionPattern<complex::MulOp> {
auto elementType = cast<FloatType>(type.getElementType());
arith::FastMathFlagsAttr fmf = op.getFastMathFlagsAttr();
auto fmfValue = fmf.getValue();
- Value lhsReal = b.create<complex::ReOp>(elementType, adaptor.getLhs());
- Value lhsImag = b.create<complex::ImOp>(elementType, adaptor.getLhs());
- Value rhsReal = b.create<complex::ReOp>(elementType, adaptor.getRhs());
- Value rhsImag = b.create<complex::ImOp>(elementType, adaptor.getRhs());
+ Value lhsReal = complex::ReOp::create(b, elementType, adaptor.getLhs());
+ Value lhsImag = complex::ImOp::create(b, elementType, adaptor.getLhs());
+ Value rhsReal = complex::ReOp::create(b, elementType, adaptor.getRhs());
+ Value rhsImag = complex::ImOp::create(b, elementType, adaptor.getRhs());
Value lhsRealTimesRhsReal =
- b.create<arith::MulFOp>(lhsReal, rhsReal, fmfValue);
+ arith::MulFOp::create(b, lhsReal, rhsReal, fmfValue);
Value lhsImagTimesRhsImag =
- b.create<arith::MulFOp>(lhsImag, rhsImag, fmfValue);
- Value real = b.create<arith::SubFOp>(lhsRealTimesRhsReal,
- lhsImagTimesRhsImag, fmfValue);
+ arith::MulFOp::create(b, lhsImag, rhsImag, fmfValue);
+ Value real = arith::SubFOp::create(b, lhsRealTimesRhsReal,
+ lhsImagTimesRhsImag, fmfValue);
Value lhsImagTimesRhsReal =
- b.create<arith::MulFOp>(lhsImag, rhsReal, fmfValue);
+ arith::MulFOp::create(b, lhsImag, rhsReal, fmfValue);
Value lhsRealTimesRhsImag =
- b.create<arith::MulFOp>(lhsReal, rhsImag, fmfValue);
- Value imag = b.create<arith::AddFOp>(lhsImagTimesRhsReal,
- lhsRealTimesRhsImag, fmfValue);
+ arith::MulFOp::create(b, lhsReal, rhsImag, fmfValue);
+ Value imag = arith::AddFOp::create(b, lhsImagTimesRhsReal,
+ lhsRealTimesRhsImag, fmfValue);
rewriter.replaceOpWithNewOp<complex::CreateOp>(op, type, real, imag);
return success();
}
@@ -543,11 +550,11 @@ struct NegOpConversion : public OpConversionPattern<complex::NegOp> {
auto elementType = cast<FloatType>(type.getElementType());
Value real =
- rewriter.create<complex::ReOp>(loc, elementType, adaptor.getComplex());
+ complex::ReOp::create(rewriter, loc, elementType, adaptor.getComplex());
Value imag =
- rewriter.create<complex::ImOp>(loc, elementType, adaptor.getComplex());
- Value negReal = rewriter.create<arith::NegFOp>(loc, real);
- Value negImag = rewriter.create<arith::NegFOp>(loc, imag);
+ complex::ImOp::create(rewriter, loc, elementType, adaptor.getComplex());
+ Value negReal = arith::NegFOp::create(rewriter, loc, real);
+ Value negImag = arith::NegFOp::create(rewriter, loc, imag);
rewriter.replaceOpWithNewOp<complex::CreateOp>(op, type, negReal, negImag);
return success();
}
@@ -570,11 +577,11 @@ struct SinOpConversion : public TrigonometricOpConversion<complex::SinOp> {
// Re(sin(x + iy)) = (0.5*t + 0.5/t) * sin x
// Im(cos(x + iy)) = (0.5*t - 0.5/t) * cos x
Value sum =
- rewriter.create<arith::AddFOp>(loc, scaledExp, reciprocalExp, fmf);
- Value resultReal = rewriter.create<arith::MulFOp>(loc, sum, sin, fmf);
+ arith::AddFOp::create(rewriter, loc, scaledExp, reciprocalExp, fmf);
+ Value resultReal = arith::MulFOp::create(rewriter, loc, sum, sin, fmf);
Value
diff =
- rewriter.create<arith::SubFOp>(loc, scaledExp, reciprocalExp, fmf);
- Value resultImag = rewriter.create<arith::MulFOp>(loc,
diff , cos, fmf);
+ arith::SubFOp::create(rewriter, loc, scaledExp, reciprocalExp, fmf);
+ Value resultImag = arith::MulFOp::create(rewriter, loc,
diff , cos, fmf);
return {resultReal, resultImag};
}
};
@@ -593,64 +600,65 @@ struct SqrtOpConversion : public OpConversionPattern<complex::SqrtOp> {
arith::FastMathFlags fmf = op.getFastMathFlagsAttr().getValue();
auto cst = [&](APFloat v) {
- return b.create<arith::ConstantOp>(elementType,
- b.getFloatAttr(elementType, v));
+ return arith::ConstantOp::create(b, elementType,
+ b.getFloatAttr(elementType, v));
};
const auto &floatSemantics = elementType.getFloatSemantics();
Value zero = cst(APFloat::getZero(floatSemantics));
- Value half = b.create<arith::ConstantOp>(elementType,
- b.getFloatAttr(elementType, 0.5));
+ Value half = arith::ConstantOp::create(b, elementType,
+ b.getFloatAttr(elementType, 0.5));
- Value real = b.create<complex::ReOp>(elementType, adaptor.getComplex());
- Value imag = b.create<complex::ImOp>(elementType, adaptor.getComplex());
+ Value real = complex::ReOp::create(b, elementType, adaptor.getComplex());
+ Value imag = complex::ImOp::create(b, elementType, adaptor.getComplex());
Value absSqrt = computeAbs(real, imag, fmf, b, AbsFn::sqrt);
- Value argArg = b.create<math::Atan2Op>(imag, real, fmf);
- Value sqrtArg = b.create<arith::MulFOp>(argArg, half, fmf);
- Value cos = b.create<math::CosOp>(sqrtArg, fmf);
- Value sin = b.create<math::SinOp>(sqrtArg, fmf);
+ Value argArg = math::Atan2Op::create(b, imag, real, fmf);
+ Value sqrtArg = arith::MulFOp::create(b, argArg, half, fmf);
+ Value cos = math::CosOp::create(b, sqrtArg, fmf);
+ Value sin = math::SinOp::create(b, sqrtArg, fmf);
// sin(atan2(0, inf)) = 0, sqrt(abs(inf)) = inf, but we can't multiply
// 0 * inf.
Value sinIsZero =
- b.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, sin, zero, fmf);
+ arith::CmpFOp::create(b, arith::CmpFPredicate::OEQ, sin, zero, fmf);
- Value resultReal = b.create<arith::MulFOp>(absSqrt, cos, fmf);
- Value resultImag = b.create<arith::SelectOp>(
- sinIsZero, zero, b.create<arith::MulFOp>(absSqrt, sin, fmf));
+ Value resultReal = arith::MulFOp::create(b, absSqrt, cos, fmf);
+ Value resultImag = arith::SelectOp::create(
+ b, sinIsZero, zero, arith::MulFOp::create(b, absSqrt, sin, fmf));
if (!arith::bitEnumContainsAll(fmf, arith::FastMathFlags::nnan |
arith::FastMathFlags::ninf)) {
Value inf = cst(APFloat::getInf(floatSemantics));
Value negInf = cst(APFloat::getInf(floatSemantics, true));
Value nan = cst(APFloat::getNaN(floatSemantics));
- Value absImag = b.create<math::AbsFOp>(elementType, imag, fmf);
+ Value absImag = math::AbsFOp::create(b, elementType, imag, fmf);
- Value absImagIsInf =
- b.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, absImag, inf, fmf);
- Value absImagIsNotInf =
- b.create<arith::CmpFOp>(arith::CmpFPredicate::ONE, absImag, inf, fmf);
+ Value absImagIsInf = arith::CmpFOp::create(b, arith::CmpFPredicate::OEQ,
+ absImag, inf, fmf);
+ Value absImagIsNotInf = arith::CmpFOp::create(
+ b, arith::CmpFPredicate::ONE, absImag, inf, fmf);
Value realIsInf =
- b.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, real, inf, fmf);
- Value realIsNegInf =
- b.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, real, negInf, fmf);
+ arith::CmpFOp::create(b, arith::CmpFPredicate::OEQ, real, inf, fmf);
+ Value realIsNegInf = arith::CmpFOp::create(b, arith::CmpFPredicate::OEQ,
+ real, negInf, fmf);
- resultReal = b.create<arith::SelectOp>(
- b.create<arith::AndIOp>(realIsNegInf, absImagIsNotInf), zero,
+ resultReal = arith::SelectOp::create(
+ b, arith::AndIOp::create(b, realIsNegInf, absImagIsNotInf), zero,
resultReal);
- resultReal = b.create<arith::SelectOp>(
- b.create<arith::OrIOp>(absImagIsInf, realIsInf), inf, resultReal);
+ resultReal = arith::SelectOp::create(
+ b, arith::OrIOp::create(b, absImagIsInf, realIsInf), inf, resultReal);
- Value imagSignInf = b.create<math::CopySignOp>(inf, imag, fmf);
- resultImag = b.create<arith::SelectOp>(
- b.create<arith::CmpFOp>(arith::CmpFPredicate::UNO, absSqrt, absSqrt),
+ Value imagSignInf = math::CopySignOp::create(b, inf, imag, fmf);
+ resultImag = arith::SelectOp::create(
+ b,
+ arith::CmpFOp::create(b, arith::CmpFPredicate::UNO, absSqrt, absSqrt),
nan, resultImag);
- resultImag = b.create<arith::SelectOp>(
- b.create<arith::OrIOp>(absImagIsInf, realIsNegInf), imagSignInf,
+ resultImag = arith::SelectOp::create(
+ b, arith::OrIOp::create(b, absImagIsInf, realIsNegInf), imagSignInf,
resultImag);
}
Value resultIsZero =
- b.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, absSqrt, zero, fmf);
- resultReal = b.create<arith::SelectOp>(resultIsZero, zero, resultReal);
- resultImag = b.create<arith::SelectOp>(resultIsZero, zero, resultImag);
+ arith::CmpFOp::create(b, arith::CmpFPredicate::OEQ, absSqrt, zero, fmf);
+ resultReal = arith::SelectOp::create(b, resultIsZero, zero, resultReal);
+ resultImag = arith::SelectOp::create(b, resultIsZero, zero, resultImag);
rewriter.replaceOpWithNewOp<complex::CreateOp>(op, type, resultReal,
resultImag);
@@ -669,19 +677,20 @@ struct SignOpConversion : public OpConversionPattern<complex::SignOp> {
mlir::ImplicitLocOpBuilder b(op.getLoc(), rewriter);
arith::FastMathFlagsAttr fmf = op.getFastMathFlagsAttr();
- Value real = b.create<complex::ReOp>(elementType, adaptor.getComplex());
- Value imag = b.create<complex::ImOp>(elementType, adaptor.getComplex());
+ Value real = complex::ReOp::create(b, elementType, adaptor.getComplex());
+ Value imag = complex::ImOp::create(b, elementType, adaptor.getComplex());
Value zero =
- b.create<arith::ConstantOp>(elementType, b.getZeroAttr(elementType));
+ arith::ConstantOp::create(b, elementType, b.getZeroAttr(elementType));
Value realIsZero =
- b.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, real, zero);
+ arith::CmpFOp::create(b, arith::CmpFPredicate::OEQ, real, zero);
Value imagIsZero =
- b.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, imag, zero);
- Value isZero = b.create<arith::AndIOp>(realIsZero, imagIsZero);
- auto abs = b.create<complex::AbsOp>(elementType, adaptor.getComplex(), fmf);
- Value realSign = b.create<arith::DivFOp>(real, abs, fmf);
- Value imagSign = b.create<arith::DivFOp>(imag, abs, fmf);
- Value sign = b.create<complex::CreateOp>(type, realSign, imagSign);
+ arith::CmpFOp::create(b, arith::CmpFPredicate::OEQ, imag, zero);
+ Value isZero = arith::AndIOp::create(b, realIsZero, imagIsZero);
+ auto abs =
+ complex::AbsOp::create(b, elementType, adaptor.getComplex(), fmf);
+ Value realSign = arith::DivFOp::create(b, real, abs, fmf);
+ Value imagSign = arith::DivFOp::create(b, imag, abs, fmf);
+ Value sign = complex::CreateOp::create(b, type, realSign, imagSign);
rewriter.replaceOpWithNewOp<arith::SelectOp>(op, isZero,
adaptor.getComplex(), sign);
return success();
@@ -703,84 +712,84 @@ struct TanTanhOpConversion : public OpConversionPattern<Op> {
const auto &floatSemantics = elementType.getFloatSemantics();
Value real =
- b.create<complex::ReOp>(loc, elementType, adaptor.getComplex());
+ complex::ReOp::create(b, loc, elementType, adaptor.getComplex());
Value imag =
- b.create<complex::ImOp>(loc, elementType, adaptor.getComplex());
- Value negOne = b.create<arith::ConstantOp>(
- elementType, b.getFloatAttr(elementType, -1.0));
+ complex::ImOp::create(b, loc, elementType, adaptor.getComplex());
+ Value negOne = arith::ConstantOp::create(b, elementType,
+ b.getFloatAttr(elementType, -1.0));
if constexpr (std::is_same_v<Op, complex::TanOp>) {
// tan(x+yi) = -i*tanh(-y + xi)
std::swap(real, imag);
- real = b.create<arith::MulFOp>(real, negOne, fmf);
+ real = arith::MulFOp::create(b, real, negOne, fmf);
}
auto cst = [&](APFloat v) {
- return b.create<arith::ConstantOp>(elementType,
- b.getFloatAttr(elementType, v));
+ return arith::ConstantOp::create(b, elementType,
+ b.getFloatAttr(elementType, v));
};
Value inf = cst(APFloat::getInf(floatSemantics));
- Value four = b.create<arith::ConstantOp>(elementType,
- b.getFloatAttr(elementType, 4.0));
- Value twoReal = b.create<arith::AddFOp>(real, real, fmf);
- Value negTwoReal = b.create<arith::MulFOp>(negOne, twoReal, fmf);
-
- Value expTwoRealMinusOne = b.create<math::ExpM1Op>(twoReal, fmf);
- Value expNegTwoRealMinusOne = b.create<math::ExpM1Op>(negTwoReal, fmf);
- Value realNum =
- b.create<arith::SubFOp>(expTwoRealMinusOne, expNegTwoRealMinusOne, fmf);
-
- Value cosImag = b.create<math::CosOp>(imag, fmf);
- Value cosImagSq = b.create<arith::MulFOp>(cosImag, cosImag, fmf);
- Value twoCosTwoImagPlusOne = b.create<arith::MulFOp>(cosImagSq, four, fmf);
- Value sinImag = b.create<math::SinOp>(imag, fmf);
-
- Value imagNum = b.create<arith::MulFOp>(
- four, b.create<arith::MulFOp>(cosImag, sinImag, fmf), fmf);
-
- Value expSumMinusTwo =
- b.create<arith::AddFOp>(expTwoRealMinusOne, expNegTwoRealMinusOne, fmf);
+ Value four = arith::ConstantOp::create(b, elementType,
+ b.getFloatAttr(elementType, 4.0));
+ Value twoReal = arith::AddFOp::create(b, real, real, fmf);
+ Value negTwoReal = arith::MulFOp::create(b, negOne, twoReal, fmf);
+
+ Value expTwoRealMinusOne = math::ExpM1Op::create(b, twoReal, fmf);
+ Value expNegTwoRealMinusOne = math::ExpM1Op::create(b, negTwoReal, fmf);
+ Value realNum = arith::SubFOp::create(b, expTwoRealMinusOne,
+ expNegTwoRealMinusOne, fmf);
+
+ Value cosImag = math::CosOp::create(b, imag, fmf);
+ Value cosImagSq = arith::MulFOp::create(b, cosImag, cosImag, fmf);
+ Value twoCosTwoImagPlusOne = arith::MulFOp::create(b, cosImagSq, four, fmf);
+ Value sinImag = math::SinOp::create(b, imag, fmf);
+
+ Value imagNum = arith::MulFOp::create(
+ b, four, arith::MulFOp::create(b, cosImag, sinImag, fmf), fmf);
+
+ Value expSumMinusTwo = arith::AddFOp::create(b, expTwoRealMinusOne,
+ expNegTwoRealMinusOne, fmf);
Value denom =
- b.create<arith::AddFOp>(expSumMinusTwo, twoCosTwoImagPlusOne, fmf);
+ arith::AddFOp::create(b, expSumMinusTwo, twoCosTwoImagPlusOne, fmf);
- Value isInf = b.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ,
- expSumMinusTwo, inf, fmf);
- Value realLimit = b.create<math::CopySignOp>(negOne, real, fmf);
+ Value isInf = arith::CmpFOp::create(b, arith::CmpFPredicate::OEQ,
+ expSumMinusTwo, inf, fmf);
+ Value realLimit = math::CopySignOp::create(b, negOne, real, fmf);
- Value resultReal = b.create<arith::SelectOp>(
- isInf, realLimit, b.create<arith::DivFOp>(realNum, denom, fmf));
- Value resultImag = b.create<arith::DivFOp>(imagNum, denom, fmf);
+ Value resultReal = arith::SelectOp::create(
+ b, isInf, realLimit, arith::DivFOp::create(b, realNum, denom, fmf));
+ Value resultImag = arith::DivFOp::create(b, imagNum, denom, fmf);
if (!arith::bitEnumContainsAll(fmf, arith::FastMathFlags::nnan |
arith::FastMathFlags::ninf)) {
- Value absReal = b.create<math::AbsFOp>(real, fmf);
- Value zero = b.create<arith::ConstantOp>(
- elementType, b.getFloatAttr(elementType, 0.0));
+ Value absReal = math::AbsFOp::create(b, real, fmf);
+ Value zero = arith::ConstantOp::create(b, elementType,
+ b.getFloatAttr(elementType, 0.0));
Value nan = cst(APFloat::getNaN(floatSemantics));
- Value absRealIsInf =
- b.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, absReal, inf, fmf);
+ Value absRealIsInf = arith::CmpFOp::create(b, arith::CmpFPredicate::OEQ,
+ absReal, inf, fmf);
Value imagIsZero =
- b.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, imag, zero, fmf);
- Value absRealIsNotInf = b.create<arith::XOrIOp>(
- absRealIsInf, b.create<arith::ConstantIntOp>(true, /*width=*/1));
+ arith::CmpFOp::create(b, arith::CmpFPredicate::OEQ, imag, zero, fmf);
+ Value absRealIsNotInf = arith::XOrIOp::create(
+ b, absRealIsInf, arith::ConstantIntOp::create(b, true, /*width=*/1));
- Value imagNumIsNaN = b.create<arith::CmpFOp>(arith::CmpFPredicate::UNO,
- imagNum, imagNum, fmf);
+ Value imagNumIsNaN = arith::CmpFOp::create(b, arith::CmpFPredicate::UNO,
+ imagNum, imagNum, fmf);
Value resultRealIsNaN =
- b.create<arith::AndIOp>(imagNumIsNaN, absRealIsNotInf);
- Value resultImagIsZero = b.create<arith::OrIOp>(
- imagIsZero, b.create<arith::AndIOp>(absRealIsInf, imagNumIsNaN));
+ arith::AndIOp::create(b, imagNumIsNaN, absRealIsNotInf);
+ Value resultImagIsZero = arith::OrIOp::create(
+ b, imagIsZero, arith::AndIOp::create(b, absRealIsInf, imagNumIsNaN));
- resultReal = b.create<arith::SelectOp>(resultRealIsNaN, nan, resultReal);
+ resultReal = arith::SelectOp::create(b, resultRealIsNaN, nan, resultReal);
resultImag =
- b.create<arith::SelectOp>(resultImagIsZero, zero, resultImag);
+ arith::SelectOp::create(b, resultImagIsZero, zero, resultImag);
}
if constexpr (std::is_same_v<Op, complex::TanOp>) {
// tan(x+yi) = -i*tanh(-y + xi)
std::swap(resultReal, resultImag);
- resultImag = b.create<arith::MulFOp>(resultImag, negOne, fmf);
+ resultImag = arith::MulFOp::create(b, resultImag, negOne, fmf);
}
rewriter.replaceOpWithNewOp<complex::CreateOp>(op, type, resultReal,
@@ -799,10 +808,10 @@ struct ConjOpConversion : public OpConversionPattern<complex::ConjOp> {
auto type = cast<ComplexType>(adaptor.getComplex().getType());
auto elementType = cast<FloatType>(type.getElementType());
Value real =
- rewriter.create<complex::ReOp>(loc, elementType, adaptor.getComplex());
+ complex::ReOp::create(rewriter, loc, elementType, adaptor.getComplex());
Value imag =
- rewriter.create<complex::ImOp>(loc, elementType, adaptor.getComplex());
- Value negImag = rewriter.create<arith::NegFOp>(loc, elementType, imag);
+ complex::ImOp::create(rewriter, loc, elementType, adaptor.getComplex());
+ Value negImag = arith::NegFOp::create(rewriter, loc, elementType, imag);
rewriter.replaceOpWithNewOp<complex::CreateOp>(op, type, real, negImag);
@@ -818,97 +827,102 @@ static Value powOpConversionImpl(mlir::ImplicitLocOpBuilder &builder,
arith::FastMathFlags fmf) {
auto elementType = cast<FloatType>(type.getElementType());
- Value a = builder.create<complex::ReOp>(lhs);
- Value b = builder.create<complex::ImOp>(lhs);
+ Value a = complex::ReOp::create(builder, lhs);
+ Value b = complex::ImOp::create(builder, lhs);
- Value abs = builder.create<complex::AbsOp>(lhs, fmf);
- Value absToC = builder.create<math::PowFOp>(abs, c, fmf);
+ Value abs = complex::AbsOp::create(builder, lhs, fmf);
+ Value absToC = math::PowFOp::create(builder, abs, c, fmf);
- Value negD = builder.create<arith::NegFOp>(d, fmf);
- Value argLhs = builder.create<math::Atan2Op>(b, a, fmf);
- Value negDArgLhs = builder.create<arith::MulFOp>(negD, argLhs, fmf);
- Value expNegDArgLhs = builder.create<math::ExpOp>(negDArgLhs, fmf);
+ Value negD = arith::NegFOp::create(builder, d, fmf);
+ Value argLhs = math::Atan2Op::create(builder, b, a, fmf);
+ Value negDArgLhs = arith::MulFOp::create(builder, negD, argLhs, fmf);
+ Value expNegDArgLhs = math::ExpOp::create(builder, negDArgLhs, fmf);
- Value coeff = builder.create<arith::MulFOp>(absToC, expNegDArgLhs, fmf);
- Value lnAbs = builder.create<math::LogOp>(abs, fmf);
- Value cArgLhs = builder.create<arith::MulFOp>(c, argLhs, fmf);
- Value dLnAbs = builder.create<arith::MulFOp>(d, lnAbs, fmf);
- Value q = builder.create<arith::AddFOp>(cArgLhs, dLnAbs, fmf);
- Value cosQ = builder.create<math::CosOp>(q, fmf);
- Value sinQ = builder.create<math::SinOp>(q, fmf);
+ Value coeff = arith::MulFOp::create(builder, absToC, expNegDArgLhs, fmf);
+ Value lnAbs = math::LogOp::create(builder, abs, fmf);
+ Value cArgLhs = arith::MulFOp::create(builder, c, argLhs, fmf);
+ Value dLnAbs = arith::MulFOp::create(builder, d, lnAbs, fmf);
+ Value q = arith::AddFOp::create(builder, cArgLhs, dLnAbs, fmf);
+ Value cosQ = math::CosOp::create(builder, q, fmf);
+ Value sinQ = math::SinOp::create(builder, q, fmf);
- Value inf = builder.create<arith::ConstantOp>(
- elementType,
+ Value inf = arith::ConstantOp::create(
+ builder, elementType,
builder.getFloatAttr(elementType,
APFloat::getInf(elementType.getFloatSemantics())));
- Value zero = builder.create<arith::ConstantOp>(
- elementType, builder.getFloatAttr(elementType, 0.0));
- Value one = builder.create<arith::ConstantOp>(
- elementType, builder.getFloatAttr(elementType, 1.0));
- Value complexOne = builder.create<complex::CreateOp>(type, one, zero);
- Value complexZero = builder.create<complex::CreateOp>(type, zero, zero);
- Value complexInf = builder.create<complex::CreateOp>(type, inf, zero);
+ Value zero = arith::ConstantOp::create(
+ builder, elementType, builder.getFloatAttr(elementType, 0.0));
+ Value one = arith::ConstantOp::create(builder, elementType,
+ builder.getFloatAttr(elementType, 1.0));
+ Value complexOne = complex::CreateOp::create(builder, type, one, zero);
+ Value complexZero = complex::CreateOp::create(builder, type, zero, zero);
+ Value complexInf = complex::CreateOp::create(builder, type, inf, zero);
// Case 0:
// d^c is 0 if d is 0 and c > 0. 0^0 is defined to be 1.0, see
// Branch Cuts for Complex Elementary Functions or Much Ado About
// Nothing's Sign Bit, W. Kahan, Section 10.
Value absEqZero =
- builder.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, abs, zero, fmf);
+ arith::CmpFOp::create(builder, arith::CmpFPredicate::OEQ, abs, zero, fmf);
Value dEqZero =
- builder.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, d, zero, fmf);
+ arith::CmpFOp::create(builder, arith::CmpFPredicate::OEQ, d, zero, fmf);
Value cEqZero =
- builder.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, c, zero, fmf);
+ arith::CmpFOp::create(builder, arith::CmpFPredicate::OEQ, c, zero, fmf);
Value bEqZero =
- builder.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, b, zero, fmf);
+ arith::CmpFOp::create(builder, arith::CmpFPredicate::OEQ, b, zero, fmf);
Value zeroLeC =
- builder.create<arith::CmpFOp>(arith::CmpFPredicate::OLE, zero, c, fmf);
- Value coeffCosQ = builder.create<arith::MulFOp>(coeff, cosQ, fmf);
- Value coeffSinQ = builder.create<arith::MulFOp>(coeff, sinQ, fmf);
+ arith::CmpFOp::create(builder, arith::CmpFPredicate::OLE, zero, c, fmf);
+ Value coeffCosQ = arith::MulFOp::create(builder, coeff, cosQ, fmf);
+ Value coeffSinQ = arith::MulFOp::create(builder, coeff, sinQ, fmf);
Value complexOneOrZero =
- builder.create<arith::SelectOp>(cEqZero, complexOne, complexZero);
+ arith::SelectOp::create(builder, cEqZero, complexOne, complexZero);
Value coeffCosSin =
- builder.create<complex::CreateOp>(type, coeffCosQ, coeffSinQ);
- Value cutoff0 = builder.create<arith::SelectOp>(
- builder.create<arith::AndIOp>(
- builder.create<arith::AndIOp>(absEqZero, dEqZero), zeroLeC),
+ complex::CreateOp::create(builder, type, coeffCosQ, coeffSinQ);
+ Value cutoff0 = arith::SelectOp::create(
+ builder,
+ arith::AndIOp::create(
+ builder, arith::AndIOp::create(builder, absEqZero, dEqZero), zeroLeC),
complexOneOrZero, coeffCosSin);
// Case 1:
// x^0 is defined to be 1 for any x, see
// Branch Cuts for Complex Elementary Functions or Much Ado About
// Nothing's Sign Bit, W. Kahan, Section 10.
- Value rhsEqZero = builder.create<arith::AndIOp>(cEqZero, dEqZero);
+ Value rhsEqZero = arith::AndIOp::create(builder, cEqZero, dEqZero);
Value cutoff1 =
- builder.create<arith::SelectOp>(rhsEqZero, complexOne, cutoff0);
+ arith::SelectOp::create(builder, rhsEqZero, complexOne, cutoff0);
// Case 2:
// 1^(c + d*i) = 1 + 0*i
- Value lhsEqOne = builder.create<arith::AndIOp>(
- builder.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, a, one, fmf),
+ Value lhsEqOne = arith::AndIOp::create(
+ builder,
+ arith::CmpFOp::create(builder, arith::CmpFPredicate::OEQ, a, one, fmf),
bEqZero);
Value cutoff2 =
- builder.create<arith::SelectOp>(lhsEqOne, complexOne, cutoff1);
+ arith::SelectOp::create(builder, lhsEqOne, complexOne, cutoff1);
// Case 3:
// inf^(c + 0*i) = inf + 0*i, c > 0
- Value lhsEqInf = builder.create<arith::AndIOp>(
- builder.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, a, inf, fmf),
+ Value lhsEqInf = arith::AndIOp::create(
+ builder,
+ arith::CmpFOp::create(builder, arith::CmpFPredicate::OEQ, a, inf, fmf),
bEqZero);
- Value rhsGt0 = builder.create<arith::AndIOp>(
- dEqZero,
- builder.create<arith::CmpFOp>(arith::CmpFPredicate::OGT, c, zero, fmf));
- Value cutoff3 = builder.create<arith::SelectOp>(
- builder.create<arith::AndIOp>(lhsEqInf, rhsGt0), complexInf, cutoff2);
+ Value rhsGt0 = arith::AndIOp::create(
+ builder, dEqZero,
+ arith::CmpFOp::create(builder, arith::CmpFPredicate::OGT, c, zero, fmf));
+ Value cutoff3 = arith::SelectOp::create(
+ builder, arith::AndIOp::create(builder, lhsEqInf, rhsGt0), complexInf,
+ cutoff2);
// Case 4:
// inf^(c + 0*i) = 0 + 0*i, c < 0
- Value rhsLt0 = builder.create<arith::AndIOp>(
- dEqZero,
- builder.create<arith::CmpFOp>(arith::CmpFPredicate::OLT, c, zero, fmf));
- Value cutoff4 = builder.create<arith::SelectOp>(
- builder.create<arith::AndIOp>(lhsEqInf, rhsLt0), complexZero, cutoff3);
+ Value rhsLt0 = arith::AndIOp::create(
+ builder, dEqZero,
+ arith::CmpFOp::create(builder, arith::CmpFPredicate::OLT, c, zero, fmf));
+ Value cutoff4 = arith::SelectOp::create(
+ builder, arith::AndIOp::create(builder, lhsEqInf, rhsLt0), complexZero,
+ cutoff3);
return cutoff4;
}
@@ -923,8 +937,8 @@ struct PowOpConversion : public OpConversionPattern<complex::PowOp> {
auto type = cast<ComplexType>(adaptor.getLhs().getType());
auto elementType = cast<FloatType>(type.getElementType());
- Value c = builder.create<complex::ReOp>(elementType, adaptor.getRhs());
- Value d = builder.create<complex::ImOp>(elementType, adaptor.getRhs());
+ Value c = complex::ReOp::create(builder, elementType, adaptor.getRhs());
+ Value d = complex::ImOp::create(builder, elementType, adaptor.getRhs());
rewriter.replaceOp(op, {powOpConversionImpl(builder, type, adaptor.getLhs(),
c, d, op.getFastmath())});
@@ -945,64 +959,64 @@ struct RsqrtOpConversion : public OpConversionPattern<complex::RsqrtOp> {
arith::FastMathFlags fmf = op.getFastMathFlagsAttr().getValue();
auto cst = [&](APFloat v) {
- return b.create<arith::ConstantOp>(elementType,
- b.getFloatAttr(elementType, v));
+ return arith::ConstantOp::create(b, elementType,
+ b.getFloatAttr(elementType, v));
};
const auto &floatSemantics = elementType.getFloatSemantics();
Value zero = cst(APFloat::getZero(floatSemantics));
Value inf = cst(APFloat::getInf(floatSemantics));
- Value negHalf = b.create<arith::ConstantOp>(
- elementType, b.getFloatAttr(elementType, -0.5));
+ Value negHalf = arith::ConstantOp::create(
+ b, elementType, b.getFloatAttr(elementType, -0.5));
Value nan = cst(APFloat::getNaN(floatSemantics));
- Value real = b.create<complex::ReOp>(elementType, adaptor.getComplex());
- Value imag = b.create<complex::ImOp>(elementType, adaptor.getComplex());
+ Value real = complex::ReOp::create(b, elementType, adaptor.getComplex());
+ Value imag = complex::ImOp::create(b, elementType, adaptor.getComplex());
Value absRsqrt = computeAbs(real, imag, fmf, b, AbsFn::rsqrt);
- Value argArg = b.create<math::Atan2Op>(imag, real, fmf);
- Value rsqrtArg = b.create<arith::MulFOp>(argArg, negHalf, fmf);
- Value cos = b.create<math::CosOp>(rsqrtArg, fmf);
- Value sin = b.create<math::SinOp>(rsqrtArg, fmf);
+ Value argArg = math::Atan2Op::create(b, imag, real, fmf);
+ Value rsqrtArg = arith::MulFOp::create(b, argArg, negHalf, fmf);
+ Value cos = math::CosOp::create(b, rsqrtArg, fmf);
+ Value sin = math::SinOp::create(b, rsqrtArg, fmf);
- Value resultReal = b.create<arith::MulFOp>(absRsqrt, cos, fmf);
- Value resultImag = b.create<arith::MulFOp>(absRsqrt, sin, fmf);
+ Value resultReal = arith::MulFOp::create(b, absRsqrt, cos, fmf);
+ Value resultImag = arith::MulFOp::create(b, absRsqrt, sin, fmf);
if (!arith::bitEnumContainsAll(fmf, arith::FastMathFlags::nnan |
arith::FastMathFlags::ninf)) {
- Value negOne = b.create<arith::ConstantOp>(
- elementType, b.getFloatAttr(elementType, -1));
+ Value negOne = arith::ConstantOp::create(b, elementType,
+ b.getFloatAttr(elementType, -1));
- Value realSignedZero = b.create<math::CopySignOp>(zero, real, fmf);
- Value imagSignedZero = b.create<math::CopySignOp>(zero, imag, fmf);
+ Value realSignedZero = math::CopySignOp::create(b, zero, real, fmf);
+ Value imagSignedZero = math::CopySignOp::create(b, zero, imag, fmf);
Value negImagSignedZero =
- b.create<arith::MulFOp>(negOne, imagSignedZero, fmf);
+ arith::MulFOp::create(b, negOne, imagSignedZero, fmf);
- Value absReal = b.create<math::AbsFOp>(real, fmf);
- Value absImag = b.create<math::AbsFOp>(imag, fmf);
+ Value absReal = math::AbsFOp::create(b, real, fmf);
+ Value absImag = math::AbsFOp::create(b, imag, fmf);
- Value absImagIsInf =
- b.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, absImag, inf, fmf);
+ Value absImagIsInf = arith::CmpFOp::create(b, arith::CmpFPredicate::OEQ,
+ absImag, inf, fmf);
Value realIsNan =
- b.create<arith::CmpFOp>(arith::CmpFPredicate::UNO, real, real, fmf);
- Value realIsInf =
- b.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, absReal, inf, fmf);
- Value inIsNanInf = b.create<arith::AndIOp>(absImagIsInf, realIsNan);
+ arith::CmpFOp::create(b, arith::CmpFPredicate::UNO, real, real, fmf);
+ Value realIsInf = arith::CmpFOp::create(b, arith::CmpFPredicate::OEQ,
+ absReal, inf, fmf);
+ Value inIsNanInf = arith::AndIOp::create(b, absImagIsInf, realIsNan);
- Value resultIsZero = b.create<arith::OrIOp>(inIsNanInf, realIsInf);
+ Value resultIsZero = arith::OrIOp::create(b, inIsNanInf, realIsInf);
resultReal =
- b.create<arith::SelectOp>(resultIsZero, realSignedZero, resultReal);
- resultImag = b.create<arith::SelectOp>(resultIsZero, negImagSignedZero,
- resultImag);
+ arith::SelectOp::create(b, resultIsZero, realSignedZero, resultReal);
+ resultImag = arith::SelectOp::create(b, resultIsZero, negImagSignedZero,
+ resultImag);
}
Value isRealZero =
- b.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, real, zero, fmf);
+ arith::CmpFOp::create(b, arith::CmpFPredicate::OEQ, real, zero, fmf);
Value isImagZero =
- b.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, imag, zero, fmf);
- Value isZero = b.create<arith::AndIOp>(isRealZero, isImagZero);
+ arith::CmpFOp::create(b, arith::CmpFPredicate::OEQ, imag, zero, fmf);
+ Value isZero = arith::AndIOp::create(b, isRealZero, isImagZero);
- resultReal = b.create<arith::SelectOp>(isZero, inf, resultReal);
- resultImag = b.create<arith::SelectOp>(isZero, nan, resultImag);
+ resultReal = arith::SelectOp::create(b, isZero, inf, resultReal);
+ resultImag = arith::SelectOp::create(b, isZero, nan, resultImag);
rewriter.replaceOpWithNewOp<complex::CreateOp>(op, type, resultReal,
resultImag);
@@ -1021,9 +1035,9 @@ struct AngleOpConversion : public OpConversionPattern<complex::AngleOp> {
arith::FastMathFlagsAttr fmf = op.getFastMathFlagsAttr();
Value real =
- rewriter.create<complex::ReOp>(loc, type, adaptor.getComplex());
+ complex::ReOp::create(rewriter, loc, type, adaptor.getComplex());
Value imag =
- rewriter.create<complex::ImOp>(loc, type, adaptor.getComplex());
+ complex::ImOp::create(rewriter, loc, type, adaptor.getComplex());
rewriter.replaceOpWithNewOp<math::Atan2Op>(op, imag, real, fmf);
diff --git a/mlir/lib/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.cpp b/mlir/lib/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.cpp
index 13a084407e53f..ff6d369176393 100644
--- a/mlir/lib/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.cpp
+++ b/mlir/lib/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.cpp
@@ -73,13 +73,13 @@ struct AssertOpLowering : public ConvertOpToLLVMPattern<cf::AssertOp> {
OpBuilder::InsertionGuard guard(rewriter);
rewriter.setInsertionPointToStart(module.getBody());
auto abortFuncTy = LLVM::LLVMFunctionType::get(getVoidType(), {});
- abortFunc = rewriter.create<LLVM::LLVMFuncOp>(rewriter.getUnknownLoc(),
- "abort", abortFuncTy);
+ abortFunc = LLVM::LLVMFuncOp::create(rewriter, rewriter.getUnknownLoc(),
+ "abort", abortFuncTy);
}
- rewriter.create<LLVM::CallOp>(loc, abortFunc, ValueRange());
- rewriter.create<LLVM::UnreachableOp>(loc);
+ LLVM::CallOp::create(rewriter, loc, abortFunc, ValueRange());
+ LLVM::UnreachableOp::create(rewriter, loc);
} else {
- rewriter.create<LLVM::BrOp>(loc, ValueRange(), continuationBlock);
+ LLVM::BrOp::create(rewriter, loc, ValueRange(), continuationBlock);
}
// Generate assertion test.
diff --git a/mlir/lib/Conversion/ControlFlowToSCF/ControlFlowToSCF.cpp b/mlir/lib/Conversion/ControlFlowToSCF/ControlFlowToSCF.cpp
index 9831dcaaaccc8..c8311eb5a6433 100644
--- a/mlir/lib/Conversion/ControlFlowToSCF/ControlFlowToSCF.cpp
+++ b/mlir/lib/Conversion/ControlFlowToSCF/ControlFlowToSCF.cpp
@@ -33,8 +33,8 @@ ControlFlowToSCFTransformation::createStructuredBranchRegionOp(
MutableArrayRef<Region> regions) {
if (auto condBrOp = dyn_cast<cf::CondBranchOp>(controlFlowCondOp)) {
assert(regions.size() == 2);
- auto ifOp = builder.create<scf::IfOp>(controlFlowCondOp->getLoc(),
- resultTypes, condBrOp.getCondition());
+ auto ifOp = scf::IfOp::create(builder, controlFlowCondOp->getLoc(),
+ resultTypes, condBrOp.getCondition());
ifOp.getThenRegion().takeBody(regions[0]);
ifOp.getElseRegion().takeBody(regions[1]);
return ifOp.getOperation();
@@ -43,8 +43,8 @@ ControlFlowToSCFTransformation::createStructuredBranchRegionOp(
if (auto switchOp = dyn_cast<cf::SwitchOp>(controlFlowCondOp)) {
// `getCFGSwitchValue` returns an i32 that we need to convert to index
// fist.
- auto cast = builder.create<arith::IndexCastUIOp>(
- controlFlowCondOp->getLoc(), builder.getIndexType(),
+ auto cast = arith::IndexCastUIOp::create(
+ builder, controlFlowCondOp->getLoc(), builder.getIndexType(),
switchOp.getFlag());
SmallVector<int64_t> cases;
if (auto caseValues = switchOp.getCaseValues())
@@ -55,8 +55,9 @@ ControlFlowToSCFTransformation::createStructuredBranchRegionOp(
assert(regions.size() == cases.size() + 1);
- auto indexSwitchOp = builder.create<scf::IndexSwitchOp>(
- controlFlowCondOp->getLoc(), resultTypes, cast, cases, cases.size());
+ auto indexSwitchOp =
+ scf::IndexSwitchOp::create(builder, controlFlowCondOp->getLoc(),
+ resultTypes, cast, cases, cases.size());
indexSwitchOp.getDefaultRegion().takeBody(regions[0]);
for (auto &&[targetRegion, sourceRegion] :
@@ -75,7 +76,7 @@ LogicalResult
ControlFlowToSCFTransformation::createStructuredBranchRegionTerminatorOp(
Location loc, OpBuilder &builder, Operation *branchRegionOp,
Operation *replacedControlFlowOp, ValueRange results) {
- builder.create<scf::YieldOp>(loc, results);
+ scf::YieldOp::create(builder, loc, results);
return success();
}
@@ -84,23 +85,24 @@ ControlFlowToSCFTransformation::createStructuredDoWhileLoopOp(
OpBuilder &builder, Operation *replacedOp, ValueRange loopVariablesInit,
Value condition, ValueRange loopVariablesNextIter, Region &&loopBody) {
Location loc = replacedOp->getLoc();
- auto whileOp = builder.create<scf::WhileOp>(loc, loopVariablesInit.getTypes(),
- loopVariablesInit);
+ auto whileOp = scf::WhileOp::create(
+ builder, loc, loopVariablesInit.getTypes(), loopVariablesInit);
whileOp.getBefore().takeBody(loopBody);
builder.setInsertionPointToEnd(&whileOp.getBefore().back());
// `getCFGSwitchValue` returns a i32. We therefore need to truncate the
// condition to i1 first. It is guaranteed to be either 0 or 1 already.
- builder.create<scf::ConditionOp>(
- loc, builder.create<arith::TruncIOp>(loc, builder.getI1Type(), condition),
+ scf::ConditionOp::create(
+ builder, loc,
+ arith::TruncIOp::create(builder, loc, builder.getI1Type(), condition),
loopVariablesNextIter);
Block *afterBlock = builder.createBlock(&whileOp.getAfter());
afterBlock->addArguments(
loopVariablesInit.getTypes(),
SmallVector<Location>(loopVariablesInit.size(), loc));
- builder.create<scf::YieldOp>(loc, afterBlock->getArguments());
+ scf::YieldOp::create(builder, loc, afterBlock->getArguments());
return whileOp.getOperation();
}
@@ -108,8 +110,8 @@ ControlFlowToSCFTransformation::createStructuredDoWhileLoopOp(
Value ControlFlowToSCFTransformation::getCFGSwitchValue(Location loc,
OpBuilder &builder,
unsigned int value) {
- return builder.create<arith::ConstantOp>(loc,
- builder.getI32IntegerAttr(value));
+ return arith::ConstantOp::create(builder, loc,
+ builder.getI32IntegerAttr(value));
}
void ControlFlowToSCFTransformation::createCFGSwitchOp(
@@ -117,15 +119,15 @@ void ControlFlowToSCFTransformation::createCFGSwitchOp(
ArrayRef<unsigned int> caseValues, BlockRange caseDestinations,
ArrayRef<ValueRange> caseArguments, Block *defaultDest,
ValueRange defaultArgs) {
- builder.create<cf::SwitchOp>(loc, flag, defaultDest, defaultArgs,
- llvm::to_vector_of<int32_t>(caseValues),
- caseDestinations, caseArguments);
+ cf::SwitchOp::create(builder, loc, flag, defaultDest, defaultArgs,
+ llvm::to_vector_of<int32_t>(caseValues),
+ caseDestinations, caseArguments);
}
Value ControlFlowToSCFTransformation::getUndefValue(Location loc,
OpBuilder &builder,
Type type) {
- return builder.create<ub::PoisonOp>(loc, type, nullptr);
+ return ub::PoisonOp::create(builder, loc, type, nullptr);
}
FailureOr<Operation *>
diff --git a/mlir/lib/Conversion/FuncToEmitC/FuncToEmitC.cpp b/mlir/lib/Conversion/FuncToEmitC/FuncToEmitC.cpp
index f8dc06f41ab87..197caeb4ffbfa 100644
--- a/mlir/lib/Conversion/FuncToEmitC/FuncToEmitC.cpp
+++ b/mlir/lib/Conversion/FuncToEmitC/FuncToEmitC.cpp
@@ -99,8 +99,8 @@ class FuncOpConversion final : public OpConversionPattern<func::FuncOp> {
}
// Create the converted `emitc.func` op.
- emitc::FuncOp newFuncOp = rewriter.create<emitc::FuncOp>(
- funcOp.getLoc(), funcOp.getName(),
+ emitc::FuncOp newFuncOp = emitc::FuncOp::create(
+ rewriter, funcOp.getLoc(), funcOp.getName(),
FunctionType::get(rewriter.getContext(),
signatureConverter.getConvertedTypes(),
resultType ? TypeRange(resultType) : TypeRange()));
diff --git a/mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp b/mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp
index 36235636d6ba2..67bb1c14c99a2 100644
--- a/mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp
+++ b/mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp
@@ -115,8 +115,8 @@ static void wrapForExternalCallers(OpBuilder &rewriter, Location loc,
SmallVector<NamedAttribute> attributes;
filterFuncAttributes(funcOp, attributes);
- auto wrapperFuncOp = rewriter.create<LLVM::LLVMFuncOp>(
- loc, llvm::formatv("_mlir_ciface_{0}", funcOp.getName()).str(),
+ auto wrapperFuncOp = LLVM::LLVMFuncOp::create(
+ rewriter, loc, llvm::formatv("_mlir_ciface_{0}", funcOp.getName()).str(),
wrapperFuncType, LLVM::Linkage::External, /*dsoLocal=*/false,
/*cconv=*/LLVM::CConv::C, /*comdat=*/nullptr, attributes);
propagateArgResAttrs(rewriter, !!resultStructType, funcOp, wrapperFuncOp);
@@ -129,14 +129,14 @@ static void wrapForExternalCallers(OpBuilder &rewriter, Location loc,
for (auto [index, argType] : llvm::enumerate(type.getInputs())) {
Value arg = wrapperFuncOp.getArgument(index + argOffset);
if (auto memrefType = dyn_cast<MemRefType>(argType)) {
- Value loaded = rewriter.create<LLVM::LoadOp>(
- loc, typeConverter.convertType(memrefType), arg);
+ Value loaded = LLVM::LoadOp::create(
+ rewriter, loc, typeConverter.convertType(memrefType), arg);
MemRefDescriptor::unpack(rewriter, loc, loaded, memrefType, args);
continue;
}
if (isa<UnrankedMemRefType>(argType)) {
- Value loaded = rewriter.create<LLVM::LoadOp>(
- loc, typeConverter.convertType(argType), arg);
+ Value loaded = LLVM::LoadOp::create(
+ rewriter, loc, typeConverter.convertType(argType), arg);
UnrankedMemRefDescriptor::unpack(rewriter, loc, loaded, args);
continue;
}
@@ -144,14 +144,14 @@ static void wrapForExternalCallers(OpBuilder &rewriter, Location loc,
args.push_back(arg);
}
- auto call = rewriter.create<LLVM::CallOp>(loc, newFuncOp, args);
+ auto call = LLVM::CallOp::create(rewriter, loc, newFuncOp, args);
if (resultStructType) {
- rewriter.create<LLVM::StoreOp>(loc, call.getResult(),
- wrapperFuncOp.getArgument(0));
- rewriter.create<LLVM::ReturnOp>(loc, ValueRange{});
+ LLVM::StoreOp::create(rewriter, loc, call.getResult(),
+ wrapperFuncOp.getArgument(0));
+ LLVM::ReturnOp::create(rewriter, loc, ValueRange{});
} else {
- rewriter.create<LLVM::ReturnOp>(loc, call.getResults());
+ LLVM::ReturnOp::create(rewriter, loc, call.getResults());
}
}
@@ -182,8 +182,8 @@ static void wrapExternalFunction(OpBuilder &builder, Location loc,
filterFuncAttributes(funcOp, attributes);
// Create the auxiliary function.
- auto wrapperFunc = builder.create<LLVM::LLVMFuncOp>(
- loc, llvm::formatv("_mlir_ciface_{0}", funcOp.getName()).str(),
+ auto wrapperFunc = LLVM::LLVMFuncOp::create(
+ builder, loc, llvm::formatv("_mlir_ciface_{0}", funcOp.getName()).str(),
wrapperType, LLVM::Linkage::External, /*dsoLocal=*/false,
/*cconv=*/LLVM::CConv::C, /*comdat=*/nullptr, attributes);
propagateArgResAttrs(builder, !!resultStructType, funcOp, wrapperFunc);
@@ -201,11 +201,11 @@ static void wrapExternalFunction(OpBuilder &builder, Location loc,
if (resultStructType) {
// Allocate the struct on the stack and pass the pointer.
Type resultType = cast<LLVM::LLVMFunctionType>(wrapperType).getParamType(0);
- Value one = builder.create<LLVM::ConstantOp>(
- loc, typeConverter.convertType(builder.getIndexType()),
+ Value one = LLVM::ConstantOp::create(
+ builder, loc, typeConverter.convertType(builder.getIndexType()),
builder.getIntegerAttr(builder.getIndexType(), 1));
Value result =
- builder.create<LLVM::AllocaOp>(loc, resultType, resultStructType, one);
+ LLVM::AllocaOp::create(builder, loc, resultType, resultStructType, one);
args.push_back(result);
}
@@ -229,12 +229,12 @@ static void wrapExternalFunction(OpBuilder &builder, Location loc,
wrapperArgsRange.take_front(numToDrop));
auto ptrTy = LLVM::LLVMPointerType::get(builder.getContext());
- Value one = builder.create<LLVM::ConstantOp>(
- loc, typeConverter.convertType(builder.getIndexType()),
+ Value one = LLVM::ConstantOp::create(
+ builder, loc, typeConverter.convertType(builder.getIndexType()),
builder.getIntegerAttr(builder.getIndexType(), 1));
- Value allocated = builder.create<LLVM::AllocaOp>(
- loc, ptrTy, packed.getType(), one, /*alignment=*/0);
- builder.create<LLVM::StoreOp>(loc, packed, allocated);
+ Value allocated = LLVM::AllocaOp::create(
+ builder, loc, ptrTy, packed.getType(), one, /*alignment=*/0);
+ LLVM::StoreOp::create(builder, loc, packed, allocated);
arg = allocated;
} else {
arg = wrapperArgsRange[0];
@@ -245,14 +245,14 @@ static void wrapExternalFunction(OpBuilder &builder, Location loc,
}
assert(wrapperArgsRange.empty() && "did not map some of the arguments");
- auto call = builder.create<LLVM::CallOp>(loc, wrapperFunc, args);
+ auto call = LLVM::CallOp::create(builder, loc, wrapperFunc, args);
if (resultStructType) {
Value result =
- builder.create<LLVM::LoadOp>(loc, resultStructType, args.front());
- builder.create<LLVM::ReturnOp>(loc, result);
+ LLVM::LoadOp::create(builder, loc, resultStructType, args.front());
+ LLVM::ReturnOp::create(builder, loc, result);
} else {
- builder.create<LLVM::ReturnOp>(loc, call.getResults());
+ LLVM::ReturnOp::create(builder, loc, call.getResults());
}
}
@@ -283,7 +283,7 @@ static void restoreByValRefArgumentType(
Type resTy = typeConverter.convertType(
cast<TypeAttr>(byValRefAttr->getValue()).getValue());
- Value valueArg = rewriter.create<LLVM::LoadOp>(arg.getLoc(), resTy, arg);
+ Value valueArg = LLVM::LoadOp::create(rewriter, arg.getLoc(), resTy, arg);
rewriter.replaceUsesOfBlockArgument(arg, valueArg);
}
}
@@ -357,8 +357,8 @@ FailureOr<LLVM::LLVMFuncOp> mlir::convertFuncOpToLLVMFuncOp(
symbolTable.remove(funcOp);
}
- auto newFuncOp = rewriter.create<LLVM::LLVMFuncOp>(
- funcOp.getLoc(), funcOp.getName(), llvmType, linkage,
+ auto newFuncOp = LLVM::LLVMFuncOp::create(
+ rewriter, funcOp.getLoc(), funcOp.getName(), llvmType, linkage,
/*dsoLocal=*/false, /*cconv=*/LLVM::CConv::C, /*comdat=*/nullptr,
attributes);
@@ -509,7 +509,7 @@ struct ConstantOpLowering : public ConvertOpToLLVMPattern<func::ConstantOp> {
return rewriter.notifyMatchFailure(op, "failed to convert result type");
auto newOp =
- rewriter.create<LLVM::AddressOfOp>(op.getLoc(), type, op.getValue());
+ LLVM::AddressOfOp::create(rewriter, op.getLoc(), type, op.getValue());
for (const NamedAttribute &attr : op->getAttrs()) {
if (attr.getName().strref() == "value")
continue;
@@ -556,9 +556,10 @@ struct CallOpInterfaceLowering : public ConvertOpToLLVMPattern<CallOpType> {
auto promoted = this->getTypeConverter()->promoteOperands(
callOp.getLoc(), /*opOperands=*/callOp->getOperands(),
adaptor.getOperands(), rewriter, useBarePtrCallConv);
- auto newOp = rewriter.create<LLVM::CallOp>(
- callOp.getLoc(), packedResult ? TypeRange(packedResult) : TypeRange(),
- promoted, callOp->getAttrs());
+ auto newOp = LLVM::CallOp::create(rewriter, callOp.getLoc(),
+ packedResult ? TypeRange(packedResult)
+ : TypeRange(),
+ promoted, callOp->getAttrs());
newOp.getProperties().operandSegmentSizes = {
static_cast<int32_t>(promoted.size()), 0};
@@ -573,8 +574,8 @@ struct CallOpInterfaceLowering : public ConvertOpToLLVMPattern<CallOpType> {
// Extract individual results from the structure and return them as list.
results.reserve(numResults);
for (unsigned i = 0; i < numResults; ++i) {
- results.push_back(rewriter.create<LLVM::ExtractValueOp>(
- callOp.getLoc(), newOp->getResult(0), i));
+ results.push_back(LLVM::ExtractValueOp::create(
+ rewriter, callOp.getLoc(), newOp->getResult(0), i));
}
}
@@ -726,9 +727,9 @@ struct ReturnOpLowering : public ConvertOpToLLVMPattern<func::ReturnOp> {
return rewriter.notifyMatchFailure(op, "could not convert result types");
}
- Value packed = rewriter.create<LLVM::PoisonOp>(loc, packedType);
+ Value packed = LLVM::PoisonOp::create(rewriter, loc, packedType);
for (auto [idx, operand] : llvm::enumerate(updatedOperands)) {
- packed = rewriter.create<LLVM::InsertValueOp>(loc, packed, operand, idx);
+ packed = LLVM::InsertValueOp::create(rewriter, loc, packed, operand, idx);
}
rewriter.replaceOpWithNewOp<LLVM::ReturnOp>(op, TypeRange(), packed,
op->getAttrs());
diff --git a/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp b/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp
index 01ca5e99a9aff..1037e296c8128 100644
--- a/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp
+++ b/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp
@@ -28,7 +28,7 @@ LLVM::LLVMFuncOp mlir::getOrDefineFunction(gpu::GPUModuleOp moduleOp,
if (!(ret = moduleOp.template lookupSymbol<LLVM::LLVMFuncOp>(name))) {
OpBuilder::InsertionGuard guard(b);
b.setInsertionPointToStart(moduleOp.getBody());
- ret = b.create<LLVM::LLVMFuncOp>(loc, name, type, LLVM::Linkage::External);
+ ret = LLVM::LLVMFuncOp::create(b, loc, name, type, LLVM::Linkage::External);
}
return ret;
}
@@ -68,9 +68,9 @@ mlir::getOrCreateStringConstant(OpBuilder &b, Location loc,
OpBuilder::InsertionGuard guard(b);
b.setInsertionPointToStart(moduleOp.getBody());
SmallString<16> name = getUniqueSymbolName(moduleOp, namePrefix);
- return b.create<LLVM::GlobalOp>(loc, globalType,
- /*isConstant=*/true, LLVM::Linkage::Internal,
- name, attr, alignment, addrSpace);
+ return LLVM::GlobalOp::create(b, loc, globalType,
+ /*isConstant=*/true, LLVM::Linkage::Internal,
+ name, attr, alignment, addrSpace);
}
LogicalResult
@@ -151,8 +151,8 @@ GPUFuncOpLowering::matchAndRewrite(gpu::GPUFuncOp gpuFuncOp, OpAdaptor adaptor,
gpuFuncOp.getWorkgroupAttributionAttr(
idx, LLVM::LLVMDialect::getAlignAttrName())))
alignment = alignAttr.getInt();
- auto globalOp = rewriter.create<LLVM::GlobalOp>(
- gpuFuncOp.getLoc(), arrayType, /*isConstant=*/false,
+ auto globalOp = LLVM::GlobalOp::create(
+ rewriter, gpuFuncOp.getLoc(), arrayType, /*isConstant=*/false,
LLVM::Linkage::Internal, name, /*value=*/Attribute(), alignment,
workgroupAddrSpace);
workgroupBuffers.push_back(globalOp);
@@ -220,8 +220,8 @@ GPUFuncOpLowering::matchAndRewrite(gpu::GPUFuncOp gpuFuncOp, OpAdaptor adaptor,
LLVM::CConv callingConvention = gpuFuncOp.isKernel()
? kernelCallingConvention
: nonKernelCallingConvention;
- auto llvmFuncOp = rewriter.create<LLVM::LLVMFuncOp>(
- gpuFuncOp.getLoc(), gpuFuncOp.getName(), funcType,
+ auto llvmFuncOp = LLVM::LLVMFuncOp::create(
+ rewriter, gpuFuncOp.getLoc(), gpuFuncOp.getName(), funcType,
LLVM::Linkage::External, /*dsoLocal=*/false, callingConvention,
/*comdat=*/nullptr, attributes);
@@ -266,11 +266,11 @@ GPUFuncOpLowering::matchAndRewrite(gpu::GPUFuncOp gpuFuncOp, OpAdaptor adaptor,
for (const auto [idx, global] : llvm::enumerate(workgroupBuffers)) {
auto ptrType = LLVM::LLVMPointerType::get(rewriter.getContext(),
global.getAddrSpace());
- Value address = rewriter.create<LLVM::AddressOfOp>(
- loc, ptrType, global.getSymNameAttr());
+ Value address = LLVM::AddressOfOp::create(rewriter, loc, ptrType,
+ global.getSymNameAttr());
Value memory =
- rewriter.create<LLVM::GEPOp>(loc, ptrType, global.getType(),
- address, ArrayRef<LLVM::GEPArg>{0, 0});
+ LLVM::GEPOp::create(rewriter, loc, ptrType, global.getType(),
+ address, ArrayRef<LLVM::GEPArg>{0, 0});
// Build a memref descriptor pointing to the buffer to plug with the
// existing memref infrastructure. This may use more registers than
@@ -298,15 +298,16 @@ GPUFuncOpLowering::matchAndRewrite(gpu::GPUFuncOp gpuFuncOp, OpAdaptor adaptor,
Type elementType = typeConverter->convertType(type.getElementType());
auto ptrType =
LLVM::LLVMPointerType::get(rewriter.getContext(), allocaAddrSpace);
- Value numElements = rewriter.create<LLVM::ConstantOp>(
- gpuFuncOp.getLoc(), int64Ty, type.getNumElements());
+ Value numElements = LLVM::ConstantOp::create(
+ rewriter, gpuFuncOp.getLoc(), int64Ty, type.getNumElements());
uint64_t alignment = 0;
if (auto alignAttr =
dyn_cast_or_null<IntegerAttr>(gpuFuncOp.getPrivateAttributionAttr(
idx, LLVM::LLVMDialect::getAlignAttrName())))
alignment = alignAttr.getInt();
- Value allocated = rewriter.create<LLVM::AllocaOp>(
- gpuFuncOp.getLoc(), ptrType, elementType, numElements, alignment);
+ Value allocated =
+ LLVM::AllocaOp::create(rewriter, gpuFuncOp.getLoc(), ptrType,
+ elementType, numElements, alignment);
Value descr = MemRefDescriptor::fromStaticShape(
rewriter, loc, *getTypeConverter(), type, allocated);
signatureConversion.remapInput(
@@ -418,8 +419,9 @@ LogicalResult GPUPrintfOpToHIPLowering::matchAndRewrite(
{llvmI64, ptrType, /*length (bytes)*/ llvmI64, /*isLast*/ llvmI32}));
/// Start the printf hostcall
- Value zeroI64 = rewriter.create<LLVM::ConstantOp>(loc, llvmI64, 0);
- auto printfBeginCall = rewriter.create<LLVM::CallOp>(loc, ocklBegin, zeroI64);
+ Value zeroI64 = LLVM::ConstantOp::create(rewriter, loc, llvmI64, 0);
+ auto printfBeginCall =
+ LLVM::CallOp::create(rewriter, loc, ocklBegin, zeroI64);
Value printfDesc = printfBeginCall.getResult();
// Create the global op or find an existing one.
@@ -427,21 +429,21 @@ LogicalResult GPUPrintfOpToHIPLowering::matchAndRewrite(
rewriter, loc, moduleOp, llvmI8, "printfFormat_", adaptor.getFormat());
// Get a pointer to the format string's first element and pass it to printf()
- Value globalPtr = rewriter.create<LLVM::AddressOfOp>(
- loc,
+ Value globalPtr = LLVM::AddressOfOp::create(
+ rewriter, loc,
LLVM::LLVMPointerType::get(rewriter.getContext(), global.getAddrSpace()),
global.getSymNameAttr());
Value stringStart =
- rewriter.create<LLVM::GEPOp>(loc, ptrType, global.getGlobalType(),
- globalPtr, ArrayRef<LLVM::GEPArg>{0, 0});
- Value stringLen = rewriter.create<LLVM::ConstantOp>(
- loc, llvmI64, cast<StringAttr>(global.getValueAttr()).size());
+ LLVM::GEPOp::create(rewriter, loc, ptrType, global.getGlobalType(),
+ globalPtr, ArrayRef<LLVM::GEPArg>{0, 0});
+ Value stringLen = LLVM::ConstantOp::create(
+ rewriter, loc, llvmI64, cast<StringAttr>(global.getValueAttr()).size());
- Value oneI32 = rewriter.create<LLVM::ConstantOp>(loc, llvmI32, 1);
- Value zeroI32 = rewriter.create<LLVM::ConstantOp>(loc, llvmI32, 0);
+ Value oneI32 = LLVM::ConstantOp::create(rewriter, loc, llvmI32, 1);
+ Value zeroI32 = LLVM::ConstantOp::create(rewriter, loc, llvmI32, 0);
- auto appendFormatCall = rewriter.create<LLVM::CallOp>(
- loc, ocklAppendStringN,
+ auto appendFormatCall = LLVM::CallOp::create(
+ rewriter, loc, ocklAppendStringN,
ValueRange{printfDesc, stringStart, stringLen,
adaptor.getArgs().empty() ? oneI32 : zeroI32});
printfDesc = appendFormatCall.getResult();
@@ -456,17 +458,18 @@ LogicalResult GPUPrintfOpToHIPLowering::matchAndRewrite(
SmallVector<mlir::Value, 2 + argsPerAppend + 1> arguments;
arguments.push_back(printfDesc);
arguments.push_back(
- rewriter.create<LLVM::ConstantOp>(loc, llvmI32, numArgsThisCall));
+ LLVM::ConstantOp::create(rewriter, loc, llvmI32, numArgsThisCall));
for (size_t i = group; i < bound; ++i) {
Value arg = adaptor.getArgs()[i];
if (auto floatType = dyn_cast<FloatType>(arg.getType())) {
if (!floatType.isF64())
- arg = rewriter.create<LLVM::FPExtOp>(
- loc, typeConverter->convertType(rewriter.getF64Type()), arg);
- arg = rewriter.create<LLVM::BitcastOp>(loc, llvmI64, arg);
+ arg = LLVM::FPExtOp::create(
+ rewriter, loc, typeConverter->convertType(rewriter.getF64Type()),
+ arg);
+ arg = LLVM::BitcastOp::create(rewriter, loc, llvmI64, arg);
}
if (arg.getType().getIntOrFloatBitWidth() != 64)
- arg = rewriter.create<LLVM::ZExtOp>(loc, llvmI64, arg);
+ arg = LLVM::ZExtOp::create(rewriter, loc, llvmI64, arg);
arguments.push_back(arg);
}
@@ -477,7 +480,7 @@ LogicalResult GPUPrintfOpToHIPLowering::matchAndRewrite(
auto isLast = (bound == nArgs) ? oneI32 : zeroI32;
arguments.push_back(isLast);
- auto call = rewriter.create<LLVM::CallOp>(loc, ocklAppendArgs, arguments);
+ auto call = LLVM::CallOp::create(rewriter, loc, ocklAppendArgs, arguments);
printfDesc = call.getResult();
}
rewriter.eraseOp(gpuPrintfOp);
@@ -510,13 +513,13 @@ LogicalResult GPUPrintfOpToLLVMCallLowering::matchAndRewrite(
/*alignment=*/0, addressSpace);
// Get a pointer to the format string's first element
- Value globalPtr = rewriter.create<LLVM::AddressOfOp>(
- loc,
+ Value globalPtr = LLVM::AddressOfOp::create(
+ rewriter, loc,
LLVM::LLVMPointerType::get(rewriter.getContext(), global.getAddrSpace()),
global.getSymNameAttr());
Value stringStart =
- rewriter.create<LLVM::GEPOp>(loc, ptrType, global.getGlobalType(),
- globalPtr, ArrayRef<LLVM::GEPArg>{0, 0});
+ LLVM::GEPOp::create(rewriter, loc, ptrType, global.getGlobalType(),
+ globalPtr, ArrayRef<LLVM::GEPArg>{0, 0});
// Construct arguments and function call
auto argsRange = adaptor.getArgs();
@@ -525,7 +528,7 @@ LogicalResult GPUPrintfOpToLLVMCallLowering::matchAndRewrite(
printfArgs.push_back(stringStart);
printfArgs.append(argsRange.begin(), argsRange.end());
- rewriter.create<LLVM::CallOp>(loc, printfDecl, printfArgs);
+ LLVM::CallOp::create(rewriter, loc, printfDecl, printfArgs);
rewriter.eraseOp(gpuPrintfOp);
return success();
}
@@ -559,10 +562,10 @@ LogicalResult GPUPrintfOpToVPrintfLowering::matchAndRewrite(
"printfFormat_", adaptor.getFormat());
// Get a pointer to the format string's first element
- Value globalPtr = rewriter.create<LLVM::AddressOfOp>(loc, global);
+ Value globalPtr = LLVM::AddressOfOp::create(rewriter, loc, global);
Value stringStart =
- rewriter.create<LLVM::GEPOp>(loc, ptrType, global.getGlobalType(),
- globalPtr, ArrayRef<LLVM::GEPArg>{0, 0});
+ LLVM::GEPOp::create(rewriter, loc, ptrType, global.getGlobalType(),
+ globalPtr, ArrayRef<LLVM::GEPArg>{0, 0});
SmallVector<Type> types;
SmallVector<Value> args;
// Promote and pack the arguments into a stack allocation.
@@ -572,27 +575,27 @@ LogicalResult GPUPrintfOpToVPrintfLowering::matchAndRewrite(
assert(type.isIntOrFloat());
if (isa<FloatType>(type)) {
type = rewriter.getF64Type();
- promotedArg = rewriter.create<LLVM::FPExtOp>(loc, type, arg);
+ promotedArg = LLVM::FPExtOp::create(rewriter, loc, type, arg);
}
types.push_back(type);
args.push_back(promotedArg);
}
Type structType =
LLVM::LLVMStructType::getLiteral(gpuPrintfOp.getContext(), types);
- Value one = rewriter.create<LLVM::ConstantOp>(loc, rewriter.getI64Type(),
- rewriter.getIndexAttr(1));
+ Value one = LLVM::ConstantOp::create(rewriter, loc, rewriter.getI64Type(),
+ rewriter.getIndexAttr(1));
Value tempAlloc =
- rewriter.create<LLVM::AllocaOp>(loc, ptrType, structType, one,
- /*alignment=*/0);
+ LLVM::AllocaOp::create(rewriter, loc, ptrType, structType, one,
+ /*alignment=*/0);
for (auto [index, arg] : llvm::enumerate(args)) {
- Value ptr = rewriter.create<LLVM::GEPOp>(
- loc, ptrType, structType, tempAlloc,
+ Value ptr = LLVM::GEPOp::create(
+ rewriter, loc, ptrType, structType, tempAlloc,
ArrayRef<LLVM::GEPArg>{0, static_cast<int32_t>(index)});
- rewriter.create<LLVM::StoreOp>(loc, arg, ptr);
+ LLVM::StoreOp::create(rewriter, loc, arg, ptr);
}
std::array<Value, 2> printfArgs = {stringStart, tempAlloc};
- rewriter.create<LLVM::CallOp>(loc, vprintfDecl, printfArgs);
+ LLVM::CallOp::create(rewriter, loc, vprintfDecl, printfArgs);
rewriter.eraseOp(gpuPrintfOp);
return success();
}
@@ -607,23 +610,23 @@ static Value scalarizeVectorOpHelper(Operation *op, ValueRange operands,
TypeRange operandTypes(operands);
VectorType vectorType = cast<VectorType>(llvm1DVectorTy);
Location loc = op->getLoc();
- Value result = rewriter.create<LLVM::PoisonOp>(loc, vectorType);
+ Value result = LLVM::PoisonOp::create(rewriter, loc, vectorType);
Type indexType = converter.convertType(rewriter.getIndexType());
StringAttr name = op->getName().getIdentifier();
Type elementType = vectorType.getElementType();
for (int64_t i = 0; i < vectorType.getNumElements(); ++i) {
- Value index = rewriter.create<LLVM::ConstantOp>(loc, indexType, i);
+ Value index = LLVM::ConstantOp::create(rewriter, loc, indexType, i);
auto extractElement = [&](Value operand) -> Value {
if (!isa<VectorType>(operand.getType()))
return operand;
- return rewriter.create<LLVM::ExtractElementOp>(loc, operand, index);
+ return LLVM::ExtractElementOp::create(rewriter, loc, operand, index);
};
auto scalarOperands = llvm::map_to_vector(operands, extractElement);
Operation *scalarOp =
rewriter.create(loc, name, scalarOperands, elementType, op->getAttrs());
- result = rewriter.create<LLVM::InsertElementOp>(
- loc, result, scalarOp->getResult(0), index);
+ result = LLVM::InsertElementOp::create(rewriter, loc, result,
+ scalarOp->getResult(0), index);
}
return result;
}
@@ -705,10 +708,10 @@ LLVM::GlobalOp getDynamicSharedMemorySymbol(
auto zeroSizedArrayType = LLVM::LLVMArrayType::get(
typeConverter->convertType(memrefType.getElementType()), 0);
- return rewriter.create<LLVM::GlobalOp>(
- op->getLoc(), zeroSizedArrayType, /*isConstant=*/false,
- LLVM::Linkage::Internal, symName, /*value=*/Attribute(), alignmentByte,
- addressSpace.value());
+ return LLVM::GlobalOp::create(rewriter, op->getLoc(), zeroSizedArrayType,
+ /*isConstant=*/false, LLVM::Linkage::Internal,
+ symName, /*value=*/Attribute(), alignmentByte,
+ addressSpace.value());
}
LogicalResult GPUDynamicSharedMemoryOpLowering::matchAndRewrite(
@@ -732,13 +735,13 @@ LogicalResult GPUDynamicSharedMemoryOpLowering::matchAndRewrite(
// Step 3. Get address of the global symbol
OpBuilder::InsertionGuard guard(rewriter);
rewriter.setInsertionPoint(op);
- auto basePtr = rewriter.create<LLVM::AddressOfOp>(loc, shmemOp);
+ auto basePtr = LLVM::AddressOfOp::create(rewriter, loc, shmemOp);
Type baseType = basePtr->getResultTypes().front();
// Step 4. Generate GEP using offsets
SmallVector<LLVM::GEPArg> gepArgs = {0};
- Value shmemPtr = rewriter.create<LLVM::GEPOp>(loc, baseType, elementType,
- basePtr, gepArgs);
+ Value shmemPtr = LLVM::GEPOp::create(rewriter, loc, baseType, elementType,
+ basePtr, gepArgs);
// Step 5. Create a memref descriptor
SmallVector<Value> shape, strides;
Value sizeBytes;
@@ -799,9 +802,9 @@ LogicalResult GPUReturnOpLowering::matchAndRewrite(
return rewriter.notifyMatchFailure(op, "could not convert result types");
}
- Value packed = rewriter.create<LLVM::PoisonOp>(loc, packedType);
+ Value packed = LLVM::PoisonOp::create(rewriter, loc, packedType);
for (auto [idx, operand] : llvm::enumerate(updatedOperands)) {
- packed = rewriter.create<LLVM::InsertValueOp>(loc, packed, operand, idx);
+ packed = LLVM::InsertValueOp::create(rewriter, loc, packed, operand, idx);
}
rewriter.replaceOpWithNewOp<LLVM::ReturnOp>(op, TypeRange(), packed,
op->getAttrs());
diff --git a/mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp b/mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp
index 167cabbc57db9..63eb6c58e87a7 100644
--- a/mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp
+++ b/mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp
@@ -79,8 +79,8 @@ class ConvertOpToGpuRuntimeCallPattern : public ConvertOpToLLVMPattern<OpTy> {
uint64_t rank = type.getRank();
Value numElements = desc.size(rewriter, loc, /*pos=*/0);
for (unsigned i = 1; i < rank; i++)
- numElements = rewriter.create<LLVM::MulOp>(
- loc, numElements, desc.size(rewriter, loc, /*pos=*/i));
+ numElements = LLVM::MulOp::create(rewriter, loc, numElements,
+ desc.size(rewriter, loc, /*pos=*/i));
return numElements;
}
@@ -582,7 +582,7 @@ LLVM::CallOp FunctionCallBuilder::create(Location loc, OpBuilder &builder,
return OpBuilder::atBlockEnd(module.getBody())
.create<LLVM::LLVMFuncOp>(loc, functionName, functionType);
}();
- return builder.create<LLVM::CallOp>(loc, function, arguments);
+ return LLVM::CallOp::create(builder, loc, function, arguments);
}
// Corresponding to cusparseIndexType_t defined in cusparse.h.
@@ -780,13 +780,13 @@ LogicalResult ConvertAllocOpToGpuRuntimeCallPattern::matchAndRewrite(
// Allocate the underlying buffer and store a pointer to it in the MemRef
// descriptor.
- auto nullPtr = rewriter.create<mlir::LLVM::ZeroOp>(loc, llvmPointerType);
+ auto nullPtr = mlir::LLVM::ZeroOp::create(rewriter, loc, llvmPointerType);
Value stream = adaptor.getAsyncDependencies().empty()
? nullPtr
: adaptor.getAsyncDependencies().front();
- auto isHostShared = rewriter.create<mlir::LLVM::ConstantOp>(
- loc, llvmInt8Type, rewriter.getI8IntegerAttr(isShared));
+ auto isHostShared = mlir::LLVM::ConstantOp::create(
+ rewriter, loc, llvmInt8Type, rewriter.getI8IntegerAttr(isShared));
Value allocatedPtr =
allocCallBuilder.create(loc, rewriter, {sizeBytes, stream, isHostShared})
@@ -1012,8 +1012,8 @@ LogicalResult LegalizeLaunchFuncOpPattern::matchAndRewrite(
uint64_t staticSize = static_cast<uint64_t>(bitwidth / 8) *
static_cast<uint64_t>(memrefTy.getNumElements());
- Value sizeArg = rewriter.create<LLVM::ConstantOp>(
- loc, getIndexType(), rewriter.getIndexAttr(staticSize));
+ Value sizeArg = LLVM::ConstantOp::create(
+ rewriter, loc, getIndexType(), rewriter.getIndexAttr(staticSize));
llvmArgumentsWithSizes.push_back(llvmArg); // Presumably a bare pointer.
llvmArgumentsWithSizes.push_back(sizeArg);
}
@@ -1025,8 +1025,8 @@ LogicalResult LegalizeLaunchFuncOpPattern::matchAndRewrite(
gpu::KernelDim3{adaptor.getClusterSizeX(), adaptor.getClusterSizeY(),
adaptor.getClusterSizeZ()};
}
- rewriter.create<gpu::LaunchFuncOp>(
- launchOp.getLoc(), launchOp.getKernelAttr(),
+ gpu::LaunchFuncOp::create(
+ rewriter, launchOp.getLoc(), launchOp.getKernelAttr(),
gpu::KernelDim3{adaptor.getGridSizeX(), adaptor.getGridSizeY(),
adaptor.getGridSizeZ()},
gpu::KernelDim3{adaptor.getBlockSizeX(), adaptor.getBlockSizeY(),
@@ -1048,8 +1048,8 @@ static Value bitAndAddrspaceCast(Location loc,
const LLVMTypeConverter &typeConverter) {
auto sourceTy = cast<LLVM::LLVMPointerType>(sourcePtr.getType());
if (destinationType.getAddressSpace() != sourceTy.getAddressSpace())
- sourcePtr = rewriter.create<LLVM::AddrSpaceCastOp>(
- loc,
+ sourcePtr = LLVM::AddrSpaceCastOp::create(
+ rewriter, loc,
LLVM::LLVMPointerType::get(rewriter.getContext(),
destinationType.getAddressSpace()),
sourcePtr);
@@ -1072,13 +1072,13 @@ LogicalResult ConvertMemcpyOpToGpuRuntimeCallPattern::matchAndRewrite(
Value numElements = getNumElements(rewriter, loc, memRefType, srcDesc);
Type elementPtrType = getElementPtrType(memRefType);
- Value nullPtr = rewriter.create<LLVM::ZeroOp>(loc, elementPtrType);
- Value gepPtr = rewriter.create<LLVM::GEPOp>(
- loc, elementPtrType,
+ Value nullPtr = LLVM::ZeroOp::create(rewriter, loc, elementPtrType);
+ Value gepPtr = LLVM::GEPOp::create(
+ rewriter, loc, elementPtrType,
typeConverter->convertType(memRefType.getElementType()), nullPtr,
numElements);
auto sizeBytes =
- rewriter.create<LLVM::PtrToIntOp>(loc, getIndexType(), gepPtr);
+ LLVM::PtrToIntOp::create(rewriter, loc, getIndexType(), gepPtr);
auto src = bitAndAddrspaceCast(loc, rewriter, llvmPointerType,
srcDesc.alignedPtr(rewriter, loc),
@@ -1123,7 +1123,7 @@ LogicalResult ConvertMemsetOpToGpuRuntimeCallPattern::matchAndRewrite(
Value numElements = getNumElements(rewriter, loc, memRefType, dstDesc);
auto value =
- rewriter.create<LLVM::BitcastOp>(loc, bitCastType, adaptor.getValue());
+ LLVM::BitcastOp::create(rewriter, loc, bitCastType, adaptor.getValue());
auto dst = bitAndAddrspaceCast(loc, rewriter, llvmPointerType,
dstDesc.alignedPtr(rewriter, loc),
*getTypeConverter());
@@ -1150,15 +1150,15 @@ LogicalResult ConvertSetDefaultDeviceOpToGpuRuntimeCallPattern::matchAndRewrite(
template <typename T>
static Value genConstInt32From(OpBuilder &builder, Location loc, T tValue) {
Type llvmInt32Type = builder.getIntegerType(32);
- return builder.create<LLVM::ConstantOp>(loc, llvmInt32Type,
- static_cast<int32_t>(tValue));
+ return LLVM::ConstantOp::create(builder, loc, llvmInt32Type,
+ static_cast<int32_t>(tValue));
}
template <typename T>
static Value genConstFloat32From(OpBuilder &builder, Location loc, T tValue) {
Type llvmFloat32Type = builder.getF32Type();
- return builder.create<LLVM::ConstantOp>(
- loc, llvmFloat32Type,
+ return LLVM::ConstantOp::create(
+ builder, loc, llvmFloat32Type,
builder.getF32FloatAttr(static_cast<float>(tValue)));
}
@@ -1189,11 +1189,11 @@ LogicalResult ConvertCreateDnTensorOpToGpuRuntimeCallPattern::matchAndRewrite(
// the dnmat is used with spmat with 2:4 sparsity
if (dims.size() == 2) {
if (isSpMMCusparseLtOp(op.getDnTensor())) {
- auto handleSz = rewriter.create<LLVM::ConstantOp>(
- loc, getIndexType(), rewriter.getIndexAttr(11032));
- handle = rewriter.create<LLVM::AllocaOp>(
- loc, llvmPointerType, llvmInt8Type, handleSz, /*alignment=*/16);
- handle = rewriter.create<LLVM::BitcastOp>(loc, llvmPointerType, handle);
+ auto handleSz = LLVM::ConstantOp::create(rewriter, loc, getIndexType(),
+ rewriter.getIndexAttr(11032));
+ handle = LLVM::AllocaOp::create(rewriter, loc, llvmPointerType,
+ llvmInt8Type, handleSz, /*alignment=*/16);
+ handle = LLVM::BitcastOp::create(rewriter, loc, llvmPointerType, handle);
createLtDnMatCallBuilder
.create(loc, rewriter,
@@ -1351,11 +1351,11 @@ LogicalResult ConvertCreate2To4SpMatOpToGpuRuntimeCallPattern::matchAndRewrite(
auto dtp = genConstInt32From(rewriter, loc, getCuSparseDataTypeFrom(dType));
// CUDA runner asserts the size is 44104 bytes.
- auto handleSz = rewriter.create<LLVM::ConstantOp>(
- loc, getIndexType(), rewriter.getIndexAttr(44104));
- Value handle = rewriter.create<LLVM::AllocaOp>(
- loc, llvmPointerType, llvmInt8Type, handleSz, /*alignment=*/16);
- handle = rewriter.create<LLVM::BitcastOp>(loc, llvmPointerType, handle);
+ auto handleSz = LLVM::ConstantOp::create(rewriter, loc, getIndexType(),
+ rewriter.getIndexAttr(44104));
+ Value handle = LLVM::AllocaOp::create(
+ rewriter, loc, llvmPointerType, llvmInt8Type, handleSz, /*alignment=*/16);
+ handle = LLVM::BitcastOp::create(rewriter, loc, llvmPointerType, handle);
create2To4SpMatCallBuilder
.create(loc, rewriter,
@@ -1441,10 +1441,11 @@ LogicalResult ConvertSpMMBufferSizeOpToGpuRuntimeCallPattern::matchAndRewrite(
genConstInt32From(rewriter, loc, get2To4PruneFlag(op.getSpmatA()));
auto computeType = genConstInt32From(
rewriter, loc, getCuSparseLtDataTypeFrom(adaptor.getComputeType()));
- auto three = rewriter.create<LLVM::ConstantOp>(loc, getIndexType(),
- rewriter.getIndexAttr(3));
- auto bufferSize = rewriter.create<LLVM::AllocaOp>(
- loc, llvmPointerType, llvmPointerType, three, /*alignment=*/16);
+ auto three = LLVM::ConstantOp::create(rewriter, loc, getIndexType(),
+ rewriter.getIndexAttr(3));
+ auto bufferSize =
+ LLVM::AllocaOp::create(rewriter, loc, llvmPointerType, llvmPointerType,
+ three, /*alignment=*/16);
createCuSparseLtSpMMBufferSizeBuilder
.create(loc, rewriter,
{bufferSize, modeA, modeB, adaptor.getSpmatA(),
@@ -1452,20 +1453,20 @@ LogicalResult ConvertSpMMBufferSizeOpToGpuRuntimeCallPattern::matchAndRewrite(
pruneFlag, stream})
.getResult();
- auto bufferSizePtr1 = rewriter.create<LLVM::GEPOp>(
- loc, llvmPointerType, llvmPointerType, bufferSize,
- ValueRange{rewriter.create<LLVM::ConstantOp>(
- loc, getIndexType(), rewriter.getIndexAttr(1))});
- auto bufferSizePtr2 = rewriter.create<LLVM::GEPOp>(
- loc, llvmPointerType, llvmPointerType, bufferSize,
- ValueRange{rewriter.create<LLVM::ConstantOp>(
- loc, getIndexType(), rewriter.getIndexAttr(2))});
+ auto bufferSizePtr1 = LLVM::GEPOp::create(
+ rewriter, loc, llvmPointerType, llvmPointerType, bufferSize,
+ ValueRange{LLVM::ConstantOp::create(rewriter, loc, getIndexType(),
+ rewriter.getIndexAttr(1))});
+ auto bufferSizePtr2 = LLVM::GEPOp::create(
+ rewriter, loc, llvmPointerType, llvmPointerType, bufferSize,
+ ValueRange{LLVM::ConstantOp::create(rewriter, loc, getIndexType(),
+ rewriter.getIndexAttr(2))});
auto bufferSize0 =
- rewriter.create<LLVM::LoadOp>(loc, llvmInt64Type, bufferSize);
+ LLVM::LoadOp::create(rewriter, loc, llvmInt64Type, bufferSize);
auto bufferSize1 =
- rewriter.create<LLVM::LoadOp>(loc, llvmInt64Type, bufferSizePtr1);
+ LLVM::LoadOp::create(rewriter, loc, llvmInt64Type, bufferSizePtr1);
auto bufferSize2 =
- rewriter.create<LLVM::LoadOp>(loc, llvmInt64Type, bufferSizePtr2);
+ LLVM::LoadOp::create(rewriter, loc, llvmInt64Type, bufferSizePtr2);
rewriter.replaceOp(op, {bufferSize0, bufferSize1, bufferSize2, stream});
} else {
@@ -1669,28 +1670,28 @@ LogicalResult ConvertSpMatGetSizeOpToGpuRuntimeCallPattern::matchAndRewrite(
Location loc = op.getLoc();
auto stream = adaptor.getAsyncDependencies().front();
- auto three = rewriter.create<LLVM::ConstantOp>(loc, getIndexType(),
- rewriter.getIndexAttr(3));
- auto buffer = rewriter.create<LLVM::AllocaOp>(
- loc, llvmPointerType, llvmInt64Type, three, /*alignment=*/16);
-
- auto rowsPtr = rewriter.create<LLVM::GEPOp>(
- loc, llvmPointerType, llvmPointerType, buffer,
- ValueRange{rewriter.create<LLVM::ConstantOp>(loc, getIndexType(),
- rewriter.getIndexAttr(0))});
- auto colsPtr = rewriter.create<LLVM::GEPOp>(
- loc, llvmPointerType, llvmPointerType, buffer,
- ValueRange{rewriter.create<LLVM::ConstantOp>(loc, getIndexType(),
- rewriter.getIndexAttr(1))});
- auto nnzsPtr = rewriter.create<LLVM::GEPOp>(
- loc, llvmPointerType, llvmPointerType, buffer,
- ValueRange{rewriter.create<LLVM::ConstantOp>(loc, getIndexType(),
- rewriter.getIndexAttr(2))});
+ auto three = LLVM::ConstantOp::create(rewriter, loc, getIndexType(),
+ rewriter.getIndexAttr(3));
+ auto buffer = LLVM::AllocaOp::create(rewriter, loc, llvmPointerType,
+ llvmInt64Type, three, /*alignment=*/16);
+
+ auto rowsPtr = LLVM::GEPOp::create(
+ rewriter, loc, llvmPointerType, llvmPointerType, buffer,
+ ValueRange{LLVM::ConstantOp::create(rewriter, loc, getIndexType(),
+ rewriter.getIndexAttr(0))});
+ auto colsPtr = LLVM::GEPOp::create(
+ rewriter, loc, llvmPointerType, llvmPointerType, buffer,
+ ValueRange{LLVM::ConstantOp::create(rewriter, loc, getIndexType(),
+ rewriter.getIndexAttr(1))});
+ auto nnzsPtr = LLVM::GEPOp::create(
+ rewriter, loc, llvmPointerType, llvmPointerType, buffer,
+ ValueRange{LLVM::ConstantOp::create(rewriter, loc, getIndexType(),
+ rewriter.getIndexAttr(2))});
createSpMatGetSizeBuilder.create(
loc, rewriter, {adaptor.getSpmat(), rowsPtr, colsPtr, nnzsPtr, stream});
- auto rows = rewriter.create<LLVM::LoadOp>(loc, llvmInt64Type, rowsPtr);
- auto cols = rewriter.create<LLVM::LoadOp>(loc, llvmInt64Type, colsPtr);
- auto nnzs = rewriter.create<LLVM::LoadOp>(loc, llvmInt64Type, nnzsPtr);
+ auto rows = LLVM::LoadOp::create(rewriter, loc, llvmInt64Type, rowsPtr);
+ auto cols = LLVM::LoadOp::create(rewriter, loc, llvmInt64Type, colsPtr);
+ auto nnzs = LLVM::LoadOp::create(rewriter, loc, llvmInt64Type, nnzsPtr);
rewriter.replaceOp(op, {rows, cols, nnzs, stream});
return success();
diff --git a/mlir/lib/Conversion/GPUCommon/IndexIntrinsicsOpLowering.h b/mlir/lib/Conversion/GPUCommon/IndexIntrinsicsOpLowering.h
index aab2409ed6328..91c43e8bd1117 100644
--- a/mlir/lib/Conversion/GPUCommon/IndexIntrinsicsOpLowering.h
+++ b/mlir/lib/Conversion/GPUCommon/IndexIntrinsicsOpLowering.h
@@ -59,13 +59,13 @@ struct OpLowering : public ConvertOpToLLVMPattern<Op> {
Operation *newOp;
switch (op.getDimension()) {
case gpu::Dimension::x:
- newOp = rewriter.create<XOp>(loc, IntegerType::get(context, 32));
+ newOp = XOp::create(rewriter, loc, IntegerType::get(context, 32));
break;
case gpu::Dimension::y:
- newOp = rewriter.create<YOp>(loc, IntegerType::get(context, 32));
+ newOp = YOp::create(rewriter, loc, IntegerType::get(context, 32));
break;
case gpu::Dimension::z:
- newOp = rewriter.create<ZOp>(loc, IntegerType::get(context, 32));
+ newOp = ZOp::create(rewriter, loc, IntegerType::get(context, 32));
break;
}
@@ -124,11 +124,13 @@ struct OpLowering : public ConvertOpToLLVMPattern<Op> {
rewriter.getContext(), 32, min, max));
}
if (indexBitwidth > 32) {
- newOp = rewriter.create<LLVM::SExtOp>(
- loc, IntegerType::get(context, indexBitwidth), newOp->getResult(0));
+ newOp = LLVM::SExtOp::create(rewriter, loc,
+ IntegerType::get(context, indexBitwidth),
+ newOp->getResult(0));
} else if (indexBitwidth < 32) {
- newOp = rewriter.create<LLVM::TruncOp>(
- loc, IntegerType::get(context, indexBitwidth), newOp->getResult(0));
+ newOp = LLVM::TruncOp::create(rewriter, loc,
+ IntegerType::get(context, indexBitwidth),
+ newOp->getResult(0));
}
rewriter.replaceOp(op, newOp->getResults());
diff --git a/mlir/lib/Conversion/GPUCommon/OpToFuncCallLowering.h b/mlir/lib/Conversion/GPUCommon/OpToFuncCallLowering.h
index 64cf09e600b88..9f36e5c369d06 100644
--- a/mlir/lib/Conversion/GPUCommon/OpToFuncCallLowering.h
+++ b/mlir/lib/Conversion/GPUCommon/OpToFuncCallLowering.h
@@ -103,7 +103,7 @@ struct OpToFuncCallLowering : public ConvertOpToLLVMPattern<SourceOp> {
LLVMFuncOp funcOp = appendOrGetFuncOp(funcName, funcType, op);
auto callOp =
- rewriter.create<LLVM::CallOp>(op->getLoc(), funcOp, castedOperands);
+ LLVM::CallOp::create(rewriter, op->getLoc(), funcOp, castedOperands);
if (resultType == adaptor.getOperands().front().getType()) {
rewriter.replaceOp(op, {callOp.getResult()});
@@ -115,19 +115,20 @@ struct OpToFuncCallLowering : public ConvertOpToLLVMPattern<SourceOp> {
// there is no guarantee of a specific value being used to indicate true,
// compare for inequality with zero (rather than truncate or shift).
if (isResultBool) {
- Value zero = rewriter.create<LLVM::ConstantOp>(
- op->getLoc(), rewriter.getIntegerType(32),
- rewriter.getI32IntegerAttr(0));
- Value truncated = rewriter.create<LLVM::ICmpOp>(
- op->getLoc(), LLVM::ICmpPredicate::ne, callOp.getResult(), zero);
+ Value zero = LLVM::ConstantOp::create(rewriter, op->getLoc(),
+ rewriter.getIntegerType(32),
+ rewriter.getI32IntegerAttr(0));
+ Value truncated =
+ LLVM::ICmpOp::create(rewriter, op->getLoc(), LLVM::ICmpPredicate::ne,
+ callOp.getResult(), zero);
rewriter.replaceOp(op, {truncated});
return success();
}
assert(callOp.getResult().getType().isF32() &&
"only f32 types are supposed to be truncated back");
- Value truncated = rewriter.create<LLVM::FPTruncOp>(
- op->getLoc(), adaptor.getOperands().front().getType(),
+ Value truncated = LLVM::FPTruncOp::create(
+ rewriter, op->getLoc(), adaptor.getOperands().front().getType(),
callOp.getResult());
rewriter.replaceOp(op, {truncated});
return success();
@@ -142,8 +143,9 @@ struct OpToFuncCallLowering : public ConvertOpToLLVMPattern<SourceOp> {
if (!f16Func.empty() && isa<Float16Type>(type))
return operand;
- return rewriter.create<LLVM::FPExtOp>(
- operand.getLoc(), Float32Type::get(rewriter.getContext()), operand);
+ return LLVM::FPExtOp::create(rewriter, operand.getLoc(),
+ Float32Type::get(rewriter.getContext()),
+ operand);
}
Type getFunctionType(Type resultType, ValueRange operands) const {
@@ -169,7 +171,7 @@ struct OpToFuncCallLowering : public ConvertOpToLLVMPattern<SourceOp> {
// location as debug info metadata inside of a function cannot be used
// outside of that function.
auto globalloc = op->getLoc()->findInstanceOfOrUnknown<FileLineColLoc>();
- return b.create<LLVMFuncOp>(globalloc, funcName, funcType);
+ return LLVMFuncOp::create(b, globalloc, funcName, funcType);
}
StringRef getFunctionName(Type type, SourceOp op) const {
diff --git a/mlir/lib/Conversion/GPUToLLVMSPV/GPUToLLVMSPV.cpp b/mlir/lib/Conversion/GPUToLLVMSPV/GPUToLLVMSPV.cpp
index 8b6b553f6eed0..c2363a1a40294 100644
--- a/mlir/lib/Conversion/GPUToLLVMSPV/GPUToLLVMSPV.cpp
+++ b/mlir/lib/Conversion/GPUToLLVMSPV/GPUToLLVMSPV.cpp
@@ -54,8 +54,8 @@ static LLVM::LLVMFuncOp lookupOrCreateSPIRVFn(Operation *symbolTable,
SymbolTable::lookupSymbolIn(symbolTable, name));
if (!func) {
OpBuilder b(symbolTable->getRegion(0));
- func = b.create<LLVM::LLVMFuncOp>(
- symbolTable->getLoc(), name,
+ func = LLVM::LLVMFuncOp::create(
+ b, symbolTable->getLoc(), name,
LLVM::LLVMFunctionType::get(resultType, paramTypes));
func.setCConv(LLVM::cconv::CConv::SPIR_FUNC);
func.setNoUnwind(true);
@@ -79,7 +79,7 @@ static LLVM::CallOp createSPIRVBuiltinCall(Location loc,
ConversionPatternRewriter &rewriter,
LLVM::LLVMFuncOp func,
ValueRange args) {
- auto call = rewriter.create<LLVM::CallOp>(loc, func, args);
+ auto call = LLVM::CallOp::create(rewriter, loc, func, args);
call.setCConv(func.getCConv());
call.setConvergentAttr(func.getConvergentAttr());
call.setNoUnwindAttr(func.getNoUnwindAttr());
@@ -121,7 +121,7 @@ struct GPUBarrierConversion final : ConvertOpToLLVMPattern<gpu::BarrierOp> {
constexpr int64_t localMemFenceFlag = 1;
Location loc = op->getLoc();
Value flag =
- rewriter.create<LLVM::ConstantOp>(loc, flagTy, localMemFenceFlag);
+ LLVM::ConstantOp::create(rewriter, loc, flagTy, localMemFenceFlag);
rewriter.replaceOp(op, createSPIRVBuiltinCall(loc, rewriter, func, flag));
return success();
}
@@ -162,8 +162,8 @@ struct LaunchConfigConversion : ConvertToLLVMPattern {
Location loc = op->getLoc();
gpu::Dimension dim = getDimension(op);
- Value dimVal = rewriter.create<LLVM::ConstantOp>(loc, dimTy,
- static_cast<int64_t>(dim));
+ Value dimVal = LLVM::ConstantOp::create(rewriter, loc, dimTy,
+ static_cast<int64_t>(dim));
rewriter.replaceOp(op, createSPIRVBuiltinCall(loc, rewriter, func, dimVal));
return success();
}
@@ -291,13 +291,13 @@ struct GPUShuffleConversion final : ConvertOpToLLVMPattern<gpu::ShuffleOp> {
ConversionPatternRewriter &rewriter) {
return TypeSwitch<Type, Value>(oldVal.getType())
.Case([&](BFloat16Type) {
- return rewriter.create<LLVM::BitcastOp>(loc, rewriter.getI16Type(),
- oldVal);
+ return LLVM::BitcastOp::create(rewriter, loc, rewriter.getI16Type(),
+ oldVal);
})
.Case([&](IntegerType intTy) -> Value {
if (intTy.getWidth() == 1)
- return rewriter.create<LLVM::ZExtOp>(loc, rewriter.getI8Type(),
- oldVal);
+ return LLVM::ZExtOp::create(rewriter, loc, rewriter.getI8Type(),
+ oldVal);
return oldVal;
})
.Default(oldVal);
@@ -308,11 +308,11 @@ struct GPUShuffleConversion final : ConvertOpToLLVMPattern<gpu::ShuffleOp> {
ConversionPatternRewriter &rewriter) {
return TypeSwitch<Type, Value>(newTy)
.Case([&](BFloat16Type) {
- return rewriter.create<LLVM::BitcastOp>(loc, newTy, oldVal);
+ return LLVM::BitcastOp::create(rewriter, loc, newTy, oldVal);
})
.Case([&](IntegerType intTy) -> Value {
if (intTy.getWidth() == 1)
- return rewriter.create<LLVM::TruncOp>(loc, newTy, oldVal);
+ return LLVM::TruncOp::create(rewriter, loc, newTy, oldVal);
return oldVal;
})
.Default(oldVal);
@@ -349,7 +349,7 @@ struct GPUShuffleConversion final : ConvertOpToLLVMPattern<gpu::ShuffleOp> {
bitcastOrTruncAfterShuffle(result, op.getType(0), loc, rewriter);
Value trueVal =
- rewriter.create<LLVM::ConstantOp>(loc, rewriter.getI1Type(), true);
+ LLVM::ConstantOp::create(rewriter, loc, rewriter.getI1Type(), true);
rewriter.replaceOp(op, {resultOrConversion, trueVal});
return success();
}
@@ -426,7 +426,7 @@ struct GPUSubgroupOpConversion final : ConvertOpToLLVMPattern<SubgroupOp> {
if (indexTy.getIntOrFloatBitWidth() < resultTy.getIntOrFloatBitWidth()) {
return failure();
}
- result = rewriter.create<LLVM::ZExtOp>(loc, indexTy, result);
+ result = LLVM::ZExtOp::create(rewriter, loc, indexTy, result);
}
rewriter.replaceOp(op, result);
diff --git a/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp b/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp
index 1ef6edea93c58..317bfc2970cf5 100644
--- a/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp
+++ b/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp
@@ -118,10 +118,10 @@ struct GPUSubgroupReduceOpLowering
Location loc = op->getLoc();
auto int32Type = IntegerType::get(rewriter.getContext(), 32);
- Value offset = rewriter.create<LLVM::ConstantOp>(loc, int32Type, -1);
+ Value offset = LLVM::ConstantOp::create(rewriter, loc, int32Type, -1);
- auto reduxOp = rewriter.create<NVVM::ReduxOp>(loc, int32Type, op.getValue(),
- mode.value(), offset);
+ auto reduxOp = NVVM::ReduxOp::create(rewriter, loc, int32Type,
+ op.getValue(), mode.value(), offset);
rewriter.replaceOp(op, reduxOp->getResult(0));
return success();
@@ -158,22 +158,22 @@ struct GPUShuffleOpLowering : public ConvertOpToLLVMPattern<gpu::ShuffleOp> {
auto int32Type = IntegerType::get(rewriter.getContext(), 32);
auto predTy = IntegerType::get(rewriter.getContext(), 1);
- Value one = rewriter.create<LLVM::ConstantOp>(loc, int32Type, 1);
- Value minusOne = rewriter.create<LLVM::ConstantOp>(loc, int32Type, -1);
- Value thirtyTwo = rewriter.create<LLVM::ConstantOp>(loc, int32Type, 32);
- Value numLeadInactiveLane = rewriter.create<LLVM::SubOp>(
- loc, int32Type, thirtyTwo, adaptor.getWidth());
+ Value one = LLVM::ConstantOp::create(rewriter, loc, int32Type, 1);
+ Value minusOne = LLVM::ConstantOp::create(rewriter, loc, int32Type, -1);
+ Value thirtyTwo = LLVM::ConstantOp::create(rewriter, loc, int32Type, 32);
+ Value numLeadInactiveLane = LLVM::SubOp::create(
+ rewriter, loc, int32Type, thirtyTwo, adaptor.getWidth());
// Bit mask of active lanes: `(-1) >> (32 - activeWidth)`.
- Value activeMask = rewriter.create<LLVM::LShrOp>(loc, int32Type, minusOne,
- numLeadInactiveLane);
+ Value activeMask = LLVM::LShrOp::create(rewriter, loc, int32Type, minusOne,
+ numLeadInactiveLane);
Value maskAndClamp;
if (op.getMode() == gpu::ShuffleMode::UP) {
// Clamp lane: `32 - activeWidth`
maskAndClamp = numLeadInactiveLane;
} else {
// Clamp lane: `activeWidth - 1`
- maskAndClamp =
- rewriter.create<LLVM::SubOp>(loc, int32Type, adaptor.getWidth(), one);
+ maskAndClamp = LLVM::SubOp::create(rewriter, loc, int32Type,
+ adaptor.getWidth(), one);
}
bool predIsUsed = !op->getResult(1).use_empty();
@@ -184,13 +184,14 @@ struct GPUShuffleOpLowering : public ConvertOpToLLVMPattern<gpu::ShuffleOp> {
resultTy = LLVM::LLVMStructType::getLiteral(rewriter.getContext(),
{valueTy, predTy});
}
- Value shfl = rewriter.create<NVVM::ShflOp>(
- loc, resultTy, activeMask, adaptor.getValue(), adaptor.getOffset(),
- maskAndClamp, convertShflKind(op.getMode()), returnValueAndIsValidAttr);
+ Value shfl = NVVM::ShflOp::create(
+ rewriter, loc, resultTy, activeMask, adaptor.getValue(),
+ adaptor.getOffset(), maskAndClamp, convertShflKind(op.getMode()),
+ returnValueAndIsValidAttr);
if (predIsUsed) {
- Value shflValue = rewriter.create<LLVM::ExtractValueOp>(loc, shfl, 0);
+ Value shflValue = LLVM::ExtractValueOp::create(rewriter, loc, shfl, 0);
Value isActiveSrcLane =
- rewriter.create<LLVM::ExtractValueOp>(loc, shfl, 1);
+ LLVM::ExtractValueOp::create(rewriter, loc, shfl, 1);
rewriter.replaceOp(op, {shflValue, isActiveSrcLane});
} else {
rewriter.replaceOp(op, {shfl, nullptr});
@@ -215,16 +216,16 @@ struct GPULaneIdOpToNVVM : ConvertOpToLLVMPattern<gpu::LaneIdOp> {
bounds = rewriter.getAttr<LLVM::ConstantRangeAttr>(
/*bitWidth=*/32, /*lower=*/0, /*upper=*/kWarpSize);
Value newOp =
- rewriter.create<NVVM::LaneIdOp>(loc, rewriter.getI32Type(), bounds);
+ NVVM::LaneIdOp::create(rewriter, loc, rewriter.getI32Type(), bounds);
// Truncate or extend the result depending on the index bitwidth specified
// by the LLVMTypeConverter options.
const unsigned indexBitwidth = getTypeConverter()->getIndexTypeBitwidth();
if (indexBitwidth > 32) {
- newOp = rewriter.create<LLVM::SExtOp>(
- loc, IntegerType::get(context, indexBitwidth), newOp);
+ newOp = LLVM::SExtOp::create(
+ rewriter, loc, IntegerType::get(context, indexBitwidth), newOp);
} else if (indexBitwidth < 32) {
- newOp = rewriter.create<LLVM::TruncOp>(
- loc, IntegerType::get(context, indexBitwidth), newOp);
+ newOp = LLVM::TruncOp::create(
+ rewriter, loc, IntegerType::get(context, indexBitwidth), newOp);
}
rewriter.replaceOp(op, {newOp});
return success();
@@ -271,10 +272,10 @@ struct AssertOpToAssertfailLowering
Block *afterBlock =
rewriter.splitBlock(assertBlock, ++assertOp->getIterator());
rewriter.setInsertionPointToEnd(beforeBlock);
- rewriter.create<cf::CondBranchOp>(loc, adaptor.getArg(), afterBlock,
- assertBlock);
+ cf::CondBranchOp::create(rewriter, loc, adaptor.getArg(), afterBlock,
+ assertBlock);
rewriter.setInsertionPointToEnd(assertBlock);
- rewriter.create<cf::BranchOp>(loc, afterBlock);
+ cf::BranchOp::create(rewriter, loc, afterBlock);
// Continue cf.assert lowering.
rewriter.setInsertionPoint(assertOp);
@@ -301,12 +302,12 @@ struct AssertOpToAssertfailLowering
// Create constants.
auto getGlobal = [&](LLVM::GlobalOp global) {
// Get a pointer to the format string's first element.
- Value globalPtr = rewriter.create<LLVM::AddressOfOp>(
- loc, LLVM::LLVMPointerType::get(ctx, global.getAddrSpace()),
+ Value globalPtr = LLVM::AddressOfOp::create(
+ rewriter, loc, LLVM::LLVMPointerType::get(ctx, global.getAddrSpace()),
global.getSymNameAttr());
Value start =
- rewriter.create<LLVM::GEPOp>(loc, ptrType, global.getGlobalType(),
- globalPtr, ArrayRef<LLVM::GEPArg>{0, 0});
+ LLVM::GEPOp::create(rewriter, loc, ptrType, global.getGlobalType(),
+ globalPtr, ArrayRef<LLVM::GEPArg>{0, 0});
return start;
};
Value assertMessage = getGlobal(getOrCreateStringConstant(
@@ -316,8 +317,8 @@ struct AssertOpToAssertfailLowering
Value assertFunc = getGlobal(getOrCreateStringConstant(
rewriter, loc, moduleOp, i8Type, "assert_func_", funcName));
Value assertLine =
- rewriter.create<LLVM::ConstantOp>(loc, i32Type, fileLine);
- Value c1 = rewriter.create<LLVM::ConstantOp>(loc, i64Type, 1);
+ LLVM::ConstantOp::create(rewriter, loc, i32Type, fileLine);
+ Value c1 = LLVM::ConstantOp::create(rewriter, loc, i64Type, 1);
// Insert function call to __assertfail.
SmallVector<Value> arguments{assertMessage, assertFile, assertLine,
diff --git a/mlir/lib/Conversion/GPUToNVVM/WmmaOpsToNvvm.cpp b/mlir/lib/Conversion/GPUToNVVM/WmmaOpsToNvvm.cpp
index 45fd933d58857..99c059cb03299 100644
--- a/mlir/lib/Conversion/GPUToNVVM/WmmaOpsToNvvm.cpp
+++ b/mlir/lib/Conversion/GPUToNVVM/WmmaOpsToNvvm.cpp
@@ -126,8 +126,8 @@ struct WmmaLoadOpToNVVMLowering
cast<MemRefType>(subgroupMmaLoadMatrixOp.getSrcMemref().getType()),
adaptor.getSrcMemref(), adaptor.getIndices());
- Value leadingDim = rewriter.create<LLVM::ConstantOp>(
- loc, rewriter.getI32Type(),
+ Value leadingDim = LLVM::ConstantOp::create(
+ rewriter, loc, rewriter.getI32Type(),
subgroupMmaLoadMatrixOp.getLeadDimensionAttr());
rewriter.replaceOpWithNewOp<NVVM::WMMALoadOp>(
op, resType, dataPtr, leadingDim, m, n, k, layout, eltype, frag);
@@ -173,7 +173,7 @@ struct WmmaStoreOpToNVVMLowering
auto matrixType = cast<LLVM::LLVMStructType>(adaptor.getSrc().getType());
for (unsigned i = 0, e = matrixType.getBody().size(); i < e; ++i) {
Value toUse =
- rewriter.create<LLVM::ExtractValueOp>(loc, adaptor.getSrc(), i);
+ LLVM::ExtractValueOp::create(rewriter, loc, adaptor.getSrc(), i);
storeOpOperands.push_back(toUse);
}
@@ -181,8 +181,8 @@ struct WmmaStoreOpToNVVMLowering
rewriter, loc,
cast<MemRefType>(subgroupMmaStoreMatrixOp.getDstMemref().getType()),
adaptor.getDstMemref(), adaptor.getIndices());
- Value leadingDim = rewriter.create<LLVM::ConstantOp>(
- loc, rewriter.getI32Type(),
+ Value leadingDim = LLVM::ConstantOp::create(
+ rewriter, loc, rewriter.getI32Type(),
subgroupMmaStoreMatrixOp.getLeadDimensionAttr());
rewriter.replaceOpWithNewOp<NVVM::WMMAStoreOp>(
op, dataPtr, m, n, k, layout, eltype, storeOpOperands, leadingDim);
@@ -216,7 +216,7 @@ struct WmmaMmaOpToNVVMLowering
auto unpackOp = [&](Value operand) {
auto structType = cast<LLVM::LLVMStructType>(operand.getType());
for (size_t i = 0, e = structType.getBody().size(); i < e; ++i) {
- Value toUse = rewriter.create<LLVM::ExtractValueOp>(loc, operand, i);
+ Value toUse = LLVM::ExtractValueOp::create(rewriter, loc, operand, i);
unpackedOps.push_back(toUse);
}
};
@@ -280,19 +280,19 @@ struct WmmaConstantOpToNVVMLowering
cast<gpu::MMAMatrixType>(subgroupMmaConstantOp.getType()));
// If the element type is a vector create a vector from the operand.
if (auto vecType = dyn_cast<VectorType>(type.getBody()[0])) {
- Value vecCst = rewriter.create<LLVM::PoisonOp>(loc, vecType);
+ Value vecCst = LLVM::PoisonOp::create(rewriter, loc, vecType);
for (int64_t vecEl = 0; vecEl < vecType.getNumElements(); vecEl++) {
- Value idx = rewriter.create<LLVM::ConstantOp>(
- loc, rewriter.getI32Type(), vecEl);
- vecCst = rewriter.create<LLVM::InsertElementOp>(loc, vecType, vecCst,
- cst, idx);
+ Value idx = LLVM::ConstantOp::create(rewriter, loc,
+ rewriter.getI32Type(), vecEl);
+ vecCst = LLVM::InsertElementOp::create(rewriter, loc, vecType, vecCst,
+ cst, idx);
}
cst = vecCst;
}
- Value matrixStruct = rewriter.create<LLVM::PoisonOp>(loc, type);
+ Value matrixStruct = LLVM::PoisonOp::create(rewriter, loc, type);
for (size_t i : llvm::seq(size_t(0), type.getBody().size())) {
matrixStruct =
- rewriter.create<LLVM::InsertValueOp>(loc, matrixStruct, cst, i);
+ LLVM::InsertValueOp::create(rewriter, loc, matrixStruct, cst, i);
}
rewriter.replaceOp(subgroupMmaConstantOp, matrixStruct);
return success();
@@ -305,17 +305,17 @@ static Value createMinMaxF(OpBuilder &builder, Location loc, Value lhs,
Type i1Type = builder.getI1Type();
if (auto vecType = dyn_cast<VectorType>(lhs.getType()))
i1Type = VectorType::get(vecType.getShape(), i1Type);
- Value cmp = builder.create<LLVM::FCmpOp>(
- loc, i1Type, isMin ? LLVM::FCmpPredicate::olt : LLVM::FCmpPredicate::ogt,
- lhs, rhs);
- Value sel = builder.create<LLVM::SelectOp>(loc, cmp, lhs, rhs);
- Value isNan = builder.create<LLVM::FCmpOp>(
- loc, i1Type, LLVM::FCmpPredicate::uno, lhs, rhs);
- Value nan = builder.create<LLVM::ConstantOp>(
- loc, lhs.getType(),
+ Value cmp = LLVM::FCmpOp::create(
+ builder, loc, i1Type,
+ isMin ? LLVM::FCmpPredicate::olt : LLVM::FCmpPredicate::ogt, lhs, rhs);
+ Value sel = LLVM::SelectOp::create(builder, loc, cmp, lhs, rhs);
+ Value isNan = LLVM::FCmpOp::create(builder, loc, i1Type,
+ LLVM::FCmpPredicate::uno, lhs, rhs);
+ Value nan = LLVM::ConstantOp::create(
+ builder, loc, lhs.getType(),
builder.getFloatAttr(floatType,
APFloat::getQNaN(floatType.getFloatSemantics())));
- return builder.create<LLVM::SelectOp>(loc, isNan, nan, sel);
+ return LLVM::SelectOp::create(builder, loc, isNan, nan, sel);
}
static Value createScalarOp(OpBuilder &builder, Location loc,
@@ -323,11 +323,11 @@ static Value createScalarOp(OpBuilder &builder, Location loc,
ArrayRef<Value> operands) {
switch (op) {
case gpu::MMAElementwiseOp::ADDF:
- return builder.create<LLVM::FAddOp>(loc, operands[0].getType(), operands);
+ return LLVM::FAddOp::create(builder, loc, operands[0].getType(), operands);
case gpu::MMAElementwiseOp::MULF:
- return builder.create<LLVM::FMulOp>(loc, operands[0].getType(), operands);
+ return LLVM::FMulOp::create(builder, loc, operands[0].getType(), operands);
case gpu::MMAElementwiseOp::DIVF:
- return builder.create<LLVM::FDivOp>(loc, operands[0].getType(), operands);
+ return LLVM::FDivOp::create(builder, loc, operands[0].getType(), operands);
case gpu::MMAElementwiseOp::MAXF:
return createMinMaxF(builder, loc, operands[0], operands[1],
/*isMin=*/false);
@@ -356,18 +356,18 @@ struct WmmaElementwiseOpToNVVMLowering
size_t numOperands = adaptor.getOperands().size();
LLVM::LLVMStructType destType = convertMMAToLLVMType(
cast<gpu::MMAMatrixType>(subgroupMmaElementwiseOp.getType()));
- Value matrixStruct = rewriter.create<LLVM::PoisonOp>(loc, destType);
+ Value matrixStruct = LLVM::PoisonOp::create(rewriter, loc, destType);
for (size_t i = 0, e = destType.getBody().size(); i < e; ++i) {
SmallVector<Value> extractedOperands;
for (size_t opIdx = 0; opIdx < numOperands; opIdx++) {
- extractedOperands.push_back(rewriter.create<LLVM::ExtractValueOp>(
- loc, adaptor.getOperands()[opIdx], i));
+ extractedOperands.push_back(LLVM::ExtractValueOp::create(
+ rewriter, loc, adaptor.getOperands()[opIdx], i));
}
Value element =
createScalarOp(rewriter, loc, subgroupMmaElementwiseOp.getOpType(),
extractedOperands);
matrixStruct =
- rewriter.create<LLVM::InsertValueOp>(loc, matrixStruct, element, i);
+ LLVM::InsertValueOp::create(rewriter, loc, matrixStruct, element, i);
}
rewriter.replaceOp(subgroupMmaElementwiseOp, matrixStruct);
return success();
diff --git a/mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp b/mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp
index 456bfaba980ca..d22364e1ef441 100644
--- a/mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp
+++ b/mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp
@@ -61,10 +61,10 @@ static Value truncOrExtToLLVMType(ConversionPatternRewriter &rewriter,
IntegerType::get(rewriter.getContext(), converter.getIndexTypeBitwidth());
// TODO: use <=> in C++20.
if (indexBitwidth > intWidth) {
- return rewriter.create<LLVM::SExtOp>(loc, indexBitwidthType, value);
+ return LLVM::SExtOp::create(rewriter, loc, indexBitwidthType, value);
}
if (indexBitwidth < intWidth) {
- return rewriter.create<LLVM::TruncOp>(loc, indexBitwidthType, value);
+ return LLVM::TruncOp::create(rewriter, loc, indexBitwidthType, value);
}
return value;
}
@@ -82,12 +82,12 @@ static bool canBeCalledWithBarePointers(gpu::GPUFuncOp func) {
static Value getLaneId(ConversionPatternRewriter &rewriter, Location loc,
const unsigned indexBitwidth) {
auto int32Type = IntegerType::get(rewriter.getContext(), 32);
- Value zero = rewriter.create<arith::ConstantIntOp>(loc, 0, 32);
- Value minus1 = rewriter.create<arith::ConstantIntOp>(loc, -1, 32);
- Value mbcntLo = rewriter.create<ROCDL::MbcntLoOp>(loc, int32Type,
- ValueRange{minus1, zero});
- Value laneId = rewriter.create<ROCDL::MbcntHiOp>(loc, int32Type,
- ValueRange{minus1, mbcntLo});
+ Value zero = arith::ConstantIntOp::create(rewriter, loc, 0, 32);
+ Value minus1 = arith::ConstantIntOp::create(rewriter, loc, -1, 32);
+ Value mbcntLo = ROCDL::MbcntLoOp::create(rewriter, loc, int32Type,
+ ValueRange{minus1, zero});
+ Value laneId = ROCDL::MbcntHiOp::create(rewriter, loc, int32Type,
+ ValueRange{minus1, mbcntLo});
return laneId;
}
static constexpr StringLiteral amdgcnDataLayout =
@@ -110,21 +110,21 @@ struct GPULaneIdOpToROCDL : ConvertOpToLLVMPattern<gpu::LaneIdOp> {
// followed by: %lid = call @llvm.amdgcn.mbcnt.hi(-1, %mlo)
Type intTy = IntegerType::get(context, 32);
- Value zero = rewriter.create<arith::ConstantIntOp>(loc, 0, 32);
- Value minus1 = rewriter.create<arith::ConstantIntOp>(loc, -1, 32);
- Value mbcntLo =
- rewriter.create<ROCDL::MbcntLoOp>(loc, intTy, ValueRange{minus1, zero});
- Value laneId = rewriter.create<ROCDL::MbcntHiOp>(
- loc, intTy, ValueRange{minus1, mbcntLo});
+ Value zero = arith::ConstantIntOp::create(rewriter, loc, 0, 32);
+ Value minus1 = arith::ConstantIntOp::create(rewriter, loc, -1, 32);
+ Value mbcntLo = ROCDL::MbcntLoOp::create(rewriter, loc, intTy,
+ ValueRange{minus1, zero});
+ Value laneId = ROCDL::MbcntHiOp::create(rewriter, loc, intTy,
+ ValueRange{minus1, mbcntLo});
// Truncate or extend the result depending on the index bitwidth specified
// by the LLVMTypeConverter options.
const unsigned indexBitwidth = getTypeConverter()->getIndexTypeBitwidth();
if (indexBitwidth > 32) {
- laneId = rewriter.create<LLVM::SExtOp>(
- loc, IntegerType::get(context, indexBitwidth), laneId);
+ laneId = LLVM::SExtOp::create(
+ rewriter, loc, IntegerType::get(context, indexBitwidth), laneId);
} else if (indexBitwidth < 32) {
- laneId = rewriter.create<LLVM::TruncOp>(
- loc, IntegerType::get(context, indexBitwidth), laneId);
+ laneId = LLVM::TruncOp::create(
+ rewriter, loc, IntegerType::get(context, indexBitwidth), laneId);
}
rewriter.replaceOp(op, {laneId});
return success();
@@ -149,8 +149,8 @@ struct GPUSubgroupSizeOpToROCDL : ConvertOpToLLVMPattern<gpu::SubgroupSizeOp> {
/*bitWidth=*/32, /*lower=*/isBeforeGfx10 ? 64 : 32,
/*upper=*/op.getUpperBoundAttr().getInt() + 1);
}
- Value wavefrontOp = rewriter.create<ROCDL::WavefrontSizeOp>(
- op.getLoc(), rewriter.getI32Type(), bounds);
+ Value wavefrontOp = ROCDL::WavefrontSizeOp::create(
+ rewriter, op.getLoc(), rewriter.getI32Type(), bounds);
wavefrontOp = truncOrExtToLLVMType(rewriter, op.getLoc(), wavefrontOp,
*getTypeConverter());
rewriter.replaceOp(op, {wavefrontOp});
@@ -190,44 +190,44 @@ struct GPUShuffleOpLowering : public ConvertOpToLLVMPattern<gpu::ShuffleOp> {
auto int32Type = IntegerType::get(rewriter.getContext(), 32);
Value width = adaptor.getWidth();
- Value zero = rewriter.create<LLVM::ConstantOp>(loc, int32Type, 0);
- Value negwidth = rewriter.create<LLVM::SubOp>(loc, int32Type, zero, width);
- Value add = rewriter.create<LLVM::AddOp>(loc, int32Type, srcLaneId, width);
+ Value zero = LLVM::ConstantOp::create(rewriter, loc, int32Type, 0);
+ Value negwidth = LLVM::SubOp::create(rewriter, loc, int32Type, zero, width);
+ Value add = LLVM::AddOp::create(rewriter, loc, int32Type, srcLaneId, width);
Value widthOrZeroIfOutside =
- rewriter.create<LLVM::AndOp>(loc, int32Type, add, negwidth);
+ LLVM::AndOp::create(rewriter, loc, int32Type, add, negwidth);
Value dstLane;
switch (op.getMode()) {
case gpu::ShuffleMode::UP:
- dstLane = rewriter.create<LLVM::SubOp>(loc, int32Type, srcLaneId,
- adaptor.getOffset());
+ dstLane = LLVM::SubOp::create(rewriter, loc, int32Type, srcLaneId,
+ adaptor.getOffset());
break;
case gpu::ShuffleMode::DOWN:
- dstLane = rewriter.create<LLVM::AddOp>(loc, int32Type, srcLaneId,
- adaptor.getOffset());
+ dstLane = LLVM::AddOp::create(rewriter, loc, int32Type, srcLaneId,
+ adaptor.getOffset());
break;
case gpu::ShuffleMode::XOR:
- dstLane = rewriter.create<LLVM::XOrOp>(loc, int32Type, srcLaneId,
- adaptor.getOffset());
+ dstLane = LLVM::XOrOp::create(rewriter, loc, int32Type, srcLaneId,
+ adaptor.getOffset());
break;
case gpu::ShuffleMode::IDX:
dstLane = adaptor.getOffset();
break;
}
- Value isActiveSrcLane = rewriter.create<LLVM::ICmpOp>(
- loc, LLVM::ICmpPredicate::slt, dstLane, widthOrZeroIfOutside);
- Value selectDstLane = rewriter.create<LLVM::SelectOp>(loc, isActiveSrcLane,
- dstLane, srcLaneId);
- Value two = rewriter.create<LLVM::ConstantOp>(loc, int32Type, 2);
+ Value isActiveSrcLane = LLVM::ICmpOp::create(
+ rewriter, loc, LLVM::ICmpPredicate::slt, dstLane, widthOrZeroIfOutside);
+ Value selectDstLane = LLVM::SelectOp::create(rewriter, loc, isActiveSrcLane,
+ dstLane, srcLaneId);
+ Value two = LLVM::ConstantOp::create(rewriter, loc, int32Type, 2);
Value dwordAlignedDstLane =
- rewriter.create<LLVM::ShlOp>(loc, int32Type, selectDstLane, two);
+ LLVM::ShlOp::create(rewriter, loc, int32Type, selectDstLane, two);
SmallVector<Value> decomposed =
LLVM::decomposeValue(rewriter, loc, initShflValue, int32Type);
SmallVector<Value> swizzled;
for (Value v : decomposed) {
- Value res = rewriter.create<ROCDL::DsBpermuteOp>(loc, int32Type,
- dwordAlignedDstLane, v);
+ Value res = ROCDL::DsBpermuteOp::create(rewriter, loc, int32Type,
+ dwordAlignedDstLane, v);
swizzled.emplace_back(res);
}
Value shflValue =
diff --git a/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp b/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp
index b99ed261ecfa3..a19194eb181fb 100644
--- a/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp
+++ b/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp
@@ -169,11 +169,11 @@ LogicalResult LaunchConfigConversion<SourceOp, builtin>::matchAndRewrite(
Value vector =
spirv::getBuiltinVariableValue(op, builtin, builtinType, rewriter);
- Value dim = rewriter.create<spirv::CompositeExtractOp>(
- op.getLoc(), builtinType, vector,
+ Value dim = spirv::CompositeExtractOp::create(
+ rewriter, op.getLoc(), builtinType, vector,
rewriter.getI32ArrayAttr({static_cast<int32_t>(op.getDimension())}));
if (forShader && builtinType != indexType)
- dim = rewriter.create<spirv::UConvertOp>(op.getLoc(), indexType, dim);
+ dim = spirv::UConvertOp::create(rewriter, op.getLoc(), indexType, dim);
rewriter.replaceOp(op, dim);
return success();
}
@@ -198,8 +198,8 @@ SingleDimLaunchConfigConversion<SourceOp, builtin>::matchAndRewrite(
Value builtinValue =
spirv::getBuiltinVariableValue(op, builtin, i32Type, rewriter);
if (i32Type != indexType)
- builtinValue = rewriter.create<spirv::UConvertOp>(op.getLoc(), indexType,
- builtinValue);
+ builtinValue = spirv::UConvertOp::create(rewriter, op.getLoc(), indexType,
+ builtinValue);
rewriter.replaceOp(op, builtinValue);
return success();
}
@@ -257,8 +257,8 @@ lowerAsEntryFunction(gpu::GPUFuncOp funcOp, const TypeConverter &typeConverter,
signatureConverter.addInputs(argType.index(), convertedType);
}
}
- auto newFuncOp = rewriter.create<spirv::FuncOp>(
- funcOp.getLoc(), funcOp.getName(),
+ auto newFuncOp = spirv::FuncOp::create(
+ rewriter, funcOp.getLoc(), funcOp.getName(),
rewriter.getFunctionType(signatureConverter.getConvertedTypes(), {}));
for (const auto &namedAttr : funcOp->getAttrs()) {
if (namedAttr.getName() == funcOp.getFunctionTypeAttrName() ||
@@ -367,8 +367,8 @@ LogicalResult GPUModuleConversion::matchAndRewrite(
// Add a keyword to the module name to avoid symbolic conflict.
std::string spvModuleName = (kSPIRVModule + moduleOp.getName()).str();
- auto spvModule = rewriter.create<spirv::ModuleOp>(
- moduleOp.getLoc(), addressingModel, *memoryModel, std::nullopt,
+ auto spvModule = spirv::ModuleOp::create(
+ rewriter, moduleOp.getLoc(), addressingModel, *memoryModel, std::nullopt,
StringRef(spvModuleName));
// Move the region from the module op into the SPIR-V module.
@@ -452,42 +452,42 @@ LogicalResult GPUShuffleConversion::matchAndRewrite(
switch (shuffleOp.getMode()) {
case gpu::ShuffleMode::XOR: {
- result = rewriter.create<spirv::GroupNonUniformShuffleXorOp>(
- loc, scope, adaptor.getValue(), adaptor.getOffset());
+ result = spirv::GroupNonUniformShuffleXorOp::create(
+ rewriter, loc, scope, adaptor.getValue(), adaptor.getOffset());
validVal = spirv::ConstantOp::getOne(rewriter.getI1Type(),
shuffleOp.getLoc(), rewriter);
break;
}
case gpu::ShuffleMode::IDX: {
- result = rewriter.create<spirv::GroupNonUniformShuffleOp>(
- loc, scope, adaptor.getValue(), adaptor.getOffset());
+ result = spirv::GroupNonUniformShuffleOp::create(
+ rewriter, loc, scope, adaptor.getValue(), adaptor.getOffset());
validVal = spirv::ConstantOp::getOne(rewriter.getI1Type(),
shuffleOp.getLoc(), rewriter);
break;
}
case gpu::ShuffleMode::DOWN: {
- result = rewriter.create<spirv::GroupNonUniformShuffleDownOp>(
- loc, scope, adaptor.getValue(), adaptor.getOffset());
+ result = spirv::GroupNonUniformShuffleDownOp::create(
+ rewriter, loc, scope, adaptor.getValue(), adaptor.getOffset());
- Value laneId = rewriter.create<gpu::LaneIdOp>(loc, widthAttr);
+ Value laneId = gpu::LaneIdOp::create(rewriter, loc, widthAttr);
Value resultLaneId =
- rewriter.create<arith::AddIOp>(loc, laneId, adaptor.getOffset());
- validVal = rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::ult,
- resultLaneId, adaptor.getWidth());
+ arith::AddIOp::create(rewriter, loc, laneId, adaptor.getOffset());
+ validVal = arith::CmpIOp::create(rewriter, loc, arith::CmpIPredicate::ult,
+ resultLaneId, adaptor.getWidth());
break;
}
case gpu::ShuffleMode::UP: {
- result = rewriter.create<spirv::GroupNonUniformShuffleUpOp>(
- loc, scope, adaptor.getValue(), adaptor.getOffset());
+ result = spirv::GroupNonUniformShuffleUpOp::create(
+ rewriter, loc, scope, adaptor.getValue(), adaptor.getOffset());
- Value laneId = rewriter.create<gpu::LaneIdOp>(loc, widthAttr);
+ Value laneId = gpu::LaneIdOp::create(rewriter, loc, widthAttr);
Value resultLaneId =
- rewriter.create<arith::SubIOp>(loc, laneId, adaptor.getOffset());
+ arith::SubIOp::create(rewriter, loc, laneId, adaptor.getOffset());
auto i32Type = rewriter.getIntegerType(32);
- validVal = rewriter.create<arith::CmpIOp>(
- loc, arith::CmpIPredicate::sge, resultLaneId,
- rewriter.create<arith::ConstantOp>(
- loc, i32Type, rewriter.getIntegerAttr(i32Type, 0)));
+ validVal = arith::CmpIOp::create(
+ rewriter, loc, arith::CmpIPredicate::sge, resultLaneId,
+ arith::ConstantOp::create(rewriter, loc, i32Type,
+ rewriter.getIntegerAttr(i32Type, 0)));
break;
}
}
@@ -516,15 +516,16 @@ LogicalResult GPURotateConversion::matchAndRewrite(
Location loc = rotateOp.getLoc();
auto scope = rewriter.getAttr<spirv::ScopeAttr>(spirv::Scope::Subgroup);
- Value rotateResult = rewriter.create<spirv::GroupNonUniformRotateKHROp>(
- loc, scope, adaptor.getValue(), adaptor.getOffset(), adaptor.getWidth());
+ Value rotateResult = spirv::GroupNonUniformRotateKHROp::create(
+ rewriter, loc, scope, adaptor.getValue(), adaptor.getOffset(),
+ adaptor.getWidth());
Value validVal;
if (widthAttr.getValue().getZExtValue() == subgroupSize) {
validVal = spirv::ConstantOp::getOne(rewriter.getI1Type(), loc, rewriter);
} else {
- Value laneId = rewriter.create<gpu::LaneIdOp>(loc, widthAttr);
- validVal = rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::ult,
- laneId, adaptor.getWidth());
+ Value laneId = gpu::LaneIdOp::create(rewriter, loc, widthAttr);
+ validVal = arith::CmpIOp::create(rewriter, loc, arith::CmpIPredicate::ult,
+ laneId, adaptor.getWidth());
}
rewriter.replaceOp(rotateOp, {rotateResult, validVal});
@@ -548,14 +549,14 @@ static Value createGroupReduceOpImpl(OpBuilder &builder, Location loc,
? spirv::GroupOperation::ClusteredReduce
: spirv::GroupOperation::Reduce);
if (isUniform) {
- return builder.create<UniformOp>(loc, type, scope, groupOp, arg)
+ return UniformOp::create(builder, loc, type, scope, groupOp, arg)
.getResult();
}
Value clusterSizeValue;
if (clusterSize.has_value())
- clusterSizeValue = builder.create<spirv::ConstantOp>(
- loc, builder.getI32Type(),
+ clusterSizeValue = spirv::ConstantOp::create(
+ builder, loc, builder.getI32Type(),
builder.getIntegerAttr(builder.getI32Type(), *clusterSize));
return builder
@@ -740,8 +741,8 @@ LogicalResult GPUPrintfConversion::matchAndRewrite(
std::string specCstName =
makeVarName(moduleOp, llvm::Twine(globalVarName) + "_sc");
- return rewriter.create<spirv::SpecConstantOp>(
- loc, rewriter.getStringAttr(specCstName), attr);
+ return spirv::SpecConstantOp::create(
+ rewriter, loc, rewriter.getStringAttr(specCstName), attr);
};
{
Operation *parent =
@@ -774,8 +775,8 @@ LogicalResult GPUPrintfConversion::matchAndRewrite(
std::string specCstCompositeName =
(llvm::Twine(globalVarName) + "_scc").str();
- specCstComposite = rewriter.create<spirv::SpecConstantCompositeOp>(
- loc, TypeAttr::get(globalType),
+ specCstComposite = spirv::SpecConstantCompositeOp::create(
+ rewriter, loc, TypeAttr::get(globalType),
rewriter.getStringAttr(specCstCompositeName),
rewriter.getArrayAttr(constituents));
@@ -785,23 +786,24 @@ LogicalResult GPUPrintfConversion::matchAndRewrite(
// Define a GlobalVarOp initialized using specialized constants
// that is used to specify the printf format string
// to be passed to the SPIRV CLPrintfOp.
- globalVar = rewriter.create<spirv::GlobalVariableOp>(
- loc, ptrType, globalVarName, FlatSymbolRefAttr::get(specCstComposite));
+ globalVar = spirv::GlobalVariableOp::create(
+ rewriter, loc, ptrType, globalVarName,
+ FlatSymbolRefAttr::get(specCstComposite));
globalVar->setAttr("Constant", rewriter.getUnitAttr());
}
// Get SSA value of Global variable and create pointer to i8 to point to
// the format string.
- Value globalPtr = rewriter.create<spirv::AddressOfOp>(loc, globalVar);
- Value fmtStr = rewriter.create<spirv::BitcastOp>(
- loc,
+ Value globalPtr = spirv::AddressOfOp::create(rewriter, loc, globalVar);
+ Value fmtStr = spirv::BitcastOp::create(
+ rewriter, loc,
spirv::PointerType::get(i8Type, spirv::StorageClass::UniformConstant),
globalPtr);
// Get printf arguments.
auto printfArgs = llvm::to_vector_of<Value, 4>(adaptor.getArgs());
- rewriter.create<spirv::CLPrintfOp>(loc, i32Type, fmtStr, printfArgs);
+ spirv::CLPrintfOp::create(rewriter, loc, i32Type, fmtStr, printfArgs);
// Need to erase the gpu.printf op as gpu.printf does not use result vs
// spirv::CLPrintfOp has i32 resultType so cannot replace with new SPIR-V
diff --git a/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRVPass.cpp b/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRVPass.cpp
index 0b2c06a08db2d..a344f88326089 100644
--- a/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRVPass.cpp
+++ b/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRVPass.cpp
@@ -144,11 +144,12 @@ void GPUToSPIRVPass::runOnOperation() {
if (targetEnvSupportsKernelCapability(moduleOp)) {
moduleOp.walk([&](gpu::GPUFuncOp funcOp) {
builder.setInsertionPoint(funcOp);
- auto newFuncOp = builder.create<func::FuncOp>(
- funcOp.getLoc(), funcOp.getName(), funcOp.getFunctionType());
+ auto newFuncOp =
+ func::FuncOp::create(builder, funcOp.getLoc(), funcOp.getName(),
+ funcOp.getFunctionType());
auto entryBlock = newFuncOp.addEntryBlock();
builder.setInsertionPointToEnd(entryBlock);
- builder.create<func::ReturnOp>(funcOp.getLoc());
+ func::ReturnOp::create(builder, funcOp.getLoc());
newFuncOp->setAttr(gpu::GPUDialect::getKernelFuncAttrName(),
builder.getUnitAttr());
funcOp.erase();
diff --git a/mlir/lib/Conversion/GPUToSPIRV/WmmaOpsToSPIRV.cpp b/mlir/lib/Conversion/GPUToSPIRV/WmmaOpsToSPIRV.cpp
index 7bb86b5ce1ddd..51dc50048024f 100644
--- a/mlir/lib/Conversion/GPUToSPIRV/WmmaOpsToSPIRV.cpp
+++ b/mlir/lib/Conversion/GPUToSPIRV/WmmaOpsToSPIRV.cpp
@@ -283,8 +283,8 @@ struct WmmaLoadOpToSPIRVLowering final
int64_t stride = op.getLeadDimension().getSExtValue();
IntegerType i32Type = rewriter.getI32Type();
- auto strideValue = rewriter.create<spirv::ConstantOp>(
- loc, i32Type, IntegerAttr::get(i32Type, stride));
+ auto strideValue = spirv::ConstantOp::create(
+ rewriter, loc, i32Type, IntegerAttr::get(i32Type, stride));
bool isColMajor = op.getTranspose().value_or(false);
auto layout = isColMajor ? spirv::CooperativeMatrixLayoutKHR::ColumnMajor
@@ -315,8 +315,8 @@ struct WmmaStoreOpToSPIRVLowering final
int64_t stride = op.getLeadDimension().getSExtValue();
IntegerType i32Type = rewriter.getI32Type();
- auto strideValue = rewriter.create<spirv::ConstantOp>(
- loc, i32Type, IntegerAttr::get(i32Type, stride));
+ auto strideValue = spirv::ConstantOp::create(
+ rewriter, loc, i32Type, IntegerAttr::get(i32Type, stride));
bool isColMajor = op.getTranspose().value_or(false);
auto layout = isColMajor ? spirv::CooperativeMatrixLayoutKHR::ColumnMajor
More information about the Mlir-commits
mailing list